mirror of
https://github.com/go-i2p/go-i2p-bt.git
synced 2025-07-12 19:04:49 -04:00
first commit
This commit is contained in:
40
.gitignore
vendored
Normal file
40
.gitignore
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
*.prof
|
||||
|
||||
# Intellij IDE
|
||||
.idea/
|
||||
*.iml
|
||||
|
||||
# test
|
||||
test_*
|
||||
|
||||
# log
|
||||
*.log
|
||||
|
||||
# Mac
|
||||
.DS_Store
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
10
.travis.yml
Normal file
10
.travis.yml
Normal file
@ -0,0 +1,10 @@
|
||||
language: go
|
||||
go:
|
||||
- 1.11.x
|
||||
- 1.12.x
|
||||
- 1.13.x
|
||||
- 1.14.x
|
||||
env:
|
||||
- GO111MODULE=on
|
||||
script:
|
||||
- go test -race ./...
|
202
LICENSE
Normal file
202
LICENSE
Normal file
@ -0,0 +1,202 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
41
README.md
Normal file
41
README.md
Normal file
@ -0,0 +1,41 @@
|
||||
# BT - Another Implementation Based On Golang [](https://travis-ci.org/xgfone/bt) [](https://pkg.go.dev/github.com/xgfone/bt) [](https://raw.githubusercontent.com/xgfone/bt/master/LICENSE)
|
||||
|
||||
A pure golang implementation of [BitTorrent](http://bittorrent.org/beps/bep_0000.html) library, which is inspired by [dht](https://github.com/shiyanhui/dht) and [torrent](https://github.com/anacrolix/torrent).
|
||||
|
||||
|
||||
## Features
|
||||
|
||||
- Support IPv4/IPv6.
|
||||
- Pure Go implementation.
|
||||
- Multi-BEPs implementation. [See below](#the-implemented-specifications)
|
||||
- Only library without any denpendencies. For the command tools, see [bttools](https://github.com/xgfone/bttools).
|
||||
|
||||
|
||||
## Install
|
||||
```shell
|
||||
$ go get github.com/xgfone/bt
|
||||
```
|
||||
|
||||
|
||||
## Example
|
||||
See [godoc](https://pkg.go.dev/github.com/xgfone/bt) or [bttools](https://github.com/xgfone/bttools).
|
||||
|
||||
|
||||
## The Implemented Specifications
|
||||
- [x] [**BEP 03:** The BitTorrent Protocol Specification](http://bittorrent.org/beps/bep_0003.html)
|
||||
- [x] [**BEP 05:** DHT Protocol](http://bittorrent.org/beps/bep_0005.html)
|
||||
- [x] [**BEP 06:** Fast Extension](http://bittorrent.org/beps/bep_0006.html)
|
||||
- [x] [**BEP 07:** IPv6 Tracker Extension](http://bittorrent.org/beps/bep_0007.html)
|
||||
- [x] [**BEP 09:** Extension for Peers to Send Metadata Files](http://bittorrent.org/beps/bep_0009.html)
|
||||
- [x] [**BEP 10:** Extension Protocol](http://bittorrent.org/beps/bep_0010.html)
|
||||
- [ ] [**BEP 11:** Peer Exchange (PEX)](http://bittorrent.org/beps/bep_0011.html)
|
||||
- [x] [**BEP 12:** Multitracker Metadata Extension](http://bittorrent.org/beps/bep_0012.html)
|
||||
- [x] [**BEP 15:** UDP Tracker Protocol for BitTorrent](http://bittorrent.org/beps/bep_0015.html)
|
||||
- [x] [**BEP 19:** WebSeed - HTTP/FTP Seeding (GetRight style)](http://bittorrent.org/beps/bep_0019.html) (Only `url-list` in metainfo)
|
||||
- [x] [**BEP 23:** Tracker Returns Compact Peer Lists](http://bittorrent.org/beps/bep_0023.html)
|
||||
- [x] [**BEP 32:** IPv6 extension for DHT](http://bittorrent.org/beps/bep_0032.html)
|
||||
- [ ] [**BEP 33:** DHT scrape](http://bittorrent.org/beps/bep_0033.html)
|
||||
- [x] [**BEP 41:** UDP Tracker Protocol Extensions](http://bittorrent.org/beps/bep_0041.html)
|
||||
- [x] [**BEP 43:** Read-only DHT Nodes](http://bittorrent.org/beps/bep_0043.html)
|
||||
- [ ] [**BEP 44:** Storing arbitrary data in the DHT](http://bittorrent.org/beps/bep_0044.html)
|
||||
- [x] [**BEP 48:** Tracker Protocol Extension: Scrape](http://bittorrent.org/beps/bep_0048.html)
|
8
bencode/AUTHORS
Normal file
8
bencode/AUTHORS
Normal file
@ -0,0 +1,8 @@
|
||||
Jeff Wendling <leterip@gmail.com>
|
||||
Liam Edwards-Playne <liamz.co>
|
||||
Casey Bodley <cbodley@gmail.com>
|
||||
Conrad Pankoff <deoxxa@fknsrs.biz>
|
||||
Cenk Alti <cenkalti@gmail.com>
|
||||
Jan Winkelmann <j-winkelmann@tuhh.de>
|
||||
Patrick Mézard <patrick@mezard.eu>
|
||||
Glen De Cauwsemaecker <contact@glendc.com>
|
19
bencode/LICENSE
Normal file
19
bencode/LICENSE
Normal file
@ -0,0 +1,19 @@
|
||||
Copyright (c) 2013 The Authors (see AUTHORS)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
614
bencode/decode.go
Normal file
614
bencode/decode.go
Normal file
@ -0,0 +1,614 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
reflectByteSliceType = reflect.TypeOf([]byte(nil))
|
||||
reflectStringType = reflect.TypeOf("")
|
||||
)
|
||||
|
||||
// Unmarshaler is the interface implemented by types that can unmarshal
|
||||
// a bencode description of themselves.
|
||||
// The input can be assumed to be a valid encoding of a bencode value.
|
||||
// UnmarshalBencode must copy the bencode data if it wishes to retain the data after returning.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalBencode([]byte) error
|
||||
}
|
||||
|
||||
// A Decoder reads and decodes bencoded data from an input stream.
|
||||
type Decoder struct {
|
||||
r *bufio.Reader
|
||||
raw bool
|
||||
buf []byte
|
||||
n int
|
||||
failUnordered bool
|
||||
}
|
||||
|
||||
// SetFailOnUnorderedKeys will cause the decoder to fail when encountering
|
||||
// unordered keys. The default is to not fail.
|
||||
func (d *Decoder) SetFailOnUnorderedKeys(fail bool) {
|
||||
d.failUnordered = fail
|
||||
}
|
||||
|
||||
// BytesParsed returns the number of bytes that have actually been parsed
|
||||
func (d *Decoder) BytesParsed() int {
|
||||
return d.n
|
||||
}
|
||||
|
||||
// read also writes into the buffer when d.raw is set.
|
||||
func (d *Decoder) read(p []byte) (n int, err error) {
|
||||
n, err = d.r.Read(p)
|
||||
if d.raw {
|
||||
d.buf = append(d.buf, p[:n]...)
|
||||
}
|
||||
d.n += n
|
||||
return
|
||||
}
|
||||
|
||||
// readBytes also writes into the buffer when d.raw is set.
|
||||
func (d *Decoder) readBytes(delim byte) (line []byte, err error) {
|
||||
line, err = d.r.ReadBytes(delim)
|
||||
if d.raw {
|
||||
d.buf = append(d.buf, line...)
|
||||
}
|
||||
d.n += len(line)
|
||||
return
|
||||
}
|
||||
|
||||
// readByte also writes into the buffer when d.raw is set.
|
||||
func (d *Decoder) readByte() (b byte, err error) {
|
||||
b, err = d.r.ReadByte()
|
||||
if d.raw {
|
||||
d.buf = append(d.buf, b)
|
||||
}
|
||||
d.n++
|
||||
return
|
||||
}
|
||||
|
||||
// readFull also writes into the buffer when d.raw is set.
|
||||
func (d *Decoder) readFull(p []byte) (n int, err error) {
|
||||
n, err = io.ReadFull(d.r, p)
|
||||
if d.raw {
|
||||
d.buf = append(d.buf, p[:n]...)
|
||||
}
|
||||
d.n += n
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Decoder) peekByte() (b byte, err error) {
|
||||
ch, err := d.r.Peek(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b = ch[0]
|
||||
return
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder that reads from r
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: bufio.NewReader(r)}
|
||||
}
|
||||
|
||||
// Decode reads the bencoded value from its input and stores it in the value pointed to by val.
|
||||
// Decode allocates maps/slices as necessary with the following additional rules:
|
||||
// To decode a bencoded value into a nil interface value, the type stored in the interface value is one of:
|
||||
// int64 for bencoded integers
|
||||
// string for bencoded strings
|
||||
// []interface{} for bencoded lists
|
||||
// map[string]interface{} for bencoded dicts
|
||||
// To unmarshal bencode into a value implementing the Unmarshaler interface,
|
||||
// Unmarshal calls that value's UnmarshalBencode method.
|
||||
// Otherwise, if the value implements encoding.TextUnmarshaler
|
||||
// and the input is a bencode string, Unmarshal calls that value's
|
||||
// UnmarshalText method with the decoded form of the string.
|
||||
func (d *Decoder) Decode(val interface{}) error {
|
||||
rv := reflect.ValueOf(val)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return errors.New("Unwritable type passed into decode")
|
||||
}
|
||||
|
||||
return d.decodeInto(rv)
|
||||
}
|
||||
|
||||
// DecodeString reads the data in the string and stores it into the value pointed to by val.
|
||||
// Read the docs for Decode for more information.
|
||||
func DecodeString(in string, val interface{}) error {
|
||||
buf := strings.NewReader(in)
|
||||
d := NewDecoder(buf)
|
||||
return d.Decode(val)
|
||||
}
|
||||
|
||||
// DecodeBytes reads the data in b and stores it into the value pointed to by val.
|
||||
// Read the docs for Decode for more information.
|
||||
func DecodeBytes(b []byte, val interface{}) error {
|
||||
r := bytes.NewReader(b)
|
||||
d := NewDecoder(r)
|
||||
return d.Decode(val)
|
||||
}
|
||||
|
||||
func indirect(v reflect.Value, alloc bool) reflect.Value {
|
||||
for {
|
||||
switch v.Kind() {
|
||||
case reflect.Interface:
|
||||
if v.IsNil() {
|
||||
if !alloc {
|
||||
return reflect.Value{}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
if v.IsNil() {
|
||||
if !alloc {
|
||||
return reflect.Value{}
|
||||
}
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
|
||||
default:
|
||||
return v
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) decodeInto(val reflect.Value) (err error) {
|
||||
var v reflect.Value
|
||||
if d.raw {
|
||||
v = val
|
||||
} else {
|
||||
var (
|
||||
unmarshaler Unmarshaler
|
||||
textUnmarshaler encoding.TextUnmarshaler
|
||||
)
|
||||
unmarshaler, textUnmarshaler, v = d.indirect(val)
|
||||
|
||||
// if we're decoding into an Unmarshaler,
|
||||
// we pass on the next bencode value to this value instead,
|
||||
// so it can decide what to do with it.
|
||||
if unmarshaler != nil {
|
||||
var x RawMessage
|
||||
if err := d.decodeInto(reflect.ValueOf(&x)); err != nil {
|
||||
return err
|
||||
}
|
||||
return unmarshaler.UnmarshalBencode([]byte(x))
|
||||
}
|
||||
|
||||
// if we're decoding into an TextUnmarshaler,
|
||||
// we'll assume that the bencode value is a string,
|
||||
// we decode it as such and pass the result onto the unmarshaler.
|
||||
if textUnmarshaler != nil {
|
||||
var b []byte
|
||||
ref := reflect.ValueOf(&b)
|
||||
if err := d.decodeString(reflect.Indirect(ref)); err != nil {
|
||||
return err
|
||||
}
|
||||
return textUnmarshaler.UnmarshalText(b)
|
||||
}
|
||||
|
||||
// if we're decoding into a RawMessage set raw to true for the rest of
|
||||
// the call stack, and switch out the value with an interface{}.
|
||||
if _, ok := v.Interface().(RawMessage); ok {
|
||||
v = reflect.Value{} // explicitly make v invalid
|
||||
|
||||
// set d.raw for the lifetime of this function call, and set the raw
|
||||
// message when the function is exiting.
|
||||
d.buf = d.buf[:0]
|
||||
d.raw = true
|
||||
defer func() {
|
||||
d.raw = false
|
||||
v := indirect(val, true)
|
||||
v.SetBytes(append([]byte(nil), d.buf...))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
next, err := d.peekByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch next {
|
||||
case 'i':
|
||||
err = d.decodeInt(v)
|
||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||
err = d.decodeString(v)
|
||||
case 'l':
|
||||
err = d.decodeList(v)
|
||||
case 'd':
|
||||
err = d.decodeDict(v)
|
||||
default:
|
||||
err = errors.New("Invalid input")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Decoder) decodeInt(v reflect.Value) error {
|
||||
// we need to read an i, some digits, and an e.
|
||||
ch, err := d.readByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ch != 'i' {
|
||||
panic("got not an i when peek returned an i")
|
||||
}
|
||||
|
||||
line, err := d.readBytes('e')
|
||||
if err != nil || d.raw {
|
||||
return err
|
||||
}
|
||||
|
||||
digits := string(line[:len(line)-1])
|
||||
|
||||
switch v.Kind() {
|
||||
default:
|
||||
return fmt.Errorf("Cannot store int64 into %s", v.Type())
|
||||
case reflect.Interface:
|
||||
n, err := strconv.ParseInt(digits, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(n))
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
n, err := strconv.ParseInt(digits, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.SetInt(n)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
n, err := strconv.ParseUint(digits, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.SetUint(n)
|
||||
case reflect.Bool:
|
||||
n, err := strconv.ParseUint(digits, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.SetBool(n != 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Decoder) decodeString(v reflect.Value) error {
|
||||
// read until a colon to get the number of digits to read after
|
||||
line, err := d.readBytes(':')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// parse it into an int for making a slice
|
||||
l32, err := strconv.ParseInt(string(line[:len(line)-1]), 10, 32)
|
||||
l := int(l32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if l < 0 {
|
||||
return fmt.Errorf("invalid negative string length: %d", l)
|
||||
}
|
||||
|
||||
// read exactly l bytes out and make our string
|
||||
buf := make([]byte, l)
|
||||
_, err = d.readFull(buf)
|
||||
if err != nil || d.raw {
|
||||
return err
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
default:
|
||||
return fmt.Errorf("Cannot store string into %s", v.Type())
|
||||
case reflect.Slice:
|
||||
if v.Type() != reflectByteSliceType {
|
||||
return fmt.Errorf("Cannot store string into %s", v.Type())
|
||||
}
|
||||
v.SetBytes(buf)
|
||||
case reflect.String:
|
||||
v.SetString(string(buf))
|
||||
case reflect.Interface:
|
||||
v.Set(reflect.ValueOf(string(buf)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Decoder) decodeList(v reflect.Value) error {
|
||||
if !d.raw {
|
||||
// if we have an interface, just put a []interface{} in it!
|
||||
if v.Kind() == reflect.Interface {
|
||||
var x []interface{}
|
||||
defer func(p reflect.Value) { p.Set(v) }(v)
|
||||
v = reflect.ValueOf(&x).Elem()
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("Cant store a []interface{} into %s", v.Type())
|
||||
}
|
||||
}
|
||||
|
||||
// read out the l that prefixes the list
|
||||
ch, err := d.readByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ch != 'l' {
|
||||
panic("got something other than a list head after a peek")
|
||||
}
|
||||
|
||||
// if we're decoding in raw mode,
|
||||
// we only want to read into the buffer,
|
||||
// without actually parsing any values
|
||||
if d.raw {
|
||||
var ch byte
|
||||
for {
|
||||
// peek for the end token and read it out
|
||||
ch, err = d.peekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ch == 'e' {
|
||||
_, err = d.readByte() // consume the end
|
||||
return err
|
||||
}
|
||||
|
||||
// decode the next value
|
||||
err = d.decodeInto(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; ; i++ {
|
||||
// peek for the end token and read it out
|
||||
ch, err := d.peekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch ch {
|
||||
case 'e':
|
||||
_, err := d.readByte() // consume the end
|
||||
return err
|
||||
}
|
||||
|
||||
// grow it if required
|
||||
if i >= v.Cap() && v.IsValid() {
|
||||
newcap := v.Cap() + v.Cap()/2
|
||||
if newcap < 4 {
|
||||
newcap = 4
|
||||
}
|
||||
newv := reflect.MakeSlice(v.Type(), v.Len(), newcap)
|
||||
reflect.Copy(newv, v)
|
||||
v.Set(newv)
|
||||
}
|
||||
|
||||
// reslice into cap (its a slice now since it had to have grown)
|
||||
if i >= v.Len() && v.IsValid() {
|
||||
v.SetLen(i + 1)
|
||||
}
|
||||
|
||||
// decode a value into the index
|
||||
if err := d.decodeInto(v.Index(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) decodeDict(v reflect.Value) error {
|
||||
// if we have an interface{}, just put a map[string]interface{} in it!
|
||||
if !d.raw && v.Kind() == reflect.Interface {
|
||||
var x map[string]interface{}
|
||||
defer func(p reflect.Value) { p.Set(v) }(v)
|
||||
v = reflect.ValueOf(&x).Elem()
|
||||
}
|
||||
|
||||
// consume the head token
|
||||
ch, err := d.readByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ch != 'd' {
|
||||
panic("got an incorrect token when it was checked already")
|
||||
}
|
||||
|
||||
if d.raw {
|
||||
// if we're decoding in raw mode,
|
||||
// we only want to read into the buffer,
|
||||
// without actually parsing any values
|
||||
for {
|
||||
// peek the next value type
|
||||
ch, err := d.peekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ch == 'e' {
|
||||
_, err = d.readByte() // consume the end token
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.decodeString(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.decodeInto(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check for correct type
|
||||
var (
|
||||
mapElem reflect.Value
|
||||
isMap bool
|
||||
vals map[string]reflect.Value
|
||||
)
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Map:
|
||||
t := v.Type()
|
||||
if t.Key() != reflectStringType {
|
||||
return fmt.Errorf("Can't store a map[string]interface{} into %s", v.Type())
|
||||
}
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.MakeMap(t))
|
||||
}
|
||||
|
||||
isMap = true
|
||||
mapElem = reflect.New(t.Elem()).Elem()
|
||||
case reflect.Struct:
|
||||
vals = make(map[string]reflect.Value)
|
||||
setStructValues(vals, v)
|
||||
default:
|
||||
return fmt.Errorf("Can't store a map[string]interface{} into %s", v.Type())
|
||||
}
|
||||
|
||||
var (
|
||||
lastKey string
|
||||
first bool = true
|
||||
)
|
||||
|
||||
for {
|
||||
var subv reflect.Value
|
||||
|
||||
// peek the next value type
|
||||
ch, err := d.peekByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ch == 'e' {
|
||||
_, err = d.readByte() // consume the end token
|
||||
return err
|
||||
}
|
||||
|
||||
// peek the next value we're suppsed to read
|
||||
var key string
|
||||
if err := d.decodeString(reflect.ValueOf(&key).Elem()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check for unordered keys
|
||||
if !first && d.failUnordered && lastKey > key {
|
||||
return fmt.Errorf("unordered dictionary: %q appears before %q",
|
||||
lastKey, key)
|
||||
}
|
||||
lastKey, first = key, false
|
||||
|
||||
if isMap {
|
||||
mapElem.Set(reflect.Zero(v.Type().Elem()))
|
||||
subv = mapElem
|
||||
} else {
|
||||
subv = vals[key]
|
||||
}
|
||||
|
||||
if !subv.IsValid() {
|
||||
// if it's invalid, grab but ignore the next value
|
||||
var x interface{}
|
||||
err := d.decodeInto(reflect.ValueOf(&x).Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// subv now contains what we load into
|
||||
if err := d.decodeInto(subv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isMap {
|
||||
v.SetMapIndex(reflect.ValueOf(key), subv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// indirect walks down v allocating pointers as needed,
|
||||
// until it gets to a non-pointer.
|
||||
// if it encounters an (Text)Unmarshaler, indirect stops and returns that.
|
||||
func (d *Decoder) indirect(v reflect.Value) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||
// If v is a named type and is addressable,
|
||||
// start with its address, so that if the type has pointer methods,
|
||||
// we find them.
|
||||
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
|
||||
v = v.Addr()
|
||||
}
|
||||
for {
|
||||
// Load value from interface, but only if the result will be
|
||||
// usefully addressable.
|
||||
if v.Kind() == reflect.Interface && !v.IsNil() {
|
||||
e := v.Elem()
|
||||
if e.Kind() == reflect.Ptr && !e.IsNil() {
|
||||
v = e
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
break
|
||||
}
|
||||
|
||||
vi := v.Interface()
|
||||
if u, ok := vi.(Unmarshaler); ok {
|
||||
return u, nil, reflect.Value{}
|
||||
}
|
||||
if u, ok := vi.(encoding.TextUnmarshaler); ok {
|
||||
return nil, u, reflect.Value{}
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
}
|
||||
return nil, nil, indirect(v, true)
|
||||
}
|
||||
|
||||
func setStructValues(m map[string]reflect.Value, v reflect.Value) {
|
||||
t := v.Type()
|
||||
if t.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
// do embedded fields first
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
v := v.FieldByIndex(f.Index)
|
||||
if f.Anonymous && f.Tag == "" {
|
||||
setStructValues(m, v)
|
||||
}
|
||||
}
|
||||
|
||||
// overwrite embedded struct tags and names
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
v := v.FieldByIndex(f.Index)
|
||||
name, _ := parseTag(f.Tag.Get("bencode"))
|
||||
if name == "" {
|
||||
if f.Anonymous {
|
||||
// it's a struct and its fields have already been added to the map
|
||||
continue
|
||||
}
|
||||
name = f.Name
|
||||
}
|
||||
if isValidTag(name) {
|
||||
m[name] = v
|
||||
}
|
||||
}
|
||||
}
|
400
bencode/decode_test.go
Normal file
400
bencode/decode_test.go
Normal file
@ -0,0 +1,400 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
type testCase struct {
|
||||
in string
|
||||
val interface{}
|
||||
expect interface{}
|
||||
err bool
|
||||
unorderedFail bool
|
||||
}
|
||||
|
||||
type dT struct {
|
||||
X string
|
||||
Y int
|
||||
Z string `bencode:"zff"`
|
||||
}
|
||||
|
||||
type Embedded struct {
|
||||
B string
|
||||
}
|
||||
|
||||
type issue22 struct {
|
||||
X string `bencode:"x"`
|
||||
Time myTimeType `bencode:"t"`
|
||||
Foo myBoolType `bencode:"f"`
|
||||
Bar myStringType `bencode:"b"`
|
||||
Slice mySliceType `bencode:"s"`
|
||||
Y string `bencode:"y"`
|
||||
}
|
||||
|
||||
type issue26 struct {
|
||||
X string `bencode:"x"`
|
||||
Foo myBoolTextType `bencode:"f"`
|
||||
Bar myTextStringType `bencode:"b"`
|
||||
Slice myTextSliceType `bencode:"s"`
|
||||
Y string `bencode:"y"`
|
||||
}
|
||||
|
||||
type issue22WithErrorChild struct {
|
||||
Name string `bencode:"n"`
|
||||
Error errorMarshalType `bencode:"e"`
|
||||
}
|
||||
|
||||
type issue26WithErrorChild struct {
|
||||
Name string `bencode:"n"`
|
||||
Error errorTextMarshalType `bencode:"e"`
|
||||
}
|
||||
|
||||
type discardNonFieldDef struct {
|
||||
B string
|
||||
D string
|
||||
}
|
||||
|
||||
type twoDefsForSameKey struct {
|
||||
A string
|
||||
A2 string `bencode:"A"`
|
||||
A3 string `bencode:"A"`
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
var decodeCases = []testCase{
|
||||
// integers
|
||||
{`i5e`, new(int), int(5), false, false},
|
||||
{`i-10e`, new(int), int(-10), false, false},
|
||||
{`i8e`, new(uint), uint(8), false, false},
|
||||
{`i8e`, new(uint8), uint8(8), false, false},
|
||||
{`i8e`, new(uint16), uint16(8), false, false},
|
||||
{`i8e`, new(uint32), uint32(8), false, false},
|
||||
{`i8e`, new(uint64), uint64(8), false, false},
|
||||
{`i8e`, new(int), int(8), false, false},
|
||||
{`i8e`, new(int8), int8(8), false, false},
|
||||
{`i8e`, new(int16), int16(8), false, false},
|
||||
{`i8e`, new(int32), int32(8), false, false},
|
||||
{`i8e`, new(int64), int64(8), false, false},
|
||||
{`i0e`, new(*int), new(int), false, false},
|
||||
{`i-2e`, new(uint), nil, true, false},
|
||||
|
||||
// bools
|
||||
{`i1e`, new(bool), true, false, false},
|
||||
{`i0e`, new(bool), false, false, false},
|
||||
{`i0e`, new(*bool), new(bool), false, false},
|
||||
{`i8e`, new(bool), true, false, false},
|
||||
|
||||
// strings
|
||||
{`3:foo`, new(string), "foo", false, false},
|
||||
{`4:foob`, new(string), "foob", false, false},
|
||||
{`0:`, new(*string), new(string), false, false},
|
||||
{`6:short`, new(string), nil, true, false},
|
||||
|
||||
// lists
|
||||
{`l3:foo3:bare`, new([]string), []string{"foo", "bar"}, false, false},
|
||||
{`li15ei20ee`, new([]int), []int{15, 20}, false, false},
|
||||
{`ld3:fooi0eed3:bari1eee`, new([]map[string]int), []map[string]int{
|
||||
{"foo": 0},
|
||||
{"bar": 1},
|
||||
}, false, false},
|
||||
|
||||
// dicts
|
||||
|
||||
{`d3:foo3:bar4:foob3:fooe`, new(map[string]string), map[string]string{
|
||||
"foo": "bar",
|
||||
"foob": "foo",
|
||||
}, false, false},
|
||||
{`d1:X3:foo1:Yi10e3:zff3:bare`, new(dT), dT{"foo", 10, "bar"}, false, false},
|
||||
|
||||
// encoding/json takes, if set, the tag as name and doesn't falls back to the
|
||||
// struct field's name.
|
||||
{`d1:X3:foo1:Yi10e1:Z3:bare`, new(dT), dT{"foo", 10, ""}, false, false},
|
||||
|
||||
{`d1:X3:foo1:Yi10e1:h3:bare`, new(dT), dT{"foo", 10, ""}, false, false},
|
||||
{`d3:fooli0ei1ee3:barli2ei3eee`, new(map[string][]int), map[string][]int{
|
||||
"foo": []int{0, 1},
|
||||
"bar": []int{2, 3},
|
||||
}, false, false},
|
||||
{`de`, new(map[string]string), map[string]string{}, false, false},
|
||||
|
||||
// into interfaces
|
||||
{`i5e`, new(interface{}), int64(5), false, false},
|
||||
{`li5ee`, new(interface{}), []interface{}{int64(5)}, false, false},
|
||||
{`5:hello`, new(interface{}), "hello", false, false},
|
||||
{`d5:helloi5ee`, new(interface{}), map[string]interface{}{"hello": int64(5)}, false, false},
|
||||
|
||||
// into values whose type support the Unmarshaler interface
|
||||
{`1:y`, new(myTimeType), nil, true, false},
|
||||
{fmt.Sprintf("i%de", now.Unix()), new(myTimeType), myTimeType{time.Unix(now.Unix(), 0)}, false, false},
|
||||
{`1:y`, new(myBoolType), myBoolType(true), false, false},
|
||||
{`i42e`, new(myBoolType), nil, true, false},
|
||||
{`1:n`, new(myBoolType), myBoolType(false), false, false},
|
||||
{`1:n`, new(errorMarshalType), nil, true, false},
|
||||
{`li102ei111ei111ee`, new(myStringType), myStringType("foo"), false, false},
|
||||
{`i42e`, new(myStringType), nil, true, false},
|
||||
{`d1:ai1e3:foo3:bare`, new(mySliceType), mySliceType{"a", int64(1), "foo", "bar"}, false, false},
|
||||
{`i42e`, new(mySliceType), nil, true, false},
|
||||
|
||||
// into values who have a child which type supports the Unmarshaler interface
|
||||
{
|
||||
fmt.Sprintf(`d1:b3:foo1:f1:y1:sd1:f3:foo1:ai42ee1:ti%de1:x1:x1:y1:ye`, now.Unix()),
|
||||
new(issue22),
|
||||
issue22{
|
||||
X: "x",
|
||||
Time: myTimeType{time.Unix(now.Unix(), 0)},
|
||||
Foo: myBoolType(true),
|
||||
Bar: myStringType("foo"),
|
||||
Slice: mySliceType{"a", int64(42), "f", "foo"},
|
||||
Y: "y",
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
`d1:ei42e1:n3:fooe`,
|
||||
new(issue22WithErrorChild),
|
||||
nil,
|
||||
true,
|
||||
false,
|
||||
},
|
||||
|
||||
// into values whose type support the TextUnmarshaler interface
|
||||
{`1:y`, new(myBoolTextType), myBoolTextType(true), false, false},
|
||||
{`1:n`, new(myBoolTextType), myBoolTextType(false), false, false},
|
||||
{`i42e`, new(myBoolTextType), nil, true, false},
|
||||
{`1:n`, new(errorTextMarshalType), nil, true, false},
|
||||
{`7:foo_bar`, new(myTextStringType), myTextStringType("bar"), false, false},
|
||||
{`i42e`, new(myTextStringType), nil, true, false},
|
||||
{`7:a,b,c,d`, new(myTextSliceType), myTextSliceType{"a", "b", "c", "d"}, false, false},
|
||||
{`i42e`, new(myTextSliceType), nil, true, false},
|
||||
|
||||
// into values who have a child which type supports the TextUnmarshaler interface
|
||||
{
|
||||
`d1:b7:foo_bar1:f1:y1:s5:1,2,31:x1:x1:y1:ye`,
|
||||
new(issue26),
|
||||
issue26{
|
||||
X: "x",
|
||||
Foo: myBoolTextType(true),
|
||||
Bar: myTextStringType("bar"),
|
||||
Slice: myTextSliceType{"1", "2", "3"},
|
||||
Y: "y",
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
`d1:ei42e1:n3:fooe`,
|
||||
new(issue26WithErrorChild),
|
||||
nil,
|
||||
true,
|
||||
false,
|
||||
},
|
||||
|
||||
// malformed
|
||||
{`i53:foo`, new(interface{}), nil, true, false},
|
||||
{`6:foo`, new(interface{}), nil, true, false},
|
||||
{`di5ei2ee`, new(interface{}), nil, true, false},
|
||||
{`d3:fooe`, new(interface{}), nil, true, false},
|
||||
{`l3:foo3:bar`, new(interface{}), nil, true, false},
|
||||
{`d-1:`, new(interface{}), nil, true, false},
|
||||
|
||||
// embedded structs
|
||||
{`d1:A3:foo1:B3:bare`, new(struct {
|
||||
A string
|
||||
Embedded
|
||||
}), struct {
|
||||
A string
|
||||
Embedded
|
||||
}{"foo", Embedded{"bar"}}, false, false},
|
||||
|
||||
// Embedded structs with a valid tag are encoded as a definition
|
||||
{`d1:B3:bar6:nestedd1:B3:fooee`, new(struct {
|
||||
Embedded `bencode:"nested"`
|
||||
}), struct {
|
||||
Embedded `bencode:"nested"`
|
||||
}{Embedded{"foo"}}, false, false},
|
||||
|
||||
// Don't fail when reading keys missing from the struct
|
||||
{"d1:A7:discard1:B4:take1:C7:discard1:D4:takee",
|
||||
new(discardNonFieldDef),
|
||||
discardNonFieldDef{"take", "take"},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
|
||||
// Don't fail when reading the same key twice
|
||||
{"d1:A1:a1:A1:b1:A1:c1:A1:de", new(twoDefsForSameKey),
|
||||
twoDefsForSameKey{"", "", "d"}, false, false},
|
||||
|
||||
// Empty struct
|
||||
{"de", new(struct{}), struct{}{}, false, false},
|
||||
|
||||
// Fail on unordered dictionaries
|
||||
{"d1:Yi10e1:X1:a3:zff1:ce", new(dT), dT{}, true, true},
|
||||
{"d3:zff1:c1:Yi10e1:X1:ae", new(dT), dT{}, true, true},
|
||||
}
|
||||
|
||||
for i, tt := range decodeCases {
|
||||
dec := NewDecoder(strings.NewReader(tt.in))
|
||||
dec.SetFailOnUnorderedKeys(tt.unorderedFail)
|
||||
err := dec.Decode(tt.val)
|
||||
if !tt.err && err != nil {
|
||||
t.Errorf("#%d (%v): Unexpected err: %v", i, tt.in, err)
|
||||
continue
|
||||
}
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("#%d (%v): Expected err is nil", i, tt.in)
|
||||
continue
|
||||
}
|
||||
v := reflect.ValueOf(tt.val).Elem().Interface()
|
||||
if !reflect.DeepEqual(v, tt.expect) && !tt.err {
|
||||
t.Errorf("#%d (%v): Val: %#v != %#v", i, tt.in, v, tt.expect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawDecode(t *testing.T) {
|
||||
type testCase struct {
|
||||
in string
|
||||
expect []byte
|
||||
err bool
|
||||
}
|
||||
|
||||
var rawDecodeCases = []testCase{
|
||||
{`i5e`, []byte(`i5e`), false},
|
||||
{`5:hello`, []byte(`5:hello`), false},
|
||||
{`li5ei10e5:helloe`, []byte(`li5ei10e5:helloe`), false},
|
||||
{`llleee`, []byte(`llleee`), false},
|
||||
{`li5eli5eli5eeee`, []byte(`li5eli5eli5eeee`), false},
|
||||
{`d5:helloi5ee`, []byte(`d5:helloi5ee`), false},
|
||||
}
|
||||
|
||||
for i, tt := range rawDecodeCases {
|
||||
var x RawMessage
|
||||
err := DecodeString(tt.in, &x)
|
||||
if !tt.err && err != nil {
|
||||
t.Errorf("#%d: Unexpected err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("#%d: Expected err is nil", i)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(x, RawMessage(tt.expect)) && !tt.err {
|
||||
t.Errorf("#%d: Val: %#v != %#v", i, x, tt.expect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type myStringType string
|
||||
|
||||
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
|
||||
func (mst *myStringType) UnmarshalBencode(b []byte) error {
|
||||
var raw []byte
|
||||
err := DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*mst = myStringType(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mySliceType []interface{}
|
||||
|
||||
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
|
||||
func (mst *mySliceType) UnmarshalBencode(b []byte) error {
|
||||
m := make(map[string]interface{})
|
||||
err := DecodeBytes(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
raw := make([]interface{}, 0, len(m)*2)
|
||||
for _, key := range keys {
|
||||
raw = append(raw, key, m[key])
|
||||
}
|
||||
|
||||
*mst = mySliceType(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
type myTextStringType string
|
||||
|
||||
// UnmarshalText implements TextUnmarshaler.UnmarshalText
|
||||
func (mst *myTextStringType) UnmarshalText(b []byte) error {
|
||||
*mst = myTextStringType(bytes.TrimPrefix(b, []byte("foo_")))
|
||||
return nil
|
||||
}
|
||||
|
||||
type myTextSliceType []string
|
||||
|
||||
// UnmarshalText implements TextUnmarshaler.UnmarshalText
|
||||
func (mst *myTextSliceType) UnmarshalText(b []byte) error {
|
||||
raw := string(b)
|
||||
*mst = strings.Split(raw, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNestedRawDecode(t *testing.T) {
|
||||
type testCase struct {
|
||||
in string
|
||||
val interface{}
|
||||
expect interface{}
|
||||
err bool
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Key string
|
||||
Val int
|
||||
Raw RawMessage
|
||||
}
|
||||
|
||||
var cases = []testCase{
|
||||
{`li5e5:hellod1:a1:beli5eee`, new([]RawMessage), []RawMessage{
|
||||
RawMessage(`i5e`),
|
||||
RawMessage(`5:hello`),
|
||||
RawMessage(`d1:a1:be`),
|
||||
RawMessage(`li5ee`),
|
||||
}, false},
|
||||
{`d1:a1:b1:c1:de`, new(map[string]RawMessage), map[string]RawMessage{
|
||||
"a": RawMessage(`1:b`),
|
||||
"c": RawMessage(`1:d`),
|
||||
}, false},
|
||||
{`d3:Key5:hello3:Rawldedei5e1:ae3:Vali10ee`, new(message), message{
|
||||
Key: "hello",
|
||||
Val: 10,
|
||||
Raw: RawMessage(`ldedei5e1:ae`),
|
||||
}, false},
|
||||
}
|
||||
|
||||
for i, tt := range cases {
|
||||
err := DecodeString(tt.in, tt.val)
|
||||
if !tt.err && err != nil {
|
||||
t.Errorf("#%d: Unexpected err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("#%d: Expected err is nil", i)
|
||||
continue
|
||||
}
|
||||
v := reflect.ValueOf(tt.val).Elem().Interface()
|
||||
if !reflect.DeepEqual(v, tt.expect) && !tt.err {
|
||||
t.Errorf("#%d: Val:\n%#v !=\n%#v", i, v, tt.expect)
|
||||
}
|
||||
}
|
||||
}
|
6
bencode/doc.go
Normal file
6
bencode/doc.go
Normal file
@ -0,0 +1,6 @@
|
||||
// Package bencode implements encoding and decoding of bencoded objects,
|
||||
// which has a similar API to the encoding/json package and many other
|
||||
// serialization formats.
|
||||
//
|
||||
// Notice: the package is moved and modified from github.com/zeebo/bencode@v1.0.0
|
||||
package bencode
|
322
bencode/encode.go
Normal file
322
bencode/encode.go
Normal file
@ -0,0 +1,322 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type sortValues []reflect.Value
|
||||
|
||||
func (p sortValues) Len() int { return len(p) }
|
||||
func (p sortValues) Less(i, j int) bool { return p[i].String() < p[j].String() }
|
||||
func (p sortValues) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// Marshaler is the interface implemented by types
|
||||
// that can marshal themselves into valid bencode.
|
||||
type Marshaler interface {
|
||||
MarshalBencode() ([]byte, error)
|
||||
}
|
||||
|
||||
// An Encoder writes bencoded objects to an output stream.
|
||||
type Encoder struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w}
|
||||
}
|
||||
|
||||
// Encode writes the bencoded data of val to its output stream.
|
||||
// If an encountered value implements the Marshaler interface,
|
||||
// its MarshalBencode method is called to produce the bencode output for this value.
|
||||
// If no MarshalBencode method is present but the value implements encoding.TextMarshaler instead,
|
||||
// its MarshalText method is called, which encodes the result as a bencode string.
|
||||
// See the documentation for Decode about the conversion of Go values to
|
||||
// bencoded data.
|
||||
func (e *Encoder) Encode(val interface{}) error {
|
||||
return encodeValue(e.w, reflect.ValueOf(val))
|
||||
}
|
||||
|
||||
// EncodeString returns the bencoded data of val as a string.
|
||||
func EncodeString(val interface{}) (string, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
e := NewEncoder(buf)
|
||||
if err := e.Encode(val); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// EncodeBytes returns the bencoded data of val as a slice of bytes.
|
||||
func EncodeBytes(val interface{}) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
e := NewEncoder(buf)
|
||||
if err := e.Encode(val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func isNilValue(v reflect.Value) bool {
|
||||
return (v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr) &&
|
||||
v.IsNil()
|
||||
}
|
||||
|
||||
func encodeValue(w io.Writer, val reflect.Value) error {
|
||||
marshaler, textMarshaler, v := indirectEncodeValue(val)
|
||||
|
||||
// marshal a type using the Marshaler type
|
||||
// if it implements that interface.
|
||||
if marshaler != nil {
|
||||
bytes, err := marshaler.MarshalBencode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = w.Write(bytes)
|
||||
return err
|
||||
}
|
||||
|
||||
// marshal a type using the TextMarshaler type
|
||||
// if it implements that interface.
|
||||
if textMarshaler != nil {
|
||||
bytes, err := textMarshaler.MarshalText()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(w, "%d:%s", len(bytes), bytes)
|
||||
return err
|
||||
}
|
||||
|
||||
// if indirection returns us an invalid value that means there was a nil
|
||||
// pointer in the path somewhere.
|
||||
if !v.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// send in a raw message if we have that type
|
||||
if rm, ok := v.Interface().(RawMessage); ok {
|
||||
_, err := io.Copy(w, bytes.NewReader(rm))
|
||||
return err
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
_, err := fmt.Fprintf(w, "i%de", v.Int())
|
||||
return err
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
_, err := fmt.Fprintf(w, "i%de", v.Uint())
|
||||
return err
|
||||
|
||||
case reflect.Bool:
|
||||
i := 0
|
||||
if v.Bool() {
|
||||
i = 1
|
||||
}
|
||||
_, err := fmt.Fprintf(w, "i%de", i)
|
||||
return err
|
||||
|
||||
case reflect.String:
|
||||
_, err := fmt.Fprintf(w, "%d:%s", len(v.String()), v.String())
|
||||
return err
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
// handle byte slices like strings
|
||||
if byteSlice, ok := val.Interface().([]byte); ok {
|
||||
_, err := fmt.Fprintf(w, "%d:", len(byteSlice))
|
||||
|
||||
if err == nil {
|
||||
_, err = w.Write(byteSlice)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, "l"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if err := encodeValue(w, v.Index(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err := fmt.Fprint(w, "e")
|
||||
return err
|
||||
|
||||
case reflect.Map:
|
||||
if _, err := fmt.Fprint(w, "d"); err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
keys sortValues = v.MapKeys()
|
||||
mval reflect.Value
|
||||
)
|
||||
sort.Sort(keys)
|
||||
for i := range keys {
|
||||
mval = v.MapIndex(keys[i])
|
||||
if isNilValue(mval) {
|
||||
continue
|
||||
}
|
||||
if err := encodeValue(w, keys[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encodeValue(w, mval); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := fmt.Fprint(w, "e")
|
||||
return err
|
||||
|
||||
case reflect.Struct:
|
||||
if _, err := fmt.Fprint(w, "d"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add embedded structs to the dictionary
|
||||
dict := make(dictionary, 0, v.NumField())
|
||||
dict, err := readStruct(dict, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sort the dictionary by keys
|
||||
sort.Sort(dict)
|
||||
|
||||
// encode the dictionary in order
|
||||
for _, def := range dict {
|
||||
// encode the key
|
||||
err := encodeValue(w, reflect.ValueOf(def.key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// encode the value
|
||||
err = encodeValue(w, def.value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = fmt.Fprint(w, "e")
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("Can't encode type: %s", v.Type())
|
||||
}
|
||||
|
||||
// indirectEncodeValue walks down v allocating pointers as needed,
|
||||
// until it gets to a non-pointer.
|
||||
// if it encounters an (Text)Marshaler, indirect stops and returns that.
|
||||
func indirectEncodeValue(v reflect.Value) (Marshaler, encoding.TextMarshaler, reflect.Value) {
|
||||
// If v is a named type and is addressable,
|
||||
// start with its address, so that if the type has pointer methods,
|
||||
// we find them.
|
||||
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
|
||||
v = v.Addr()
|
||||
}
|
||||
for {
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
break
|
||||
}
|
||||
|
||||
vi := v.Interface()
|
||||
if m, ok := vi.(Marshaler); ok {
|
||||
return m, nil, reflect.Value{}
|
||||
}
|
||||
if m, ok := vi.(encoding.TextMarshaler); ok {
|
||||
return nil, m, reflect.Value{}
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Ptr {
|
||||
break
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
}
|
||||
return nil, nil, indirect(v, false)
|
||||
}
|
||||
|
||||
type definition struct {
|
||||
key string
|
||||
value reflect.Value
|
||||
}
|
||||
|
||||
type dictionary []definition
|
||||
|
||||
func (d dictionary) Len() int { return len(d) }
|
||||
func (d dictionary) Less(i, j int) bool { return d[i].key < d[j].key }
|
||||
func (d dictionary) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
|
||||
|
||||
func readStruct(dict dictionary, v reflect.Value) (dictionary, error) {
|
||||
t := v.Type()
|
||||
var (
|
||||
fieldValue reflect.Value
|
||||
rkey string
|
||||
)
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
key := t.Field(i)
|
||||
rkey = key.Name
|
||||
fieldValue = v.FieldByIndex(key.Index)
|
||||
|
||||
// filter out unexported values etc.
|
||||
if !fieldValue.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
// filter out nil pointer values
|
||||
if isNilValue(fieldValue) {
|
||||
continue
|
||||
}
|
||||
|
||||
// * Near identical to usage in JSON except with key 'bencode'
|
||||
//
|
||||
// * Struct values encode as BEncode dictionaries. Each exported
|
||||
// struct field becomes a set in the dictionary unless
|
||||
// - the field's tag is "-", or
|
||||
// - the field is empty and its tag specifies the "omitempty"
|
||||
// option.
|
||||
//
|
||||
// * The default key string is the struct field name but can be
|
||||
// specified in the struct field's tag value. The "bencode"
|
||||
// key in struct field's tag value is the key name, followed
|
||||
// by an optional comma and options.
|
||||
tagValue := key.Tag.Get("bencode")
|
||||
if tagValue != "" {
|
||||
// Keys with '-' are omit from output
|
||||
if tagValue == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
name, options := parseTag(tagValue)
|
||||
// Keys with 'omitempty' are omitted if the field is empty
|
||||
if options.Contains("omitempty") && isEmptyValue(fieldValue) {
|
||||
continue
|
||||
}
|
||||
|
||||
// All other values are treated as the key string
|
||||
if isValidTag(name) {
|
||||
rkey = name
|
||||
}
|
||||
}
|
||||
|
||||
if key.Anonymous && key.Type.Kind() == reflect.Struct && tagValue == "" {
|
||||
var err error
|
||||
dict, err = readStruct(dict, fieldValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dict = append(dict, definition{rkey, fieldValue})
|
||||
}
|
||||
}
|
||||
return dict, nil
|
||||
}
|
109
bencode/encode_decode_test.go
Normal file
109
bencode/encode_decode_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type myBoolType bool
|
||||
|
||||
// MarshalBencode implements Marshaler.MarshalBencode
|
||||
func (mbt myBoolType) MarshalBencode() ([]byte, error) {
|
||||
var c string
|
||||
if mbt {
|
||||
c = "y"
|
||||
} else {
|
||||
c = "n"
|
||||
}
|
||||
|
||||
return EncodeBytes(c)
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
|
||||
func (mbt *myBoolType) UnmarshalBencode(b []byte) error {
|
||||
var str string
|
||||
err := DecodeBytes(b, &str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch str {
|
||||
case "y":
|
||||
*mbt = true
|
||||
case "n":
|
||||
*mbt = false
|
||||
default:
|
||||
err = errors.New("invalid myBoolType")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type myBoolTextType bool
|
||||
|
||||
// MarshalText implements TextMarshaler.MarshalText
|
||||
func (mbt myBoolTextType) MarshalText() ([]byte, error) {
|
||||
if mbt {
|
||||
return []byte("y"), nil
|
||||
}
|
||||
|
||||
return []byte("n"), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements TextUnmarshaler.UnmarshalText
|
||||
func (mbt *myBoolTextType) UnmarshalText(b []byte) error {
|
||||
switch string(b) {
|
||||
case "y":
|
||||
*mbt = true
|
||||
case "n":
|
||||
*mbt = false
|
||||
default:
|
||||
return errors.New("invalid myBoolType")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type myTimeType struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
// MarshalBencode implements Marshaler.MarshalBencode
|
||||
func (mtt myTimeType) MarshalBencode() ([]byte, error) {
|
||||
return EncodeBytes(mtt.Time.Unix())
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
|
||||
func (mtt *myTimeType) UnmarshalBencode(b []byte) error {
|
||||
var epoch int64
|
||||
err := DecodeBytes(b, &epoch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mtt.Time = time.Unix(epoch, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
type errorMarshalType struct{}
|
||||
|
||||
// MarshalBencode implements Marshaler.MarshalBencode
|
||||
func (emt errorMarshalType) MarshalBencode() ([]byte, error) {
|
||||
return nil, errors.New("oops")
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
|
||||
func (emt errorMarshalType) UnmarshalBencode([]byte) error {
|
||||
return errors.New("oops")
|
||||
}
|
||||
|
||||
type errorTextMarshalType struct{}
|
||||
|
||||
// MarshalText implements TextMarshaler.MarshalText
|
||||
func (emt errorTextMarshalType) MarshalText() ([]byte, error) {
|
||||
return nil, errors.New("oops")
|
||||
}
|
||||
|
||||
// UnmarshalText implements TextUnmarshaler.UnmarshalText
|
||||
func (emt errorTextMarshalType) UnmarshalText([]byte) error {
|
||||
return errors.New("oops")
|
||||
}
|
354
bencode/encode_test.go
Normal file
354
bencode/encode_test.go
Normal file
@ -0,0 +1,354 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
type encodeTestCase struct {
|
||||
in interface{}
|
||||
out string
|
||||
err bool
|
||||
}
|
||||
|
||||
type eT struct {
|
||||
A string
|
||||
X string `bencode:"D"`
|
||||
Y string `bencode:"B"`
|
||||
Z string `bencode:"C"`
|
||||
}
|
||||
|
||||
type sortProblem struct {
|
||||
A string
|
||||
B string `bencode:","`
|
||||
}
|
||||
|
||||
type issue18Sub struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type issue18 struct {
|
||||
T *issue18Sub
|
||||
}
|
||||
|
||||
type Embedded struct {
|
||||
B string
|
||||
}
|
||||
|
||||
type issue22 struct {
|
||||
Time myTimeType `bencode:"t"`
|
||||
Foo myBoolType `bencode:"f"`
|
||||
}
|
||||
|
||||
type issue22WithErrorChild struct {
|
||||
Name string `bencode:"n"`
|
||||
Error errorMarshalType `bencode:"e"`
|
||||
}
|
||||
|
||||
type issue26 struct {
|
||||
Answer int64 `bencode:"a"`
|
||||
Foo myBoolTextType `bencode:"f"`
|
||||
Name string `bencode:"n"`
|
||||
}
|
||||
|
||||
type issue26WithErrorChild struct {
|
||||
Name string `bencode:"n"`
|
||||
Error errorTextMarshalType `bencode:"e"`
|
||||
}
|
||||
|
||||
type issue28 struct {
|
||||
X string `bencode:"x"`
|
||||
Time *myTimeType `bencode:"t"`
|
||||
Foo myBoolPtrType `bencode:"f"`
|
||||
Y string `bencode:"y"`
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
var encodeCases = []encodeTestCase{
|
||||
// integers
|
||||
{10, `i10e`, false},
|
||||
{-10, `i-10e`, false},
|
||||
{0, `i0e`, false},
|
||||
{int(10), `i10e`, false},
|
||||
{int8(10), `i10e`, false},
|
||||
{int16(10), `i10e`, false},
|
||||
{int32(10), `i10e`, false},
|
||||
{int64(10), `i10e`, false},
|
||||
{uint(10), `i10e`, false},
|
||||
{uint8(10), `i10e`, false},
|
||||
{uint16(10), `i10e`, false},
|
||||
{uint32(10), `i10e`, false},
|
||||
{uint64(10), `i10e`, false},
|
||||
{(*int)(nil), ``, false},
|
||||
|
||||
// ptr-to-integer
|
||||
{func() *int {
|
||||
i := 42
|
||||
return &i
|
||||
}(), `i42e`, false},
|
||||
|
||||
// strings
|
||||
{"foo", `3:foo`, false},
|
||||
{"barbb", `5:barbb`, false},
|
||||
{"", `0:`, false},
|
||||
{(*string)(nil), ``, false},
|
||||
|
||||
// ptr-to-string
|
||||
{func() *string {
|
||||
str := "foo"
|
||||
return &str
|
||||
}(), `3:foo`, false},
|
||||
|
||||
// lists
|
||||
{[]interface{}{"foo", 20}, `l3:fooi20ee`, false},
|
||||
{[]interface{}{90, 20}, `li90ei20ee`, false},
|
||||
{[]interface{}{[]interface{}{"foo", "bar"}, 20}, `ll3:foo3:barei20ee`, false},
|
||||
{[]map[string]int{
|
||||
{"a": 0, "b": 1},
|
||||
{"c": 2, "d": 3},
|
||||
}, `ld1:ai0e1:bi1eed1:ci2e1:di3eee`, false},
|
||||
{[][]byte{
|
||||
[]byte{'0', '2', '4', '6', '8'},
|
||||
[]byte{'a', 'c', 'e'},
|
||||
}, `l5:024683:acee`, false},
|
||||
{(*[]interface{})(nil), ``, false},
|
||||
|
||||
// boolean
|
||||
{true, "i1e", false},
|
||||
{false, "i0e", false},
|
||||
{(*bool)(nil), ``, false},
|
||||
|
||||
// dicts
|
||||
{map[string]interface{}{
|
||||
"a": "foo",
|
||||
"c": "bar",
|
||||
"b": "tes",
|
||||
}, `d1:a3:foo1:b3:tes1:c3:bare`, false},
|
||||
{eT{"foo", "bar", "far", "boo"}, `d1:A3:foo1:B3:far1:C3:boo1:D3:bare`, false},
|
||||
{map[string][]int{
|
||||
"a": {0, 1},
|
||||
"b": {2, 3},
|
||||
}, `d1:ali0ei1ee1:bli2ei3eee`, false},
|
||||
{struct{ A, b int }{1, 2}, "d1:Ai1ee", false},
|
||||
{(*struct{ A int })(nil), ``, false},
|
||||
|
||||
// raw
|
||||
{RawMessage(`i5e`), `i5e`, false},
|
||||
{[]RawMessage{
|
||||
RawMessage(`i5e`),
|
||||
RawMessage(`5:hello`),
|
||||
RawMessage(`ldededee`),
|
||||
}, `li5e5:helloldededeee`, false},
|
||||
{map[string]RawMessage{
|
||||
"a": RawMessage(`i5e`),
|
||||
"b": RawMessage(`5:hello`),
|
||||
"c": RawMessage(`ldededee`),
|
||||
}, `d1:ai5e1:b5:hello1:cldededeee`, false},
|
||||
|
||||
// problem sorting
|
||||
{sortProblem{A: "foo", B: "bar"}, `d1:A3:foo1:B3:bare`, false},
|
||||
|
||||
// nil values dropped from maps and structs
|
||||
{map[string]*int{"a": nil}, `de`, false},
|
||||
{struct{ A *int }{nil}, `de`, false},
|
||||
{issue18{}, `de`, false},
|
||||
{map[string]interface{}{"a": nil}, `de`, false},
|
||||
{struct{ A interface{} }{nil}, `de`, false},
|
||||
|
||||
// embedded structs
|
||||
{struct {
|
||||
A string
|
||||
Embedded
|
||||
}{"foo", Embedded{"bar"}}, `d1:A3:foo1:B3:bare`, false},
|
||||
{struct {
|
||||
A string
|
||||
Embedded `bencode:"C"`
|
||||
}{"foo", Embedded{"bar"}}, `d1:A3:foo1:Cd1:B3:baree`, false},
|
||||
|
||||
// embedded structs order issue #20
|
||||
{struct {
|
||||
Embedded
|
||||
A string
|
||||
}{Embedded{"bar"}, "foo"}, `d1:A3:foo1:B3:bare`, false},
|
||||
|
||||
// types which implement the Marshal interface will
|
||||
// be marshalled using this interface
|
||||
{myBoolType(true), `1:y`, false},
|
||||
{myBoolType(false), `1:n`, false},
|
||||
{myTimeType{now}, fmt.Sprintf("i%de", now.Unix()), false},
|
||||
{errorMarshalType{}, "", true},
|
||||
|
||||
// pointers to types which implement the Marshal interface will
|
||||
// be marshalled using this interface
|
||||
{func() *myBoolType {
|
||||
b := myBoolType(true)
|
||||
return &b
|
||||
}(), `1:y`, false},
|
||||
{func() *myTimeType {
|
||||
t := myTimeType{now}
|
||||
return &t
|
||||
}(), fmt.Sprintf("i%de", now.Unix()), false},
|
||||
{func() *errorMarshalType {
|
||||
e := errorMarshalType{}
|
||||
return &e
|
||||
}(), "", true},
|
||||
|
||||
// nil-pointers to types which implement the Marshal interface will be ignored
|
||||
{(*myBoolType)(nil), "", false},
|
||||
{(*myTimeType)(nil), "", false},
|
||||
{(*errorMarshalType)(nil), "", false},
|
||||
|
||||
// ptr-types which implements the Marshal interface will
|
||||
// be marshalled using this interface
|
||||
{func() *myBoolPtrType {
|
||||
b := myBoolPtrType(true)
|
||||
return &b
|
||||
}(), `1:y`, false},
|
||||
{func() *myBoolPtrType {
|
||||
b := myBoolPtrType(false)
|
||||
return &b
|
||||
}(), `1:n`, false},
|
||||
{(*myBoolPtrType)(nil), ``, false},
|
||||
|
||||
// structures can also have children which support
|
||||
// the Marshal interface
|
||||
{
|
||||
issue22{Time: myTimeType{now}, Foo: myBoolType(true)},
|
||||
fmt.Sprintf("d1:f1:y1:ti%dee", now.Unix()),
|
||||
false,
|
||||
},
|
||||
{ // an error will be returned if a child can't be marshalled
|
||||
issue22WithErrorChild{Name: "Foo", Error: errorMarshalType{}},
|
||||
"", true,
|
||||
},
|
||||
// structures passed by reference which have children that support
|
||||
// the (Text)Marshal interface (by value or by reference),
|
||||
// will be marshaled using that interface
|
||||
{
|
||||
&issue22{Time: myTimeType{now}, Foo: myBoolType(true)},
|
||||
fmt.Sprintf("d1:f1:y1:ti%dee", now.Unix()),
|
||||
false,
|
||||
},
|
||||
{ // an error will be returned if a child can't be marshalled
|
||||
&issue22WithErrorChild{Name: "Foo", Error: errorMarshalType{}},
|
||||
"", true,
|
||||
},
|
||||
|
||||
// types which implement the TextMarshal interface will
|
||||
// be marshalled into a bencode string value using this interface
|
||||
{myBoolTextType(true), `1:y`, false},
|
||||
{myBoolTextType(false), `1:n`, false},
|
||||
{errorTextMarshalType{}, "", true},
|
||||
|
||||
// structures can also have children which support
|
||||
// the TextMarshal interface
|
||||
{
|
||||
issue26{Answer: 42, Foo: myBoolTextType(true), Name: "Nova"},
|
||||
`d1:ai42e1:f1:y1:n4:Novae`,
|
||||
false,
|
||||
},
|
||||
{ // an error will be returned if a child TextMarshaler returns an error
|
||||
issue26WithErrorChild{Name: "Foo", Error: errorTextMarshalType{}},
|
||||
"", true,
|
||||
},
|
||||
|
||||
// ptr types which are used as value types,
|
||||
// but which ptr version implement the Marshaler/TextMarshaler interface,
|
||||
// will still get marshalling using this interface, when possible
|
||||
{
|
||||
&issue28{X: "x", Time: &myTimeType{now}, Foo: myBoolPtrType(true), Y: "y"},
|
||||
fmt.Sprintf(`d1:f1:y1:ti%de1:x1:x1:y1:ye`, now.Unix()),
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range encodeCases {
|
||||
t.Logf("%d: %#v", i, tt.in)
|
||||
data, err := EncodeString(tt.in)
|
||||
if !tt.err && err != nil {
|
||||
t.Errorf("#%d: Unexpected err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("#%d: Expected err is nil", i)
|
||||
continue
|
||||
}
|
||||
if tt.out != data {
|
||||
t.Errorf("#%d: Val: %q != %q", i, data, tt.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type myBoolPtrType bool
|
||||
|
||||
// MarshalBencode implements Marshaler.MarshalBencode
|
||||
func (mbt *myBoolPtrType) MarshalBencode() ([]byte, error) {
|
||||
var c string
|
||||
if *mbt {
|
||||
c = "y"
|
||||
} else {
|
||||
c = "n"
|
||||
}
|
||||
|
||||
return EncodeBytes(c)
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
|
||||
func (mbt *myBoolPtrType) UnmarshalBencode(b []byte) error {
|
||||
var str string
|
||||
err := DecodeBytes(b, &str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch str {
|
||||
case "y":
|
||||
*mbt = true
|
||||
case "n":
|
||||
*mbt = false
|
||||
default:
|
||||
err = errors.New("invalid myBoolType")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func TestEncodeOmit(t *testing.T) {
|
||||
type encodeTestCase struct {
|
||||
in interface{}
|
||||
out string
|
||||
err bool
|
||||
}
|
||||
|
||||
type eT struct {
|
||||
A string `bencode:",omitempty"`
|
||||
B int `bencode:",omitempty"`
|
||||
C *int `bencode:",omitempty"`
|
||||
}
|
||||
|
||||
var encodeCases = []encodeTestCase{
|
||||
{eT{}, `de`, false},
|
||||
{eT{A: "a"}, `d1:A1:ae`, false},
|
||||
{eT{B: 5}, `d1:Bi5ee`, false},
|
||||
{eT{C: new(int)}, `d1:Ci0ee`, false},
|
||||
}
|
||||
|
||||
for i, tt := range encodeCases {
|
||||
data, err := EncodeString(tt.in)
|
||||
if !tt.err && err != nil {
|
||||
t.Errorf("#%d: Unexpected err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("#%d: Expected err is nil", i)
|
||||
continue
|
||||
}
|
||||
if tt.out != data {
|
||||
t.Errorf("#%d: Val: %q != %q", i, data, tt.out)
|
||||
}
|
||||
}
|
||||
}
|
67
bencode/example_test.go
Normal file
67
bencode/example_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
data string
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
)
|
||||
|
||||
func ExampleDecodeString() {
|
||||
var torrent interface{}
|
||||
if err := DecodeString(data, &torrent); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleEncodeString() {
|
||||
var torrent interface{}
|
||||
data, err := EncodeString(torrent)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(data)
|
||||
}
|
||||
|
||||
func ExampleDecodeBytes() {
|
||||
var torrent interface{}
|
||||
if err := DecodeBytes([]byte(data), &torrent); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleEncodeBytes() {
|
||||
var torrent interface{}
|
||||
data, err := EncodeBytes(torrent)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(data)
|
||||
}
|
||||
|
||||
func ExampleEncoder_Encode() {
|
||||
var x struct {
|
||||
Foo string
|
||||
Bar []string `bencode:"name"`
|
||||
}
|
||||
|
||||
enc := NewEncoder(w)
|
||||
if err := enc.Encode(x); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleDecoder_Decode() {
|
||||
dec := NewDecoder(r)
|
||||
var torrent struct {
|
||||
Announce string
|
||||
List [][]string `bencode:"announce-list"`
|
||||
}
|
||||
if err := dec.Decode(&torrent); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
5
bencode/raw.go
Normal file
5
bencode/raw.go
Normal file
@ -0,0 +1,5 @@
|
||||
package bencode
|
||||
|
||||
// RawMessage is a special type that will store the raw bencode data when
|
||||
// encoding or decoding.
|
||||
type RawMessage []byte
|
76
bencode/tag.go
Normal file
76
bencode/tag.go
Normal file
@ -0,0 +1,76 @@
|
||||
package bencode
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's "bencode"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string
|
||||
|
||||
// parseTag splits a struct field's tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
return tag[:idx], tagOptions(tag[idx+1:])
|
||||
}
|
||||
return tag, tagOptions("")
|
||||
}
|
||||
|
||||
// Contains returns whether checks that a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (options tagOptions) Contains(optionName string) bool {
|
||||
s := string(options)
|
||||
for s != "" {
|
||||
var next string
|
||||
i := strings.Index(s, ",")
|
||||
if i >= 0 {
|
||||
s, next = s[:i], s[i+1:]
|
||||
}
|
||||
if s == optionName {
|
||||
return true
|
||||
}
|
||||
s = next
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isValidTag(key string) bool {
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range key {
|
||||
if c != ' ' && c != '$' && c != '-' && c != '_' && c != '.' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func matchName(key string) func(string) bool {
|
||||
return func(s string) bool {
|
||||
return strings.ToLower(key) == strings.ToLower(s)
|
||||
}
|
||||
}
|
||||
|
||||
func isEmptyValue(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
|
||||
return v.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return v.IsNil()
|
||||
default:
|
||||
return v.IsZero()
|
||||
}
|
||||
}
|
181
dht/blacklist.go
Normal file
181
dht/blacklist.go
Normal file
@ -0,0 +1,181 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Blacklist is used to manage the ip blacklist.
|
||||
//
|
||||
// Notice: The implementation should clear the address existed for long time.
|
||||
type Blacklist interface {
|
||||
// In reports whether the address, ip and port, is in the blacklist.
|
||||
In(ip string, port int) bool
|
||||
|
||||
// If port is equal to 0, it should ignore port and only use ip when matching.
|
||||
Add(ip string, port int)
|
||||
|
||||
// If port is equal to 0, it should delete the address by only the ip.
|
||||
Del(ip string, port int)
|
||||
|
||||
// Close is used to notice the implementation to release the underlying
|
||||
// resource.
|
||||
Close()
|
||||
}
|
||||
|
||||
type noopBlacklist struct{}
|
||||
|
||||
func (nbl noopBlacklist) In(ip string, port int) bool { return false }
|
||||
func (nbl noopBlacklist) Add(ip string, port int) {}
|
||||
func (nbl noopBlacklist) Del(ip string, port int) {}
|
||||
func (nbl noopBlacklist) Close() {}
|
||||
|
||||
// NewNoopBlacklist returns a no-op Blacklist.
|
||||
func NewNoopBlacklist() Blacklist { return noopBlacklist{} }
|
||||
|
||||
// DebugBlacklist returns a new Blacklist to log the information as debug.
|
||||
func DebugBlacklist(bl Blacklist, logf func(string, ...interface{})) Blacklist {
|
||||
return logBlacklist{Blacklist: bl, logf: logf}
|
||||
}
|
||||
|
||||
type logBlacklist struct {
|
||||
Blacklist
|
||||
logf func(string, ...interface{})
|
||||
}
|
||||
|
||||
func (dbl logBlacklist) Add(ip string, port int) {
|
||||
dbl.logf("add the blacklist: ip=%s, port=%d", ip, port)
|
||||
dbl.Blacklist.Add(ip, port)
|
||||
}
|
||||
|
||||
func (dbl logBlacklist) Del(ip string, port int) {
|
||||
dbl.logf("delete the blacklist: ip=%s, port=%d", ip, port)
|
||||
dbl.Blacklist.Del(ip, port)
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
// NewMemoryBlacklist returns a blacklst implementation based on memory.
|
||||
//
|
||||
// if maxnum is equal to 0, no limit.
|
||||
func NewMemoryBlacklist(maxnum int, duration time.Duration) Blacklist {
|
||||
bl := &blacklist{
|
||||
num: maxnum,
|
||||
ips: make(map[string]*wrappedPort, 128),
|
||||
exit: make(chan struct{}),
|
||||
}
|
||||
go bl.loop(duration)
|
||||
return bl
|
||||
}
|
||||
|
||||
type wrappedPort struct {
|
||||
Time time.Time
|
||||
Enable bool
|
||||
Ports map[int]struct{}
|
||||
}
|
||||
|
||||
type blacklist struct {
|
||||
exit chan struct{}
|
||||
lock sync.RWMutex
|
||||
ips map[string]*wrappedPort
|
||||
num int
|
||||
}
|
||||
|
||||
func (bl *blacklist) loop(interval time.Duration) {
|
||||
tick := time.NewTicker(interval)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-bl.exit:
|
||||
return
|
||||
case now := <-tick.C:
|
||||
bl.lock.Lock()
|
||||
for ip, wp := range bl.ips {
|
||||
if now.Sub(wp.Time) > interval {
|
||||
delete(bl.ips, ip)
|
||||
}
|
||||
}
|
||||
bl.lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bl *blacklist) Close() {
|
||||
select {
|
||||
case <-bl.exit:
|
||||
default:
|
||||
close(bl.exit)
|
||||
}
|
||||
}
|
||||
|
||||
// In reports whether the address, ip and port, is in the blacklist.
|
||||
func (bl *blacklist) In(ip string, port int) (yes bool) {
|
||||
bl.lock.RLock()
|
||||
if wp, ok := bl.ips[ip]; ok {
|
||||
if wp.Enable {
|
||||
_, yes = wp.Ports[port]
|
||||
} else {
|
||||
yes = true
|
||||
}
|
||||
}
|
||||
bl.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (bl *blacklist) Add(ip string, port int) {
|
||||
bl.lock.Lock()
|
||||
wp, ok := bl.ips[ip]
|
||||
if !ok {
|
||||
if bl.num > 0 && len(bl.ips) >= bl.num {
|
||||
bl.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
wp = &wrappedPort{Enable: true}
|
||||
bl.ips[ip] = wp
|
||||
}
|
||||
|
||||
if port < 1 {
|
||||
wp.Enable = false
|
||||
wp.Ports = nil
|
||||
} else if wp.Ports == nil {
|
||||
wp.Ports = map[int]struct{}{port: struct{}{}}
|
||||
} else {
|
||||
wp.Ports[port] = struct{}{}
|
||||
}
|
||||
|
||||
wp.Time = time.Now()
|
||||
bl.lock.Unlock()
|
||||
}
|
||||
|
||||
func (bl *blacklist) Del(ip string, port int) {
|
||||
bl.lock.Lock()
|
||||
if wp, ok := bl.ips[ip]; ok {
|
||||
if port < 1 {
|
||||
delete(bl.ips, ip)
|
||||
} else if wp.Enable {
|
||||
switch len(wp.Ports) {
|
||||
case 0, 1:
|
||||
delete(bl.ips, ip)
|
||||
default:
|
||||
delete(wp.Ports, port)
|
||||
wp.Time = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
bl.lock.Unlock()
|
||||
}
|
80
dht/blacklist_test.go
Normal file
80
dht/blacklist_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (bl *blacklist) portsLen() (n int) {
|
||||
bl.lock.RLock()
|
||||
for _, wp := range bl.ips {
|
||||
n += len(wp.Ports)
|
||||
}
|
||||
bl.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (bl *blacklist) getIPs() []string {
|
||||
bl.lock.RLock()
|
||||
ips := make([]string, 0, len(bl.ips))
|
||||
for ip := range bl.ips {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
bl.lock.RUnlock()
|
||||
return ips
|
||||
}
|
||||
|
||||
func TestMemoryBlacklist(t *testing.T) {
|
||||
bl := NewMemoryBlacklist(3, time.Second).(*blacklist)
|
||||
defer bl.Close()
|
||||
|
||||
bl.Add("1.1.1.1", 123)
|
||||
bl.Add("1.1.1.1", 456)
|
||||
bl.Add("1.1.1.1", 789)
|
||||
bl.Add("2.2.2.2", 111)
|
||||
bl.Add("3.3.3.3", 0)
|
||||
bl.Add("4.4.4.4", 222)
|
||||
|
||||
ips := bl.getIPs()
|
||||
if len(ips) != 3 {
|
||||
t.Error(ips)
|
||||
} else {
|
||||
for _, ip := range ips {
|
||||
switch ip {
|
||||
case "1.1.1.1", "2.2.2.2", "3.3.3.3":
|
||||
default:
|
||||
t.Error(ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if n := bl.portsLen(); n != 4 {
|
||||
t.Errorf("expect port num 4, but got %d", n)
|
||||
}
|
||||
|
||||
if bl.In("1.1.1.1", 111) || !bl.In("1.1.1.1", 123) {
|
||||
t.Fail()
|
||||
}
|
||||
if !bl.In("3.3.3.3", 111) || bl.In("4.4.4.4", 222) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
bl.Del("3.3.3.3", 0)
|
||||
if bl.In("3.3.3.3", 111) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
864
dht/dht_server.go
Normal file
864
dht/dht_server.go
Normal file
@ -0,0 +1,864 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package dht implements the DHT Protocol. And you can use it to build or join
|
||||
// the DHT swarm network.
|
||||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
"github.com/xgfone/bt/krpc"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
const (
|
||||
queryMethodPing = "ping"
|
||||
queryMethodFindNode = "find_node"
|
||||
queryMethodGetPeers = "get_peers"
|
||||
queryMethodAnnouncePeer = "announce_peer"
|
||||
)
|
||||
|
||||
var errUnsupportedIPProtocol = fmt.Errorf("unsupported ip protocol")
|
||||
|
||||
// Predefine some ip protocol stacks.
|
||||
const (
|
||||
IPv4Protocol IPProtocolStack = 4
|
||||
IPv6Protocol IPProtocolStack = 6
|
||||
)
|
||||
|
||||
// IPProtocolStack represents the ip protocol stack, such as IPv4 or IPv6
|
||||
type IPProtocolStack uint8
|
||||
|
||||
// Result is used to pass the response result to the callback function.
|
||||
type Result struct {
|
||||
// Addr is the address of the peer where the request is sent to.
|
||||
//
|
||||
// Notice: it may be nil for "get_peers" request.
|
||||
Addr *net.UDPAddr
|
||||
|
||||
// For Error
|
||||
Code int // 0 represents the success.
|
||||
Reason string // Reason indicates the reason why the request failed when Code > 0.
|
||||
Timeout bool // Timeout indicates whether the response is timeout.
|
||||
|
||||
// The list of the address of the peers returned by GetPeers.
|
||||
Peers []metainfo.Address
|
||||
}
|
||||
|
||||
// Config is used to configure the DHT server.
|
||||
type Config struct {
|
||||
// K is the size of the bucket of the routing table.
|
||||
//
|
||||
// The default is 8.
|
||||
K int
|
||||
|
||||
// ID is the id of the current DHT server node.
|
||||
//
|
||||
// The default is generated randomly
|
||||
ID metainfo.Hash
|
||||
|
||||
// IPProtocols is used to specify the supported IP Protocol Stack.
|
||||
//
|
||||
// If not given, it will detect the address family of the server listens
|
||||
// automatically and set the supported ip protocol stack by the address
|
||||
// family. If fails, the default is []IPProtocolStack{IPv4Protocol}.
|
||||
// So, if the server is listening on the empty address, such as ":6881",
|
||||
// and only supports IPv6, you should specify it to "IPv6Protocol".
|
||||
IPProtocols []IPProtocolStack
|
||||
|
||||
// ReadOnly indicates whether the current node is read-only.
|
||||
//
|
||||
// If true, the DHT server will enter in the read-only mode.
|
||||
//
|
||||
// The default is false.
|
||||
//
|
||||
// BEP 43
|
||||
ReadOnly bool
|
||||
|
||||
// MsgSize is the maximum size of the DHT message.
|
||||
//
|
||||
// The default is 4096.
|
||||
MsgSize int
|
||||
|
||||
// SearchDepth is used to control the depth to send the "get_peers" or
|
||||
// "find_node" query recursively to get the peers storing the torrent
|
||||
// infohash.
|
||||
//
|
||||
// The default depth is 8.
|
||||
SearchDepth int
|
||||
|
||||
// RespTimeout is the response timeout, that's, the response is valid
|
||||
// only before the timeout reaches.
|
||||
//
|
||||
// The default is "10s".
|
||||
RespTimeout time.Duration
|
||||
|
||||
// RoutingTableStorage is used to store the nodes in the routing table.
|
||||
//
|
||||
// The default is nil.
|
||||
RoutingTableStorage RoutingTableStorage
|
||||
|
||||
// PeerManager is used to manage the peers on torrent infohash,
|
||||
// which is called when receiving the "get_peers" query.
|
||||
//
|
||||
// The default uses the inner token-peer manager.
|
||||
PeerManager PeerManager
|
||||
|
||||
// Blacklist is used to manage the ip blacklist.
|
||||
//
|
||||
// The default is NewMemoryBlacklist(1024, time.Hour*24*7).
|
||||
Blacklist Blacklist
|
||||
|
||||
// ErrorLog is used to log the error.
|
||||
//
|
||||
// The default is log.Printf.
|
||||
ErrorLog func(format string, args ...interface{})
|
||||
|
||||
// OnSearch is called when someone searches the torrent infohash,
|
||||
// that's, the "get_peers" query.
|
||||
//
|
||||
// The default callback does noting.
|
||||
OnSearch func(infohash string, ip net.IP, port uint16)
|
||||
|
||||
// OnTorrent is called when someone has the torrent infohash
|
||||
// or someone has just downloaded the torrent infohash,
|
||||
// that's, the "get_peers" response or "announce_peer" query.
|
||||
//
|
||||
// The default callback does noting.
|
||||
OnTorrent func(infohash string, ip net.IP, port uint16)
|
||||
|
||||
// HandleInMessage is used to intercept the incoming DHT message.
|
||||
// For example, you can debug the message as the log.
|
||||
//
|
||||
// Return true if going on handling by the default. Or return false.
|
||||
//
|
||||
// The default is nil.
|
||||
HandleInMessage func(*net.UDPAddr, *krpc.Message) bool
|
||||
|
||||
// HandleOutMessage is used to intercept the outgoing DHT message.
|
||||
// For example, you can debug the message as the log.
|
||||
//
|
||||
// Return (false, nil) if going on handling by the default.
|
||||
//
|
||||
// The default is nil.
|
||||
HandleOutMessage func(*net.UDPAddr, *krpc.Message) (wrote bool, err error)
|
||||
}
|
||||
|
||||
func (c Config) in(*net.UDPAddr, *krpc.Message) bool { return true }
|
||||
func (c Config) out(*net.UDPAddr, *krpc.Message) (bool, error) { return false, nil }
|
||||
|
||||
func (c *Config) set(conf ...Config) {
|
||||
if len(conf) > 0 {
|
||||
*c = conf[0]
|
||||
}
|
||||
|
||||
if c.K <= 0 {
|
||||
c.K = 8
|
||||
}
|
||||
if c.ID.IsZero() {
|
||||
c.ID = metainfo.NewRandomHash()
|
||||
}
|
||||
if c.MsgSize <= 0 {
|
||||
c.MsgSize = 4096
|
||||
}
|
||||
if c.ErrorLog == nil {
|
||||
c.ErrorLog = log.Printf
|
||||
}
|
||||
if c.SearchDepth < 1 {
|
||||
c.SearchDepth = 8
|
||||
}
|
||||
if c.RoutingTableStorage == nil {
|
||||
c.RoutingTableStorage = noopStorage{}
|
||||
}
|
||||
if c.Blacklist == nil {
|
||||
c.Blacklist = NewMemoryBlacklist(1024, time.Hour*24*7)
|
||||
}
|
||||
if c.RespTimeout == 0 {
|
||||
c.RespTimeout = time.Second * 10
|
||||
}
|
||||
if c.OnSearch == nil {
|
||||
c.OnSearch = func(string, net.IP, uint16) {}
|
||||
}
|
||||
if c.OnTorrent == nil {
|
||||
c.OnTorrent = func(string, net.IP, uint16) {}
|
||||
}
|
||||
if c.HandleInMessage == nil {
|
||||
c.HandleInMessage = c.in
|
||||
}
|
||||
if c.HandleOutMessage == nil {
|
||||
c.HandleOutMessage = c.out
|
||||
}
|
||||
}
|
||||
|
||||
// Server is a DHT server.
|
||||
type Server struct {
|
||||
conf Config
|
||||
exit chan struct{}
|
||||
conn net.PacketConn
|
||||
once sync.Once
|
||||
|
||||
ipv4 bool
|
||||
ipv6 bool
|
||||
want []krpc.Want
|
||||
|
||||
peerManager PeerManager
|
||||
// routingTable *routingTable
|
||||
routingTable4 *routingTable
|
||||
routingTable6 *routingTable
|
||||
tokenManager *tokenManager
|
||||
tokenPeerManager *tokenPeerManager
|
||||
transactionManager *transactionManager
|
||||
}
|
||||
|
||||
// NewServer returns a new DHT server.
|
||||
func NewServer(conn net.PacketConn, config ...Config) *Server {
|
||||
var conf Config
|
||||
conf.set(config...)
|
||||
|
||||
if len(conf.IPProtocols) == 0 {
|
||||
host, _, err := net.SplitHostPort(conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if ip := net.ParseIP(host); ipIsZero(ip) || ip.To4() != nil {
|
||||
conf.IPProtocols = []IPProtocolStack{IPv4Protocol}
|
||||
} else {
|
||||
conf.IPProtocols = []IPProtocolStack{IPv6Protocol}
|
||||
}
|
||||
}
|
||||
|
||||
var ipv4, ipv6 bool
|
||||
var want []krpc.Want
|
||||
for _, ip := range conf.IPProtocols {
|
||||
switch ip {
|
||||
case IPv4Protocol:
|
||||
ipv4 = true
|
||||
want = append(want, krpc.WantNodes)
|
||||
case IPv6Protocol:
|
||||
ipv6 = true
|
||||
want = append(want, krpc.WantNodes6)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
ipv4: ipv4,
|
||||
ipv6: ipv6,
|
||||
want: want,
|
||||
conn: conn,
|
||||
conf: conf,
|
||||
exit: make(chan struct{}),
|
||||
peerManager: conf.PeerManager,
|
||||
tokenManager: newTokenManager(),
|
||||
tokenPeerManager: newTokenPeerManager(),
|
||||
transactionManager: newTransactionManager(),
|
||||
}
|
||||
|
||||
s.routingTable4 = newRoutingTable(s, false)
|
||||
s.routingTable6 = newRoutingTable(s, true)
|
||||
if s.peerManager == nil {
|
||||
s.peerManager = s.tokenPeerManager
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// ID returns the ID of the DHT server node.
|
||||
func (s *Server) ID() metainfo.Hash { return s.conf.ID }
|
||||
|
||||
// Bootstrap initializes the routing table at first.
|
||||
//
|
||||
// Notice: If the routing table has had some nodes, it does noting.
|
||||
func (s *Server) Bootstrap(addrs []string) {
|
||||
if (s.ipv4 && s.routingTable4.Len() == 0) ||
|
||||
(s.ipv6 && s.routingTable6.Len() == 0) {
|
||||
for _, addr := range addrs {
|
||||
as, err := metainfo.NewAddressesFromString(addr)
|
||||
if err != nil {
|
||||
s.conf.ErrorLog(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, a := range as {
|
||||
if err = s.FindNode(a.UDPAddr(), s.conf.ID); err != nil {
|
||||
s.conf.ErrorLog(`fail to bootstrap '%s': %s`, a.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Node4Num returns the number of the ipv4 nodes in the routing table.
|
||||
func (s *Server) Node4Num() int { return s.routingTable4.Len() }
|
||||
|
||||
// Node6Num returns the number of the ipv6 nodes in the routing table.
|
||||
func (s *Server) Node6Num() int { return s.routingTable6.Len() }
|
||||
|
||||
// AddNode adds the node into the routing table.
|
||||
//
|
||||
// The returned value:
|
||||
// NodeAdded: The node is added successfully.
|
||||
// NodeNotAdded: The node is not added and is discarded.
|
||||
// NodeExistAndUpdated: The node has existed, and its status has been updated.
|
||||
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
|
||||
// The current node will be discarded.
|
||||
//
|
||||
func (s *Server) AddNode(node krpc.Node) int {
|
||||
// For IPv6
|
||||
if isIPv6(node.Addr.IP) {
|
||||
if s.ipv6 {
|
||||
return s.routingTable6.AddNode(node)
|
||||
}
|
||||
return NodeNotAdded
|
||||
}
|
||||
|
||||
// For IPv4
|
||||
if s.ipv4 {
|
||||
return s.routingTable4.AddNode(node)
|
||||
}
|
||||
|
||||
return NodeNotAdded
|
||||
}
|
||||
|
||||
func (s *Server) addNode(a *net.UDPAddr, id metainfo.Hash, ro bool) (r int) {
|
||||
if ro { // BEP 43
|
||||
return NodeNotAdded
|
||||
}
|
||||
|
||||
if r = s.AddNode(krpc.NewNodeByUDPAddr(id, a)); r == NodeExistAndChanged {
|
||||
s.conf.Blacklist.Add(a.IP.String(), a.Port)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) addNode2(node krpc.Node, ro bool) int {
|
||||
if ro { // BEP 43
|
||||
return NodeNotAdded
|
||||
}
|
||||
return s.AddNode(node)
|
||||
}
|
||||
|
||||
func (s *Server) stop() {
|
||||
close(s.exit)
|
||||
s.routingTable4.Stop()
|
||||
s.routingTable6.Stop()
|
||||
s.tokenManager.Stop()
|
||||
s.tokenPeerManager.Stop()
|
||||
s.transactionManager.Stop()
|
||||
s.conf.Blacklist.Close()
|
||||
}
|
||||
|
||||
// Close stops the DHT server.
|
||||
func (s *Server) Close() { s.once.Do(s.stop) }
|
||||
|
||||
// Sync is used to synchronize the routing table to the underlying storage.
|
||||
func (s *Server) Sync() {
|
||||
s.routingTable4.Sync()
|
||||
s.routingTable6.Sync()
|
||||
}
|
||||
|
||||
// Run starts the DHT server.
|
||||
func (s *Server) Run() {
|
||||
go s.routingTable4.Start(time.Minute * 5)
|
||||
go s.routingTable6.Start(time.Minute * 5)
|
||||
go s.tokenManager.Start(time.Minute * 10)
|
||||
go s.tokenPeerManager.Start(time.Hour * 24)
|
||||
go s.transactionManager.Start(s, s.conf.RespTimeout)
|
||||
|
||||
buf := make([]byte, s.conf.MsgSize)
|
||||
for {
|
||||
n, raddr, err := s.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
s.conf.ErrorLog("fail to read the dht message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.handlePacket(raddr.(*net.UDPAddr), buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) isDisabled(raddr *net.UDPAddr) bool {
|
||||
if isIPv6(raddr.IP) {
|
||||
if !s.ipv6 {
|
||||
return true
|
||||
}
|
||||
} else if !s.ipv4 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HandlePacket handles the incoming DHT message.
|
||||
func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) {
|
||||
if s.isDisabled(raddr) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check whether the raddr is in the ip blacklist. If yes, discard it.
|
||||
if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) {
|
||||
return
|
||||
}
|
||||
|
||||
var msg krpc.Message
|
||||
if err := bencode.DecodeBytes(data, &msg); err != nil {
|
||||
s.conf.ErrorLog("decode krpc message error: %s", err)
|
||||
return
|
||||
} else if msg.T == "" {
|
||||
s.conf.ErrorLog("no transaction id from '%s'", raddr)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Should we use a task pool??
|
||||
go s.handleMessage(raddr, msg)
|
||||
}
|
||||
|
||||
func (s *Server) handleMessage(raddr *net.UDPAddr, m krpc.Message) {
|
||||
if !s.conf.HandleInMessage(raddr, &m) {
|
||||
return
|
||||
}
|
||||
|
||||
switch m.Y {
|
||||
case "q":
|
||||
if !m.A.ID.IsZero() {
|
||||
r := s.addNode(raddr, m.A.ID, m.RO)
|
||||
if r != NodeExistAndChanged && !s.conf.ReadOnly { // BEP 43
|
||||
s.handleQuery(raddr, m)
|
||||
}
|
||||
}
|
||||
case "r":
|
||||
if !m.R.ID.IsZero() {
|
||||
if s.addNode(raddr, m.R.ID, m.RO) == NodeExistAndChanged {
|
||||
return
|
||||
}
|
||||
|
||||
if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil {
|
||||
t.OnResponse(t, raddr, m)
|
||||
}
|
||||
}
|
||||
case "e":
|
||||
if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil {
|
||||
t.OnError(t, m.E.Code, m.E.Reason)
|
||||
}
|
||||
default:
|
||||
s.conf.ErrorLog("unknown dht message type '%s'", m.Y)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) {
|
||||
switch m.Q {
|
||||
case queryMethodPing:
|
||||
s.reply(raddr, m.T, krpc.ResponseResult{})
|
||||
case queryMethodFindNode: // See BEP 32
|
||||
var r krpc.ResponseResult
|
||||
n4 := m.A.ContainsWant(krpc.WantNodes)
|
||||
n6 := m.A.ContainsWant(krpc.WantNodes6)
|
||||
if !n4 && !n6 {
|
||||
if isIPv6(raddr.IP) {
|
||||
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
|
||||
} else {
|
||||
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
|
||||
}
|
||||
} else {
|
||||
if n4 {
|
||||
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
|
||||
}
|
||||
if n6 {
|
||||
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
|
||||
}
|
||||
}
|
||||
s.reply(raddr, m.T, r)
|
||||
case queryMethodGetPeers: // See BEP 32
|
||||
n4 := m.A.ContainsWant(krpc.WantNodes)
|
||||
n6 := m.A.ContainsWant(krpc.WantNodes6)
|
||||
|
||||
// Get the ipv4/ipv6 peers storing the torrent infohash.
|
||||
var r krpc.ResponseResult
|
||||
if !n4 && !n6 {
|
||||
r.Values = s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, isIPv6(raddr.IP))
|
||||
} else {
|
||||
if n4 {
|
||||
r.Values = s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, false)
|
||||
}
|
||||
|
||||
if n6 {
|
||||
values := s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, true)
|
||||
if len(r.Values) == 0 {
|
||||
r.Values = values
|
||||
} else {
|
||||
r.Values = append(r.Values, values...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No Peers, and return the closest other nodes.
|
||||
if len(r.Values) == 0 {
|
||||
if !n4 && !n6 {
|
||||
if isIPv6(raddr.IP) {
|
||||
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
|
||||
} else {
|
||||
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
|
||||
}
|
||||
} else {
|
||||
if n4 {
|
||||
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
|
||||
}
|
||||
if n6 {
|
||||
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.Token = s.tokenManager.Token(raddr)
|
||||
s.reply(raddr, m.T, r)
|
||||
s.conf.OnSearch(m.A.InfoHash.HexString(), raddr.IP, uint16(raddr.Port))
|
||||
case queryMethodAnnouncePeer:
|
||||
if s.tokenManager.Check(raddr, m.A.Token) {
|
||||
return
|
||||
}
|
||||
s.reply(raddr, m.T, krpc.ResponseResult{})
|
||||
s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.IP, m.A.GetPort(raddr.Port))
|
||||
default:
|
||||
s.sendError(raddr, m.T, "unknown query method", krpc.ErrorCodeMethodUnknown)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) {
|
||||
// // TODO: Should we check the ip blacklist??
|
||||
// if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) {
|
||||
// return
|
||||
// }
|
||||
|
||||
m.RO = s.conf.ReadOnly // BEP 43
|
||||
if wrote, err = s.conf.HandleOutMessage(raddr, &m); !wrote && err == nil {
|
||||
wrote, err = s._send(raddr, m)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) {
|
||||
if m.T == "" || m.Y == "" {
|
||||
panic(`DHT message "t" or "y" must not be empty`)
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(128)
|
||||
if err = bencode.NewEncoder(buf).Encode(m); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
n, err := s.conn.WriteTo(buf.Bytes(), raddr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error writing %d bytes to %s: %s", buf.Len(), raddr, err)
|
||||
s.conf.Blacklist.Add(raddr.IP.String(), 0)
|
||||
return
|
||||
}
|
||||
|
||||
wrote = true
|
||||
if n != buf.Len() {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) sendError(raddr *net.UDPAddr, tid, reason string, code int) {
|
||||
if _, err := s.send(raddr, krpc.NewErrorMsg(tid, code, reason)); err != nil {
|
||||
s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) reply(raddr *net.UDPAddr, tid string, r krpc.ResponseResult) {
|
||||
r.ID = s.conf.ID
|
||||
if _, err := s.send(raddr, krpc.NewResponseMsg(tid, r)); err != nil {
|
||||
s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) request(t *transaction) (err error) {
|
||||
if s.isDisabled(t.Addr) {
|
||||
return errUnsupportedIPProtocol
|
||||
}
|
||||
|
||||
t.Arg.ID = s.conf.ID
|
||||
if t.ID == "" {
|
||||
t.ID = s.transactionManager.GetTransactionID()
|
||||
}
|
||||
if _, err = s.send(t.Addr, krpc.NewQueryMsg(t.ID, t.Query, t.Arg)); err != nil {
|
||||
s.conf.ErrorLog("error replying to %s: %s", t.Addr.String(), err.Error())
|
||||
} else {
|
||||
s.transactionManager.AddTransaction(t)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) onError(t *transaction, code int, reason string) {
|
||||
s.conf.ErrorLog("got an error to ping '%s': code=%d, reason=%s",
|
||||
t.Addr.String(), code, reason)
|
||||
t.Done(Result{Code: code, Reason: reason})
|
||||
}
|
||||
|
||||
func (s *Server) onTimeout(t *transaction) {
|
||||
// TODO: Should we use a task pool??
|
||||
t.Done(Result{Timeout: true})
|
||||
s.conf.ErrorLog("transaction '%s' timeout: query=%s, raddr=%s",
|
||||
t.ID, t.Query, t.Addr.String())
|
||||
}
|
||||
|
||||
func (s *Server) onPingResp(t *transaction, a *net.UDPAddr, m krpc.Message) {
|
||||
t.Done(Result{})
|
||||
}
|
||||
|
||||
func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) {
|
||||
// Store the response node with the token.
|
||||
if m.R.Token != "" {
|
||||
s.tokenPeerManager.Set(m.R.ID, a, m.R.Token)
|
||||
}
|
||||
|
||||
// Get the peers.
|
||||
if len(m.R.Values) > 0 {
|
||||
t.Done(Result{Peers: m.R.Values})
|
||||
for _, addr := range m.R.Values {
|
||||
s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr.IP, addr.Port)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Search the torrent infohash recursively.
|
||||
t.Depth--
|
||||
if t.Depth < 1 {
|
||||
t.Done(Result{})
|
||||
return
|
||||
}
|
||||
|
||||
var found bool
|
||||
ids := t.Visited
|
||||
nodes := make([]krpc.Node, 0, len(m.R.Nodes)+len(m.R.Nodes6))
|
||||
for _, node := range m.R.Nodes {
|
||||
if node.ID == t.Arg.InfoHash {
|
||||
found = true
|
||||
}
|
||||
if ids.Contains(node.ID) {
|
||||
continue
|
||||
}
|
||||
if s.addNode2(node, m.RO) == NodeAdded {
|
||||
nodes = append(nodes, node)
|
||||
ids = append(ids, node.ID)
|
||||
}
|
||||
}
|
||||
for _, node := range m.R.Nodes6 {
|
||||
if node.ID == t.Arg.InfoHash {
|
||||
found = true
|
||||
}
|
||||
if ids.Contains(node.ID) {
|
||||
continue
|
||||
}
|
||||
if s.addNode2(node, m.RO) == NodeAdded {
|
||||
nodes = append(nodes, node)
|
||||
ids = append(ids, node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if found || len(nodes) == 0 {
|
||||
t.Done(Result{})
|
||||
return
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
s.getPeers(t.Arg.InfoHash, node.Addr, t.Depth, ids, t.Callback)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int,
|
||||
ids metainfo.Hashes, cb ...func(Result)) {
|
||||
arg := krpc.QueryArg{InfoHash: info, Wants: s.want}
|
||||
t := newTransaction(s, addr.UDPAddr(), queryMethodGetPeers, arg, cb...)
|
||||
t.OnResponse = s.onGetPeersResp
|
||||
t.Depth = depth
|
||||
t.Visited = ids
|
||||
if err := s.request(t); err != nil {
|
||||
s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ping sends a PING query to addr, and the callback function cb will be called
|
||||
// when the response or error is returned, or it's timeout.
|
||||
func (s *Server) Ping(addr *net.UDPAddr, cb ...func(Result)) (err error) {
|
||||
t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...)
|
||||
t.OnResponse = s.onPingResp
|
||||
return s.request(t)
|
||||
}
|
||||
|
||||
// GetPeers searches the peer storing the torrent by the infohash of the torrent,
|
||||
// which will search it recursively until some peers are returned or it reaches
|
||||
// the maximun depth, that's, ServerConfig.SearchDepth.
|
||||
//
|
||||
// If cb is given, it will be called when some peers are returned.
|
||||
// Notice: it may be called for many times.
|
||||
func (s *Server) GetPeers(infohash metainfo.Hash, cb ...func(Result)) {
|
||||
if infohash.IsZero() {
|
||||
panic("the infohash of the torrent is ZERO")
|
||||
}
|
||||
|
||||
var nodes []krpc.Node
|
||||
if s.ipv4 {
|
||||
nodes = s.routingTable4.Closest(infohash, s.conf.K)
|
||||
}
|
||||
if s.ipv6 {
|
||||
nodes = append(nodes, s.routingTable6.Closest(infohash, s.conf.K)...)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
if len(cb) != 0 && cb[0] != nil {
|
||||
cb[0](Result{})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ids := make(metainfo.Hashes, len(nodes))
|
||||
for i, node := range nodes {
|
||||
ids[i] = node.ID
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
s.getPeers(infohash, node.Addr, s.conf.SearchDepth, ids, cb...)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// AnnouncePeer announces the torrent infohash to the K closest nodes,
|
||||
// and returns the nodes to which it sends the announce_peer query.
|
||||
func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort bool) []krpc.Node {
|
||||
if infohash.IsZero() {
|
||||
panic("the infohash of the torrent is ZERO")
|
||||
}
|
||||
|
||||
var nodes []krpc.Node
|
||||
if s.ipv4 {
|
||||
nodes = s.routingTable4.Closest(infohash, s.conf.K)
|
||||
}
|
||||
if s.ipv6 {
|
||||
nodes = append(nodes, s.routingTable6.Closest(infohash, s.conf.K)...)
|
||||
}
|
||||
|
||||
sentNodes := make([]krpc.Node, 0, len(nodes))
|
||||
for _, node := range nodes {
|
||||
addr := node.Addr.UDPAddr()
|
||||
token := s.tokenPeerManager.Get(infohash, addr)
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
arg := krpc.QueryArg{ImpliedPort: impliedPort, InfoHash: infohash, Port: port, Token: token}
|
||||
t := newTransaction(s, addr, queryMethodAnnouncePeer, arg)
|
||||
if err := s.request(t); err != nil {
|
||||
s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err)
|
||||
} else {
|
||||
sentNodes = append(sentNodes, node)
|
||||
}
|
||||
}
|
||||
|
||||
return sentNodes
|
||||
}
|
||||
|
||||
// FindNode sends the "find_node" query to the addr to find the target node.
|
||||
//
|
||||
// Notice: In general, it's used to bootstrap the routing table.
|
||||
func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error {
|
||||
if target.IsZero() {
|
||||
panic("the target is ZERO")
|
||||
}
|
||||
|
||||
return s.findNode(target, addr, s.conf.SearchDepth, nil)
|
||||
}
|
||||
|
||||
func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int,
|
||||
ids metainfo.Hashes) error {
|
||||
arg := krpc.QueryArg{Target: target, Wants: s.want}
|
||||
t := newTransaction(s, addr, queryMethodFindNode, arg)
|
||||
t.OnResponse = s.onFindNodeResp
|
||||
t.Visited = ids
|
||||
return s.request(t)
|
||||
}
|
||||
|
||||
func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message) {
|
||||
// Search the target node recursively.
|
||||
t.Depth--
|
||||
if t.Depth < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var found bool
|
||||
ids := t.Visited
|
||||
nodes := make([]krpc.Node, 0, len(m.R.Nodes)+len(m.R.Nodes6))
|
||||
for _, node := range m.R.Nodes {
|
||||
if node.ID == t.Arg.Target {
|
||||
found = true
|
||||
}
|
||||
if ids.Contains(node.ID) {
|
||||
continue
|
||||
}
|
||||
if s.addNode2(node, m.RO) == NodeAdded {
|
||||
nodes = append(nodes, node)
|
||||
ids = append(ids, node.ID)
|
||||
}
|
||||
}
|
||||
for _, node := range m.R.Nodes6 {
|
||||
if node.ID == t.Arg.Target {
|
||||
found = true
|
||||
}
|
||||
if ids.Contains(node.ID) {
|
||||
continue
|
||||
}
|
||||
if s.addNode2(node, m.RO) == NodeAdded {
|
||||
nodes = append(nodes, node)
|
||||
ids = append(ids, node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if found || len(nodes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
err := s.findNode(t.Arg.Target, node.Addr.UDPAddr(), t.Depth, ids)
|
||||
if err != nil {
|
||||
s.conf.ErrorLog(`fail to send "find_node" query to '%s': %s`,
|
||||
node.Addr.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isIPv6(ip net.IP) bool {
|
||||
if ip.To4() == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ipIsZero(ip net.IP) bool {
|
||||
for _, b := range ip {
|
||||
if b != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
182
dht/dht_server_test.go
Normal file
182
dht/dht_server_test.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
type testPeerManager struct {
|
||||
lock sync.RWMutex
|
||||
peers map[metainfo.Hash][]metainfo.Address
|
||||
}
|
||||
|
||||
func newTestPeerManager() *testPeerManager {
|
||||
return &testPeerManager{peers: make(map[metainfo.Hash][]metainfo.Address)}
|
||||
}
|
||||
|
||||
func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr metainfo.Address) {
|
||||
pm.lock.Lock()
|
||||
var exist bool
|
||||
for _, orig := range pm.peers[infohash] {
|
||||
if orig.Equal(addr) {
|
||||
exist = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exist {
|
||||
pm.peers[infohash] = append(pm.peers[infohash], addr)
|
||||
}
|
||||
pm.lock.Unlock()
|
||||
}
|
||||
|
||||
func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int,
|
||||
ipv6 bool) (addrs []metainfo.Address) {
|
||||
// We only supports IPv4, so ignore the ipv6 argument.
|
||||
pm.lock.RLock()
|
||||
_addrs := pm.peers[infohash]
|
||||
if _len := len(_addrs); _len > 0 {
|
||||
if _len > maxnum {
|
||||
_len = maxnum
|
||||
}
|
||||
addrs = _addrs[:_len]
|
||||
}
|
||||
pm.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func onSearch(infohash string, ip net.IP, port uint16) {
|
||||
addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10))
|
||||
fmt.Printf("%s is searching %s\n", addr, infohash)
|
||||
}
|
||||
|
||||
func onTorrent(infohash string, ip net.IP, port uint16) {
|
||||
addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10))
|
||||
fmt.Printf("%s has downloaded %s\n", addr, infohash)
|
||||
}
|
||||
|
||||
func newDHTServer(id metainfo.Hash, addr string, pm PeerManager) (s *Server, err error) {
|
||||
conn, err := net.ListenPacket("udp", addr)
|
||||
if err == nil {
|
||||
c := Config{ID: id, PeerManager: pm, OnSearch: onSearch, OnTorrent: onTorrent}
|
||||
s = NewServer(conn, c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ExampleServer() {
|
||||
// For test, we predefine some node ids and infohash.
|
||||
id1 := metainfo.NewRandomHash()
|
||||
id2 := metainfo.NewRandomHash()
|
||||
id3 := metainfo.NewRandomHash()
|
||||
infohash := metainfo.Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
|
||||
|
||||
// Create first DHT server
|
||||
pm := newTestPeerManager()
|
||||
server1, err := newDHTServer(id1, ":9001", pm)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer server1.Close()
|
||||
|
||||
// Create second DHT server
|
||||
server2, err := newDHTServer(id2, ":9002", nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer server2.Close()
|
||||
|
||||
// Create third DHT server
|
||||
server3, err := newDHTServer(id3, ":9003", nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer server3.Close()
|
||||
|
||||
// Start the three DHT servers
|
||||
go server1.Run()
|
||||
go server2.Run()
|
||||
go server3.Run()
|
||||
|
||||
// Wait that the DHT servers start.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Bootstrap the routing table to create a DHT network with the three servers
|
||||
server1.Bootstrap([]string{"127.0.0.1:9002", "127.0.0.1:9003"})
|
||||
server2.Bootstrap([]string{"127.0.0.1:9001", "127.0.0.1:9003"})
|
||||
server3.Bootstrap([]string{"127.0.0.1:9001", "127.0.0.1:9002"})
|
||||
|
||||
// Wait that the DHT servers learn the routing table
|
||||
time.Sleep(time.Second)
|
||||
|
||||
fmt.Println("Server1:", server1.Node4Num())
|
||||
fmt.Println("Server2:", server2.Node4Num())
|
||||
fmt.Println("Server3:", server3.Node4Num())
|
||||
|
||||
server1.GetPeers(infohash, func(r Result) {
|
||||
if len(r.Peers) == 0 {
|
||||
fmt.Printf("no peers for %s\n", infohash)
|
||||
} else {
|
||||
for _, peer := range r.Peers {
|
||||
fmt.Printf("%s: %s\n", infohash, peer.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Wait that the last get_peers ends.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Add the peer to let the DHT server1 has the peer.
|
||||
pm.AddPeer(infohash, metainfo.NewAddress(net.ParseIP("127.0.0.1"), 9001))
|
||||
|
||||
// Search the torrent infohash again, but from DHT server2,
|
||||
// which will search the DHT server1 recursively.
|
||||
server2.GetPeers(infohash, func(r Result) {
|
||||
if len(r.Peers) == 0 {
|
||||
fmt.Printf("no peers for %s\n", infohash)
|
||||
} else {
|
||||
for _, peer := range r.Peers {
|
||||
fmt.Printf("%s: %s\n", infohash, peer.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Wait that the recursive call ends.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Unordered output:
|
||||
// Server1: 2
|
||||
// Server2: 2
|
||||
// Server3: 2
|
||||
// 127.0.0.1:9001 is searching 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// 127.0.0.1:9001 is searching 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// no peers for 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// no peers for 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// 127.0.0.1:9002 is searching 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// 127.0.0.1:9002 is searching 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// no peers for 0102030405060708090a0b0c0d0e0f1011121314
|
||||
// 0102030405060708090a0b0c0d0e0f1011121314: 127.0.0.1:9001
|
||||
// 127.0.0.1:9001 has downloaded 0102030405060708090a0b0c0d0e0f1011121314
|
||||
}
|
139
dht/peer_manager.go
Normal file
139
dht/peer_manager.go
Normal file
@ -0,0 +1,139 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// PeerManager is used to manage the peers.
|
||||
type PeerManager interface {
|
||||
// If ipv6 is true, only return ipv6 addresses. Or return ipv4 addresses.
|
||||
GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []metainfo.Address
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
ID metainfo.Hash
|
||||
IP net.IP
|
||||
Port uint16
|
||||
Token string
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
type tokenPeerManager struct {
|
||||
lock sync.RWMutex
|
||||
exit chan struct{}
|
||||
peers map[metainfo.Hash]map[string]peer
|
||||
}
|
||||
|
||||
func newTokenPeerManager() *tokenPeerManager {
|
||||
return &tokenPeerManager{
|
||||
exit: make(chan struct{}),
|
||||
peers: make(map[metainfo.Hash]map[string]peer, 128),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the token-peer manager.
|
||||
func (tpm *tokenPeerManager) Start(interval time.Duration) {
|
||||
tick := time.NewTicker(interval)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tpm.exit:
|
||||
return
|
||||
case now := <-tick.C:
|
||||
tpm.lock.Lock()
|
||||
for id, peers := range tpm.peers {
|
||||
for addr, peer := range peers {
|
||||
if now.Sub(peer.Time) > interval {
|
||||
delete(peers, addr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(peers) == 0 {
|
||||
delete(tpm.peers, id)
|
||||
}
|
||||
}
|
||||
tpm.lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token string) {
|
||||
addrkey := addr.String()
|
||||
tpm.lock.Lock()
|
||||
peers, ok := tpm.peers[id]
|
||||
if !ok {
|
||||
peers = make(map[string]peer, 4)
|
||||
tpm.peers[id] = peers
|
||||
}
|
||||
peers[addrkey] = peer{
|
||||
ID: id,
|
||||
IP: addr.IP,
|
||||
Port: uint16(addr.Port),
|
||||
Token: token,
|
||||
Time: time.Now(),
|
||||
}
|
||||
tpm.lock.Unlock()
|
||||
}
|
||||
|
||||
func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr *net.UDPAddr) (token string) {
|
||||
addrkey := addr.String()
|
||||
tpm.lock.RLock()
|
||||
if peers, ok := tpm.peers[id]; ok {
|
||||
if peer, ok := peers[addrkey]; ok {
|
||||
token = peer.Token
|
||||
}
|
||||
}
|
||||
tpm.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (tpm *tokenPeerManager) Stop() {
|
||||
select {
|
||||
case <-tpm.exit:
|
||||
default:
|
||||
close(tpm.exit)
|
||||
}
|
||||
}
|
||||
|
||||
func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int,
|
||||
ipv6 bool) (addrs []metainfo.Address) {
|
||||
addrs = make([]metainfo.Address, 0, maxnum)
|
||||
tpm.lock.RLock()
|
||||
if peers, ok := tpm.peers[infohash]; ok {
|
||||
for _, peer := range peers {
|
||||
if maxnum < 1 {
|
||||
break
|
||||
}
|
||||
|
||||
if ipv6 { // For IPv6
|
||||
if isIPv6(peer.IP) {
|
||||
maxnum--
|
||||
addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port))
|
||||
}
|
||||
} else if !isIPv6(peer.IP) { // For IPv4
|
||||
maxnum--
|
||||
addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port))
|
||||
}
|
||||
}
|
||||
}
|
||||
tpm.lock.RUnlock()
|
||||
return
|
||||
}
|
439
dht/routing_table.go
Normal file
439
dht/routing_table.go
Normal file
@ -0,0 +1,439 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/krpc"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
const bktlen = 160
|
||||
|
||||
const (
|
||||
badTimeout = time.Minute * 20
|
||||
dubiousTimeout = time.Minute * 15
|
||||
)
|
||||
|
||||
type routingTable struct {
|
||||
k int
|
||||
s *Server
|
||||
ipv6 bool
|
||||
sync chan struct{}
|
||||
exit chan struct{}
|
||||
root metainfo.Hash
|
||||
lock sync.RWMutex
|
||||
bkts []*bucket
|
||||
}
|
||||
|
||||
func newRoutingTable(s *Server, ipv6 bool) *routingTable {
|
||||
rt := &routingTable{
|
||||
k: s.conf.K,
|
||||
s: s,
|
||||
ipv6: ipv6,
|
||||
root: s.conf.ID,
|
||||
sync: make(chan struct{}),
|
||||
exit: make(chan struct{}),
|
||||
bkts: make([]*bucket, bktlen),
|
||||
}
|
||||
|
||||
for i := range rt.bkts {
|
||||
rt.bkts[i] = &bucket{table: rt}
|
||||
}
|
||||
|
||||
// Load all the nodes from the storage.
|
||||
nodes, err := s.conf.RoutingTableStorage.Load(s.conf.ID, ipv6)
|
||||
if err != nil {
|
||||
s.conf.ErrorLog("fail to load routing table(ipv6=%v): %s", ipv6, err)
|
||||
} else {
|
||||
now := time.Now()
|
||||
for _, node := range nodes {
|
||||
if now.Sub(node.LastChanged) < dubiousTimeout {
|
||||
rt.addNode(node.Node, node.LastChanged)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rt
|
||||
}
|
||||
|
||||
// Sync dumps the information of the routing table to the underlying storage.
|
||||
func (rt *routingTable) Sync() {
|
||||
select {
|
||||
case rt.sync <- struct{}{}:
|
||||
case <-rt.exit:
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the routing table.
|
||||
func (rt *routingTable) Start(interval time.Duration) {
|
||||
tick := time.NewTicker(interval)
|
||||
defer tick.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-rt.exit:
|
||||
rt.dump()
|
||||
return
|
||||
case <-rt.sync:
|
||||
rt.dump()
|
||||
case now := <-tick.C:
|
||||
rt.lock.Lock()
|
||||
rt.checkAllBuckets(now)
|
||||
rt.lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dump stores all the nodes into the underlying storage.
|
||||
func (rt *routingTable) dump() {
|
||||
nodes := make([]RoutingTableNode, 0, 128)
|
||||
rt.lock.RLock()
|
||||
for _, bkt := range rt.bkts {
|
||||
for _, n := range bkt.Nodes {
|
||||
nodes = append(nodes, RoutingTableNode{
|
||||
Node: krpc.Node{ID: n.ID, Addr: n.Addr},
|
||||
LastChanged: n.LastChanged,
|
||||
})
|
||||
}
|
||||
}
|
||||
rt.lock.RUnlock()
|
||||
|
||||
if len(nodes) > 0 {
|
||||
err := rt.s.conf.RoutingTableStorage.Dump(rt.root, nodes, rt.ipv6)
|
||||
if err != nil {
|
||||
rt.s.conf.ErrorLog("fail to dump nodes in routing table(ipv6=%v): %s", rt.ipv6, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *routingTable) checkAllBuckets(now time.Time) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
rt.s.conf.ErrorLog("panic: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, bkt := range rt.bkts {
|
||||
bkt.CheckAllNodes(now)
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *routingTable) Len() (n int) {
|
||||
rt.lock.RLock()
|
||||
for _, bkt := range rt.bkts {
|
||||
n += len(bkt.Nodes)
|
||||
}
|
||||
rt.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Stop stops the routing table.
|
||||
func (rt *routingTable) Stop() {
|
||||
select {
|
||||
case <-rt.exit:
|
||||
default:
|
||||
close(rt.exit)
|
||||
}
|
||||
}
|
||||
|
||||
// AddNode adds the node into the routing table.
|
||||
//
|
||||
// The returned value:
|
||||
// NodeAdded: The node is added successfully.
|
||||
// NodeNotAdded: The node is not added and is discarded.
|
||||
// NodeExistAndUpdated: The node has existed, and its status has been updated.
|
||||
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
|
||||
// The current node will be discarded.
|
||||
//
|
||||
func (rt *routingTable) AddNode(n krpc.Node) (r int) {
|
||||
if n.ID == rt.root { // Don't add itself.
|
||||
return NodeNotAdded
|
||||
}
|
||||
return rt.addNode(n, time.Now())
|
||||
}
|
||||
|
||||
func (rt *routingTable) addNode(n krpc.Node, now time.Time) (r int) {
|
||||
bktid := bucketid(rt.root, n.ID)
|
||||
rt.lock.Lock()
|
||||
r = rt.bkts[bktid].AddNode(n, now)
|
||||
rt.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Closest returns the maxnum number of the nodes which are the closest to target.
|
||||
func (rt *routingTable) Closest(target metainfo.Hash, maxnum int) (nodes []krpc.Node) {
|
||||
close := nodesByDistance{target: target, maxnum: maxnum}
|
||||
rt.lock.RLock()
|
||||
for _, b := range rt.bkts {
|
||||
for _, n := range b.Nodes {
|
||||
close.push(n.Node)
|
||||
}
|
||||
}
|
||||
rt.lock.RUnlock()
|
||||
return close.nodes
|
||||
}
|
||||
|
||||
type nodesByDistance struct {
|
||||
maxnum int
|
||||
target metainfo.Hash
|
||||
nodes []krpc.Node
|
||||
}
|
||||
|
||||
func (ns *nodesByDistance) push(node krpc.Node) {
|
||||
ix := sort.Search(len(ns.nodes), func(i int) bool {
|
||||
return distcmp(ns.target, ns.nodes[i].ID, node.ID) > 0
|
||||
})
|
||||
|
||||
if len(ns.nodes) < ns.maxnum {
|
||||
ns.nodes = append(ns.nodes, node)
|
||||
}
|
||||
|
||||
// slide existing nodes down to make room.
|
||||
// this will overwrite the entry we just appended.
|
||||
if ix != len(ns.nodes) {
|
||||
copy(ns.nodes[ix+1:], ns.nodes[ix:])
|
||||
ns.nodes[ix] = node
|
||||
}
|
||||
}
|
||||
|
||||
// distcmp compares the distances a->target and b->target.
|
||||
// Returns -1 if a is closer to target, 1 if b is closer to target
|
||||
// and 0 if they are equal.
|
||||
func distcmp(target, a, b metainfo.Hash) int {
|
||||
for i := range target {
|
||||
da := a[i] ^ target[i]
|
||||
db := b[i] ^ target[i]
|
||||
if da > db {
|
||||
return 1
|
||||
} else if da < db {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
/// Bucket
|
||||
|
||||
// Predefine some values returned by the method AddNode of routing table.
|
||||
const (
|
||||
NodeAdded = iota
|
||||
NodeNotAdded
|
||||
NodeExistAndUpdated
|
||||
NodeExistAndChanged
|
||||
)
|
||||
|
||||
type bucket struct {
|
||||
table *routingTable
|
||||
Nodes []*wrappedNode
|
||||
LastChanged time.Time
|
||||
}
|
||||
|
||||
func (b *bucket) emit(f func(metainfo.Hash, RoutingTableNode) error, n *wrappedNode) {
|
||||
node := RoutingTableNode{Node: n.Node, LastChanged: n.LastChanged}
|
||||
f(b.table.root, node)
|
||||
}
|
||||
|
||||
func (b *bucket) UpdateLastChangedTime(now time.Time) {
|
||||
b.LastChanged = now
|
||||
}
|
||||
|
||||
func (b *bucket) AddNode(n krpc.Node, now time.Time) (status int) {
|
||||
// Update the old one.
|
||||
for _, orig := range b.Nodes {
|
||||
if orig.ID == n.ID {
|
||||
if orig.Addr.Equal(n.Addr) {
|
||||
orig.UpdateLastChangedTime(now)
|
||||
status = NodeExistAndUpdated
|
||||
return
|
||||
}
|
||||
|
||||
// // TODO: Should we replace the old one??
|
||||
// b.UpdateLastChangedTime(now)
|
||||
// copy(b.Nodes[i:], b.Nodes[i+1:])
|
||||
// b.Nodes[len(b.Nodes)-1] = newWrappedNode(b, n, now)
|
||||
// status = nodeExistAndChanged
|
||||
return NodeExistAndChanged
|
||||
}
|
||||
}
|
||||
|
||||
// The bucket is not full and append it.
|
||||
if _len := len(b.Nodes); _len < b.table.k {
|
||||
b.UpdateLastChangedTime(now)
|
||||
b.Nodes = append(b.Nodes, newWrappedNode(b, n, now))
|
||||
status = NodeAdded
|
||||
return
|
||||
}
|
||||
|
||||
// The bucket is full and remove the bad node, or discard the current node.
|
||||
for i, node := range b.Nodes {
|
||||
if node.IsBad(now) {
|
||||
b.UpdateLastChangedTime(now)
|
||||
copy(b.Nodes[i:], b.Nodes[i+1:])
|
||||
b.Nodes = append(b.Nodes, newWrappedNode(b, n, now))
|
||||
status = NodeAdded
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// It will discard the current node and return false.
|
||||
status = NodeNotAdded
|
||||
return
|
||||
}
|
||||
|
||||
func (b *bucket) delNodeByIndex(index int) {
|
||||
b.Nodes = append(b.Nodes[:index], b.Nodes[index+1:]...)
|
||||
}
|
||||
|
||||
func (b *bucket) CheckAllNodes(now time.Time) {
|
||||
_len := len(b.Nodes)
|
||||
if _len == 0 || now.Sub(b.LastChanged) < dubiousTimeout {
|
||||
return
|
||||
}
|
||||
|
||||
// Check all the dubious nodes.
|
||||
indexes := make([]int, 0, len(b.Nodes))
|
||||
for i, node := range b.Nodes {
|
||||
switch status := node.Status(now); status {
|
||||
case nodeStatusGood:
|
||||
case nodeStatusDubious:
|
||||
// Try to send the PING query to the dubious node to check whether it is alive.
|
||||
if err := b.table.s.Ping(node.Node.Addr.UDPAddr()); err != nil {
|
||||
b.table.s.conf.ErrorLog("fail to ping '%s': %s", node.Node.String(), err)
|
||||
}
|
||||
case nodeStatusBad:
|
||||
// Remove the bad node
|
||||
indexes = append(indexes, i)
|
||||
default:
|
||||
panic(status)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all the bad nodes.
|
||||
for _, index := range indexes {
|
||||
b.delNodeByIndex(index)
|
||||
}
|
||||
}
|
||||
|
||||
func bucketid(ownerid, nid metainfo.Hash) int {
|
||||
if ownerid == nid {
|
||||
return 0
|
||||
}
|
||||
|
||||
var i int
|
||||
var bite byte
|
||||
var bitDiff int
|
||||
var v byte
|
||||
for i, bite = range ownerid {
|
||||
v = bite ^ nid[i]
|
||||
switch {
|
||||
case v&0x80 > 0:
|
||||
bitDiff = 7
|
||||
goto calc
|
||||
case v&0x40 > 0:
|
||||
bitDiff = 6
|
||||
goto calc
|
||||
case v&0x20 > 0:
|
||||
bitDiff = 5
|
||||
goto calc
|
||||
case v&0x10 > 0:
|
||||
bitDiff = 4
|
||||
goto calc
|
||||
case v&0x08 > 0:
|
||||
bitDiff = 3
|
||||
goto calc
|
||||
case v&0x04 > 0:
|
||||
bitDiff = 2
|
||||
goto calc
|
||||
case v&0x02 > 0:
|
||||
bitDiff = 1
|
||||
goto calc
|
||||
case v&0x01 > 0:
|
||||
bitDiff = 0
|
||||
goto calc
|
||||
}
|
||||
}
|
||||
|
||||
calc:
|
||||
return i*8 + (8 - bitDiff)
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
/// Node
|
||||
|
||||
const (
|
||||
nodeStatusGood = iota
|
||||
nodeStatusDubious
|
||||
nodeStatusBad
|
||||
)
|
||||
|
||||
type wrappedNode struct {
|
||||
krpc.Node
|
||||
LastChanged time.Time
|
||||
bkt *bucket
|
||||
}
|
||||
|
||||
func newWrappedNode(bkt *bucket, n krpc.Node, now time.Time) *wrappedNode {
|
||||
return &wrappedNode{bkt: bkt, Node: n, LastChanged: now}
|
||||
}
|
||||
|
||||
func (n *wrappedNode) UpdateLastChangedTime(now ...time.Time) {
|
||||
var _now time.Time
|
||||
if len(now) > 0 {
|
||||
_now = now[0]
|
||||
} else {
|
||||
_now = time.Now()
|
||||
}
|
||||
|
||||
n.LastChanged = _now
|
||||
n.bkt.UpdateLastChangedTime(_now)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *wrappedNode) IsBad(now ...time.Time) bool {
|
||||
var _now time.Time
|
||||
if len(now) > 0 {
|
||||
_now = now[0]
|
||||
} else {
|
||||
_now = time.Now()
|
||||
}
|
||||
return _now.Sub(n.LastChanged) > badTimeout
|
||||
}
|
||||
|
||||
func (n *wrappedNode) IsDubious(now ...time.Time) bool {
|
||||
var _now time.Time
|
||||
if len(now) > 0 {
|
||||
_now = now[0]
|
||||
} else {
|
||||
_now = time.Now()
|
||||
}
|
||||
return _now.Sub(n.LastChanged) > dubiousTimeout
|
||||
}
|
||||
|
||||
func (n *wrappedNode) Status(now time.Time) int {
|
||||
duration := now.Sub(n.LastChanged)
|
||||
switch {
|
||||
case duration > badTimeout:
|
||||
return nodeStatusBad
|
||||
case duration > dubiousTimeout:
|
||||
return nodeStatusDubious
|
||||
default:
|
||||
return nodeStatusGood
|
||||
}
|
||||
}
|
42
dht/routing_table_storage.go
Normal file
42
dht/routing_table_storage.go
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/krpc"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// RoutingTableNode represents the node with last changed time in the routing table.
|
||||
type RoutingTableNode struct {
|
||||
Node krpc.Node
|
||||
LastChanged time.Time
|
||||
}
|
||||
|
||||
// RoutingTableStorage is used to store the nodes in the routing table.
|
||||
type RoutingTableStorage interface {
|
||||
Load(ownid metainfo.Hash, ipv6 bool) (nodes []RoutingTableNode, err error)
|
||||
Dump(ownid metainfo.Hash, nodes []RoutingTableNode, ipv6 bool) (err error)
|
||||
}
|
||||
|
||||
// NewNoopRoutingTableStorage returns a no-op RoutingTableStorage.
|
||||
func NewNoopRoutingTableStorage() RoutingTableStorage { return noopStorage{} }
|
||||
|
||||
type noopStorage struct{}
|
||||
|
||||
func (s noopStorage) Load(metainfo.Hash, bool) (nodes []RoutingTableNode, err error) { return }
|
||||
func (s noopStorage) Dump(metainfo.Hash, []RoutingTableNode, bool) (err error) { return }
|
107
dht/token_manager.go
Normal file
107
dht/token_manager.go
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/utils"
|
||||
)
|
||||
|
||||
// TokenManager is used to manage and validate the token.
|
||||
//
|
||||
// TODO: Should we allocate the different token for each node??
|
||||
type tokenManager struct {
|
||||
lock sync.RWMutex
|
||||
last string
|
||||
new string
|
||||
exit chan struct{}
|
||||
addrs sync.Map
|
||||
}
|
||||
|
||||
func newTokenManager() *tokenManager {
|
||||
token := utils.RandomString(8)
|
||||
return &tokenManager{last: token, new: token, exit: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (tm *tokenManager) updateToken() {
|
||||
token := utils.RandomString(8)
|
||||
tm.lock.Lock()
|
||||
tm.last, tm.new = tm.new, token
|
||||
tm.lock.Unlock()
|
||||
}
|
||||
|
||||
func (tm *tokenManager) clear(now time.Time, expired time.Duration) {
|
||||
tm.addrs.Range(func(k interface{}, v interface{}) bool {
|
||||
if now.Sub(v.(time.Time)) >= expired {
|
||||
tm.addrs.Delete(k)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Start starts the token manager.
|
||||
func (tm *tokenManager) Start(expired time.Duration) {
|
||||
tick1 := time.NewTicker(expired)
|
||||
tick2 := time.NewTicker(expired / 2)
|
||||
defer tick1.Stop()
|
||||
defer tick2.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tm.exit:
|
||||
return
|
||||
case <-tick2.C:
|
||||
tm.updateToken()
|
||||
case now := <-tick1.C:
|
||||
tm.clear(now, expired)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the token manager.
|
||||
func (tm *tokenManager) Stop() {
|
||||
select {
|
||||
case <-tm.exit:
|
||||
default:
|
||||
close(tm.exit)
|
||||
}
|
||||
}
|
||||
|
||||
// Token allocates a token for a node addr and returns the token.
|
||||
func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) {
|
||||
addrs := addr.String()
|
||||
tm.lock.RLock()
|
||||
token = tm.new
|
||||
tm.lock.RUnlock()
|
||||
tm.addrs.Store(addrs, time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
// Check checks whether the token associated with the node addr is valid,
|
||||
// that's, it's not expired.
|
||||
func (tm *tokenManager) Check(addr *net.UDPAddr, token string) (ok bool) {
|
||||
tm.lock.RLock()
|
||||
last, new := tm.last, tm.new
|
||||
tm.lock.RUnlock()
|
||||
|
||||
if last == token || new == token {
|
||||
_, ok = tm.addrs.Load(addr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
152
dht/transaction_manager.go
Normal file
152
dht/transaction_manager.go
Normal file
@ -0,0 +1,152 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/krpc"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
type transaction struct {
|
||||
ID string
|
||||
Query string
|
||||
Arg krpc.QueryArg
|
||||
Addr *net.UDPAddr
|
||||
Time time.Time
|
||||
Depth int
|
||||
|
||||
Visited metainfo.Hashes
|
||||
Callback func(Result)
|
||||
OnError func(t *transaction, code int, reason string)
|
||||
OnTimeout func(t *transaction)
|
||||
OnResponse func(t *transaction, radd *net.UDPAddr, msg krpc.Message)
|
||||
}
|
||||
|
||||
func (t *transaction) Done(r Result) {
|
||||
if t.Callback != nil {
|
||||
r.Addr = t.Addr
|
||||
t.Callback(r)
|
||||
}
|
||||
}
|
||||
|
||||
func noopResponse(*transaction, *net.UDPAddr, krpc.Message) {}
|
||||
func newTransaction(s *Server, a *net.UDPAddr, q string, qa krpc.QueryArg,
|
||||
callback ...func(Result)) *transaction {
|
||||
var cb func(Result)
|
||||
if len(callback) > 0 {
|
||||
cb = callback[0]
|
||||
}
|
||||
|
||||
return &transaction{
|
||||
Addr: a,
|
||||
Query: q,
|
||||
Arg: qa,
|
||||
Callback: cb,
|
||||
OnError: s.onError,
|
||||
OnTimeout: s.onTimeout,
|
||||
OnResponse: noopResponse,
|
||||
Time: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
type transactionkey struct {
|
||||
id string
|
||||
addr string
|
||||
}
|
||||
|
||||
type transactionManager struct {
|
||||
lock sync.Mutex
|
||||
exit chan struct{}
|
||||
trans map[transactionkey]*transaction
|
||||
tid uint32
|
||||
}
|
||||
|
||||
func newTransactionManager() *transactionManager {
|
||||
return &transactionManager{
|
||||
exit: make(chan struct{}),
|
||||
trans: make(map[transactionkey]*transaction, 128),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the transaction manager.
|
||||
func (tm *transactionManager) Start(s *Server, interval time.Duration) {
|
||||
tick := time.NewTicker(interval)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tm.exit:
|
||||
return
|
||||
case now := <-tick.C:
|
||||
tm.lock.Lock()
|
||||
for k, t := range tm.trans {
|
||||
if now.Sub(t.Time) > interval {
|
||||
delete(tm.trans, k)
|
||||
t.OnTimeout(t)
|
||||
}
|
||||
}
|
||||
tm.lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the transaction manager.
|
||||
func (tm *transactionManager) Stop() {
|
||||
select {
|
||||
case <-tm.exit:
|
||||
default:
|
||||
close(tm.exit)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTransactionID returns a new transaction id.
|
||||
func (tm *transactionManager) GetTransactionID() string {
|
||||
return strconv.FormatUint(uint64(atomic.AddUint32(&tm.tid, 1)), 36)
|
||||
}
|
||||
|
||||
// AddTransaction adds the new transaction.
|
||||
func (tm *transactionManager) AddTransaction(t *transaction) {
|
||||
key := transactionkey{id: t.ID, addr: t.Addr.String()}
|
||||
tm.lock.Lock()
|
||||
tm.trans[key] = t
|
||||
tm.lock.Unlock()
|
||||
}
|
||||
|
||||
// DeleteTransaction deletes the transaction.
|
||||
func (tm *transactionManager) DeleteTransaction(t *transaction) {
|
||||
key := transactionkey{id: t.ID, addr: t.Addr.String()}
|
||||
tm.lock.Lock()
|
||||
delete(tm.trans, key)
|
||||
tm.lock.Unlock()
|
||||
}
|
||||
|
||||
// PopTransaction deletes and returns the transaction by the transaction id
|
||||
// and the peer address.
|
||||
//
|
||||
// Return nil if there is no the transaction.
|
||||
func (tm *transactionManager) PopTransaction(tid string, addr *net.UDPAddr) (t *transaction) {
|
||||
key := transactionkey{id: tid, addr: addr.String()}
|
||||
tm.lock.Lock()
|
||||
if t = tm.trans[key]; t != nil {
|
||||
delete(tm.trans, key)
|
||||
}
|
||||
tm.lock.Unlock()
|
||||
return
|
||||
}
|
315
downloader/torrent_info.go
Normal file
315
downloader/torrent_info.go
Normal file
@ -0,0 +1,315 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package downloader is used to download the torrent or the real file
|
||||
// from the peer node by the peer wire protocol.
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
"github.com/xgfone/bt/peerprotocol"
|
||||
)
|
||||
|
||||
const peerBlockSize = 16384 // 16KiB.
|
||||
|
||||
// Request is used to send a download request.
|
||||
type request struct {
|
||||
Host string
|
||||
Port uint16
|
||||
PeerID metainfo.Hash
|
||||
InfoHash metainfo.Hash
|
||||
}
|
||||
|
||||
// TorrentResponse represents a torrent info response.
|
||||
type TorrentResponse struct {
|
||||
Host string // Which host the torrent is downloaded from
|
||||
Port uint16 // Which port the torrent is downloaded from
|
||||
PeerID metainfo.Hash // ID of the peer where torrent is downloaded from
|
||||
InfoHash metainfo.Hash // The SHA-1 hash of the torrent to be downloaded
|
||||
InfoBytes []byte // The content of the info part in the torrent
|
||||
}
|
||||
|
||||
// TorrentDownloaderConfig is used to configure the TorrentDownloader.
|
||||
type TorrentDownloaderConfig struct {
|
||||
// ID is the id of the downloader peer node.
|
||||
//
|
||||
// The default is a random id.
|
||||
ID metainfo.Hash
|
||||
|
||||
// WorkerNum is the number of the worker downloading the torrent concurrently.
|
||||
//
|
||||
// The default is 128.
|
||||
WorkerNum int
|
||||
|
||||
// ErrorLog is used to log the error.
|
||||
//
|
||||
// The default is log.Printf.
|
||||
ErrorLog func(format string, args ...interface{})
|
||||
}
|
||||
|
||||
func (c *TorrentDownloaderConfig) set(conf ...TorrentDownloaderConfig) {
|
||||
if len(conf) > 0 {
|
||||
*c = conf[0]
|
||||
}
|
||||
|
||||
if c.WorkerNum <= 0 {
|
||||
c.WorkerNum = 128
|
||||
}
|
||||
if c.ID.IsZero() {
|
||||
c.ID = metainfo.NewRandomHash()
|
||||
}
|
||||
if c.ErrorLog == nil {
|
||||
c.ErrorLog = log.Printf
|
||||
}
|
||||
}
|
||||
|
||||
// TorrentDownloader is used to download the torrent file from the peer.
|
||||
type TorrentDownloader struct {
|
||||
conf TorrentDownloaderConfig
|
||||
exit chan struct{}
|
||||
requests chan request
|
||||
responses chan TorrentResponse
|
||||
|
||||
ondht func(string, uint16)
|
||||
ebits peerprotocol.ExtensionBits
|
||||
ehmsg peerprotocol.ExtendedHandshakeMsg
|
||||
}
|
||||
|
||||
// NewTorrentDownloader returns a new TorrentDownloader.
|
||||
//
|
||||
// If id is ZERO, it is reset to a random id. workerNum is 128 by default.
|
||||
func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader {
|
||||
var conf TorrentDownloaderConfig
|
||||
conf.set(c...)
|
||||
|
||||
d := &TorrentDownloader{
|
||||
conf: conf,
|
||||
exit: make(chan struct{}),
|
||||
requests: make(chan request, conf.WorkerNum),
|
||||
responses: make(chan TorrentResponse, 1024),
|
||||
|
||||
ehmsg: peerprotocol.ExtendedHandshakeMsg{
|
||||
M: map[string]uint8{peerprotocol.ExtendedMessageNameMetadata: 1},
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < conf.WorkerNum; i++ {
|
||||
go d.worker()
|
||||
}
|
||||
|
||||
d.ebits.Set(peerprotocol.ExtensionBitExtended)
|
||||
return d
|
||||
}
|
||||
|
||||
// Request submits a download request.
|
||||
func (d *TorrentDownloader) Request(host string, port uint16, infohash metainfo.Hash) {
|
||||
d.requests <- request{Host: host, Port: port, InfoHash: infohash}
|
||||
}
|
||||
|
||||
// Response returns a response channel to get the downloaded torrent info.
|
||||
func (d *TorrentDownloader) Response() <-chan TorrentResponse {
|
||||
return d.responses
|
||||
}
|
||||
|
||||
// Close closes the downloader and releases the underlying resources.
|
||||
func (d *TorrentDownloader) Close() {
|
||||
select {
|
||||
case <-d.exit:
|
||||
default:
|
||||
close(d.exit)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDHTNode sets the DHT node callback, which will enable DHT extenstion bit.
|
||||
//
|
||||
// In the callback function, you maybe ping it like DHT by UDP.
|
||||
// If the node responds, you can add the node in DHT routing table.
|
||||
//
|
||||
// BEP 5
|
||||
func (d *TorrentDownloader) OnDHTNode(cb func(host string, port uint16)) {
|
||||
d.ondht = cb
|
||||
if cb == nil {
|
||||
d.ebits.Unset(peerprotocol.ExtensionBitDHT)
|
||||
} else {
|
||||
d.ebits.Set(peerprotocol.ExtensionBitDHT)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) worker() {
|
||||
for {
|
||||
select {
|
||||
case <-d.exit:
|
||||
return
|
||||
case r := <-d.requests:
|
||||
if err := d.download(r.Host, r.Port, r.PeerID, r.InfoHash); err != nil {
|
||||
d.conf.ErrorLog("fail to download the torrent '%s': %s",
|
||||
r.InfoHash.HexString(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) download(host string, port uint16,
|
||||
peerID, infohash metainfo.Hash) (err error) {
|
||||
addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10))
|
||||
conn, err := peerprotocol.NewPeerConnByDial(d.conf.ID, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to dial to '%s': %s", addr, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.ExtensionBits = d.ebits
|
||||
rmsg, err := conn.Handshake(infohash)
|
||||
if err != nil || rmsg.InfoHash != infohash || !rmsg.IsSupportExtended() {
|
||||
return
|
||||
} else if !peerID.IsZero() && peerID != rmsg.PeerID {
|
||||
return fmt.Errorf("inconsistent peer id '%s'", rmsg.PeerID.HexString())
|
||||
}
|
||||
|
||||
if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var pieces [][]byte
|
||||
var piecesNum int
|
||||
var metadataSize int
|
||||
var utmetadataID uint8
|
||||
var msg peerprotocol.Message
|
||||
|
||||
for {
|
||||
if msg, err = conn.ReadMsg(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-d.exit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if msg.Keepalive {
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case peerprotocol.Extended:
|
||||
case peerprotocol.Port:
|
||||
if d.ondht != nil {
|
||||
d.ondht(host, msg.Port)
|
||||
}
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg.ExtendedID {
|
||||
case peerprotocol.ExtendedIDHandshake:
|
||||
if utmetadataID > 0 {
|
||||
return fmt.Errorf("rehandshake from the peer '%s'", conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
var ehmsg peerprotocol.ExtendedHandshakeMsg
|
||||
if err = ehmsg.Decode(msg.ExtendedPayload); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
utmetadataID = ehmsg.M[peerprotocol.ExtendedMessageNameMetadata]
|
||||
if utmetadataID == 0 {
|
||||
return errors.New(`the peer does not support "ut_metadata"`)
|
||||
}
|
||||
|
||||
metadataSize = ehmsg.MetadataSize
|
||||
piecesNum = metadataSize / peerBlockSize
|
||||
if metadataSize%peerBlockSize != 0 {
|
||||
piecesNum++
|
||||
}
|
||||
|
||||
pieces = make([][]byte, piecesNum)
|
||||
go d.requestPieces(conn, utmetadataID, piecesNum)
|
||||
case 1:
|
||||
if pieces == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var utmsg peerprotocol.UtMetadataExtendedMsg
|
||||
if err = utmsg.DecodeFromPayload(msg.ExtendedPayload); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if utmsg.MsgType != peerprotocol.UtMetadataExtendedMsgTypeData {
|
||||
continue
|
||||
}
|
||||
|
||||
pieceLen := len(utmsg.Data)
|
||||
if (utmsg.Piece != piecesNum-1 && pieceLen != peerBlockSize) ||
|
||||
(utmsg.Piece == piecesNum-1 && pieceLen != metadataSize%peerBlockSize) {
|
||||
return
|
||||
}
|
||||
pieces[utmsg.Piece] = utmsg.Data
|
||||
|
||||
finish := true
|
||||
for _, piece := range pieces {
|
||||
if len(piece) == 0 {
|
||||
finish = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if finish {
|
||||
metadataInfo := bytes.Join(pieces, nil)
|
||||
if infohash == metainfo.Hash(sha1.Sum(metadataInfo)) {
|
||||
d.responses <- TorrentResponse{
|
||||
Host: host,
|
||||
Port: port,
|
||||
PeerID: rmsg.PeerID,
|
||||
InfoHash: infohash,
|
||||
InfoBytes: metadataInfo,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *TorrentDownloader) requestPieces(conn *peerprotocol.PeerConn, utMetadataID uint8, piecesNum int) {
|
||||
for i := 0; i < piecesNum; i++ {
|
||||
payload, err := peerprotocol.UtMetadataExtendedMsg{
|
||||
MsgType: peerprotocol.UtMetadataExtendedMsgTypeRequest,
|
||||
Piece: i,
|
||||
}.EncodeToBytes()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
msg := peerprotocol.Message{
|
||||
Type: peerprotocol.Extended,
|
||||
ExtendedID: utMetadataID,
|
||||
ExtendedPayload: payload,
|
||||
}
|
||||
|
||||
if err = conn.WriteMsg(msg); err != nil {
|
||||
d.conf.ErrorLog("fail to send the ut_metadata request: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
16
krpc/doc.go
Normal file
16
krpc/doc.go
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package krpc supplies the KRPC message used by DHT.
|
||||
package krpc
|
440
krpc/message.go
Normal file
440
krpc/message.go
Normal file
@ -0,0 +1,440 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package krpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// Predefine some error code.
|
||||
const (
|
||||
// BEP 5
|
||||
ErrorCodeGeneric = 201
|
||||
ErrorCodeServer = 202
|
||||
ErrorCodeProtocol = 203
|
||||
ErrorCodeMethodUnknown = 204
|
||||
|
||||
// BEP 44
|
||||
ErrorCodeMessageValueFieldTooBig = 205
|
||||
ErrorCodeInvalidSignature = 206
|
||||
ErrorCodeSaltFieldTooBig = 207
|
||||
ErrorCodeCasHashMismatched = 301
|
||||
ErrorCodeSequenceNumberLessThanCurrent = 302
|
||||
)
|
||||
|
||||
// Message represents messages that nodes in the network send to each other
|
||||
// as specified by the KRPC protocol, which are also referred to as the KRPC
|
||||
// messages.
|
||||
//
|
||||
// There are three types of messages: QUERY, RESPONSE, ERROR
|
||||
// The message is a dictonary that is then "bencoded"
|
||||
// (serialization & compression format adopted by the BitTorrent)
|
||||
// and sent via the UDP connection to peers.
|
||||
//
|
||||
// A KRPC message is a single dictionary with two keys common to every message
|
||||
// and additional keys depending on the type of message. Every message has a key
|
||||
// "t" with a string value representing a transaction ID. This transaction ID
|
||||
// is generated by the querying node and is echoed in the response, so responses
|
||||
// may be correlated with multiple queries to the same node. The transaction ID
|
||||
// should be encoded as a short string of binary numbers, typically 2 characters
|
||||
// are enough as they cover 2^16 outstanding queries. The other key contained
|
||||
// in every KRPC message is "y" with a single character value describing
|
||||
// the type of message. The value of the "y" key is one of "q" for query,
|
||||
// "r" for response, or "e" for error.
|
||||
type Message struct {
|
||||
T string `bencode:"t"` // required: transaction ID
|
||||
Y string `bencode:"y"` // required: type of the message: q for QUERY, r for RESPONSE, e for ERROR
|
||||
Q string `bencode:"q,omitempty"` // Query method (one of 4: "ping", "find_node", "get_peers", "announce_peer")
|
||||
A QueryArg `bencode:"a,omitempty"` // named arguments sent with a query
|
||||
R ResponseResult `bencode:"r,omitempty"` // RESPONSE type only
|
||||
E Error `bencode:"e,omitempty"` // ERROR type only
|
||||
|
||||
RO bool `bencode:"ro,omitempty"` // BEP 43: ReadOnly
|
||||
}
|
||||
|
||||
// NewQueryMsg return a QUERY message.
|
||||
func NewQueryMsg(tid, method string, arg QueryArg) Message {
|
||||
return Message{T: tid, Y: "q", Q: method, A: arg}
|
||||
}
|
||||
|
||||
// NewResponseMsg return a RESPONSE message.
|
||||
func NewResponseMsg(tid string, data ResponseResult) Message {
|
||||
return Message{T: tid, Y: "r", R: data}
|
||||
}
|
||||
|
||||
// NewErrorMsg return a ERROR message.
|
||||
func NewErrorMsg(tid string, code int, reason string) Message {
|
||||
return Message{T: tid, Y: "e", E: Error{Code: code, Reason: reason}}
|
||||
}
|
||||
|
||||
// IsQuery reports whether the message is an QUERY.
|
||||
func (m Message) IsQuery() bool {
|
||||
return m.Y == "q"
|
||||
}
|
||||
|
||||
// IsResponse reports whether the message is an RESPONSE.
|
||||
func (m Message) IsResponse() bool {
|
||||
return m.Y == "r"
|
||||
}
|
||||
|
||||
// IsError reports whether the message is an ERROR.
|
||||
func (m Message) IsError() bool {
|
||||
return m.Y == "e"
|
||||
}
|
||||
|
||||
// RID returns the value named "id" in "r".
|
||||
//
|
||||
// Return "" instead if no "id".
|
||||
func (m Message) RID() metainfo.Hash {
|
||||
return m.R.ID
|
||||
}
|
||||
|
||||
// QID returns the value named "id" in "a", that's, the query arguments.
|
||||
//
|
||||
// Return "" instead if no "id".
|
||||
func (m Message) QID() metainfo.Hash {
|
||||
return m.A.ID
|
||||
}
|
||||
|
||||
// ID returns the QID or RID.
|
||||
func (m Message) ID() metainfo.Hash {
|
||||
switch m.Y {
|
||||
case "q":
|
||||
return m.QID()
|
||||
case "r":
|
||||
return m.RID()
|
||||
default:
|
||||
panic(fmt.Errorf("unknown message type '%s'", m.Y))
|
||||
}
|
||||
}
|
||||
|
||||
// Error represents a response error.
|
||||
type Error struct {
|
||||
Code int
|
||||
Reason string
|
||||
}
|
||||
|
||||
// NewError returns a new Error.
|
||||
func NewError(code int, reason string) Error {
|
||||
return Error{Code: code, Reason: reason}
|
||||
}
|
||||
|
||||
func (e *Error) decode(vs []interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("unpacking %#v: %v", vs, r)
|
||||
}
|
||||
}()
|
||||
|
||||
e.Code = int(vs[0].(int64))
|
||||
e.Reason = vs[1].(string)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (e *Error) UnmarshalBencode(b []byte) (err error) {
|
||||
var iface interface{}
|
||||
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := iface.(type) {
|
||||
case []interface{}:
|
||||
err = e.decode(v)
|
||||
case string:
|
||||
e.Reason = v
|
||||
default:
|
||||
err = fmt.Errorf(`KRPC error bencode value has unexpected type: %T`, v)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (e Error) MarshalBencode() (ret []byte, err error) {
|
||||
if e.Code == 0 && e.Reason == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(32)
|
||||
err = bencode.NewEncoder(buf).Encode([]interface{}{e.Code, e.Reason})
|
||||
if err != nil {
|
||||
ret = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
return fmt.Sprintf("KRPC error %d: %s", e.Code, e.Reason)
|
||||
}
|
||||
|
||||
// Want represents the type of nodes that the request wants.
|
||||
//
|
||||
// BEP 32
|
||||
type Want string
|
||||
|
||||
// Predefine some wants.
|
||||
//
|
||||
// BEP 32
|
||||
const (
|
||||
WantNodes Want = "n4"
|
||||
WantNodes6 Want = "n6"
|
||||
)
|
||||
|
||||
// QueryArg represents the arguments used by the QUERY message.
|
||||
type QueryArg struct {
|
||||
// ID is used to identify a querying node.
|
||||
ID metainfo.Hash `bencode:"id"` // BEP 5
|
||||
|
||||
// Target is used to identify the node sought by the queryer.
|
||||
//
|
||||
// find_node
|
||||
Target metainfo.Hash `bencode:"target,omitempty"` // BEP 5
|
||||
|
||||
// InfoHash is the infohash of the torrent file.
|
||||
//
|
||||
// get_peers, announce_peer
|
||||
InfoHash metainfo.Hash `bencode:"info_hash,omitempty"` // BEP 5
|
||||
|
||||
// Port is the port on where sender is listening to allow others
|
||||
// to download the torrent.
|
||||
//
|
||||
// announce_peer
|
||||
Port uint16 `bencode:"port,omitempty"` // BEP 5
|
||||
|
||||
// Token is the received one from an earlier "get_peers" query.
|
||||
//
|
||||
// announce_peer
|
||||
Token string `bencode:"token,omitempty"` // BEP 5
|
||||
|
||||
// ImpliedPort is used by the senders to apparent DHT port
|
||||
// to improve NAT support.
|
||||
//
|
||||
// announce_peer
|
||||
ImpliedPort bool `bencode:"implied_port,omitempty"` // BEP 5
|
||||
|
||||
// Wants is only used to represent to expect "nodes" for "n4"
|
||||
// or "nodes6" for "n6".
|
||||
//
|
||||
// Notice: It only governs the presence of the "nodes" and "nodes6"
|
||||
// parameters, not the interpretation of "values".
|
||||
//
|
||||
// find_node, get_peers
|
||||
Wants []Want `bencode:"want,omitempty"` // BEP 32
|
||||
}
|
||||
|
||||
// ContainsWant reports whether the request contains the given Want.
|
||||
func (a QueryArg) ContainsWant(w Want) bool {
|
||||
for _, want := range a.Wants {
|
||||
if want == w {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPort returns the real port of the peer.
|
||||
func (a QueryArg) GetPort(port int) uint16 {
|
||||
if a.ImpliedPort {
|
||||
return uint16(port)
|
||||
}
|
||||
return a.Port
|
||||
}
|
||||
|
||||
// ResponseResult represents the results used by the RESPONSE message.
|
||||
type ResponseResult struct {
|
||||
// ID is used to indentify the queried node, that's, the response node.
|
||||
ID metainfo.Hash `bencode:"id"` // BEP 5
|
||||
|
||||
// Nodes is a string containing the compact node information for the list
|
||||
// of the ipv4 target node, or the K(8) closest good nodes in routing table
|
||||
// of the requested ipv4 target.
|
||||
//
|
||||
// find_node
|
||||
Nodes CompactIPv4Node `bencode:"nodes,omitempty"` // BEP 5
|
||||
|
||||
// Nodes6 is a string containing the compact node information for the list
|
||||
// of the ipv6 target node, or the K(8) closest good nodes in routing table
|
||||
// of the requested ipv6 target.
|
||||
//
|
||||
// find_node
|
||||
Nodes6 CompactIPv6Node `bencode:"nodes6,omitempty"` // BEP 32
|
||||
|
||||
// Token is used for future "announce_peer".
|
||||
//
|
||||
// get_peers
|
||||
Token string `bencode:"token,omitempty"` // BEP 5
|
||||
|
||||
// Values is a list of the torrent peers.
|
||||
//
|
||||
// get_peers
|
||||
Values CompactAddresses `bencode:"values,omitempty"` // BEP 5
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
// CompactAddresses represents a group of the compact addresses.
|
||||
type CompactAddresses []metainfo.Address
|
||||
|
||||
// MarshalBinary implements the interface binary.BinaryMarshaler.
|
||||
func (cas CompactAddresses) MarshalBinary() ([]byte, error) {
|
||||
ss := make([]string, len(cas))
|
||||
for i, addr := range cas {
|
||||
data, err := addr.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ss[i] = string(data)
|
||||
}
|
||||
return bencode.EncodeBytes(ss)
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
|
||||
func (cas *CompactAddresses) UnmarshalBinary(b []byte) (err error) {
|
||||
var ss []string
|
||||
if err = bencode.DecodeBytes(b, &ss); err == nil {
|
||||
addrs := make(CompactAddresses, len(ss))
|
||||
for i, s := range ss {
|
||||
var addr metainfo.Address
|
||||
if err = addr.UnmarshalBinary([]byte(s)); err != nil {
|
||||
return
|
||||
}
|
||||
addrs[i] = addr
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
// CompactIPv4Node is a set of IPv4 Nodes.
|
||||
type CompactIPv4Node []Node
|
||||
|
||||
// MarshalBinary implements the interface binary.BinaryMarshaler.
|
||||
func (cn CompactIPv4Node) MarshalBinary() ([]byte, error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(26 * len(cn))
|
||||
for _, ni := range cn {
|
||||
if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 {
|
||||
continue
|
||||
}
|
||||
if n, err := ni.WriteBinary(buf); err != nil {
|
||||
return nil, err
|
||||
} else if n != 26 {
|
||||
panic(fmt.Errorf("CompactIPv4NodeInfo: the invalid NodeInfo length '%d'", n))
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
|
||||
func (cn *CompactIPv4Node) UnmarshalBinary(b []byte) (err error) {
|
||||
_len := len(b)
|
||||
if _len%26 != 0 {
|
||||
return fmt.Errorf("CompactIPv4NodeInfo: invalid bytes length '%d'", _len)
|
||||
}
|
||||
|
||||
nis := make([]Node, 0, _len/26)
|
||||
for i := 0; i < _len; i += 26 {
|
||||
var ni Node
|
||||
if err = ni.UnmarshalBinary(b[i : i+26]); err != nil {
|
||||
return
|
||||
}
|
||||
nis = append(nis, ni)
|
||||
}
|
||||
|
||||
*cn = nis
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (cn CompactIPv4Node) MarshalBencode() (b []byte, err error) {
|
||||
if b, err = cn.MarshalBinary(); err == nil {
|
||||
b, err = bencode.EncodeBytes(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (cn *CompactIPv4Node) UnmarshalBencode(b []byte) (err error) {
|
||||
var s string
|
||||
if err = bencode.DecodeBytes(b, &s); err == nil {
|
||||
err = cn.UnmarshalBinary([]byte(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
// CompactIPv6Node is a set of IPv6 Nodes.
|
||||
type CompactIPv6Node []Node
|
||||
|
||||
// MarshalBinary implements the interface binary.BinaryMarshaler.
|
||||
func (cn CompactIPv6Node) MarshalBinary() ([]byte, error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(38 * len(cn))
|
||||
for _, ni := range cn {
|
||||
ni.Addr.IP = ni.Addr.IP.To16()
|
||||
if n, err := ni.WriteBinary(buf); err != nil {
|
||||
return nil, err
|
||||
} else if n != 38 {
|
||||
panic(fmt.Errorf("CompactIPv6NodeInfo: the invalid NodeInfo length '%d'", n))
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
|
||||
func (cn *CompactIPv6Node) UnmarshalBinary(b []byte) (err error) {
|
||||
_len := len(b)
|
||||
if _len%38 != 0 {
|
||||
return fmt.Errorf("CompactIPv6NodeInfo: invalid bytes length '%d'", _len)
|
||||
}
|
||||
|
||||
nis := make([]Node, 0, _len/38)
|
||||
for i := 0; i < _len; i += 38 {
|
||||
var ni Node
|
||||
if err = ni.UnmarshalBinary(b[i : i+38]); err != nil {
|
||||
return
|
||||
}
|
||||
nis = append(nis, ni)
|
||||
}
|
||||
|
||||
*cn = nis
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (cn CompactIPv6Node) MarshalBencode() (b []byte, err error) {
|
||||
if b, err = cn.MarshalBinary(); err == nil {
|
||||
b, err = bencode.EncodeBytes(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (cn *CompactIPv6Node) UnmarshalBencode(b []byte) (err error) {
|
||||
var s string
|
||||
if err = bencode.DecodeBytes(b, &s); err == nil {
|
||||
err = cn.UnmarshalBinary([]byte(s))
|
||||
}
|
||||
return
|
||||
}
|
41
krpc/message_test.go
Normal file
41
krpc/message_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package krpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
)
|
||||
|
||||
func TestMessage(t *testing.T) {
|
||||
s, err := bencode.EncodeString(Message{RO: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var ms map[string]interface{}
|
||||
if err := bencode.DecodeString(s, &ms); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(ms) != 3 {
|
||||
t.Fatal()
|
||||
} else if v, ok := ms["t"]; !ok || v != "" {
|
||||
t.Fatal()
|
||||
} else if v, ok := ms["y"]; !ok || v != "" {
|
||||
t.Fatal()
|
||||
} else if v, ok := ms["ro"].(int64); !ok || v != 1 {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
84
krpc/node.go
Normal file
84
krpc/node.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package krpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// Node represents a node information.
|
||||
type Node struct {
|
||||
ID metainfo.Hash
|
||||
Addr metainfo.Address
|
||||
}
|
||||
|
||||
// NewNode returns a new Node.
|
||||
func NewNode(id metainfo.Hash, ip net.IP, port int) Node {
|
||||
return Node{ID: id, Addr: metainfo.NewAddress(ip, uint16(port))}
|
||||
}
|
||||
|
||||
// NewNodeByUDPAddr returns a new Node with the id and the UDP address.
|
||||
func NewNodeByUDPAddr(id metainfo.Hash, addr *net.UDPAddr) (n Node) {
|
||||
n.ID = id
|
||||
n.Addr.FromUDPAddr(addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (n Node) String() string {
|
||||
return fmt.Sprintf("Node<%x@%s>", n.ID, n.Addr)
|
||||
}
|
||||
|
||||
// Equal reports whether n is equal to o.
|
||||
func (n Node) Equal(o Node) bool {
|
||||
return n.ID == o.ID && n.Addr.Equal(o.Addr)
|
||||
}
|
||||
|
||||
// WriteBinary is the same as MarshalBinary, but writes the result into w
|
||||
// instead of returning.
|
||||
func (n Node) WriteBinary(w io.Writer) (m int, err error) {
|
||||
var n1, n2 int
|
||||
if n1, err = w.Write(n.ID[:]); err == nil {
|
||||
m = n1
|
||||
if n2, err = n.Addr.WriteBinary(w); err == nil {
|
||||
m += n2
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBinary implements the interface binary.BinaryMarshaler.
|
||||
func (n Node) MarshalBinary() (data []byte, err error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(48)
|
||||
if _, err = n.WriteBinary(buf); err == nil {
|
||||
data = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
|
||||
func (n *Node) UnmarshalBinary(b []byte) error {
|
||||
if len(b) < 26 {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
copy(n.ID[:], b[:20])
|
||||
return n.Addr.UnmarshalBinary(b[20:])
|
||||
}
|
342
metainfo/address.go
Normal file
342
metainfo/address.go
Normal file
@ -0,0 +1,342 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
)
|
||||
|
||||
// ErrInvalidAddr is returned when the compact address is invalid.
|
||||
var ErrInvalidAddr = fmt.Errorf("invalid compact information of ip and port")
|
||||
|
||||
// Address represents a client/server listening on a UDP port implementing
|
||||
// the DHT protocol.
|
||||
type Address struct {
|
||||
IP net.IP // For IPv4, its length must be 4.
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// NewAddress returns a new Address.
|
||||
func NewAddress(ip net.IP, port uint16) Address {
|
||||
if ipv4 := ip.To4(); len(ipv4) > 0 {
|
||||
ip = ipv4
|
||||
}
|
||||
return Address{IP: ip, Port: port}
|
||||
}
|
||||
|
||||
// NewAddressFromString returns a new Address by the address string.
|
||||
func NewAddressFromString(s string) (addr Address, err error) {
|
||||
err = addr.FromString(s)
|
||||
return
|
||||
}
|
||||
|
||||
// NewAddressesFromString returns a list of Addresses by the address string.
|
||||
func NewAddressesFromString(s string) (addrs []Address, err error) {
|
||||
shost, sport, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address '%s': %s", s, err)
|
||||
}
|
||||
|
||||
var port uint16
|
||||
if sport != "" {
|
||||
v, err := strconv.ParseUint(sport, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address '%s': %s", s, err)
|
||||
}
|
||||
port = uint16(v)
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(shost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to lookup the domain '%s': %s", shost, err)
|
||||
}
|
||||
|
||||
addrs = make([]Address, len(ips))
|
||||
for i, ip := range ips {
|
||||
addrs[i] = Address{IP: ip, Port: port}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FromString parses and sets the ip from the string addr.
|
||||
func (a *Address) FromString(addr string) (err error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid address '%s': %s", addr, err)
|
||||
}
|
||||
|
||||
if port != "" {
|
||||
v, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid address '%s': %s", addr, err)
|
||||
}
|
||||
a.Port = uint16(v)
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to lookup the domain '%s': %s", host, err)
|
||||
} else if len(ips) == 0 {
|
||||
return fmt.Errorf("the domain '%s' has no ips", host)
|
||||
}
|
||||
|
||||
a.IP = ips[0]
|
||||
if ip := a.IP.To4(); len(ip) > 0 {
|
||||
a.IP = ip
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FromUDPAddr sets the ip from net.UDPAddr.
|
||||
func (a *Address) FromUDPAddr(ua *net.UDPAddr) {
|
||||
a.Port = uint16(ua.Port)
|
||||
a.IP = ua.IP
|
||||
if ipv4 := a.IP.To4(); len(ipv4) > 0 {
|
||||
a.IP = ipv4
|
||||
}
|
||||
}
|
||||
|
||||
// UDPAddr creates a new net.UDPAddr.
|
||||
func (a Address) UDPAddr() *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: a.IP,
|
||||
Port: int(a.Port),
|
||||
}
|
||||
}
|
||||
|
||||
func (a Address) String() string {
|
||||
if a.Port == 0 {
|
||||
return a.IP.String()
|
||||
}
|
||||
return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10))
|
||||
}
|
||||
|
||||
// Equal reports whether n is equal to o, which is equal to
|
||||
// n.HasIPAndPort(o.IP, o.Port)
|
||||
func (a Address) Equal(o Address) bool {
|
||||
return a.Port == o.Port && a.IP.Equal(o.IP)
|
||||
}
|
||||
|
||||
// HasIPAndPort reports whether the current node has the ip and the port.
|
||||
func (a Address) HasIPAndPort(ip net.IP, port uint16) bool {
|
||||
return port == a.Port && a.IP.Equal(ip)
|
||||
}
|
||||
|
||||
// WriteBinary is the same as MarshalBinary, but writes the result into w
|
||||
// instead of returning.
|
||||
func (a Address) WriteBinary(w io.Writer) (m int, err error) {
|
||||
if m, err = w.Write(a.IP); err == nil {
|
||||
if err = binary.Write(w, binary.BigEndian, a.Port); err == nil {
|
||||
m += 2
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
|
||||
func (a *Address) UnmarshalBinary(b []byte) (err error) {
|
||||
_len := len(b) - 2
|
||||
switch _len {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
default:
|
||||
return ErrInvalidAddr
|
||||
}
|
||||
|
||||
a.IP = make(net.IP, _len)
|
||||
copy(a.IP, b[:_len])
|
||||
a.Port = binary.BigEndian.Uint16(b[_len:])
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBinary implements the interface binary.BinaryMarshaler.
|
||||
func (a Address) MarshalBinary() (data []byte, err error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(20)
|
||||
if _, err = a.WriteBinary(buf); err == nil {
|
||||
data = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Address) decode(vs []interface{}) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = e.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
host := vs[0].(string)
|
||||
if a.IP = net.ParseIP(host); len(a.IP) == 0 {
|
||||
return ErrInvalidAddr
|
||||
} else if ip := a.IP.To4(); len(ip) > 0 {
|
||||
a.IP = ip
|
||||
}
|
||||
|
||||
a.Port = uint16(vs[1].(int64))
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (a *Address) UnmarshalBencode(b []byte) (err error) {
|
||||
var iface interface{}
|
||||
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := iface.(type) {
|
||||
case string:
|
||||
err = a.FromString(v)
|
||||
case []interface{}:
|
||||
err = a.decode(v)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported type: %T", iface)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (a Address) MarshalBencode() (b []byte, err error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(32)
|
||||
err = bencode.NewEncoder(buf).Encode([]interface{}{a.IP.String(), a.Port})
|
||||
if err == nil {
|
||||
b = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
// HostAddress is the same as the Address, but the host part may be
|
||||
// either a domain or a ip.
|
||||
type HostAddress struct {
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// NewHostAddress returns a new host addrress.
|
||||
func NewHostAddress(host string, port uint16) HostAddress {
|
||||
return HostAddress{Host: host, Port: port}
|
||||
}
|
||||
|
||||
// NewHostAddressFromString returns a new host address by the string.
|
||||
func NewHostAddressFromString(s string) (addr HostAddress, err error) {
|
||||
err = addr.FromString(s)
|
||||
return
|
||||
}
|
||||
|
||||
// FromString parses and sets the host from the string addr.
|
||||
func (a *HostAddress) FromString(addr string) (err error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid address '%s': %s", addr, err)
|
||||
} else if host == "" {
|
||||
return fmt.Errorf("invalid address '%s': missing host", addr)
|
||||
}
|
||||
|
||||
if port != "" {
|
||||
v, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid address '%s': %s", addr, err)
|
||||
}
|
||||
a.Port = uint16(v)
|
||||
}
|
||||
|
||||
a.Host = host
|
||||
return
|
||||
}
|
||||
|
||||
func (a HostAddress) String() string {
|
||||
if a.Port == 0 {
|
||||
return a.Host
|
||||
}
|
||||
return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10))
|
||||
}
|
||||
|
||||
// Addresses parses the host address to a list of Addresses.
|
||||
func (a HostAddress) Addresses() (addrs []Address, err error) {
|
||||
if ip := net.ParseIP(a.Host); len(ip) > 0 {
|
||||
return []Address{NewAddress(ip, a.Port)}, nil
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(a.Host)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("fail to lookup the domain '%s': %s", a.Host, err)
|
||||
} else {
|
||||
addrs = make([]Address, len(ips))
|
||||
for i, ip := range ips {
|
||||
addrs[i] = NewAddress(ip, a.Port)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Equal reports whether a is equal to o.
|
||||
func (a HostAddress) Equal(o HostAddress) bool {
|
||||
return a.Port == o.Port && a.Host == o.Host
|
||||
}
|
||||
|
||||
func (a *HostAddress) decode(vs []interface{}) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = e.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
a.Host = vs[0].(string)
|
||||
a.Port = uint16(vs[1].(int64))
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (a *HostAddress) UnmarshalBencode(b []byte) (err error) {
|
||||
var iface interface{}
|
||||
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := iface.(type) {
|
||||
case string:
|
||||
err = a.FromString(v)
|
||||
case []interface{}:
|
||||
err = a.decode(v)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported type: %T", iface)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (a HostAddress) MarshalBencode() (b []byte, err error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(32)
|
||||
err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port})
|
||||
if err == nil {
|
||||
b = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
48
metainfo/address_test.go
Normal file
48
metainfo/address_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddress(t *testing.T) {
|
||||
var addr1 Address
|
||||
if err := addr1.FromString("1.2.3.4:1234"); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := addr1.MarshalBencode()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
} else if s := string(data); s != `l7:1.2.3.4i1234ee` {
|
||||
t.Errorf(`expected 'l7:1.2.3.4i1234ee', but got '%s'`, s)
|
||||
}
|
||||
|
||||
var addr2 Address
|
||||
if err = addr2.UnmarshalBencode(data); err != nil {
|
||||
t.Error(err)
|
||||
} else if addr2.String() != `1.2.3.4:1234` {
|
||||
t.Errorf("expected '1.2.3.4:1234', but got '%s'", addr2)
|
||||
}
|
||||
|
||||
if data, err = addr2.MarshalBinary(); err != nil {
|
||||
t.Error(err)
|
||||
} else if s := string(data); s != "\x01\x02\x03\x04\x04\xd2" {
|
||||
t.Errorf(`expected '\x01\x02\x03\x04\x04\xd2', but got '%#x'`, s)
|
||||
}
|
||||
}
|
17
metainfo/doc.go
Normal file
17
metainfo/doc.go
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package metainfo is used to encode or decode the metainfo of the torrent file
|
||||
// and the MagNet link.
|
||||
package metainfo
|
65
metainfo/file.go
Normal file
65
metainfo/file.go
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// File represents a file in the multi-file case.
|
||||
type File struct {
|
||||
// Length is the length of the file in bytes.
|
||||
Length int64 `json:"length" bencode:"length"` // BEP 3
|
||||
|
||||
// Paths is a list containing one or more string elements that together
|
||||
// represent the path and filename. Each element in the list corresponds
|
||||
// to either a directory name or (in the case of the final element) the
|
||||
// filename.
|
||||
//
|
||||
// For example, a the file "dir1/dir2/file.ext" would consist of three
|
||||
// string elements: "dir1", "dir2", and "file.ext". This is encoded as
|
||||
// a bencoded list of strings such as l4:dir14:dir28:file.exte.
|
||||
Paths []string `json:"path" bencode:"path"` // BEP 3
|
||||
}
|
||||
|
||||
func (f File) String() string {
|
||||
return filepath.Join(f.Paths...)
|
||||
}
|
||||
|
||||
// Path returns the path of the current.
|
||||
func (f File) Path(info Info) string {
|
||||
if info.IsDir() {
|
||||
return f.String()
|
||||
}
|
||||
return info.Name
|
||||
}
|
||||
|
||||
// Offset returns the offset of the current file from the start.
|
||||
func (f File) Offset(info Info) (ret int64) {
|
||||
path := f.Path(info)
|
||||
for _, file := range info.AllFiles() {
|
||||
if path == file.Path(info) {
|
||||
return
|
||||
}
|
||||
ret += file.Length
|
||||
}
|
||||
panic("not found")
|
||||
}
|
||||
|
||||
type files []File
|
||||
|
||||
func (fs files) Len() int { return len(fs) }
|
||||
func (fs files) Less(i, j int) bool { return fs[i].String() < fs[j].String() }
|
||||
func (fs files) Swap(i, j int) { f := fs[i]; fs[i] = fs[j]; fs[j] = f }
|
138
metainfo/info.go
Normal file
138
metainfo/info.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Info is the file inforatino.
|
||||
type Info struct {
|
||||
// Name is the name of the file in the single file case.
|
||||
// Or, it is the name of the directory in the muliple file case.
|
||||
Name string `json:"name" bencode:"name"` // BEP 3
|
||||
|
||||
// PieceLength is the number of bytes in each piece, which is usually
|
||||
// a power of 2.
|
||||
PieceLength int64 `json:"piece length" bencode:"piece length"` // BEP 3
|
||||
|
||||
// Pieces is the concatenation of all 20-byte SHA1 hash values,
|
||||
// one per piece (byte string, i.e. not urlencoded).
|
||||
Pieces Hashes `json:"pieces" bencode:"pieces"` // BEP 3
|
||||
|
||||
// Length is the length of the file in bytes in the single file case.
|
||||
//
|
||||
// It's mutually exclusive with Files.
|
||||
Length int64 `json:"length,omitempty" bencode:"length,omitempty"` // BEP 3
|
||||
|
||||
// Files is the list of all the files in the multi-file case.
|
||||
//
|
||||
// For the purposes of the other keys, the multi-file case is treated
|
||||
// as only having a single file by concatenating the files in the order
|
||||
// they appear in the files list.
|
||||
//
|
||||
// It's mutually exclusive with Length.
|
||||
Files []File `json:"files,omitempty" bencode:"files,omitempty"` // BEP 3
|
||||
}
|
||||
|
||||
// NewInfoFromFilePath returns a new Info from a file or directory.
|
||||
func NewInfoFromFilePath(root string, pieceLength int64) (info Info, err error) {
|
||||
info.Name = filepath.Base(root)
|
||||
info.PieceLength = pieceLength
|
||||
err = filepath.Walk(root, func(path string, fi os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if path == root && !fi.IsDir() { // The root is a file.
|
||||
info.Length = fi.Size()
|
||||
return nil
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting relative path: %s", err)
|
||||
}
|
||||
|
||||
info.Files = append(info.Files, File{
|
||||
Paths: strings.Split(relPath, string(filepath.Separator)),
|
||||
Length: fi.Size(),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
sort.Sort(files(info.Files))
|
||||
info.Pieces, err = GeneratePiecesFromFiles(info.AllFiles(), info.PieceLength,
|
||||
func(file File) (io.ReadCloser, error) {
|
||||
if _len := len(file.Paths); _len > 0 {
|
||||
paths := make([]string, 0, _len+1)
|
||||
paths = append(paths, root)
|
||||
paths = append(paths, file.Paths...)
|
||||
return os.Open(filepath.Join(paths...))
|
||||
}
|
||||
return os.Open(root)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error generating pieces: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsDir reports whether the name is a directory, that's, the file is not
|
||||
// a single file.
|
||||
func (info Info) IsDir() bool { return len(info.Files) != 0 }
|
||||
|
||||
// CountPieces returns the number of the pieces.
|
||||
func (info Info) CountPieces() int { return len(info.Pieces) }
|
||||
|
||||
// TotalLength returns the total length of the torrent file.
|
||||
func (info Info) TotalLength() (ret int64) {
|
||||
if info.IsDir() {
|
||||
for _, fi := range info.Files {
|
||||
ret += fi.Length
|
||||
}
|
||||
} else {
|
||||
ret = info.Length
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Piece returns the Piece by the index starting with 0.
|
||||
func (info Info) Piece(index int) Piece {
|
||||
if n := len(info.Pieces); index >= n {
|
||||
panic(fmt.Errorf("Info.Piece: index '%d' exceeds maximum '%d'", index, n))
|
||||
}
|
||||
return Piece{info: info, index: index}
|
||||
}
|
||||
|
||||
// AllFiles returns all the files.
|
||||
//
|
||||
// Notice: for the single file, the Path is nil.
|
||||
func (info Info) AllFiles() []File {
|
||||
if info.IsDir() {
|
||||
return info.Files
|
||||
}
|
||||
return []File{{Length: info.Length}}
|
||||
}
|
207
metainfo/infohash.go
Normal file
207
metainfo/infohash.go
Normal file
@ -0,0 +1,207 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base32"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
)
|
||||
|
||||
var zeroHash Hash
|
||||
|
||||
// HashSize is the size of the InfoHash.
|
||||
const HashSize = 20
|
||||
|
||||
// Hash is the 20-byte SHA1 hash used for info and pieces.
|
||||
type Hash [HashSize]byte
|
||||
|
||||
// NewRandomHash returns a random hash.
|
||||
func NewRandomHash() (h Hash) {
|
||||
rand.Read(h[:])
|
||||
return
|
||||
}
|
||||
|
||||
// NewHash converts the 20-bytes to Hash.
|
||||
func NewHash(b []byte) (h Hash) {
|
||||
copy(h[:], b[:HashSize])
|
||||
return
|
||||
}
|
||||
|
||||
// NewHashFromString returns a new Hash from a string.
|
||||
func NewHashFromString(s string) (h Hash) {
|
||||
err := h.FromString(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NewHashFromHexString returns a new Hash from a hex string.
|
||||
func NewHashFromHexString(s string) (h Hash) {
|
||||
err := h.FromHexString(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NewHashFromBytes returns a new Hash from a byte slice.
|
||||
func NewHashFromBytes(b []byte) (ret Hash) {
|
||||
hasher := sha1.New()
|
||||
hasher.Write(b)
|
||||
copy(ret[:], hasher.Sum(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Bytes returns the byte slice type.
|
||||
func (h Hash) Bytes() []byte {
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// String is equal to HexString.
|
||||
func (h Hash) String() string {
|
||||
return h.HexString()
|
||||
}
|
||||
|
||||
// BytesString returns the bytes string, that's, string(h[:]).
|
||||
func (h Hash) BytesString() string {
|
||||
return string(h[:])
|
||||
}
|
||||
|
||||
// HexString returns the hex string format.
|
||||
func (h Hash) HexString() string {
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// IsZero reports whether the whole hash is zero.
|
||||
func (h Hash) IsZero() bool {
|
||||
return h == zeroHash
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (h Hash) MarshalBencode() (b []byte, err error) {
|
||||
return bencode.EncodeBytes(h[:])
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (h *Hash) UnmarshalBencode(b []byte) (err error) {
|
||||
var s string
|
||||
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&s); err == nil {
|
||||
err = h.FromString(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FromString resets the info hash from the string.
|
||||
func (h *Hash) FromString(s string) (err error) {
|
||||
switch len(s) {
|
||||
case HashSize:
|
||||
copy(h[:], s)
|
||||
case 2 * HashSize:
|
||||
err = h.FromHexString(s)
|
||||
case 32:
|
||||
var bs []byte
|
||||
if bs, err = base32.StdEncoding.DecodeString(s); err == nil {
|
||||
copy(h[:], bs)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("hash string has bad length: %d", len(s))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FromHexString resets the info hash from the hex string.
|
||||
func (h *Hash) FromHexString(s string) (err error) {
|
||||
if len(s) != 2*HashSize {
|
||||
err = fmt.Errorf("hash hex string has bad length: %d", len(s))
|
||||
return
|
||||
}
|
||||
|
||||
n, err := hex.Decode(h[:], []byte(s))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if n != HashSize {
|
||||
panic(n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Xor returns the hash of h XOR o.
|
||||
func (h Hash) Xor(o Hash) (ret Hash) {
|
||||
for i := range o {
|
||||
ret[i] = h[i] ^ o[i]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Compare returns 0 if h == o, -1 if h < o, or +1 if h > o.
|
||||
func (h Hash) Compare(o Hash) int { return bytes.Compare(h[:], o[:]) }
|
||||
|
||||
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
// Hashes is a set of Hashes.
|
||||
type Hashes []Hash
|
||||
|
||||
// Contains reports whether hs contains h.
|
||||
func (hs Hashes) Contains(h Hash) bool {
|
||||
for _, _h := range hs {
|
||||
if h == _h {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (hs Hashes) MarshalBencode() ([]byte, error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Grow(HashSize * len(hs))
|
||||
for _, h := range hs {
|
||||
buf.Write(h[:])
|
||||
}
|
||||
return bencode.EncodeBytes(buf.Bytes())
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (hs *Hashes) UnmarshalBencode(b []byte) (err error) {
|
||||
var bs []byte
|
||||
if err = bencode.DecodeBytes(b, &bs); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_len := len(bs)
|
||||
if _len%HashSize != 0 {
|
||||
return fmt.Errorf("Hashes: invalid bytes length '%d'", _len)
|
||||
}
|
||||
|
||||
hashes := make(Hashes, 0, _len/HashSize)
|
||||
for i := 0; i < _len; i += HashSize {
|
||||
var h Hash
|
||||
copy(h[:], bs[i:i+HashSize])
|
||||
hashes = append(hashes, h)
|
||||
}
|
||||
|
||||
*hs = hashes
|
||||
return
|
||||
}
|
138
metainfo/magnet.go
Normal file
138
metainfo/magnet.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Mozilla Public License Version 2.0
|
||||
// Modify from github.com/anacrolix/torrent/metainfo.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Magnet link components.
|
||||
type Magnet struct {
|
||||
InfoHash Hash // From "xt"
|
||||
Trackers []string // From "tr"
|
||||
DisplayName string // From "dn" if not empty
|
||||
Params url.Values // All other values, such as "as", "xs", etc
|
||||
}
|
||||
|
||||
const xtPrefix = "urn:btih:"
|
||||
|
||||
// Peers returns the list of the addresses of the peers.
|
||||
//
|
||||
// See BEP 9
|
||||
func (m Magnet) Peers() (peers []HostAddress, err error) {
|
||||
vs := m.Params["x.pe"]
|
||||
peers = make([]HostAddress, 0, len(vs))
|
||||
for _, v := range vs {
|
||||
if v != "" {
|
||||
var addr HostAddress
|
||||
if err = addr.FromString(v); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
peers = append(peers, addr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Magnet) String() string {
|
||||
vs := make(url.Values, len(m.Params)+len(m.Trackers)+2)
|
||||
for k, v := range m.Params {
|
||||
vs[k] = append([]string(nil), v...)
|
||||
}
|
||||
|
||||
for _, tr := range m.Trackers {
|
||||
vs.Add("tr", tr)
|
||||
}
|
||||
if m.DisplayName != "" {
|
||||
vs.Add("dn", m.DisplayName)
|
||||
}
|
||||
|
||||
// Transmission and Deluge both expect "urn:btih:" to be unescaped.
|
||||
// Deluge wants it to be at the start of the magnet link.
|
||||
// The InfoHash field is expected to be BitTorrent in this implementation.
|
||||
u := url.URL{
|
||||
Scheme: "magnet",
|
||||
RawQuery: "xt=" + xtPrefix + m.InfoHash.HexString(),
|
||||
}
|
||||
if len(vs) != 0 {
|
||||
u.RawQuery += "&" + vs.Encode()
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ParseMagnetURI parses Magnet-formatted URIs into a Magnet instance.
|
||||
func ParseMagnetURI(uri string) (m Magnet, err error) {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error parsing uri: %w", err)
|
||||
return
|
||||
} else if u.Scheme != "magnet" {
|
||||
err = fmt.Errorf("unexpected scheme %q", u.Scheme)
|
||||
return
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
xt := q.Get("xt")
|
||||
if m.InfoHash, err = parseInfohash(q.Get("xt")); err != nil {
|
||||
err = fmt.Errorf("error parsing infohash %q: %w", xt, err)
|
||||
return
|
||||
}
|
||||
dropFirst(q, "xt")
|
||||
|
||||
m.DisplayName = q.Get("dn")
|
||||
dropFirst(q, "dn")
|
||||
|
||||
m.Trackers = q["tr"]
|
||||
delete(q, "tr")
|
||||
|
||||
if len(q) == 0 {
|
||||
q = nil
|
||||
}
|
||||
|
||||
m.Params = q
|
||||
return
|
||||
}
|
||||
|
||||
func parseInfohash(xt string) (ih Hash, err error) {
|
||||
if !strings.HasPrefix(xt, xtPrefix) {
|
||||
err = errors.New("bad xt parameter prefix")
|
||||
return
|
||||
}
|
||||
|
||||
var n int
|
||||
encoded := xt[len(xtPrefix):]
|
||||
switch len(encoded) {
|
||||
case 40:
|
||||
n, err = hex.Decode(ih[:], []byte(encoded))
|
||||
case 32:
|
||||
n, err = base32.StdEncoding.Decode(ih[:], []byte(encoded))
|
||||
default:
|
||||
err = fmt.Errorf("unhandled xt parameter encoding (encoded length %d)", len(encoded))
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error decoding xt: %w", err)
|
||||
} else if n != 20 {
|
||||
panic(fmt.Errorf("invalid length '%d' of the decoded bytes", n))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func dropFirst(vs url.Values, key string) {
|
||||
sl := vs[key]
|
||||
switch len(sl) {
|
||||
case 0, 1:
|
||||
vs.Del(key)
|
||||
default:
|
||||
vs[key] = sl[1:]
|
||||
}
|
||||
}
|
170
metainfo/metainfo.go
Normal file
170
metainfo/metainfo.go
Normal file
@ -0,0 +1,170 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
"github.com/xgfone/bt/utils"
|
||||
)
|
||||
|
||||
// Bytes is the []byte type.
|
||||
type Bytes = bencode.RawMessage
|
||||
|
||||
// AnnounceList is a list of the announces.
|
||||
type AnnounceList [][]string
|
||||
|
||||
// Unique returns the list of the unique announces.
|
||||
func (al AnnounceList) Unique() (announces []string) {
|
||||
announces = make([]string, 0, len(al))
|
||||
for _, tier := range al {
|
||||
for _, v := range tier {
|
||||
if v != "" && utils.InStringSlice(announces, v) {
|
||||
announces = append(announces, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// URLList represents a list of the url.
|
||||
//
|
||||
// BEP 19
|
||||
type URLList []string
|
||||
|
||||
// FullURL returns the index-th full url.
|
||||
//
|
||||
// For the single-file case, name is the "name" of "info".
|
||||
// For the multi-file case, name is the path "name/path/file"
|
||||
// from "info" and "files".
|
||||
//
|
||||
// See http://bittorrent.org/beps/bep_0019.html
|
||||
func (us URLList) FullURL(index int, name string) (url string) {
|
||||
if url = us[index]; strings.HasSuffix(url, "/") {
|
||||
url += name
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (us URLList) MarshalBencode() (b []byte, err error) {
|
||||
return bencode.EncodeBytes([]string(us))
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (us *URLList) UnmarshalBencode(b []byte) (err error) {
|
||||
var v interface{}
|
||||
if err = bencode.DecodeBytes(b, &v); err == nil {
|
||||
switch vs := v.(type) {
|
||||
case string:
|
||||
*us = URLList{vs}
|
||||
case []interface{}:
|
||||
urls := make(URLList, len(vs))
|
||||
for i, u := range vs {
|
||||
s, ok := u.(string)
|
||||
if !ok {
|
||||
return errors.New("the element of 'url-list' is not string")
|
||||
}
|
||||
urls[i] = s
|
||||
}
|
||||
*us = urls
|
||||
default:
|
||||
return errors.New("invalid 'url-lsit'")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MetaInfo represents the .torrent file.
|
||||
type MetaInfo struct {
|
||||
InfoBytes Bytes `bencode:"info"` // BEP 3
|
||||
Announce string `bencode:"announce,omitempty"` // BEP 3
|
||||
AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12
|
||||
Nodes []HostAddress `bencode:"nodes,omitempty"` // BEP 5
|
||||
URLList URLList `bencode:"url-list,omitempty"` // BEP 19
|
||||
|
||||
// Where's this specified?
|
||||
// Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification.
|
||||
// All of them are optional.
|
||||
|
||||
// CreationDate is the creation time of the torrent, in standard UNIX epoch
|
||||
// format (seconds since 1-Jan-1970 00:00:00 UTC).
|
||||
CreationDate int64 `bencode:"creation date,omitempty"`
|
||||
// Comment is the free-form textual comments of the author.
|
||||
Comment string `bencode:"comment,omitempty"`
|
||||
// CreatedBy is name and version of the program used to create the .torrent.
|
||||
CreatedBy string `bencode:"created by,omitempty"`
|
||||
// Encoding is the string encoding format used to generate the pieces part
|
||||
// of the info dictionary in the .torrent metafile.
|
||||
Encoding string `bencode:"encoding,omitempty"`
|
||||
}
|
||||
|
||||
// Load loads a MetaInfo from an io.Reader.
|
||||
func Load(r io.Reader) (mi MetaInfo, err error) {
|
||||
err = bencode.NewDecoder(r).Decode(&mi)
|
||||
return
|
||||
}
|
||||
|
||||
// LoadFromFile loads a MetaInfo from a file.
|
||||
func LoadFromFile(filename string) (mi MetaInfo, err error) {
|
||||
f, err := os.Open(filename)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
mi, err = Load(f)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Announces returns all the announces.
|
||||
func (mi MetaInfo) Announces() AnnounceList {
|
||||
if len(mi.AnnounceList) > 0 {
|
||||
return mi.AnnounceList
|
||||
} else if mi.Announce != "" {
|
||||
return [][]string{{mi.Announce}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Magnet creates a Magnet from a MetaInfo.
|
||||
func (mi *MetaInfo) Magnet(displayName string, infoHash Hash) (m Magnet) {
|
||||
for _, t := range mi.Announces().Unique() {
|
||||
m.Trackers = append(m.Trackers, t)
|
||||
}
|
||||
|
||||
m.DisplayName = displayName
|
||||
m.InfoHash = infoHash
|
||||
return
|
||||
}
|
||||
|
||||
// Write encodes the metainfo to w.
|
||||
func (mi MetaInfo) Write(w io.Writer) error {
|
||||
return bencode.NewEncoder(w).Encode(mi)
|
||||
}
|
||||
|
||||
// InfoHash returns the hash of the info.
|
||||
func (mi MetaInfo) InfoHash() Hash {
|
||||
return NewHashFromBytes(mi.InfoBytes)
|
||||
}
|
||||
|
||||
// Info parses the InfoBytes to the Info.
|
||||
func (mi MetaInfo) Info() (info Info, err error) {
|
||||
err = bencode.DecodeBytes(mi.InfoBytes, &info)
|
||||
return
|
||||
}
|
99
metainfo/piece.go
Normal file
99
metainfo/piece.go
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/xgfone/bt/utils"
|
||||
)
|
||||
|
||||
// Piece represents a torrent file piece.
|
||||
type Piece struct {
|
||||
info Info
|
||||
index int
|
||||
}
|
||||
|
||||
// Index returns the index of the current piece.
|
||||
func (p Piece) Index() int { return p.index }
|
||||
|
||||
// Offset returns the offset that the current piece is in all the files.
|
||||
func (p Piece) Offset() int64 { return int64(p.index) * p.info.PieceLength }
|
||||
|
||||
// Hash returns the hash representation of the piece.
|
||||
func (p Piece) Hash() (h Hash) { return p.info.Pieces[p.index] }
|
||||
|
||||
// Length returns the length of the current piece.
|
||||
func (p Piece) Length() int64 {
|
||||
if p.index == p.info.CountPieces()-1 {
|
||||
return p.info.TotalLength() - int64(p.index)*p.info.PieceLength
|
||||
}
|
||||
return p.info.PieceLength
|
||||
}
|
||||
|
||||
// GeneratePieces generates the pieces from the reader.
|
||||
func GeneratePieces(r io.Reader, pieceLength int64) (hs Hashes, err error) {
|
||||
buf := make([]byte, pieceLength)
|
||||
for {
|
||||
h := sha1.New()
|
||||
written, err := utils.CopyNBuffer(h, r, pieceLength, buf)
|
||||
if written > 0 {
|
||||
hs = append(hs, NewHash(h.Sum(nil)))
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return hs, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeFiles(w io.Writer, files []File, open func(File) (io.ReadCloser, error)) error {
|
||||
buf := make([]byte, 8192)
|
||||
for _, file := range files {
|
||||
r, err := open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening %s: %s", file, err)
|
||||
}
|
||||
|
||||
n, err := utils.CopyNBuffer(w, r, file.Length, buf)
|
||||
r.Close()
|
||||
|
||||
if n != file.Length {
|
||||
return fmt.Errorf("error copying %s: %s", file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GeneratePiecesFromFiles generates the pieces from the files.
|
||||
func GeneratePiecesFromFiles(files []File, pieceLength int64,
|
||||
open func(File) (io.ReadCloser, error)) (Hashes, error) {
|
||||
if pieceLength <= 0 {
|
||||
return nil, errors.New("piece length must be a positive integer")
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
go func() { pw.CloseWithError(writeFiles(pw, files, open)) }()
|
||||
return GeneratePieces(pr, pieceLength)
|
||||
}
|
19
peerprotocol/doc.go
Normal file
19
peerprotocol/doc.go
Normal file
@ -0,0 +1,19 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package peerprotocol implements the core BT peer protocol. You can use it
|
||||
// to implement a peer client, and upload/download the file to/from other peers.
|
||||
//
|
||||
// See TorrentDownloader to download the torrent file from other peer.
|
||||
package peerprotocol
|
161
peerprotocol/extension.go
Normal file
161
peerprotocol/extension.go
Normal file
@ -0,0 +1,161 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
)
|
||||
|
||||
var errInvalidIP = errors.New("invalid ipv4 or ipv6")
|
||||
|
||||
// Predefine some extended message identifiers.
|
||||
const (
|
||||
ExtendedIDHandshake = 0 // BEP 10
|
||||
)
|
||||
|
||||
// Predefine some extended message names.
|
||||
const (
|
||||
ExtendedMessageNameMetadata = "ut_metadata" // BEP 9
|
||||
ExtendedMessageNamePex = "ut_pex" // BEP 11
|
||||
)
|
||||
|
||||
// Predefine some "ut_metadata" extended message types.
|
||||
const (
|
||||
UtMetadataExtendedMsgTypeRequest = 0 // BEP 9
|
||||
UtMetadataExtendedMsgTypeData = 1 // BEP 9
|
||||
UtMetadataExtendedMsgTypeReject = 2 // BEP 9
|
||||
)
|
||||
|
||||
// CompactIP is used to handle the compact ipv4 or ipv6.
|
||||
type CompactIP net.IP
|
||||
|
||||
func (ci CompactIP) String() string {
|
||||
return net.IP(ci).String()
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (ci CompactIP) MarshalBencode() ([]byte, error) {
|
||||
ip := net.IP(ci)
|
||||
if ipv4 := ip.To4(); len(ipv4) != 0 {
|
||||
ip = ipv4
|
||||
}
|
||||
return bencode.EncodeBytes(ip[:])
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (ci *CompactIP) UnmarshalBencode(b []byte) (err error) {
|
||||
var ip net.IP
|
||||
if err = bencode.DecodeBytes(b, &ip); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch len(ip) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
default:
|
||||
return errInvalidIP
|
||||
}
|
||||
|
||||
if ipv4 := ip.To4(); len(ipv4) != 0 {
|
||||
ip = ipv4
|
||||
}
|
||||
|
||||
*ci = CompactIP(ip)
|
||||
return
|
||||
}
|
||||
|
||||
// ExtendedHandshakeMsg represent the extended handshake message.
|
||||
//
|
||||
// BEP 10
|
||||
type ExtendedHandshakeMsg struct {
|
||||
// M is the type of map[ExtendedMessageName]ExtendedMessageID.
|
||||
M map[string]uint8 `bencode:"m"` // BEP 10
|
||||
V string `bencode:"v,omitempty"` // BEP 10
|
||||
Reqq int `bencode:"reqq,omitempty"` // BEP 10
|
||||
|
||||
// Port is the local client port, which is redundant and no need
|
||||
// for the receiving side of the connection to send this.
|
||||
Port uint16 `bencode:"p,omitempty"` // BEP 10
|
||||
IPv6 net.IP `bencode:"ipv6,omitempty"` // BEP 10
|
||||
IPv4 CompactIP `bencode:"ipv4,omitempty"` // BEP 10
|
||||
YourIP CompactIP `bencode:"yourip,omitempty"` // BEP 10
|
||||
|
||||
MetadataSize int `bencode:"metadata_size,omitempty"` // BEP 9
|
||||
}
|
||||
|
||||
// Decode decodes the extended handshake message from b.
|
||||
func (ehm *ExtendedHandshakeMsg) Decode(b []byte) (err error) {
|
||||
return bencode.DecodeBytes(b, ehm)
|
||||
}
|
||||
|
||||
// Encode encodes the extended handshake message to b.
|
||||
func (ehm ExtendedHandshakeMsg) Encode() (b []byte, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 128))
|
||||
if err = bencode.NewEncoder(buf).Encode(ehm); err != nil {
|
||||
b = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UtMetadataExtendedMsg represents the "ut_metadata" extended message.
|
||||
type UtMetadataExtendedMsg struct {
|
||||
MsgType uint8 `bencode:"msg_type"` // BEP 9
|
||||
Piece int `bencode:"piece"` // BEP 9
|
||||
|
||||
// They are only used by "data" type
|
||||
TotalSize int `bencode:"total_size,omitempty"` // BEP 9
|
||||
Data []byte `bencode:"-"`
|
||||
}
|
||||
|
||||
// EncodeToPayload encodes UtMetadataExtendedMsg to extended payload
|
||||
// and write the result into buf.
|
||||
func (um UtMetadataExtendedMsg) EncodeToPayload(buf *bytes.Buffer) (err error) {
|
||||
if um.MsgType != UtMetadataExtendedMsgTypeData {
|
||||
um.TotalSize = 0
|
||||
um.Data = nil
|
||||
}
|
||||
|
||||
buf.Grow(len(um.Data) + 50)
|
||||
if err = bencode.NewEncoder(buf).Encode(um); err == nil {
|
||||
_, err = buf.Write(um.Data)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeToBytes is equal to
|
||||
//
|
||||
// buf := new(bytes.Buffer)
|
||||
// err = um.EncodeToPayload(buf)
|
||||
// return buf.Bytes(), err
|
||||
//
|
||||
func (um UtMetadataExtendedMsg) EncodeToBytes() (b []byte, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 128))
|
||||
if err = um.EncodeToPayload(buf); err == nil {
|
||||
b = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeFromPayload decodes the extended payload to itself.
|
||||
func (um *UtMetadataExtendedMsg) DecodeFromPayload(b []byte) (err error) {
|
||||
dec := bencode.NewDecoder(bytes.NewReader(b))
|
||||
if err = dec.Decode(&um); err == nil {
|
||||
um.Data = b[dec.BytesParsed():]
|
||||
}
|
||||
return
|
||||
}
|
54
peerprotocol/extension_test.go
Normal file
54
peerprotocol/extension_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompactIP(t *testing.T) {
|
||||
ipv4 := CompactIP([]byte{1, 2, 3, 4})
|
||||
b, err := ipv4.MarshalBencode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var ip CompactIP
|
||||
if err = ip.UnmarshalBencode(b); err != nil {
|
||||
t.Error(err)
|
||||
} else if ip.String() != "1.2.3.4" {
|
||||
t.Error(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUtMetadataExtendedMsg(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
data := []byte{0x31, 0x32, 0x33, 0x34, 0x35}
|
||||
m1 := UtMetadataExtendedMsg{MsgType: 1, Piece: 2, TotalSize: 1024, Data: data}
|
||||
if err := m1.EncodeToPayload(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg := Message{Type: Extended, ExtendedPayload: buf.Bytes()}
|
||||
m2, err := msg.UtMetadataExtendedMsg()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if m2.MsgType != 1 || m2.Piece != 2 || m2.TotalSize != 1024 {
|
||||
t.Error(m2)
|
||||
} else if !bytes.Equal(m2.Data, data) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
69
peerprotocol/fastset.go
Normal file
69
peerprotocol/fastset.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// GenerateAllowedFastSet generates some allowed fast set of the torrent file.
|
||||
//
|
||||
// Argument:
|
||||
// set: generated piece set, the length of which is the number to be generated.
|
||||
// sz: the number of pieces in torrent.
|
||||
// ip: the of the remote peer of the connection.
|
||||
// infohash: infohash of torrent.
|
||||
//
|
||||
// BEP 6
|
||||
func GenerateAllowedFastSet(set []uint32, sz uint32, ip net.IP, infohash metainfo.Hash) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ip = ipv4
|
||||
}
|
||||
|
||||
iplen := len(ip)
|
||||
x := make([]byte, 20+iplen)
|
||||
for i, j := 0, iplen-1; i < j; i++ { // (1) compatible with IPv4/IPv6
|
||||
x[i] = ip[i] & 0xff // (1)
|
||||
}
|
||||
// x[iplen-1] = 0 // It is equal to 0 primitively.
|
||||
copy(x[iplen:], infohash[:]) // (2)
|
||||
|
||||
for cur, k := 0, len(set); cur < k; {
|
||||
sum := sha1.Sum(x) // (3)
|
||||
x = sum[:] // (3)
|
||||
for i := 0; i < 5 && cur < k; i++ { // (4)
|
||||
j := i * 4 // (5)
|
||||
y := binary.BigEndian.Uint32(x[j : j+4]) // (6)
|
||||
index := y % sz // (7)
|
||||
if !uint32Contains(set, index) { // (8)
|
||||
set[cur] = index // (9)
|
||||
cur++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func uint32Contains(ss []uint32, s uint32) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
109
peerprotocol/fastset_test.go
Normal file
109
peerprotocol/fastset_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
func TestGenerateAllowedFastSet(t *testing.T) {
|
||||
hexs := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
infohash := metainfo.NewHashFromHexString(hexs)
|
||||
|
||||
sets := make([]uint32, 7)
|
||||
GenerateAllowedFastSet(sets, 1313, net.ParseIP("80.4.4.200"), infohash)
|
||||
for i, v := range sets {
|
||||
switch i {
|
||||
case 0:
|
||||
if v != 1059 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 1:
|
||||
if v != 431 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 2:
|
||||
if v != 808 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 3:
|
||||
if v != 1217 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 4:
|
||||
if v != 287 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 5:
|
||||
if v != 376 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 6:
|
||||
if v != 1188 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
}
|
||||
|
||||
sets = make([]uint32, 9)
|
||||
GenerateAllowedFastSet(sets, 1313, net.ParseIP("80.4.4.200"), infohash)
|
||||
for i, v := range sets {
|
||||
switch i {
|
||||
case 0:
|
||||
if v != 1059 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 1:
|
||||
if v != 431 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 2:
|
||||
if v != 808 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 3:
|
||||
if v != 1217 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 4:
|
||||
if v != 287 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 5:
|
||||
if v != 376 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 6:
|
||||
if v != 1188 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 7:
|
||||
if v != 353 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
case 8:
|
||||
if v != 508 {
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unknown '%d' at the index '%d'", v, i)
|
||||
}
|
||||
}
|
||||
}
|
140
peerprotocol/handshake.go
Normal file
140
peerprotocol/handshake.go
Normal file
@ -0,0 +1,140 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
var errInvalidProtocolHeader = fmt.Errorf("unexpected peer protocol header string")
|
||||
|
||||
// Predefine some known extension bits.
|
||||
const (
|
||||
ExtensionBitDHT = 0 // BEP 5
|
||||
ExtensionBitFast = 2 // BEP 6
|
||||
ExtensionBitExtended = 20 // BEP 10
|
||||
|
||||
ExtensionBitMax = 64
|
||||
)
|
||||
|
||||
// ExtensionBits is the reserved bytes to be used by all extensions.
|
||||
//
|
||||
// BEP 10: The bit is counted starting at 0 from right to left.
|
||||
type ExtensionBits [8]byte
|
||||
|
||||
// String returns the hex string format.
|
||||
func (eb ExtensionBits) String() string {
|
||||
return hex.EncodeToString(eb[:])
|
||||
}
|
||||
|
||||
// Set sets the bit to 1, that's, to set it to be on.
|
||||
func (eb ExtensionBits) Set(bit uint) {
|
||||
eb[7-bit/8] |= 1 << (bit % 8)
|
||||
}
|
||||
|
||||
// Unset sets the bit to 0, that's, to set it to be off.
|
||||
func (eb ExtensionBits) Unset(bit uint) {
|
||||
eb[7-bit/8] &^= 1 << (bit % 8)
|
||||
}
|
||||
|
||||
// IsSet reports whether the bit is on.
|
||||
func (eb ExtensionBits) IsSet(bit uint) (yes bool) {
|
||||
return eb[7-bit/8]&(1<<(bit%8)) != 0
|
||||
}
|
||||
|
||||
// IsSupportDHT reports whether ExtensionBitDHT is set.
|
||||
func (eb ExtensionBits) IsSupportDHT() (yes bool) {
|
||||
return eb.IsSet(ExtensionBitDHT)
|
||||
}
|
||||
|
||||
// IsSupportFast reports whether ExtensionBitFast is set.
|
||||
func (eb ExtensionBits) IsSupportFast() (yes bool) {
|
||||
return eb.IsSet(ExtensionBitFast)
|
||||
}
|
||||
|
||||
// IsSupportExtended reports whether ExtensionBitExtended is set.
|
||||
func (eb ExtensionBits) IsSupportExtended() (yes bool) {
|
||||
return eb.IsSet(ExtensionBitExtended)
|
||||
}
|
||||
|
||||
// HandshakeMsg is the message used by the handshake
|
||||
type HandshakeMsg struct {
|
||||
ExtensionBits
|
||||
|
||||
PeerID metainfo.Hash
|
||||
InfoHash metainfo.Hash
|
||||
}
|
||||
|
||||
// NewHandshakeMsg returns a new HandshakeMsg.
|
||||
func NewHandshakeMsg(peerID, infoHash metainfo.Hash, es ...ExtensionBits) HandshakeMsg {
|
||||
var e ExtensionBits
|
||||
if len(es) > 0 {
|
||||
e = es[0]
|
||||
}
|
||||
return HandshakeMsg{ExtensionBits: e, PeerID: peerID, InfoHash: infoHash}
|
||||
}
|
||||
|
||||
// Handshake finishes the handshake with the peer.
|
||||
//
|
||||
// InfoHash may be ZERO, and it will read it from the peer then send it back
|
||||
// to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func Handshake(sock io.ReadWriter, msg HandshakeMsg) (ret HandshakeMsg, err error) {
|
||||
var read bool
|
||||
if msg.InfoHash.IsZero() {
|
||||
if err = getPeerHandshakeMsg(sock, &ret); err != nil {
|
||||
return
|
||||
}
|
||||
read = true
|
||||
msg.InfoHash = ret.InfoHash
|
||||
}
|
||||
|
||||
if _, err = io.WriteString(sock, ProtocolHeader); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = sock.Write(msg.ExtensionBits[:]); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = sock.Write(msg.InfoHash[:]); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = sock.Write(msg.PeerID[:]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !read {
|
||||
err = getPeerHandshakeMsg(sock, &ret)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getPeerHandshakeMsg(sock io.Reader, ret *HandshakeMsg) (err error) {
|
||||
var b [68]byte
|
||||
if _, err = io.ReadFull(sock, b[:]); err != nil {
|
||||
return
|
||||
} else if string(b[:20]) != ProtocolHeader {
|
||||
return errInvalidProtocolHeader
|
||||
}
|
||||
|
||||
copy(ret.ExtensionBits[:], b[20:28])
|
||||
copy(ret.InfoHash[:], b[28:48])
|
||||
copy(ret.PeerID[:], b[48:68])
|
||||
return
|
||||
}
|
277
peerprotocol/message.go
Normal file
277
peerprotocol/message.go
Normal file
@ -0,0 +1,277 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
var errMessageTooLong = fmt.Errorf("the peer message is too long")
|
||||
|
||||
// Message is the message used by the peer protocol, which contains
|
||||
// all the fields specified by the standard message types.
|
||||
type Message struct {
|
||||
Keepalive bool
|
||||
Type MessageType
|
||||
|
||||
// Index is used by these message types:
|
||||
//
|
||||
// BEP 3: Cancel, Request, Have, Piece
|
||||
// BEP 6: Reject, Suggest, AllowedFast
|
||||
//
|
||||
Index uint32
|
||||
|
||||
// Begin is used by these message types:
|
||||
//
|
||||
// BEP 3: Request, Cancel, Piece
|
||||
// BEP 6: Reject
|
||||
//
|
||||
Begin uint32
|
||||
|
||||
// Length is used by these message types:
|
||||
//
|
||||
// BEP 3: Request, Cancel
|
||||
// BEP 6: Reject
|
||||
//
|
||||
Length uint32
|
||||
|
||||
// Piece is used by these message types:
|
||||
//
|
||||
// BEP 3: Piece
|
||||
Piece []byte
|
||||
|
||||
// Bitfield is used by these message types:
|
||||
//
|
||||
// BEP 3: Bitfield
|
||||
Bitfield []bool
|
||||
|
||||
// ExtendedID and ExtendedPayload are used by these message types:
|
||||
//
|
||||
// BEP 10: Extended
|
||||
//
|
||||
ExtendedID uint8
|
||||
ExtendedPayload []byte
|
||||
|
||||
// Port is used by these message types:
|
||||
//
|
||||
// BEP 5: Port
|
||||
//
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// DecodeToMessage is equal to msg.Decode(r, maxLength).
|
||||
func DecodeToMessage(r io.Reader, maxLength uint32) (msg Message, err error) {
|
||||
err = msg.Decode(r, maxLength)
|
||||
return
|
||||
}
|
||||
|
||||
// UtMetadataExtendedMsg decodes the extended payload as UtMetadataExtendedMsg.
|
||||
//
|
||||
// Notice: the message type must be Extended.
|
||||
func (m Message) UtMetadataExtendedMsg() (um UtMetadataExtendedMsg, err error) {
|
||||
if m.Type != Extended {
|
||||
panic("the message type is Extended")
|
||||
}
|
||||
err = um.DecodeFromPayload(m.ExtendedPayload)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler,
|
||||
// which is equal to m.Decode(bytes.NewBuffer(data), 0).
|
||||
func (m *Message) UnmarshalBinary(data []byte) (err error) {
|
||||
return m.Decode(bytes.NewBuffer(data), 0)
|
||||
}
|
||||
|
||||
func readByte(r io.Reader) (b byte, err error) {
|
||||
var bs [1]byte
|
||||
if _, err = r.Read(bs[:]); err == nil {
|
||||
b = bs[0]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode reads the data from r and decodes it to Message.
|
||||
//
|
||||
// if maxLength is equal to 0, it is unlimited. Or, it will read maxLength bytes
|
||||
// at most.
|
||||
func (m *Message) Decode(r io.Reader, maxLength uint32) (err error) {
|
||||
var length uint32
|
||||
if err = binary.Read(r, binary.BigEndian, &length); err != nil {
|
||||
if err != io.EOF {
|
||||
err = fmt.Errorf("error reading peer message message length: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
m.Keepalive = true
|
||||
return
|
||||
} else if maxLength > 0 && length > maxLength {
|
||||
return errMessageTooLong
|
||||
}
|
||||
|
||||
m.Keepalive = false
|
||||
lr := &io.LimitedReader{R: r, N: int64(length)}
|
||||
|
||||
// Check that all of r was utilized.
|
||||
defer func() {
|
||||
if err == nil && lr.N != 0 {
|
||||
err = fmt.Errorf("%d bytes unused in message type %d", lr.N, m.Type)
|
||||
}
|
||||
}()
|
||||
|
||||
_type, err := readByte(lr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch m.Type = MessageType(_type); m.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
case Have, AllowedFast, Suggest:
|
||||
err = binary.Read(lr, binary.BigEndian, &m.Index)
|
||||
case Request, Cancel, Reject:
|
||||
if err = binary.Read(lr, binary.BigEndian, &m.Index); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Read(lr, binary.BigEndian, &m.Begin); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Read(lr, binary.BigEndian, &m.Length); err != nil {
|
||||
return
|
||||
}
|
||||
case Bitfield:
|
||||
_len := length - 1
|
||||
bs := make([]byte, _len)
|
||||
if _, err = io.ReadFull(lr, bs); err == nil {
|
||||
m.Bitfield = make([]bool, 0, _len*8)
|
||||
for _, b := range bs {
|
||||
for i := byte(7); i >= 0; i-- {
|
||||
m.Bitfield = append(m.Bitfield, (b>>i)&1 == 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
case Piece:
|
||||
if err = binary.Read(lr, binary.BigEndian, &m.Index); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Read(lr, binary.BigEndian, &m.Begin); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Should we use a []byte pool?
|
||||
m.Piece = make([]byte, lr.N)
|
||||
if _, err = io.ReadFull(lr, m.Piece); err != nil {
|
||||
return fmt.Errorf("reading piece data error: %s", err)
|
||||
}
|
||||
case Extended:
|
||||
if m.ExtendedID, err = readByte(lr); err == nil {
|
||||
m.ExtendedPayload, err = ioutil.ReadAll(lr)
|
||||
}
|
||||
case Port:
|
||||
err = binary.Read(lr, binary.BigEndian, &m.Port)
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type %v", m.Type)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBinary implements the interface encoding.BinaryMarshaler.
|
||||
func (m Message) MarshalBinary() (data []byte, err error) {
|
||||
// TODO: Should we use a buffer pool?
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 4))
|
||||
if err = m.Encode(buf); err == nil {
|
||||
data = buf.Bytes()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Encode encodes the message to buf.
|
||||
func (m Message) Encode(buf *bytes.Buffer) (err error) {
|
||||
// The 4-bytes is the placeholder of the length.
|
||||
buf.Reset()
|
||||
buf.Write([]byte{0, 0, 0, 0})
|
||||
|
||||
// Write the non-keepalive message.
|
||||
if !m.Keepalive {
|
||||
if err = buf.WriteByte(byte(m.Type)); err != nil {
|
||||
return
|
||||
} else if err = m.marshalBinaryType(buf); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate and reset the length of the message body.
|
||||
data := buf.Bytes()
|
||||
if payloadLen := len(data) - 4; payloadLen > 0 {
|
||||
binary.BigEndian.PutUint32(data[:4], uint32(payloadLen))
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m Message) marshalBinaryType(buf *bytes.Buffer) (err error) {
|
||||
switch m.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
case Have:
|
||||
err = binary.Write(buf, binary.BigEndian, m.Index)
|
||||
case Request, Cancel, Reject:
|
||||
if err = binary.Write(buf, binary.BigEndian, m.Index); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Write(buf, binary.BigEndian, m.Begin); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Write(buf, binary.BigEndian, m.Length); err != nil {
|
||||
return
|
||||
}
|
||||
case Bitfield:
|
||||
_len := (len(m.Bitfield) + 7) / 8
|
||||
bs := make([]byte, _len)
|
||||
for i, has := range m.Bitfield {
|
||||
if has {
|
||||
bs[i/8] |= (1 << byte(7-i%8))
|
||||
}
|
||||
}
|
||||
buf.Write(bs)
|
||||
case Piece:
|
||||
if err = binary.Write(buf, binary.BigEndian, m.Index); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Write(buf, binary.BigEndian, m.Begin); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if n, err := buf.Write(m.Piece); err != nil {
|
||||
return err
|
||||
} else if _len := len(m.Piece); n != _len {
|
||||
return fmt.Errorf("expect writing %d bytes, but wrote %d", _len, n)
|
||||
}
|
||||
case Extended:
|
||||
if err = buf.WriteByte(byte(m.ExtendedID)); err != nil {
|
||||
_, err = buf.Write(m.ExtendedPayload)
|
||||
}
|
||||
case Port:
|
||||
err = binary.Write(buf, binary.BigEndian, m.Port)
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type: %v", m.Type)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
522
peerprotocol/peerconn.go
Normal file
522
peerprotocol/peerconn.go
Normal file
@ -0,0 +1,522 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// Predefine some errors about extension support.
|
||||
var (
|
||||
ErrNotSupportDHT = fmt.Errorf("not support DHT extension")
|
||||
ErrNotSupportFast = fmt.Errorf("not support Fast extension")
|
||||
ErrNotSupportExtended = fmt.Errorf("not support Extended extension")
|
||||
)
|
||||
|
||||
// Bep3Handler is used to handle the BEP 3 type message if Handler has also
|
||||
// implemented the interface.
|
||||
type Bep3Handler interface {
|
||||
Choke(pc *PeerConn) error
|
||||
Unchoke(pc *PeerConn) error
|
||||
Interested(pc *PeerConn) error
|
||||
NotInterested(pc *PeerConn) error
|
||||
Have(pc *PeerConn, index uint32) error
|
||||
Bitfield(pc *PeerConn, bits []bool) error
|
||||
Request(pc *PeerConn, index uint32, begin uint32, length uint32) error
|
||||
Piece(pc *PeerConn, index uint32, begin uint32, piece []byte) error
|
||||
Cancel(pc *PeerConn, index uint32, begin uint32, length uint32) error
|
||||
}
|
||||
|
||||
// Bep5Handler is used to handle the BEP 5 type message if Handler has also
|
||||
// implemented the interface.
|
||||
//
|
||||
// Notice: the server must enable the DHT extension bit.
|
||||
type Bep5Handler interface {
|
||||
Port(pc *PeerConn, port uint16) error
|
||||
}
|
||||
|
||||
// Bep6Handler is used to handle the BEP 6 type message if Handler has also
|
||||
// implemented the interface.
|
||||
//
|
||||
// Notice: the server must enable the Fast extension bit.
|
||||
type Bep6Handler interface {
|
||||
HaveAll(pc *PeerConn) error
|
||||
HaveNone(pc *PeerConn) error
|
||||
Suggest(pc *PeerConn, index uint32) error
|
||||
AllowedFast(pc *PeerConn, index uint32) error
|
||||
Reject(pc *PeerConn, index uint32, begin uint32, length uint32) error
|
||||
}
|
||||
|
||||
// Bep10Handler is used to handle the BEP 10 extended peer message
|
||||
// if Handler has also implemented the interface.
|
||||
//
|
||||
// Notice: the server must enable the Extended extension bit.
|
||||
type Bep10Handler interface {
|
||||
OnHandShake(conn *PeerConn, exthmsg ExtendedHandshakeMsg) error
|
||||
OnPayload(conn *PeerConn, extid uint8, payload []byte) error
|
||||
}
|
||||
|
||||
// PeerConn is used to manage the connection to the peer.
|
||||
type PeerConn struct {
|
||||
net.Conn
|
||||
ExtensionBits
|
||||
ID metainfo.Hash
|
||||
|
||||
// Timeout is used to control the timeout of reading/writing the message.
|
||||
//
|
||||
// The default is 0, which represents no timeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// MaxLength is used to limit the maximum number of the message body.
|
||||
//
|
||||
// The default is 0, which represents no limit.
|
||||
MaxLength uint32
|
||||
|
||||
// These two states is controlled by the local client peer.
|
||||
//
|
||||
// Choked is used to indicate whether or not the local client has choked
|
||||
// the remote peer, that's, if Choked is true, the local client will
|
||||
// discard all the pending requests from the remote peer and not answer
|
||||
// any requests until the local client is unchoked.
|
||||
//
|
||||
// Interested is used to indicate whether or not the local client is
|
||||
// interested in something the remote peer has to offer, that's,
|
||||
// if Interested is true, the local client will begin requesting blocks
|
||||
// when the remote client unchokes them.
|
||||
//
|
||||
Choked bool // The default should be true.
|
||||
Interested bool
|
||||
|
||||
// These two states is controlled by the remote peer.
|
||||
//
|
||||
// PeerChoked is used to indicate whether or not the remote peer has choked
|
||||
// the local client, that's, if PeerChoked is true, the remote peer will
|
||||
// discard all the pending requests from the local client and not answer
|
||||
// any requests until the remote peer is unchoked.
|
||||
//
|
||||
// PeerInterested is used to indicate whether or not the remote peer is
|
||||
// interested in something the local client has to offer, that's,
|
||||
// if PeerInterested is true, the remote peer will begin requesting blocks
|
||||
// when the local client unchokes them.
|
||||
//
|
||||
PeerChoked bool // The default should be true.
|
||||
PeerInterested bool
|
||||
|
||||
// Data is used to store the context data associated with the connection.
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// NewPeerConn returns a new PeerConn.
|
||||
func NewPeerConn(id metainfo.Hash, conn net.Conn) *PeerConn {
|
||||
return &PeerConn{ID: id, Conn: conn, Choked: true, PeerChoked: true}
|
||||
}
|
||||
|
||||
// NewPeerConnByDial returns a new PeerConn by dialing to addr with the "tcp" network.
|
||||
func NewPeerConnByDial(id metainfo.Hash, addr string) (pc *PeerConn, err error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err == nil {
|
||||
pc = NewPeerConn(id, conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (pc *PeerConn) setReadTimeout() {
|
||||
if pc.Timeout > 0 {
|
||||
pc.Conn.SetReadDeadline(time.Now().Add(pc.Timeout))
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *PeerConn) setWriteTimeout() {
|
||||
if pc.Timeout > 0 {
|
||||
pc.Conn.SetWriteDeadline(time.Now().Add(pc.Timeout))
|
||||
}
|
||||
}
|
||||
|
||||
// SetChoked sets the Choked state of the local client peer.
|
||||
//
|
||||
// Notice: if the current state is not Choked, it will send a Choked message
|
||||
// to the remote peer.
|
||||
func (pc *PeerConn) SetChoked() (err error) {
|
||||
if !pc.Choked {
|
||||
if err = pc.SendChoke(); err == nil {
|
||||
pc.Choked = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetUnchoked sets the Unchoked state of the local client peer.
|
||||
//
|
||||
// Notice: if the current state is not Unchoked, it will send a Unchoked message
|
||||
// to the remote peer.
|
||||
func (pc *PeerConn) SetUnchoked() (err error) {
|
||||
if pc.Choked {
|
||||
if err = pc.SendUnchoke(); err == nil {
|
||||
pc.Choked = false
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetInterested sets the Interested state of the local client peer.
|
||||
//
|
||||
// Notice: if the current state is not Interested, it will send a Interested
|
||||
// message to the remote peer.
|
||||
func (pc *PeerConn) SetInterested() (err error) {
|
||||
if !pc.Interested {
|
||||
if err = pc.SendInterested(); err == nil {
|
||||
pc.Interested = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetNotInterested sets the NotInterested state of the local client peer.
|
||||
//
|
||||
// Notice: if the current state is not NotInterested, it will send
|
||||
// a NotInterested message to the remote peer.
|
||||
func (pc *PeerConn) SetNotInterested() (err error) {
|
||||
if pc.Interested {
|
||||
if err = pc.SendNotInterested(); err == nil {
|
||||
pc.Interested = false
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handshake does a handshake with the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) Handshake(infoHash metainfo.Hash) (HandshakeMsg, error) {
|
||||
m := HandshakeMsg{ExtensionBits: pc.ExtensionBits, PeerID: pc.ID, InfoHash: infoHash}
|
||||
pc.setReadTimeout()
|
||||
return Handshake(pc.Conn, m)
|
||||
}
|
||||
|
||||
// ReadMsg reads the message.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) ReadMsg() (m Message, err error) {
|
||||
pc.setReadTimeout()
|
||||
err = m.Decode(pc.Conn, pc.MaxLength)
|
||||
return
|
||||
}
|
||||
|
||||
// WriteMsg writes the message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) WriteMsg(m Message) (err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 128))
|
||||
if err = m.Encode(buf); err == nil {
|
||||
pc.setWriteTimeout()
|
||||
|
||||
var n int
|
||||
if n, err = pc.Conn.Write(buf.Bytes()); err == nil && n < buf.Len() {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SendKeepalive sends a Keepalive message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendKeepalive() error {
|
||||
return pc.WriteMsg(Message{Keepalive: true})
|
||||
}
|
||||
|
||||
// SendChoke sends a Choke message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendChoke() error {
|
||||
return pc.WriteMsg(Message{Type: Choke})
|
||||
}
|
||||
|
||||
// SendUnchoke sends a Unchoke message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendUnchoke() error {
|
||||
return pc.WriteMsg(Message{Type: Unchoke})
|
||||
}
|
||||
|
||||
// SendInterested sends a Interested message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendInterested() error {
|
||||
return pc.WriteMsg(Message{Type: Interested})
|
||||
}
|
||||
|
||||
// SendNotInterested sends a NotInterested message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendNotInterested() error {
|
||||
return pc.WriteMsg(Message{Type: NotInterested})
|
||||
}
|
||||
|
||||
// SendBitfield sends a Bitfield message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendBitfield(bits []bool) error {
|
||||
return pc.WriteMsg(Message{Type: Bitfield, Bitfield: bits})
|
||||
}
|
||||
|
||||
// SendHave sends a Have message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendHave(index uint32) error {
|
||||
return pc.WriteMsg(Message{Type: Have, Index: index})
|
||||
}
|
||||
|
||||
// SendRequest sends a Request message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendRequest(index, begin, length uint32) error {
|
||||
return pc.WriteMsg(Message{Type: Request, Index: index, Begin: begin, Length: length})
|
||||
}
|
||||
|
||||
// SendCancel sends a Cancel message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendCancel(index, begin, length uint32) error {
|
||||
return pc.WriteMsg(Message{Type: Cancel, Index: index, Begin: begin, Length: length})
|
||||
}
|
||||
|
||||
// SendPiece sends a Piece message to the peer.
|
||||
//
|
||||
// BEP 3
|
||||
func (pc *PeerConn) SendPiece(index, begin uint32, piece []byte) error {
|
||||
return pc.WriteMsg(Message{Type: Piece, Index: index, Begin: begin, Piece: piece})
|
||||
}
|
||||
|
||||
// SendPort sends a Port message to the peer.
|
||||
//
|
||||
// BEP 5
|
||||
func (pc *PeerConn) SendPort(port uint16) error {
|
||||
return pc.WriteMsg(Message{Type: Port, Port: port})
|
||||
}
|
||||
|
||||
// SendHaveAll sends a HaveAll message to the peer.
|
||||
//
|
||||
// BEP 6
|
||||
func (pc *PeerConn) SendHaveAll() error {
|
||||
return pc.WriteMsg(Message{Type: HaveAll})
|
||||
}
|
||||
|
||||
// SendHaveNone sends a HaveNone message to the peer.
|
||||
//
|
||||
// BEP 6
|
||||
func (pc *PeerConn) SendHaveNone() error {
|
||||
return pc.WriteMsg(Message{Type: HaveNone})
|
||||
}
|
||||
|
||||
// SendSuggest sends a Suggest message to the peer.
|
||||
//
|
||||
// BEP 6
|
||||
func (pc *PeerConn) SendSuggest(index uint32) error {
|
||||
return pc.WriteMsg(Message{Type: Suggest, Index: index})
|
||||
}
|
||||
|
||||
// SendReject sends a Reject message to the peer.
|
||||
//
|
||||
// BEP 6
|
||||
func (pc *PeerConn) SendReject(index, begin, length uint32) error {
|
||||
return pc.WriteMsg(Message{Type: Reject, Index: index, Begin: begin, Length: length})
|
||||
}
|
||||
|
||||
// SendAllowedFast sends a AllowedFast message to the peer.
|
||||
//
|
||||
// BEP 6
|
||||
func (pc *PeerConn) SendAllowedFast(index uint32) error {
|
||||
return pc.WriteMsg(Message{Type: AllowedFast, Index: index})
|
||||
}
|
||||
|
||||
// SendExtHandshakeMsg sends the Extended Handshake message to the peer.
|
||||
//
|
||||
// BEP 10
|
||||
func (pc *PeerConn) SendExtHandshakeMsg(m ExtendedHandshakeMsg) (err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 128))
|
||||
if err = bencode.NewEncoder(buf).Encode(m); err == nil {
|
||||
err = pc.SendExtMsg(ExtendedIDHandshake, buf.Bytes())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SendExtMsg sends the Extended message with the extended id and the payload.
|
||||
//
|
||||
// BEP 10
|
||||
func (pc *PeerConn) SendExtMsg(extID uint8, payload []byte) error {
|
||||
return pc.WriteMsg(Message{Type: Extended, ExtendedID: extID, ExtendedPayload: payload})
|
||||
}
|
||||
|
||||
// HandleMessage calls the method of the handler to handle the message.
|
||||
//
|
||||
// If handler has also implemented the interfaces Bep3Handler, Bep5Handler,
|
||||
// Bep6Handler or Bep10Handler, their methods will be called instead of
|
||||
// Handler.OnMessage for the corresponding type message.
|
||||
func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
|
||||
if msg.Keepalive {
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
// BEP 3 - The BitTorrent Protocol Specification
|
||||
case Choke:
|
||||
pc.PeerChoked = true
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Choke(pc)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Unchoke:
|
||||
pc.PeerChoked = false
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Unchoke(pc)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Interested:
|
||||
pc.PeerInterested = true
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Interested(pc)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case NotInterested:
|
||||
pc.PeerInterested = false
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.NotInterested(pc)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Have:
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Have(pc, msg.Index)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Bitfield:
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Bitfield(pc, msg.Bitfield)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Request:
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Request(pc, msg.Index, msg.Begin, msg.Length)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Piece:
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Piece(pc, msg.Index, msg.Begin, msg.Piece)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Cancel:
|
||||
if h, ok := handler.(Bep3Handler); ok {
|
||||
err = h.Cancel(pc, msg.Index, msg.Begin, msg.Length)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
|
||||
// BEP 5 - DHT Protocol
|
||||
case Port:
|
||||
if !pc.IsSupportDHT() {
|
||||
return ErrNotSupportDHT
|
||||
} else if h, ok := handler.(Bep5Handler); ok {
|
||||
err = h.Port(pc, msg.Port)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
|
||||
// BEP 6 - Fast Extension
|
||||
case Suggest:
|
||||
if !pc.IsSupportFast() {
|
||||
return ErrNotSupportFast
|
||||
} else if h, ok := handler.(Bep6Handler); ok {
|
||||
err = h.Suggest(pc, msg.Index)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case HaveAll:
|
||||
if !pc.IsSupportFast() {
|
||||
return ErrNotSupportFast
|
||||
} else if h, ok := handler.(Bep6Handler); ok {
|
||||
err = h.HaveAll(pc)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case HaveNone:
|
||||
if !pc.IsSupportFast() {
|
||||
return ErrNotSupportFast
|
||||
} else if h, ok := handler.(Bep6Handler); ok {
|
||||
err = h.HaveNone(pc)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case Reject:
|
||||
if !pc.IsSupportFast() {
|
||||
return ErrNotSupportFast
|
||||
} else if h, ok := handler.(Bep6Handler); ok {
|
||||
err = h.Reject(pc, msg.Index, msg.Begin, msg.Length)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
case AllowedFast:
|
||||
if !pc.IsSupportFast() {
|
||||
return ErrNotSupportFast
|
||||
} else if h, ok := handler.(Bep6Handler); ok {
|
||||
err = h.AllowedFast(pc, msg.Index)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
|
||||
// BEP 10 - Extension Protocol
|
||||
case Extended:
|
||||
if !pc.IsSupportExtended() {
|
||||
return ErrNotSupportExtended
|
||||
} else if h, ok := handler.(Bep10Handler); ok {
|
||||
err = pc.handleExtMsg(h, msg)
|
||||
} else {
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
|
||||
// Other
|
||||
default:
|
||||
err = handler.OnMessage(pc, msg)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (pc *PeerConn) handleExtMsg(h Bep10Handler, m Message) (err error) {
|
||||
if m.ExtendedID == ExtendedIDHandshake {
|
||||
var ehmsg ExtendedHandshakeMsg
|
||||
if err = bencode.DecodeBytes(m.ExtendedPayload, &ehmsg); err == nil {
|
||||
err = h.OnHandShake(pc, ehmsg)
|
||||
}
|
||||
} else {
|
||||
err = h.OnPayload(pc, m.ExtendedID, m.ExtendedPayload)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
97
peerprotocol/protocol.go
Normal file
97
peerprotocol/protocol.go
Normal file
@ -0,0 +1,97 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ProtocolHeader is the BT protocal prefix.
|
||||
//
|
||||
// BEP 3
|
||||
const ProtocolHeader = "\x13BitTorrent protocol"
|
||||
|
||||
// Predefine some message types.
|
||||
const (
|
||||
// BEP 3
|
||||
Choke MessageType = 0
|
||||
Unchoke MessageType = 1
|
||||
Interested MessageType = 2
|
||||
NotInterested MessageType = 3
|
||||
Have MessageType = 4
|
||||
Bitfield MessageType = 5
|
||||
Request MessageType = 6
|
||||
Piece MessageType = 7
|
||||
Cancel MessageType = 8
|
||||
|
||||
// BEP 5
|
||||
Port MessageType = 9
|
||||
|
||||
// BEP 6 - Fast extension
|
||||
Suggest MessageType = 0x0d // 13
|
||||
HaveAll MessageType = 0x0e // 14
|
||||
HaveNone MessageType = 0x0f // 15
|
||||
Reject MessageType = 0x10 // 16
|
||||
AllowedFast MessageType = 0x11 // 17
|
||||
|
||||
// BEP 10
|
||||
Extended MessageType = 20
|
||||
)
|
||||
|
||||
// MessageType is used to represent the message type.
|
||||
type MessageType byte
|
||||
|
||||
func (mt MessageType) String() string {
|
||||
switch mt {
|
||||
case Choke:
|
||||
return "Choke"
|
||||
case Unchoke:
|
||||
return "Unchoke"
|
||||
case Interested:
|
||||
return "Interested"
|
||||
case NotInterested:
|
||||
return "NotInterested"
|
||||
case Have:
|
||||
return "Have"
|
||||
case Bitfield:
|
||||
return "Bitfield"
|
||||
case Request:
|
||||
return "Request"
|
||||
case Piece:
|
||||
return "Piece"
|
||||
case Cancel:
|
||||
return "Cancel"
|
||||
case Port:
|
||||
return "Port"
|
||||
case Suggest:
|
||||
return "Suggest"
|
||||
case HaveAll:
|
||||
return "HaveAll"
|
||||
case HaveNone:
|
||||
return "HaveNone"
|
||||
case Reject:
|
||||
return "Reject"
|
||||
case AllowedFast:
|
||||
return "AllowedFast"
|
||||
case Extended:
|
||||
return "Extended"
|
||||
}
|
||||
return fmt.Sprintf("MessageType(%d)", mt)
|
||||
}
|
||||
|
||||
// FastExtension reports whether the message type is fast extension.
|
||||
func (mt MessageType) FastExtension() bool {
|
||||
return mt >= Suggest && mt <= AllowedFast
|
||||
}
|
169
peerprotocol/server.go
Normal file
169
peerprotocol/server.go
Normal file
@ -0,0 +1,169 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package peerprotocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// Handler is used to handle the incoming peer connection.
|
||||
type Handler interface {
|
||||
// OnHandShake is used to check whether the handshake extension is acceptable.
|
||||
OnHandShake(conn *PeerConn, hmsg HandshakeMsg) error
|
||||
|
||||
// OnMessage is used to handle the incoming peer message.
|
||||
//
|
||||
// If requires, it should write the response to the peer.
|
||||
OnMessage(conn *PeerConn, msg Message) error
|
||||
}
|
||||
|
||||
// Config is used to configure the server.
|
||||
type Config struct {
|
||||
// ExtBits is used to handshake with the client.
|
||||
ExtBits ExtensionBits
|
||||
|
||||
// MaxLength is used to limit the maximum number of the message body.
|
||||
//
|
||||
// The default is 0, which represents no limit.
|
||||
MaxLength uint32
|
||||
|
||||
// Timeout is used to control the timeout of the read/write the message.
|
||||
//
|
||||
// The default is 0, which represents no timeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// ErrorLog is used to log the error.
|
||||
ErrorLog func(format string, args ...interface{}) // Default: log.Printf
|
||||
|
||||
// HandleMessage is used to handle the incoming message. So you can
|
||||
// customize it to add the request queue.
|
||||
//
|
||||
// The default handler is to forward to pc.HandleMessage(msg, handler).
|
||||
HandleMessage func(pc *PeerConn, msg Message, handler Handler) error
|
||||
}
|
||||
|
||||
func (c *Config) set(conf ...Config) {
|
||||
if len(conf) > 0 {
|
||||
*c = conf[0]
|
||||
}
|
||||
|
||||
if c.ErrorLog == nil {
|
||||
c.ErrorLog = log.Printf
|
||||
}
|
||||
if c.HandleMessage == nil {
|
||||
c.HandleMessage = func(pc *PeerConn, m Message, h Handler) error {
|
||||
return pc.HandleMessage(m, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Server is used to implement the peer protocol server.
|
||||
type Server struct {
|
||||
ln net.Listener
|
||||
id metainfo.Hash
|
||||
h Handler
|
||||
c Config
|
||||
}
|
||||
|
||||
// NewServerByListen returns a new Server by listening on the address.
|
||||
func NewServerByListen(network, address string, id metainfo.Hash, h Handler,
|
||||
c ...Config) (*Server, error) {
|
||||
ln, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewServer(ln, id, h, c...), nil
|
||||
}
|
||||
|
||||
// NewServer returns a new Server.
|
||||
func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c ...Config) *Server {
|
||||
if id.IsZero() {
|
||||
panic("the peer node id must not be empty")
|
||||
}
|
||||
|
||||
var conf Config
|
||||
conf.set(c...)
|
||||
return &Server{ln: ln, id: id, h: h, c: conf}
|
||||
}
|
||||
|
||||
// Run starts the peer protocol server.
|
||||
func (s *Server) Run() {
|
||||
for {
|
||||
conn, err := s.ln.Accept()
|
||||
if err != nil {
|
||||
s.c.ErrorLog("fail to accept new connection: %s", err)
|
||||
}
|
||||
go s.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(conn net.Conn) {
|
||||
pc := &PeerConn{
|
||||
ID: s.id,
|
||||
Conn: conn,
|
||||
ExtensionBits: s.c.ExtBits,
|
||||
Timeout: s.c.Timeout,
|
||||
MaxLength: s.c.MaxLength,
|
||||
Choked: true,
|
||||
PeerChoked: true,
|
||||
}
|
||||
|
||||
if err := s.handlePeerMessage(pc); err != nil {
|
||||
s.c.ErrorLog(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePeerMessage(pc *PeerConn) (err error) {
|
||||
defer pc.Close()
|
||||
m, err := pc.Handshake(metainfo.Hash{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to handshake with '%s': %s", pc.RemoteAddr().String(), err)
|
||||
} else if err = s.h.OnHandShake(pc, m); err != nil {
|
||||
return fmt.Errorf("handshake error with '%s': %s", pc.RemoteAddr().String(), err)
|
||||
}
|
||||
|
||||
return s.loopRun(pc, s.h)
|
||||
}
|
||||
|
||||
// LoopRun loops running Read-Handle message.
|
||||
func (s *Server) loopRun(pc *PeerConn, handler Handler) error {
|
||||
for {
|
||||
msg, err := pc.ReadMsg()
|
||||
switch err {
|
||||
case nil:
|
||||
case io.EOF:
|
||||
return nil
|
||||
default:
|
||||
s := err.Error()
|
||||
if strings.Contains(s, "closed") {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("fail to decode the message from '%s': %s",
|
||||
pc.RemoteAddr().String(), s)
|
||||
}
|
||||
|
||||
if err = s.c.HandleMessage(pc, msg, handler); err != nil {
|
||||
return fmt.Errorf("fail to handle peer message from '%s': %s",
|
||||
pc.RemoteAddr().String(), err)
|
||||
}
|
||||
}
|
||||
}
|
312
tracker/httptracker/http.go
Normal file
312
tracker/httptracker/http.go
Normal file
@ -0,0 +1,312 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package httptracker implements the tracker protocol based on HTTP/HTTPS.
|
||||
//
|
||||
// You can use the package to implement a HTTP tracker server to track the
|
||||
// information that other peers upload or download the file, or to create
|
||||
// a HTTP tracker client to communicate with the HTTP tracker server.
|
||||
package httptracker
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// AnnounceRequest is the tracker announce requests.
|
||||
//
|
||||
// BEP 3
|
||||
type AnnounceRequest struct {
|
||||
// InfoHash is the sha1 hash of the bencoded form of the info value from the metainfo file.
|
||||
InfoHash metainfo.Hash `bencode:"info_hash"` // BEP 3
|
||||
|
||||
// PeerID is the id of the downloader.
|
||||
//
|
||||
// Each downloader generates its own id at random at the start of a new download.
|
||||
PeerID metainfo.Hash `bencode:"peer_id"` // BEP 3
|
||||
|
||||
// Uploaded is the total amount uploaded so far, encoded in base ten ascii.
|
||||
Uploaded int64 `bencode:"uploaded"` // BEP 3
|
||||
|
||||
// Downloaded is the total amount downloaded so far, encoded in base ten ascii.
|
||||
Downloaded int64 `bencode:"downloaded"` // BEP 3
|
||||
|
||||
// Left is the number of bytes this peer still has to download,
|
||||
// encoded in base ten ascii.
|
||||
//
|
||||
// Note that this can't be computed from downloaded and the file length
|
||||
// since it might be a resume, and there's a chance that some of the
|
||||
// downloaded data failed an integrity check and had to be re-downloaded.
|
||||
//
|
||||
// If less than 0, math.MaxInt64 will be used for HTTP trackers instead.
|
||||
Left int64 `bencode:"left"` // BEP 3
|
||||
|
||||
// Port is the port that this peer is listening on.
|
||||
//
|
||||
// Common behavior is for a downloader to try to listen on port 6881,
|
||||
// and if that port is taken try 6882, then 6883, etc. and give up after 6889.
|
||||
Port uint16 `bencode:"port"` // BEP 3
|
||||
|
||||
// IP is the ip or DNS name which this peer is at, which generally used
|
||||
// for the origin if it's on the same machine as the tracker.
|
||||
//
|
||||
// Optional.
|
||||
IP string `bencode:"ip,omitempty"` // BEP 3
|
||||
|
||||
// If not present, this is one of the announcements done at regular intervals.
|
||||
// An announcement using started is sent when a download first begins,
|
||||
// and one using completed is sent when the download is complete.
|
||||
// No completed is sent if the file was complete when started.
|
||||
// Downloaders send an announcement using stopped when they cease downloading.
|
||||
//
|
||||
// Optional
|
||||
Event uint32 `bencode:"event,omitempty"` // BEP 3
|
||||
|
||||
// Compact indicates whether it hopes the tracker to return the compact
|
||||
// peer lists.
|
||||
//
|
||||
// Optional
|
||||
Compact bool `bencode:"compact,omitempty"` // BEP 23
|
||||
|
||||
// NumWant is the number of peers that the client would like to receive
|
||||
// from the tracker. This value is permitted to be zero. If omitted,
|
||||
// typically defaults to 50 peers.
|
||||
//
|
||||
// See https://wiki.theory.org/index.php/BitTorrentSpecification
|
||||
//
|
||||
// Optional.
|
||||
NumWant int32 `bencode:"numwant,omitempty"`
|
||||
|
||||
Key int32 `bencode:"key,omitempty"`
|
||||
}
|
||||
|
||||
// ToQuery converts the Request to URL Query.
|
||||
func (r AnnounceRequest) ToQuery() (vs url.Values) {
|
||||
vs = make(url.Values, 9)
|
||||
vs.Set("info_hash", r.InfoHash.BytesString())
|
||||
vs.Set("peer_id", r.PeerID.BytesString())
|
||||
vs.Set("uploaded", strconv.FormatInt(r.Uploaded, 10))
|
||||
vs.Set("downloaded", strconv.FormatInt(r.Downloaded, 10))
|
||||
vs.Set("left", strconv.FormatInt(r.Left, 10))
|
||||
|
||||
if r.IP != "" {
|
||||
vs.Set("ip", r.IP)
|
||||
}
|
||||
if r.Event > 0 {
|
||||
vs.Set("event", strconv.FormatInt(int64(r.Event), 10))
|
||||
}
|
||||
if r.Port > 0 {
|
||||
vs.Set("port", strconv.FormatUint(uint64(r.Port), 10))
|
||||
}
|
||||
if r.NumWant > 0 {
|
||||
vs.Set("numwant", strconv.FormatUint(uint64(r.NumWant), 10))
|
||||
}
|
||||
if r.Key != 0 {
|
||||
vs.Set("key", strconv.FormatInt(int64(r.Key), 10))
|
||||
}
|
||||
|
||||
// BEP 23
|
||||
if r.Compact {
|
||||
vs.Set("compact", "1")
|
||||
} else {
|
||||
vs.Set("compact", "0")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FromQuery converts URL Query to itself.
|
||||
func (r *AnnounceRequest) FromQuery(vs url.Values) (err error) {
|
||||
if err = r.InfoHash.FromString(vs.Get("info_hash")); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = r.PeerID.FromString(vs.Get("peer_id")); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(vs.Get("uploaded"), 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.Uploaded = v
|
||||
|
||||
v, err = strconv.ParseInt(vs.Get("downloaded"), 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.Downloaded = v
|
||||
|
||||
v, err = strconv.ParseInt(vs.Get("left"), 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.Left = v
|
||||
|
||||
if s := vs.Get("event"); s != "" {
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Event = uint32(v)
|
||||
}
|
||||
|
||||
if s := vs.Get("port"); s != "" {
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Port = uint16(v)
|
||||
}
|
||||
|
||||
if s := vs.Get("numwant"); s != "" {
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.NumWant = int32(v)
|
||||
}
|
||||
|
||||
if s := vs.Get("key"); s != "" {
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Key = int32(v)
|
||||
}
|
||||
|
||||
r.IP = vs.Get("ip")
|
||||
switch vs.Get("compact") {
|
||||
case "1":
|
||||
r.Compact = true
|
||||
case "0":
|
||||
r.Compact = false
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// AnnounceResponse is a announce response.
|
||||
type AnnounceResponse struct {
|
||||
FailureReason string `bencode:"failure reason,omitempty"`
|
||||
|
||||
// Interval is the seconds the downloader should wait before next rerequest.
|
||||
Interval uint32 `bencode:"interval,omitempty"` // BEP 3
|
||||
|
||||
// Peers is the list of the peers.
|
||||
Peers Peers `bencode:"peers,omitempty"` // BEP 3, BEP 23
|
||||
|
||||
// Peers6 is only used for ipv6 in the compact case.
|
||||
Peers6 Peers6 `bencode:"peers6,omitempty"` // BEP 7
|
||||
|
||||
// Where's this specified?
|
||||
// Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification.
|
||||
|
||||
// Complete is the number of peers with the entire file.
|
||||
Complete uint32 `bencode:"complete,omitempty"`
|
||||
// Incomplete is the number of non-seeder peers.
|
||||
Incomplete uint32 `bencode:"incomplete,omitempty"`
|
||||
// TrackerID is that the client should send back on its next announcements.
|
||||
// If absent and a previous announce sent a tracker id,
|
||||
// do not discard the old value; keep using it.
|
||||
TrackerID string `bencode:"tracker id,omitempty"`
|
||||
}
|
||||
|
||||
// ScrapeResponseResult is the result of the scraped file.
|
||||
type ScrapeResponseResult struct {
|
||||
// Complete is the number of active peers that have completed downloading.
|
||||
Complete uint32 `bencode:"complete"` // BEP 48
|
||||
|
||||
// Incomplete is the number of active peers that have not completed downloading.
|
||||
Incomplete uint32 `bencode:"incomplete"` // BEP 48
|
||||
|
||||
// The number of peers that have ever completed downloading.
|
||||
Downloaded uint32 `bencode:"downloaded"` // BEP 48
|
||||
}
|
||||
|
||||
// ScrapeResponse represents a Scrape response.
|
||||
//
|
||||
// BEP 48
|
||||
type ScrapeResponse struct {
|
||||
FailureReason string `bencode:"failure_reason,omitempty"`
|
||||
|
||||
Files map[metainfo.Hash]ScrapeResponseResult `bencode:"files,omitempty"`
|
||||
}
|
||||
|
||||
// DecodeFrom reads the []byte data from r and decodes them to sr by bencode.
|
||||
//
|
||||
// r may be the body of the request from the http client.
|
||||
func (sr *ScrapeResponse) DecodeFrom(r io.Reader) (err error) {
|
||||
return bencode.NewDecoder(r).Decode(sr)
|
||||
}
|
||||
|
||||
// EncodeTo encodes the response to []byte by bencode and write the result into w.
|
||||
//
|
||||
// w may be http.ResponseWriter.
|
||||
func (sr ScrapeResponse) EncodeTo(w io.Writer) (err error) {
|
||||
return bencode.NewEncoder(w).Encode(sr)
|
||||
}
|
||||
|
||||
// TrackerClient represents a tracker client based on HTTP/HTTPS.
|
||||
type TrackerClient struct {
|
||||
AnnounceURL string
|
||||
ScrapeURL string
|
||||
}
|
||||
|
||||
// NewTrackerClient returns a new HTTPTrackerClient.
|
||||
//
|
||||
// scrapeURL may be empty, which will replace the "announce" in announceURL
|
||||
// with "scrape" to generate the scrapeURL.
|
||||
func NewTrackerClient(announceURL, scrapeURL string) *TrackerClient {
|
||||
if scrapeURL == "" {
|
||||
scrapeURL = strings.Replace(announceURL, "announce", "scrape", -1)
|
||||
}
|
||||
return &TrackerClient{AnnounceURL: announceURL, ScrapeURL: scrapeURL}
|
||||
}
|
||||
|
||||
func (t *TrackerClient) send(u string, vs url.Values, r interface{}) (err error) {
|
||||
sym := "?"
|
||||
if strings.IndexByte(u, '?') > 0 {
|
||||
sym = "&"
|
||||
}
|
||||
|
||||
resp, err := http.Get(u + sym + vs.Encode())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return bencode.NewDecoder(resp.Body).Decode(r)
|
||||
}
|
||||
|
||||
// Announce sends a Announce request to the tracker.
|
||||
func (t *TrackerClient) Announce(req AnnounceRequest) (resp AnnounceResponse, err error) {
|
||||
err = t.send(t.AnnounceURL, req.ToQuery(), &resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Scrape sends a Scrape request to the tracker.
|
||||
func (t *TrackerClient) Scrape(infohashes []metainfo.Hash) (resp ScrapeResponse, err error) {
|
||||
hs := make([]string, len(infohashes))
|
||||
for i, h := range infohashes {
|
||||
hs[i] = h.BytesString()
|
||||
}
|
||||
err = t.send(t.ScrapeURL, url.Values{"info_hash": hs}, &resp)
|
||||
return
|
||||
}
|
192
tracker/httptracker/http_peer.go
Normal file
192
tracker/httptracker/http_peer.go
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httptracker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/xgfone/bt/bencode"
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
var errInvalidPeer = errors.New("invalid peer information format")
|
||||
|
||||
// Peer is a tracker peer.
|
||||
type Peer struct {
|
||||
// ID is the peer's self-selected ID.
|
||||
ID string `bencode:"peer id"` // BEP 3
|
||||
|
||||
// IP is the IP address or dns name.
|
||||
IP string `bencode:"ip"` // BEP 3
|
||||
Port uint16 `bencode:"port"` // BEP 3
|
||||
}
|
||||
|
||||
// Addresses returns the list of the addresses that the peer listens on.
|
||||
func (p Peer) Addresses() (addrs []metainfo.Address, err error) {
|
||||
if ip := net.ParseIP(p.IP); len(ip) != 0 {
|
||||
return []metainfo.Address{{IP: ip, Port: p.Port}}, nil
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(p.IP)
|
||||
if _len := len(ips); err == nil && len(ips) != 0 {
|
||||
addrs = make([]metainfo.Address, _len)
|
||||
for i, ip := range ips {
|
||||
addrs[i] = metainfo.Address{IP: ip, Port: p.Port}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Peers is a set of the peers.
|
||||
type Peers []Peer
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (ps *Peers) UnmarshalBencode(b []byte) (err error) {
|
||||
var v interface{}
|
||||
if err = bencode.DecodeBytes(b, &v); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch vs := v.(type) {
|
||||
case string: // BEP 23
|
||||
_len := len(vs)
|
||||
if _len%6 != 0 {
|
||||
return metainfo.ErrInvalidAddr
|
||||
}
|
||||
|
||||
peers := make(Peers, 0, _len/6)
|
||||
for i := 0; i < _len; i += 6 {
|
||||
var addr metainfo.Address
|
||||
if err = addr.UnmarshalBinary([]byte(vs[i : i+6])); err != nil {
|
||||
return
|
||||
}
|
||||
peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port})
|
||||
}
|
||||
|
||||
*ps = peers
|
||||
case []interface{}: // BEP 3
|
||||
peers := make(Peers, len(vs))
|
||||
for i, p := range vs {
|
||||
m, ok := p.(map[string]interface{})
|
||||
if !ok {
|
||||
return errInvalidPeer
|
||||
}
|
||||
|
||||
pid, ok := m["peer id"].(string)
|
||||
if !ok {
|
||||
return errInvalidPeer
|
||||
}
|
||||
|
||||
ip, ok := m["ip"].(string)
|
||||
if !ok {
|
||||
return errInvalidPeer
|
||||
}
|
||||
|
||||
port, ok := m["port"].(int64)
|
||||
if !ok {
|
||||
return errInvalidPeer
|
||||
}
|
||||
|
||||
peers[i] = Peer{ID: pid, IP: ip, Port: uint16(port)}
|
||||
}
|
||||
*ps = peers
|
||||
default:
|
||||
return errInvalidPeer
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (ps Peers) MarshalBencode() (b []byte, err error) {
|
||||
for _, p := range ps {
|
||||
if p.ID == "" {
|
||||
return ps.marshalCompactBencode() // BEP 23
|
||||
}
|
||||
}
|
||||
|
||||
// BEP 3
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 64*len(ps)))
|
||||
buf.WriteByte('l')
|
||||
for _, p := range ps {
|
||||
if err = bencode.NewEncoder(buf).Encode(p); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
buf.WriteByte('e')
|
||||
b = buf.Bytes()
|
||||
return
|
||||
}
|
||||
|
||||
func (ps Peers) marshalCompactBencode() (b []byte, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 6*len(ps)))
|
||||
for _, peer := range ps {
|
||||
ip := net.ParseIP(peer.IP).To4()
|
||||
if len(ip) == 0 {
|
||||
return nil, errInvalidPeer
|
||||
}
|
||||
buf.Write(ip[:])
|
||||
binary.Write(buf, binary.BigEndian, peer.Port)
|
||||
}
|
||||
return bencode.EncodeBytes(buf.Bytes())
|
||||
}
|
||||
|
||||
// Peers6 is a set of the peers for IPv6 in the compact case.
|
||||
//
|
||||
// BEP 7
|
||||
type Peers6 []Peer
|
||||
|
||||
// UnmarshalBencode implements the interface bencode.Unmarshaler.
|
||||
func (ps *Peers6) UnmarshalBencode(b []byte) (err error) {
|
||||
var s string
|
||||
if err = bencode.DecodeBytes(b, &s); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_len := len(s)
|
||||
if _len%18 != 0 {
|
||||
return metainfo.ErrInvalidAddr
|
||||
}
|
||||
|
||||
peers := make(Peers6, 0, _len/18)
|
||||
for i := 0; i < _len; i += 18 {
|
||||
var addr metainfo.Address
|
||||
if err = addr.UnmarshalBinary([]byte(s[i : i+18])); err != nil {
|
||||
return
|
||||
}
|
||||
peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port})
|
||||
}
|
||||
|
||||
*ps = peers
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBencode implements the interface bencode.Marshaler.
|
||||
func (ps Peers6) MarshalBencode() (b []byte, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 18*len(ps)))
|
||||
for _, peer := range ps {
|
||||
ip := net.ParseIP(peer.IP).To16()
|
||||
if len(ip) == 0 {
|
||||
return nil, errInvalidPeer
|
||||
}
|
||||
|
||||
buf.Write(ip[:])
|
||||
binary.Write(buf, binary.BigEndian, peer.Port)
|
||||
}
|
||||
return bencode.EncodeBytes(buf.Bytes())
|
||||
}
|
75
tracker/httptracker/http_peer_test.go
Normal file
75
tracker/httptracker/http_peer_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httptracker
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPeers(t *testing.T) {
|
||||
peers := Peers{
|
||||
{ID: "123", IP: "1.1.1.1", Port: 80},
|
||||
{ID: "456", IP: "2.2.2.2", Port: 81},
|
||||
}
|
||||
|
||||
b, err := peers.MarshalBencode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var ps Peers
|
||||
if err = ps.UnmarshalBencode(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !reflect.DeepEqual(ps, peers) {
|
||||
t.Errorf("%v != %v", ps, peers)
|
||||
}
|
||||
|
||||
/// For BEP 23
|
||||
peers = Peers{
|
||||
{IP: "1.1.1.1", Port: 80},
|
||||
{IP: "2.2.2.2", Port: 81},
|
||||
}
|
||||
|
||||
b, err = peers.MarshalBencode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = ps.UnmarshalBencode(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !reflect.DeepEqual(ps, peers) {
|
||||
t.Errorf("%v != %v", ps, peers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeers6(t *testing.T) {
|
||||
peers := Peers6{
|
||||
{IP: "fe80::5054:ff:fef0:1ab", Port: 80},
|
||||
{IP: "fe80::5054:ff:fe29:205d", Port: 81},
|
||||
}
|
||||
|
||||
b, err := peers.MarshalBencode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var ps Peers6
|
||||
if err = ps.UnmarshalBencode(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !reflect.DeepEqual(ps, peers) {
|
||||
t.Errorf("%v != %v", ps, peers)
|
||||
}
|
||||
}
|
72
tracker/httptracker/http_test.go
Normal file
72
tracker/httptracker/http_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httptracker
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
func TestHTTPAnnounceRequest(t *testing.T) {
|
||||
infohash := metainfo.NewRandomHash()
|
||||
peerid := metainfo.NewRandomHash()
|
||||
v1 := AnnounceRequest{
|
||||
InfoHash: infohash,
|
||||
PeerID: peerid,
|
||||
Uploaded: 789,
|
||||
Downloaded: 456,
|
||||
Left: 123,
|
||||
Port: 80,
|
||||
Event: 123,
|
||||
Compact: true,
|
||||
}
|
||||
vs := v1.ToQuery()
|
||||
|
||||
var v2 AnnounceRequest
|
||||
if err := v2.FromQuery(vs); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v2.InfoHash != infohash {
|
||||
t.Error(v2.InfoHash)
|
||||
}
|
||||
if v2.PeerID != peerid {
|
||||
t.Error(v2.PeerID)
|
||||
}
|
||||
if v2.Uploaded != 789 {
|
||||
t.Error(v2.Uploaded)
|
||||
}
|
||||
if v2.Downloaded != 456 {
|
||||
t.Error(v2.Downloaded)
|
||||
}
|
||||
if v2.Left != 123 {
|
||||
t.Error(v2.Left)
|
||||
}
|
||||
if v2.Port != 80 {
|
||||
t.Error(v2.Port)
|
||||
}
|
||||
if v2.Event != 123 {
|
||||
t.Error(v2.Event)
|
||||
}
|
||||
if !v2.Compact {
|
||||
t.Error(v2.Compact)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("%v != %v", v1, v2)
|
||||
}
|
||||
}
|
266
tracker/tracker.go
Normal file
266
tracker/tracker.go
Normal file
@ -0,0 +1,266 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package tracker supplies some common type interfaces of the BT tracker
|
||||
// protocol.
|
||||
package tracker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
"github.com/xgfone/bt/tracker/httptracker"
|
||||
"github.com/xgfone/bt/tracker/udptracker"
|
||||
)
|
||||
|
||||
// Predefine some announce events.
|
||||
//
|
||||
// BEP 3
|
||||
const (
|
||||
None uint32 = iota
|
||||
Completed // The local peer just completed the torrent.
|
||||
Started // The local peer has just resumed this torrent.
|
||||
Stopped // The local peer is leaving the swarm.
|
||||
)
|
||||
|
||||
// AnnounceRequest is the common Announce request.
|
||||
//
|
||||
// BEP 3, 15
|
||||
type AnnounceRequest struct {
|
||||
InfoHash metainfo.Hash
|
||||
PeerID metainfo.Hash
|
||||
|
||||
Uploaded int64
|
||||
Downloaded int64
|
||||
Left int64
|
||||
Event uint32
|
||||
|
||||
IP net.IP
|
||||
Key int32
|
||||
NumWant int32 // -1 for default
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// ToHTTPAnnounceRequest creates a new httptracker.AnnounceRequest from itself.
|
||||
func (ar AnnounceRequest) ToHTTPAnnounceRequest() httptracker.AnnounceRequest {
|
||||
var ip string
|
||||
if len(ar.IP) != 0 {
|
||||
ip = ar.IP.String()
|
||||
}
|
||||
|
||||
return httptracker.AnnounceRequest{
|
||||
InfoHash: ar.InfoHash,
|
||||
PeerID: ar.PeerID,
|
||||
Uploaded: ar.Uploaded,
|
||||
Downloaded: ar.Downloaded,
|
||||
Left: ar.Left,
|
||||
Port: ar.Port,
|
||||
IP: ip,
|
||||
Event: ar.Event,
|
||||
NumWant: ar.NumWant,
|
||||
Key: ar.Key,
|
||||
}
|
||||
}
|
||||
|
||||
// ToUDPAnnounceRequest creates a new udptracker.AnnounceRequest from itself.
|
||||
func (ar AnnounceRequest) ToUDPAnnounceRequest() udptracker.AnnounceRequest {
|
||||
return udptracker.AnnounceRequest{
|
||||
InfoHash: ar.InfoHash,
|
||||
PeerID: ar.PeerID,
|
||||
Downloaded: ar.Downloaded,
|
||||
Left: ar.Left,
|
||||
Uploaded: ar.Uploaded,
|
||||
Event: ar.Event,
|
||||
IP: ar.IP,
|
||||
Key: ar.Key,
|
||||
NumWant: ar.NumWant,
|
||||
Port: ar.Port,
|
||||
}
|
||||
}
|
||||
|
||||
// AnnounceResponse is a common Announce response.
|
||||
//
|
||||
// BEP 3, 15
|
||||
type AnnounceResponse struct {
|
||||
Interval uint32
|
||||
Leechers uint32
|
||||
Seeders uint32
|
||||
Addresses []metainfo.Address
|
||||
}
|
||||
|
||||
// FromHTTPAnnounceResponse sets itself from r.
|
||||
func (ar *AnnounceResponse) FromHTTPAnnounceResponse(r httptracker.AnnounceResponse) {
|
||||
ar.Interval = r.Interval
|
||||
ar.Leechers = r.Incomplete
|
||||
ar.Seeders = r.Complete
|
||||
ar.Addresses = make([]metainfo.Address, 0, len(r.Peers)+len(r.Peers6))
|
||||
for _, peer := range r.Peers {
|
||||
addrs, _ := peer.Addresses()
|
||||
ar.Addresses = append(ar.Addresses, addrs...)
|
||||
}
|
||||
for _, peer := range r.Peers6 {
|
||||
addrs, _ := peer.Addresses()
|
||||
ar.Addresses = append(ar.Addresses, addrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// FromUDPAnnounceResponse sets itself from r.
|
||||
func (ar *AnnounceResponse) FromUDPAnnounceResponse(r udptracker.AnnounceResponse) {
|
||||
ar.Interval = r.Interval
|
||||
ar.Leechers = r.Leechers
|
||||
ar.Seeders = r.Seeders
|
||||
ar.Addresses = r.Addresses
|
||||
}
|
||||
|
||||
// ScrapeResponseResult is a commont Scrape response result.
|
||||
type ScrapeResponseResult struct {
|
||||
// Seeders is the number of active peers that have completed downloading.
|
||||
Seeders uint32 `bencode:"complete"` // BEP 15, 48
|
||||
|
||||
// Leechers is the number of active peers that have not completed downloading.
|
||||
Leechers uint32 `bencode:"incomplete"` // BEP 15, 48
|
||||
|
||||
// Completed is the total number of peers that have ever completed downloading.
|
||||
Completed uint32 `bencode:"downloaded"` // BEP 15, 48
|
||||
}
|
||||
|
||||
// EncodeTo encodes the response to buf.
|
||||
func (r ScrapeResponseResult) EncodeTo(buf *bytes.Buffer) {
|
||||
binary.Write(buf, binary.BigEndian, r.Seeders)
|
||||
binary.Write(buf, binary.BigEndian, r.Completed)
|
||||
binary.Write(buf, binary.BigEndian, r.Leechers)
|
||||
}
|
||||
|
||||
// DecodeFrom decodes the response from b.
|
||||
func (r *ScrapeResponseResult) DecodeFrom(b []byte) {
|
||||
r.Seeders = binary.BigEndian.Uint32(b[:4])
|
||||
r.Completed = binary.BigEndian.Uint32(b[4:8])
|
||||
r.Leechers = binary.BigEndian.Uint32(b[8:12])
|
||||
}
|
||||
|
||||
// ScrapeResponse is a commont Scrape response.
|
||||
type ScrapeResponse map[metainfo.Hash]ScrapeResponseResult
|
||||
|
||||
// FromHTTPScrapeResponse sets itself from r.
|
||||
func (sr ScrapeResponse) FromHTTPScrapeResponse(r httptracker.ScrapeResponse) {
|
||||
for k, v := range r.Files {
|
||||
sr[k] = ScrapeResponseResult{
|
||||
Seeders: v.Complete,
|
||||
Leechers: v.Incomplete,
|
||||
Completed: v.Downloaded,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FromUDPScrapeResponse sets itself from hs and r.
|
||||
func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash,
|
||||
r []udptracker.ScrapeResponse) {
|
||||
klen := len(hs)
|
||||
if _len := len(r); _len < klen {
|
||||
klen = _len
|
||||
}
|
||||
|
||||
for i := 0; i < klen; i++ {
|
||||
sr[hs[i]] = ScrapeResponseResult{
|
||||
Seeders: r[i].Seeders,
|
||||
Leechers: r[i].Leechers,
|
||||
Completed: r[i].Completed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Client is the interface of BT tracker client.
|
||||
type Client interface {
|
||||
Announce(AnnounceRequest) (AnnounceResponse, error)
|
||||
Scrape([]metainfo.Hash) (ScrapeResponse, error)
|
||||
}
|
||||
|
||||
// NewClient returns a new Client.
|
||||
func NewClient(connURL string) (c Client, err error) {
|
||||
u, err := url.Parse(connURL)
|
||||
if err == nil {
|
||||
switch u.Scheme {
|
||||
case "http", "https":
|
||||
c = &tclient{http: httptracker.NewTrackerClient(connURL, "")}
|
||||
case "udp", "udp4", "udp6":
|
||||
var utc *udptracker.TrackerClient
|
||||
utc, err = udptracker.NewTrackerClientByDial(u.Scheme, u.Host)
|
||||
if err == nil {
|
||||
var e []udptracker.Extension
|
||||
if p := u.RequestURI(); p != "" {
|
||||
e = []udptracker.Extension{udptracker.NewURLData([]byte(p))}
|
||||
}
|
||||
c = &tclient{exts: e, udp: utc}
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unknown url scheme '%s'", u.Scheme)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type tclient struct {
|
||||
http *httptracker.TrackerClient // BEP 3
|
||||
udp *udptracker.TrackerClient // BEP 15
|
||||
exts []udptracker.Extension // BEP 41
|
||||
}
|
||||
|
||||
func (c *tclient) Announce(req AnnounceRequest) (resp AnnounceResponse, err error) {
|
||||
if c.http != nil {
|
||||
var r httptracker.AnnounceResponse
|
||||
if r, err = c.http.Announce(req.ToHTTPAnnounceRequest()); err != nil {
|
||||
return
|
||||
} else if r.FailureReason != "" {
|
||||
err = errors.New(r.FailureReason)
|
||||
return
|
||||
}
|
||||
resp.FromHTTPAnnounceResponse(r)
|
||||
return
|
||||
}
|
||||
|
||||
r := req.ToUDPAnnounceRequest()
|
||||
r.Exts = c.exts
|
||||
rs, err := c.udp.Announce(r)
|
||||
if err == nil {
|
||||
resp.FromUDPAnnounceResponse(rs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *tclient) Scrape(hs []metainfo.Hash) (resp ScrapeResponse, err error) {
|
||||
if c.http != nil {
|
||||
var r httptracker.ScrapeResponse
|
||||
if r, err = c.http.Scrape(hs); err != nil {
|
||||
return
|
||||
} else if r.FailureReason != "" {
|
||||
err = errors.New(r.FailureReason)
|
||||
return
|
||||
}
|
||||
resp = make(ScrapeResponse, len(r.Files))
|
||||
resp.FromHTTPScrapeResponse(r)
|
||||
return
|
||||
}
|
||||
|
||||
r, err := c.udp.Scrape(hs)
|
||||
if err == nil {
|
||||
resp = make(ScrapeResponse, len(r))
|
||||
resp.FromUDPScrapeResponse(hs, r)
|
||||
}
|
||||
return
|
||||
}
|
121
tracker/tracker_test.go
Normal file
121
tracker/tracker_test.go
Normal file
@ -0,0 +1,121 @@
|
||||
package tracker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
"github.com/xgfone/bt/tracker/udptracker"
|
||||
)
|
||||
|
||||
type testHandler struct{}
|
||||
|
||||
func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return }
|
||||
func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) (
|
||||
r udptracker.AnnounceResponse, err error) {
|
||||
if req.Port != 80 {
|
||||
err = errors.New("port is not 80")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Exts) > 0 {
|
||||
for i, ext := range req.Exts {
|
||||
switch ext.Type {
|
||||
case udptracker.URLData:
|
||||
fmt.Printf("Extensions[%d]: URLData(%s)\n", i, string(ext.Data))
|
||||
default:
|
||||
fmt.Printf("Extensions[%d]: %s\n", i, ext.Type.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r = udptracker.AnnounceResponse{
|
||||
Interval: 1,
|
||||
Leechers: 2,
|
||||
Seeders: 3,
|
||||
Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}},
|
||||
}
|
||||
return
|
||||
}
|
||||
func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (
|
||||
rs []udptracker.ScrapeResponse, err error) {
|
||||
rs = make([]udptracker.ScrapeResponse, len(infohashes))
|
||||
for i := range infohashes {
|
||||
rs[i] = udptracker.ScrapeResponse{
|
||||
Seeders: uint32(i)*10 + 1,
|
||||
Leechers: uint32(i)*10 + 2,
|
||||
Completed: uint32(i)*10 + 3,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ExampleClient() {
|
||||
// Start the UDP tracker server
|
||||
sconn, err := net.ListenPacket("udp4", "127.0.0.1:8001")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
server := udptracker.NewTrackerServer(sconn, testHandler{})
|
||||
defer server.Close()
|
||||
go server.Run()
|
||||
|
||||
// Wait for the server to be started
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Create a client and dial to the UDP tracker server.
|
||||
client, err := NewClient("udp://127.0.0.1:8001/path?a=1&b=2")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Send the ANNOUNCE request to the UDP tracker server,
|
||||
// and get the ANNOUNCE response.
|
||||
req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80}
|
||||
resp, err := client.Announce(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Interval: %d\n", resp.Interval)
|
||||
fmt.Printf("Leechers: %d\n", resp.Leechers)
|
||||
fmt.Printf("Seeders: %d\n", resp.Seeders)
|
||||
for i, addr := range resp.Addresses {
|
||||
fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String())
|
||||
fmt.Printf("Address[%d].Port: %d\n", i, addr.Port)
|
||||
}
|
||||
|
||||
// Send the SCRAPE request to the UDP tracker server,
|
||||
// and get the SCRAPE respsone.
|
||||
h1 := metainfo.Hash{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
|
||||
h2 := metainfo.Hash{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
|
||||
rs, err := client.Scrape([]metainfo.Hash{h1, h2})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
} else if len(rs) != 2 {
|
||||
log.Fatalf("%+v", rs)
|
||||
}
|
||||
|
||||
for i, r := range rs {
|
||||
fmt.Printf("%s.Seeders: %d\n", i.HexString(), r.Seeders)
|
||||
fmt.Printf("%s.Leechers: %d\n", i.HexString(), r.Leechers)
|
||||
fmt.Printf("%s.Completed: %d\n", i.HexString(), r.Completed)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// Extensions[0]: URLData(/path?a=1&b=2)
|
||||
// Interval: 1
|
||||
// Leechers: 2
|
||||
// Seeders: 3
|
||||
// Address[0].IP: 127.0.0.1
|
||||
// Address[0].Port: 8001
|
||||
// 0101010101010101010101010101010101010101.Seeders: 1
|
||||
// 0101010101010101010101010101010101010101.Leechers: 2
|
||||
// 0101010101010101010101010101010101010101.Completed: 3
|
||||
// 0202020202020202020202020202020202020202.Seeders: 11
|
||||
// 0202020202020202020202020202020202020202.Leechers: 12
|
||||
// 0202020202020202020202020202020202020202.Completed: 13
|
||||
}
|
271
tracker/udptracker/udp.go
Normal file
271
tracker/udptracker/udp.go
Normal file
@ -0,0 +1,271 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package udptracker implements the tracker protocol based on UDP.
|
||||
//
|
||||
// You can use the package to implement a UDP tracker server to track the
|
||||
// information that other peers upload or download the file, or to create
|
||||
// a UDP tracker client to communicate with the UDP tracker server.
|
||||
package udptracker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// ProtocolID is magic constant for the udp tracker connection.
|
||||
//
|
||||
// BEP 15
|
||||
const ProtocolID = uint64(0x41727101980)
|
||||
|
||||
// Predefine some actions.
|
||||
//
|
||||
// BEP 15
|
||||
const (
|
||||
ActionConnect = uint32(0)
|
||||
ActionAnnounce = uint32(1)
|
||||
ActionScrape = uint32(2)
|
||||
ActionError = uint32(3)
|
||||
)
|
||||
|
||||
// AnnounceRequest represents the announce request used by UDP tracker.
|
||||
//
|
||||
// BEP 15
|
||||
type AnnounceRequest struct {
|
||||
InfoHash metainfo.Hash
|
||||
PeerID metainfo.Hash
|
||||
|
||||
Downloaded int64
|
||||
Left int64
|
||||
Uploaded int64
|
||||
Event uint32
|
||||
|
||||
IP net.IP
|
||||
Key int32
|
||||
NumWant int32 // -1 for default
|
||||
Port uint16
|
||||
|
||||
Exts []Extension // BEP 41
|
||||
}
|
||||
|
||||
// DecodeFrom decodes the request from b.
|
||||
func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) {
|
||||
r.InfoHash = metainfo.NewHash(b[0:20])
|
||||
r.PeerID = metainfo.NewHash(b[20:40])
|
||||
r.Downloaded = int64(binary.BigEndian.Uint64(b[40:48]))
|
||||
r.Left = int64(binary.BigEndian.Uint64(b[48:56]))
|
||||
r.Uploaded = int64(binary.BigEndian.Uint64(b[56:64]))
|
||||
r.Event = binary.BigEndian.Uint32(b[64:68])
|
||||
|
||||
if ipv4 {
|
||||
r.IP = make(net.IP, net.IPv4len)
|
||||
copy(r.IP, b[68:72])
|
||||
b = b[72:]
|
||||
} else {
|
||||
r.IP = make(net.IP, net.IPv6len)
|
||||
copy(r.IP, b[68:84])
|
||||
b = b[84:]
|
||||
}
|
||||
|
||||
r.Key = int32(binary.BigEndian.Uint32(b[0:4]))
|
||||
r.NumWant = int32(binary.BigEndian.Uint32(b[4:8]))
|
||||
r.Port = binary.BigEndian.Uint16(b[8:10])
|
||||
|
||||
b = b[10:]
|
||||
for len(b) > 0 {
|
||||
var ext Extension
|
||||
parsed := ext.DecodeFrom(b)
|
||||
r.Exts = append(r.Exts, ext)
|
||||
b = b[parsed:]
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeTo encodes the request to buf.
|
||||
func (r AnnounceRequest) EncodeTo(buf *bytes.Buffer) {
|
||||
buf.Grow(82)
|
||||
buf.Write(r.InfoHash[:])
|
||||
buf.Write(r.PeerID[:])
|
||||
|
||||
binary.Write(buf, binary.BigEndian, r.Downloaded)
|
||||
binary.Write(buf, binary.BigEndian, r.Left)
|
||||
binary.Write(buf, binary.BigEndian, r.Uploaded)
|
||||
binary.Write(buf, binary.BigEndian, r.Event)
|
||||
|
||||
if ip := r.IP.To4(); ip != nil {
|
||||
buf.Write(ip[:])
|
||||
} else {
|
||||
buf.Write(r.IP[:])
|
||||
}
|
||||
|
||||
binary.Write(buf, binary.BigEndian, r.Key)
|
||||
binary.Write(buf, binary.BigEndian, r.NumWant)
|
||||
binary.Write(buf, binary.BigEndian, r.Port)
|
||||
|
||||
for _, ext := range r.Exts {
|
||||
ext.EncodeTo(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// AnnounceResponse represents the announce response used by UDP tracker.
|
||||
//
|
||||
// BEP 15
|
||||
type AnnounceResponse struct {
|
||||
Interval uint32
|
||||
Leechers uint32
|
||||
Seeders uint32
|
||||
Addresses []metainfo.Address
|
||||
}
|
||||
|
||||
// EncodeTo encodes the response to buf.
|
||||
func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer) {
|
||||
buf.Grow(12 + len(r.Addresses)*18)
|
||||
binary.Write(buf, binary.BigEndian, r.Interval)
|
||||
binary.Write(buf, binary.BigEndian, r.Leechers)
|
||||
binary.Write(buf, binary.BigEndian, r.Seeders)
|
||||
for _, addr := range r.Addresses {
|
||||
if ip := addr.IP.To4(); ip != nil {
|
||||
buf.Write(ip[:])
|
||||
} else {
|
||||
buf.Write(addr.IP[:])
|
||||
}
|
||||
binary.Write(buf, binary.BigEndian, addr.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeFrom decodes the response from b.
|
||||
func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) {
|
||||
r.Interval = binary.BigEndian.Uint32(b[:4])
|
||||
r.Leechers = binary.BigEndian.Uint32(b[4:8])
|
||||
r.Seeders = binary.BigEndian.Uint32(b[8:12])
|
||||
|
||||
b = b[12:]
|
||||
iplen := net.IPv6len
|
||||
if ipv4 {
|
||||
iplen = net.IPv4len
|
||||
}
|
||||
|
||||
_len := len(b)
|
||||
step := iplen + 2
|
||||
r.Addresses = make([]metainfo.Address, 0, _len/step)
|
||||
for i := step; i <= _len; i += step {
|
||||
ip := make(net.IP, iplen)
|
||||
copy(ip, b[i-step:i-2])
|
||||
port := binary.BigEndian.Uint16(b[i-2 : i])
|
||||
r.Addresses = append(r.Addresses, metainfo.Address{IP: ip, Port: port})
|
||||
}
|
||||
}
|
||||
|
||||
// ScrapeResponse represents the UDP SCRAPE response.
|
||||
//
|
||||
// BEP 15
|
||||
type ScrapeResponse struct {
|
||||
Seeders uint32
|
||||
Leechers uint32
|
||||
Completed uint32
|
||||
}
|
||||
|
||||
// EncodeTo encodes the response to buf.
|
||||
func (r ScrapeResponse) EncodeTo(buf *bytes.Buffer) {
|
||||
binary.Write(buf, binary.BigEndian, r.Seeders)
|
||||
binary.Write(buf, binary.BigEndian, r.Completed)
|
||||
binary.Write(buf, binary.BigEndian, r.Leechers)
|
||||
}
|
||||
|
||||
// DecodeFrom decodes the response from b.
|
||||
func (r *ScrapeResponse) DecodeFrom(b []byte) {
|
||||
r.Seeders = binary.BigEndian.Uint32(b[:4])
|
||||
r.Completed = binary.BigEndian.Uint32(b[4:8])
|
||||
r.Leechers = binary.BigEndian.Uint32(b[8:12])
|
||||
}
|
||||
|
||||
// Predefine some UDP extension types.
|
||||
//
|
||||
// BEP 41
|
||||
const (
|
||||
EndOfOptions ExtensionType = iota
|
||||
Nop
|
||||
URLData
|
||||
)
|
||||
|
||||
// NewEndOfOptions returns a new EndOfOptions UDP extension.
|
||||
func NewEndOfOptions() Extension { return Extension{Type: EndOfOptions} }
|
||||
|
||||
// NewNop returns a new Nop UDP extension.
|
||||
func NewNop() Extension { return Extension{Type: Nop} }
|
||||
|
||||
// NewURLData returns a new URLData UDP extension.
|
||||
func NewURLData(data []byte) Extension {
|
||||
return Extension{Type: URLData, Length: uint8(len(data)), Data: data}
|
||||
}
|
||||
|
||||
// ExtensionType represents the type of UDP extension.
|
||||
type ExtensionType uint8
|
||||
|
||||
func (et ExtensionType) String() string {
|
||||
switch et {
|
||||
case EndOfOptions:
|
||||
return "EndOfOptions"
|
||||
case Nop:
|
||||
return "Nop"
|
||||
case URLData:
|
||||
return "URLData"
|
||||
default:
|
||||
return fmt.Sprintf("ExtensionType(%d)", et)
|
||||
}
|
||||
}
|
||||
|
||||
// Extension represent the extension used by the UDP ANNOUNCE request.
|
||||
//
|
||||
// BEP 41
|
||||
type Extension struct {
|
||||
Type ExtensionType
|
||||
Length uint8
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// EncodeTo encodes the response to buf.
|
||||
func (e Extension) EncodeTo(buf *bytes.Buffer) {
|
||||
if _len := uint8(len(e.Data)); e.Length == 0 && _len != 0 {
|
||||
e.Length = _len
|
||||
} else if _len != e.Length {
|
||||
panic("the length of data is inconsistent")
|
||||
}
|
||||
|
||||
buf.WriteByte(byte(e.Type))
|
||||
buf.WriteByte(e.Length)
|
||||
buf.Write(e.Data)
|
||||
}
|
||||
|
||||
// DecodeFrom decodes the response from b.
|
||||
func (e *Extension) DecodeFrom(b []byte) (parsed int) {
|
||||
switch len(b) {
|
||||
case 0:
|
||||
case 1:
|
||||
e.Type = ExtensionType(b[0])
|
||||
parsed = 1
|
||||
default:
|
||||
e.Type = ExtensionType(b[0])
|
||||
e.Length = b[1]
|
||||
parsed = 2
|
||||
if e.Length > 0 {
|
||||
parsed += int(e.Length)
|
||||
e.Data = b[2:parsed]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
289
tracker/udptracker/udp_client.go
Normal file
289
tracker/udptracker/udp_client.go
Normal file
@ -0,0 +1,289 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package udptracker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// NewTrackerClientByDial returns a new TrackerClient by dialing.
|
||||
func NewTrackerClientByDial(network, address string, c ...TrackerClientConfig) (
|
||||
*TrackerClient, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewTrackerClient(conn.(*net.UDPConn), c...), nil
|
||||
}
|
||||
|
||||
// NewTrackerClient returns a new TrackerClient.
|
||||
func NewTrackerClient(conn *net.UDPConn, c ...TrackerClientConfig) *TrackerClient {
|
||||
var conf TrackerClientConfig
|
||||
conf.set(c...)
|
||||
ipv4 := strings.Contains(conn.LocalAddr().String(), ".")
|
||||
return &TrackerClient{conn: conn, conf: conf, ipv4: ipv4}
|
||||
}
|
||||
|
||||
// TrackerClientConfig is used to configure the TrackerClient.
|
||||
type TrackerClientConfig struct {
|
||||
// ReadTimeout is used to receive the response.
|
||||
ReadTimeout time.Duration // Default: 5s
|
||||
MaxBufSize int // Default: 2048
|
||||
}
|
||||
|
||||
func (c *TrackerClientConfig) set(conf ...TrackerClientConfig) {
|
||||
if len(conf) > 0 {
|
||||
*c = conf[0]
|
||||
}
|
||||
|
||||
if c.MaxBufSize <= 0 {
|
||||
c.MaxBufSize = 2048
|
||||
}
|
||||
if c.ReadTimeout <= 0 {
|
||||
c.ReadTimeout = time.Second * 5
|
||||
}
|
||||
}
|
||||
|
||||
// TrackerClient is a tracker client based on UDP.
|
||||
//
|
||||
// Notice: the request is synchronized, that's, the last request is not returned,
|
||||
// the next request must not be sent.
|
||||
//
|
||||
// BEP 15
|
||||
type TrackerClient struct {
|
||||
ipv4 bool
|
||||
conf TrackerClientConfig
|
||||
conn *net.UDPConn
|
||||
last time.Time
|
||||
cid uint64
|
||||
tid uint32
|
||||
}
|
||||
|
||||
// Close closes the UDP tracker client.
|
||||
func (utc *TrackerClient) Close() { utc.conn.Close() }
|
||||
|
||||
func (utc *TrackerClient) readResp(b []byte) (int, error) {
|
||||
utc.conn.SetReadDeadline(time.Now().Add(utc.conf.ReadTimeout))
|
||||
return utc.conn.Read(b)
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) getTranID() uint32 {
|
||||
return atomic.AddUint32(&utc.tid, 1)
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) parseError(b []byte) (tid uint32, reason string) {
|
||||
tid = binary.BigEndian.Uint32(b[:4])
|
||||
reason = string(b[4:])
|
||||
return
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) send(b []byte) (err error) {
|
||||
n, err := utc.conn.Write(b)
|
||||
if err == nil && n < len(b) {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) connect() (err error) {
|
||||
tid := utc.getTranID()
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 16))
|
||||
binary.Write(buf, binary.BigEndian, ProtocolID)
|
||||
binary.Write(buf, binary.BigEndian, ActionConnect)
|
||||
binary.Write(buf, binary.BigEndian, tid)
|
||||
if err = utc.send(buf.Bytes()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, 32)
|
||||
n, err := utc.readResp(data)
|
||||
if err != nil {
|
||||
return
|
||||
} else if n < 8 {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
data = data[:n]
|
||||
switch binary.BigEndian.Uint32(data[:4]) {
|
||||
case ActionConnect:
|
||||
case ActionError:
|
||||
_, reason := utc.parseError(data[4:])
|
||||
return errors.New(reason)
|
||||
default:
|
||||
return errors.New("tracker response not connect action")
|
||||
}
|
||||
|
||||
if n < 16 {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
if binary.BigEndian.Uint32(data[4:8]) != tid {
|
||||
return errors.New("invalid transaction id")
|
||||
}
|
||||
|
||||
utc.cid = binary.BigEndian.Uint64(data[8:16])
|
||||
utc.last = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) getConnectionID() (cid uint64, err error) {
|
||||
cid = utc.cid
|
||||
if time.Now().Sub(utc.last) > time.Minute {
|
||||
if err = utc.connect(); err == nil {
|
||||
cid = utc.cid
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) announce(req AnnounceRequest) (r AnnounceResponse, err error) {
|
||||
cid, err := utc.getConnectionID()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tid := utc.getTranID()
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 110))
|
||||
binary.Write(buf, binary.BigEndian, cid)
|
||||
binary.Write(buf, binary.BigEndian, ActionAnnounce)
|
||||
binary.Write(buf, binary.BigEndian, tid)
|
||||
req.EncodeTo(buf)
|
||||
b := buf.Bytes()
|
||||
if err = utc.send(b); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, utc.conf.MaxBufSize)
|
||||
n, err := utc.readResp(data)
|
||||
if err != nil {
|
||||
return
|
||||
} else if n < 8 {
|
||||
err = io.ErrShortBuffer
|
||||
return
|
||||
}
|
||||
|
||||
data = data[:n]
|
||||
switch binary.BigEndian.Uint32(data[:4]) {
|
||||
case ActionAnnounce:
|
||||
case ActionError:
|
||||
_, reason := utc.parseError(data[4:])
|
||||
err = errors.New(reason)
|
||||
return
|
||||
default:
|
||||
err = errors.New("tracker response not connect action")
|
||||
return
|
||||
}
|
||||
|
||||
if n < 16 {
|
||||
err = io.ErrShortBuffer
|
||||
return
|
||||
}
|
||||
|
||||
if binary.BigEndian.Uint32(data[4:8]) != tid {
|
||||
err = errors.New("invalid transaction id")
|
||||
return
|
||||
}
|
||||
|
||||
r.DecodeFrom(data[8:], utc.ipv4)
|
||||
return
|
||||
}
|
||||
|
||||
// Announce sends a Announce request to the tracker.
|
||||
//
|
||||
// Notice:
|
||||
// 1. if it does not connect to the UDP tracker server, it will connect to it,
|
||||
// then send the ANNOUNCE request.
|
||||
// 2. If returning an error, you should retry it.
|
||||
// See http://www.bittorrent.org/beps/bep_0015.html#time-outs
|
||||
func (utc *TrackerClient) Announce(r AnnounceRequest) (AnnounceResponse, error) {
|
||||
return utc.announce(r)
|
||||
}
|
||||
|
||||
func (utc *TrackerClient) scrape(infohashes []metainfo.Hash) (rs []ScrapeResponse, err error) {
|
||||
cid, err := utc.getConnectionID()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tid := utc.getTranID()
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 16+len(infohashes)*20))
|
||||
binary.Write(buf, binary.BigEndian, cid)
|
||||
binary.Write(buf, binary.BigEndian, ActionScrape)
|
||||
binary.Write(buf, binary.BigEndian, tid)
|
||||
for _, h := range infohashes {
|
||||
buf.Write(h[:])
|
||||
}
|
||||
if err = utc.send(buf.Bytes()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, utc.conf.MaxBufSize)
|
||||
n, err := utc.readResp(data)
|
||||
if err != nil {
|
||||
return
|
||||
} else if n < 8 {
|
||||
err = io.ErrShortBuffer
|
||||
return
|
||||
}
|
||||
|
||||
data = data[:n]
|
||||
switch binary.BigEndian.Uint32(data[:4]) {
|
||||
case ActionScrape:
|
||||
case ActionError:
|
||||
_, reason := utc.parseError(data[4:])
|
||||
err = errors.New(reason)
|
||||
return
|
||||
default:
|
||||
err = errors.New("tracker response not connect action")
|
||||
return
|
||||
}
|
||||
|
||||
if binary.BigEndian.Uint32(data[4:8]) != tid {
|
||||
err = errors.New("invalid transaction id")
|
||||
return
|
||||
}
|
||||
|
||||
data = data[8:]
|
||||
_len := len(data)
|
||||
rs = make([]ScrapeResponse, 0, _len/12)
|
||||
for i := 12; i <= _len; i += 12 {
|
||||
var r ScrapeResponse
|
||||
r.DecodeFrom(data[i-12 : i])
|
||||
rs = append(rs, r)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Scrape sends a Scrape request to the tracker.
|
||||
//
|
||||
// Notice:
|
||||
// 1. if it does not connect to the UDP tracker server, it will connect to it,
|
||||
// then send the ANNOUNCE request.
|
||||
// 2. If returning an error, you should retry it.
|
||||
// See http://www.bittorrent.org/beps/bep_0015.html#time-outs
|
||||
func (utc *TrackerClient) Scrape(hs []metainfo.Hash) ([]ScrapeResponse, error) {
|
||||
return utc.scrape(hs)
|
||||
}
|
281
tracker/udptracker/udp_server.go
Normal file
281
tracker/udptracker/udp_server.go
Normal file
@ -0,0 +1,281 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package udptracker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
// TrackerServerHandler is used to handle the request from the client.
|
||||
type TrackerServerHandler interface {
|
||||
// OnConnect is used to check whether to make the connection or not.
|
||||
OnConnect(raddr *net.UDPAddr) (err error)
|
||||
OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (AnnounceResponse, error)
|
||||
OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ([]ScrapeResponse, error)
|
||||
}
|
||||
|
||||
func encodeResponseHeader(buf *bytes.Buffer, action, tid uint32) {
|
||||
binary.Write(buf, binary.BigEndian, action)
|
||||
binary.Write(buf, binary.BigEndian, tid)
|
||||
}
|
||||
|
||||
type wrappedPeerAddr struct {
|
||||
Addr *net.UDPAddr
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
// TrackerServerConfig is used to configure the TrackerServer.
|
||||
type TrackerServerConfig struct {
|
||||
MaxBufSize int // Default: 2048
|
||||
ErrorLog func(format string, args ...interface{}) // Default: log.Printf
|
||||
}
|
||||
|
||||
func (c *TrackerServerConfig) setDefault() {
|
||||
if c.MaxBufSize <= 0 {
|
||||
c.MaxBufSize = 2048
|
||||
}
|
||||
if c.ErrorLog == nil {
|
||||
c.ErrorLog = log.Printf
|
||||
}
|
||||
}
|
||||
|
||||
// TrackerServer is a tracker server based on UDP.
|
||||
type TrackerServer struct {
|
||||
conn net.PacketConn
|
||||
conf TrackerServerConfig
|
||||
handler TrackerServerHandler
|
||||
bufpool sync.Pool
|
||||
|
||||
cid uint64
|
||||
exit chan struct{}
|
||||
lock sync.RWMutex
|
||||
conns map[uint64]wrappedPeerAddr
|
||||
}
|
||||
|
||||
// NewTrackerServer returns a new TrackerServer.
|
||||
func NewTrackerServer(c net.PacketConn, h TrackerServerHandler,
|
||||
config ...TrackerServerConfig) *TrackerServer {
|
||||
var conf TrackerServerConfig
|
||||
if len(config) > 0 {
|
||||
conf = config[0]
|
||||
}
|
||||
conf.setDefault()
|
||||
|
||||
s := &TrackerServer{
|
||||
conf: conf,
|
||||
conn: c,
|
||||
handler: h,
|
||||
exit: make(chan struct{}),
|
||||
conns: make(map[uint64]wrappedPeerAddr, 128),
|
||||
}
|
||||
s.bufpool.New = func() interface{} { return make([]byte, conf.MaxBufSize) }
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Close closes the tracker server.
|
||||
func (uts *TrackerServer) Close() {
|
||||
select {
|
||||
case <-uts.exit:
|
||||
default:
|
||||
close(uts.exit)
|
||||
uts.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) cleanConnectionID(interval time.Duration) {
|
||||
tick := time.NewTicker(interval)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-uts.exit:
|
||||
return
|
||||
case now := <-tick.C:
|
||||
uts.lock.RLock()
|
||||
for cid, wa := range uts.conns {
|
||||
if now.Sub(wa.Time) > interval {
|
||||
delete(uts.conns, cid)
|
||||
}
|
||||
}
|
||||
uts.lock.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the tracker server.
|
||||
func (uts *TrackerServer) Run() {
|
||||
go uts.cleanConnectionID(time.Minute * 2)
|
||||
for {
|
||||
buf := uts.bufpool.Get().([]byte)
|
||||
n, raddr, err := uts.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "closed") {
|
||||
uts.conf.ErrorLog("failed to read udp tracker request: %s", err)
|
||||
}
|
||||
return
|
||||
} else if n < 16 {
|
||||
continue
|
||||
}
|
||||
go uts.handleRequest(raddr.(*net.UDPAddr), buf, n)
|
||||
}
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) handleRequest(raddr *net.UDPAddr, buf []byte, n int) {
|
||||
defer uts.bufpool.Put(buf)
|
||||
uts.handlePacket(raddr, buf[:n])
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) send(raddr *net.UDPAddr, b []byte) {
|
||||
n, err := uts.conn.WriteTo(b, raddr)
|
||||
if err != nil {
|
||||
uts.conf.ErrorLog("fail to send the udp tracker response to '%s': %s",
|
||||
raddr.String(), err)
|
||||
} else if n < len(b) {
|
||||
uts.conf.ErrorLog("too short udp tracker response sent to '%s'", raddr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) getConnectionID() uint64 {
|
||||
return atomic.AddUint64(&uts.cid, 1)
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) addConnection(cid uint64, raddr *net.UDPAddr) {
|
||||
now := time.Now()
|
||||
uts.lock.Lock()
|
||||
uts.conns[cid] = wrappedPeerAddr{Addr: raddr, Time: now}
|
||||
uts.lock.Unlock()
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) checkConnection(cid uint64, raddr *net.UDPAddr) (ok bool) {
|
||||
uts.lock.RLock()
|
||||
if w, _ok := uts.conns[cid]; _ok && w.Addr.Port == raddr.Port &&
|
||||
bytes.Equal(w.Addr.IP, raddr.IP) {
|
||||
ok = true
|
||||
}
|
||||
uts.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) sendError(raddr *net.UDPAddr, tid uint32, reason string) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 8+len(reason)))
|
||||
encodeResponseHeader(buf, ActionError, tid)
|
||||
buf.WriteString(reason)
|
||||
uts.send(raddr, buf.Bytes())
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) sendConnResp(raddr *net.UDPAddr, tid uint32, cid uint64) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 16))
|
||||
encodeResponseHeader(buf, ActionConnect, tid)
|
||||
binary.Write(buf, binary.BigEndian, cid)
|
||||
uts.send(raddr, buf.Bytes())
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) sendAnnounceResp(raddr *net.UDPAddr, tid uint32,
|
||||
resp AnnounceResponse) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 8+12+len(resp.Addresses)*18))
|
||||
encodeResponseHeader(buf, ActionAnnounce, tid)
|
||||
resp.EncodeTo(buf)
|
||||
uts.send(raddr, buf.Bytes())
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) sendScrapResp(raddr *net.UDPAddr, tid uint32,
|
||||
rs []ScrapeResponse) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 8+len(rs)*12))
|
||||
encodeResponseHeader(buf, ActionScrape, tid)
|
||||
for _, r := range rs {
|
||||
r.EncodeTo(buf)
|
||||
}
|
||||
uts.send(raddr, buf.Bytes())
|
||||
}
|
||||
|
||||
func (uts *TrackerServer) handlePacket(raddr *net.UDPAddr, b []byte) {
|
||||
cid := binary.BigEndian.Uint64(b[:8])
|
||||
action := binary.BigEndian.Uint32(b[8:12])
|
||||
tid := binary.BigEndian.Uint32(b[12:16])
|
||||
b = b[16:]
|
||||
|
||||
// Handle the connection request.
|
||||
if cid == ProtocolID && action == ActionConnect {
|
||||
if err := uts.handler.OnConnect(raddr); err != nil {
|
||||
uts.sendError(raddr, tid, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
cid := uts.getConnectionID()
|
||||
uts.addConnection(cid, raddr)
|
||||
uts.sendConnResp(raddr, tid, cid)
|
||||
return
|
||||
}
|
||||
|
||||
// Check whether the request is connected.
|
||||
if !uts.checkConnection(cid, raddr) {
|
||||
uts.sendError(raddr, tid, "connection is expired")
|
||||
return
|
||||
}
|
||||
|
||||
switch action {
|
||||
case ActionAnnounce:
|
||||
var req AnnounceRequest
|
||||
if raddr.IP.To4() != nil { // For ipv4
|
||||
if len(b) < 82 {
|
||||
uts.sendError(raddr, tid, "invalid announce request")
|
||||
return
|
||||
}
|
||||
req.DecodeFrom(b, true)
|
||||
} else { // For ipv6
|
||||
if len(b) < 94 {
|
||||
uts.sendError(raddr, tid, "invalid announce request")
|
||||
return
|
||||
}
|
||||
req.DecodeFrom(b, false)
|
||||
}
|
||||
|
||||
resp, err := uts.handler.OnAnnounce(raddr, req)
|
||||
if err != nil {
|
||||
uts.sendError(raddr, tid, err.Error())
|
||||
} else {
|
||||
uts.sendAnnounceResp(raddr, tid, resp)
|
||||
}
|
||||
case ActionScrape:
|
||||
_len := len(b)
|
||||
infohashes := make([]metainfo.Hash, 0, _len/20)
|
||||
for i, _len := 20, len(b); i <= _len; i += 20 {
|
||||
infohashes = append(infohashes, metainfo.NewHash(b[i-20:i]))
|
||||
}
|
||||
|
||||
if len(infohashes) == 0 {
|
||||
uts.sendError(raddr, tid, "no infohash")
|
||||
return
|
||||
}
|
||||
|
||||
resps, err := uts.handler.OnScrap(raddr, infohashes)
|
||||
if err != nil {
|
||||
uts.sendError(raddr, tid, err.Error())
|
||||
} else {
|
||||
uts.sendScrapResp(raddr, tid, resps)
|
||||
}
|
||||
default:
|
||||
uts.sendError(raddr, tid, "unkwnown action")
|
||||
}
|
||||
}
|
135
tracker/udptracker/udp_test.go
Normal file
135
tracker/udptracker/udp_test.go
Normal file
@ -0,0 +1,135 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package udptracker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/xgfone/bt/metainfo"
|
||||
)
|
||||
|
||||
type testHandler struct{}
|
||||
|
||||
func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return }
|
||||
func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (
|
||||
r AnnounceResponse, err error) {
|
||||
if req.Port != 80 {
|
||||
err = errors.New("port is not 80")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Exts) > 0 {
|
||||
for i, ext := range req.Exts {
|
||||
switch ext.Type {
|
||||
case URLData:
|
||||
fmt.Printf("Extensions[%d]: URLData(%s)\n", i, string(ext.Data))
|
||||
default:
|
||||
fmt.Printf("Extensions[%d]: %s\n", i, ext.Type.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r = AnnounceResponse{
|
||||
Interval: 1,
|
||||
Leechers: 2,
|
||||
Seeders: 3,
|
||||
Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}},
|
||||
}
|
||||
return
|
||||
}
|
||||
func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (
|
||||
rs []ScrapeResponse, err error) {
|
||||
rs = make([]ScrapeResponse, len(infohashes))
|
||||
for i := range infohashes {
|
||||
rs[i] = ScrapeResponse{
|
||||
Seeders: uint32(i)*10 + 1,
|
||||
Leechers: uint32(i)*10 + 2,
|
||||
Completed: uint32(i)*10 + 3,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ExampleTrackerClient() {
|
||||
// Start the UDP tracker server
|
||||
sconn, err := net.ListenPacket("udp4", "127.0.0.1:8001")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
server := NewTrackerServer(sconn, testHandler{})
|
||||
defer server.Close()
|
||||
go server.Run()
|
||||
|
||||
// Wait for the server to be started
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Create a client and dial to the UDP tracker server.
|
||||
client, err := NewTrackerClientByDial("udp4", "127.0.0.1:8001")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Send the ANNOUNCE request to the UDP tracker server,
|
||||
// and get the ANNOUNCE response.
|
||||
exts := []Extension{NewURLData([]byte("data")), NewNop()}
|
||||
req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80, Exts: exts}
|
||||
resp, err := client.Announce(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Interval: %d\n", resp.Interval)
|
||||
fmt.Printf("Leechers: %d\n", resp.Leechers)
|
||||
fmt.Printf("Seeders: %d\n", resp.Seeders)
|
||||
for i, addr := range resp.Addresses {
|
||||
fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String())
|
||||
fmt.Printf("Address[%d].Port: %d\n", i, addr.Port)
|
||||
}
|
||||
|
||||
// Send the SCRAPE request to the UDP tracker server,
|
||||
// and get the SCRAPE respsone.
|
||||
hs := []metainfo.Hash{metainfo.NewRandomHash(), metainfo.NewRandomHash()}
|
||||
rs, err := client.Scrape(hs)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
} else if len(rs) != 2 {
|
||||
log.Fatalf("%+v", rs)
|
||||
}
|
||||
|
||||
for i, r := range rs {
|
||||
fmt.Printf("%d.Seeders: %d\n", i, r.Seeders)
|
||||
fmt.Printf("%d.Leechers: %d\n", i, r.Leechers)
|
||||
fmt.Printf("%d.Completed: %d\n", i, r.Completed)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// Extensions[0]: URLData(data)
|
||||
// Extensions[1]: Nop
|
||||
// Interval: 1
|
||||
// Leechers: 2
|
||||
// Seeders: 3
|
||||
// Address[0].IP: 127.0.0.1
|
||||
// Address[0].Port: 8001
|
||||
// 0.Seeders: 1
|
||||
// 0.Leechers: 2
|
||||
// 0.Completed: 3
|
||||
// 1.Seeders: 11
|
||||
// 1.Leechers: 12
|
||||
// 1.Completed: 13
|
||||
}
|
50
utils/bool.go
Normal file
50
utils/bool.go
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Bool is used to implement a atomic bool.
|
||||
type Bool struct {
|
||||
v uint32
|
||||
}
|
||||
|
||||
// NewBool returns a new Bool with the initialized value.
|
||||
func NewBool(t bool) (b Bool) {
|
||||
if t {
|
||||
b.v = 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get returns the bool value.
|
||||
func (b *Bool) Get() bool { return atomic.LoadUint32(&b.v) == 1 }
|
||||
|
||||
// SetTrue sets the bool value to true.
|
||||
func (b *Bool) SetTrue() { atomic.StoreUint32(&b.v, 1) }
|
||||
|
||||
// SetFalse sets the bool value to false.
|
||||
func (b *Bool) SetFalse() { atomic.StoreUint32(&b.v, 0) }
|
||||
|
||||
// Set sets the bool value to t.
|
||||
func (b *Bool) Set(t bool) {
|
||||
if t {
|
||||
b.SetTrue()
|
||||
} else {
|
||||
b.SetFalse()
|
||||
}
|
||||
}
|
16
utils/doc.go
Normal file
16
utils/doc.go
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package utils supplies some convenient functions.
|
||||
package utils
|
30
utils/io.go
Normal file
30
utils/io.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import "io"
|
||||
|
||||
// CopyNBuffer is the same as io.CopyN, but uses the given buf as the buffer.
|
||||
func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) {
|
||||
written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf)
|
||||
if written == n {
|
||||
return n, nil
|
||||
} else if written < n && err == nil {
|
||||
// src stopped early; must have been EOF.
|
||||
err = io.EOF
|
||||
}
|
||||
|
||||
return
|
||||
}
|
26
utils/slice.go
Normal file
26
utils/slice.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
// InStringSlice reports whether s is in ss.
|
||||
func InStringSlice(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
29
utils/slice_test.go
Normal file
29
utils/slice_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInStringSlice(t *testing.T) {
|
||||
if !InStringSlice([]string{"a", "b"}, "a") {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
if InStringSlice([]string{"a", "b"}, "z") {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
24
utils/string.go
Normal file
24
utils/string.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright 2020 xgfone
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import "crypto/rand"
|
||||
|
||||
// RandomString generates a size-length string randomly.
|
||||
func RandomString(size int) string {
|
||||
bs := make([]byte, size)
|
||||
rand.Read(bs)
|
||||
return string(bs)
|
||||
}
|
Reference in New Issue
Block a user