first commit

This commit is contained in:
xgfone
2020-06-07 13:43:15 +08:00
commit 6bfa2f6700
65 changed files with 10388 additions and 0 deletions

40
.gitignore vendored Normal file
View File

@ -0,0 +1,40 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
# Intellij IDE
.idea/
*.iml
# test
test_*
# log
*.log
# Mac
.DS_Store
# VS Code
.vscode/

10
.travis.yml Normal file
View File

@ -0,0 +1,10 @@
language: go
go:
- 1.11.x
- 1.12.x
- 1.13.x
- 1.14.x
env:
- GO111MODULE=on
script:
- go test -race ./...

202
LICENSE Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

41
README.md Normal file
View File

@ -0,0 +1,41 @@
# BT - Another Implementation Based On Golang [![Build Status](https://travis-ci.org/xgfone/bt.svg?branch=master)](https://travis-ci.org/xgfone/bt) [![GoDoc](https://godoc.org/github.com/xgfone/bt?status.svg)](https://pkg.go.dev/github.com/xgfone/bt) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://raw.githubusercontent.com/xgfone/bt/master/LICENSE)
A pure golang implementation of [BitTorrent](http://bittorrent.org/beps/bep_0000.html) library, which is inspired by [dht](https://github.com/shiyanhui/dht) and [torrent](https://github.com/anacrolix/torrent).
## Features
- Support IPv4/IPv6.
- Pure Go implementation.
- Multi-BEPs implementation. [See below](#the-implemented-specifications)
- Only library without any denpendencies. For the command tools, see [bttools](https://github.com/xgfone/bttools).
## Install
```shell
$ go get github.com/xgfone/bt
```
## Example
See [godoc](https://pkg.go.dev/github.com/xgfone/bt) or [bttools](https://github.com/xgfone/bttools).
## The Implemented Specifications
- [x] [**BEP 03:** The BitTorrent Protocol Specification](http://bittorrent.org/beps/bep_0003.html)
- [x] [**BEP 05:** DHT Protocol](http://bittorrent.org/beps/bep_0005.html)
- [x] [**BEP 06:** Fast Extension](http://bittorrent.org/beps/bep_0006.html)
- [x] [**BEP 07:** IPv6 Tracker Extension](http://bittorrent.org/beps/bep_0007.html)
- [x] [**BEP 09:** Extension for Peers to Send Metadata Files](http://bittorrent.org/beps/bep_0009.html)
- [x] [**BEP 10:** Extension Protocol](http://bittorrent.org/beps/bep_0010.html)
- [ ] [**BEP 11:** Peer Exchange (PEX)](http://bittorrent.org/beps/bep_0011.html)
- [x] [**BEP 12:** Multitracker Metadata Extension](http://bittorrent.org/beps/bep_0012.html)
- [x] [**BEP 15:** UDP Tracker Protocol for BitTorrent](http://bittorrent.org/beps/bep_0015.html)
- [x] [**BEP 19:** WebSeed - HTTP/FTP Seeding (GetRight style)](http://bittorrent.org/beps/bep_0019.html) (Only `url-list` in metainfo)
- [x] [**BEP 23:** Tracker Returns Compact Peer Lists](http://bittorrent.org/beps/bep_0023.html)
- [x] [**BEP 32:** IPv6 extension for DHT](http://bittorrent.org/beps/bep_0032.html)
- [ ] [**BEP 33:** DHT scrape](http://bittorrent.org/beps/bep_0033.html)
- [x] [**BEP 41:** UDP Tracker Protocol Extensions](http://bittorrent.org/beps/bep_0041.html)
- [x] [**BEP 43:** Read-only DHT Nodes](http://bittorrent.org/beps/bep_0043.html)
- [ ] [**BEP 44:** Storing arbitrary data in the DHT](http://bittorrent.org/beps/bep_0044.html)
- [x] [**BEP 48:** Tracker Protocol Extension: Scrape](http://bittorrent.org/beps/bep_0048.html)

8
bencode/AUTHORS Normal file
View File

@ -0,0 +1,8 @@
Jeff Wendling <leterip@gmail.com>
Liam Edwards-Playne <liamz.co>
Casey Bodley <cbodley@gmail.com>
Conrad Pankoff <deoxxa@fknsrs.biz>
Cenk Alti <cenkalti@gmail.com>
Jan Winkelmann <j-winkelmann@tuhh.de>
Patrick Mézard <patrick@mezard.eu>
Glen De Cauwsemaecker <contact@glendc.com>

19
bencode/LICENSE Normal file
View File

@ -0,0 +1,19 @@
Copyright (c) 2013 The Authors (see AUTHORS)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

614
bencode/decode.go Normal file
View File

@ -0,0 +1,614 @@
package bencode
import (
"bufio"
"bytes"
"encoding"
"errors"
"fmt"
"io"
"reflect"
"strconv"
"strings"
)
var (
reflectByteSliceType = reflect.TypeOf([]byte(nil))
reflectStringType = reflect.TypeOf("")
)
// Unmarshaler is the interface implemented by types that can unmarshal
// a bencode description of themselves.
// The input can be assumed to be a valid encoding of a bencode value.
// UnmarshalBencode must copy the bencode data if it wishes to retain the data after returning.
type Unmarshaler interface {
UnmarshalBencode([]byte) error
}
// A Decoder reads and decodes bencoded data from an input stream.
type Decoder struct {
r *bufio.Reader
raw bool
buf []byte
n int
failUnordered bool
}
// SetFailOnUnorderedKeys will cause the decoder to fail when encountering
// unordered keys. The default is to not fail.
func (d *Decoder) SetFailOnUnorderedKeys(fail bool) {
d.failUnordered = fail
}
// BytesParsed returns the number of bytes that have actually been parsed
func (d *Decoder) BytesParsed() int {
return d.n
}
// read also writes into the buffer when d.raw is set.
func (d *Decoder) read(p []byte) (n int, err error) {
n, err = d.r.Read(p)
if d.raw {
d.buf = append(d.buf, p[:n]...)
}
d.n += n
return
}
// readBytes also writes into the buffer when d.raw is set.
func (d *Decoder) readBytes(delim byte) (line []byte, err error) {
line, err = d.r.ReadBytes(delim)
if d.raw {
d.buf = append(d.buf, line...)
}
d.n += len(line)
return
}
// readByte also writes into the buffer when d.raw is set.
func (d *Decoder) readByte() (b byte, err error) {
b, err = d.r.ReadByte()
if d.raw {
d.buf = append(d.buf, b)
}
d.n++
return
}
// readFull also writes into the buffer when d.raw is set.
func (d *Decoder) readFull(p []byte) (n int, err error) {
n, err = io.ReadFull(d.r, p)
if d.raw {
d.buf = append(d.buf, p[:n]...)
}
d.n += n
return
}
func (d *Decoder) peekByte() (b byte, err error) {
ch, err := d.r.Peek(1)
if err != nil {
return
}
b = ch[0]
return
}
// NewDecoder returns a new decoder that reads from r
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: bufio.NewReader(r)}
}
// Decode reads the bencoded value from its input and stores it in the value pointed to by val.
// Decode allocates maps/slices as necessary with the following additional rules:
// To decode a bencoded value into a nil interface value, the type stored in the interface value is one of:
// int64 for bencoded integers
// string for bencoded strings
// []interface{} for bencoded lists
// map[string]interface{} for bencoded dicts
// To unmarshal bencode into a value implementing the Unmarshaler interface,
// Unmarshal calls that value's UnmarshalBencode method.
// Otherwise, if the value implements encoding.TextUnmarshaler
// and the input is a bencode string, Unmarshal calls that value's
// UnmarshalText method with the decoded form of the string.
func (d *Decoder) Decode(val interface{}) error {
rv := reflect.ValueOf(val)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("Unwritable type passed into decode")
}
return d.decodeInto(rv)
}
// DecodeString reads the data in the string and stores it into the value pointed to by val.
// Read the docs for Decode for more information.
func DecodeString(in string, val interface{}) error {
buf := strings.NewReader(in)
d := NewDecoder(buf)
return d.Decode(val)
}
// DecodeBytes reads the data in b and stores it into the value pointed to by val.
// Read the docs for Decode for more information.
func DecodeBytes(b []byte, val interface{}) error {
r := bytes.NewReader(b)
d := NewDecoder(r)
return d.Decode(val)
}
func indirect(v reflect.Value, alloc bool) reflect.Value {
for {
switch v.Kind() {
case reflect.Interface:
if v.IsNil() {
if !alloc {
return reflect.Value{}
}
return v
}
case reflect.Ptr:
if v.IsNil() {
if !alloc {
return reflect.Value{}
}
v.Set(reflect.New(v.Type().Elem()))
}
default:
return v
}
v = v.Elem()
}
}
func (d *Decoder) decodeInto(val reflect.Value) (err error) {
var v reflect.Value
if d.raw {
v = val
} else {
var (
unmarshaler Unmarshaler
textUnmarshaler encoding.TextUnmarshaler
)
unmarshaler, textUnmarshaler, v = d.indirect(val)
// if we're decoding into an Unmarshaler,
// we pass on the next bencode value to this value instead,
// so it can decide what to do with it.
if unmarshaler != nil {
var x RawMessage
if err := d.decodeInto(reflect.ValueOf(&x)); err != nil {
return err
}
return unmarshaler.UnmarshalBencode([]byte(x))
}
// if we're decoding into an TextUnmarshaler,
// we'll assume that the bencode value is a string,
// we decode it as such and pass the result onto the unmarshaler.
if textUnmarshaler != nil {
var b []byte
ref := reflect.ValueOf(&b)
if err := d.decodeString(reflect.Indirect(ref)); err != nil {
return err
}
return textUnmarshaler.UnmarshalText(b)
}
// if we're decoding into a RawMessage set raw to true for the rest of
// the call stack, and switch out the value with an interface{}.
if _, ok := v.Interface().(RawMessage); ok {
v = reflect.Value{} // explicitly make v invalid
// set d.raw for the lifetime of this function call, and set the raw
// message when the function is exiting.
d.buf = d.buf[:0]
d.raw = true
defer func() {
d.raw = false
v := indirect(val, true)
v.SetBytes(append([]byte(nil), d.buf...))
}()
}
}
next, err := d.peekByte()
if err != nil {
return
}
switch next {
case 'i':
err = d.decodeInt(v)
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
err = d.decodeString(v)
case 'l':
err = d.decodeList(v)
case 'd':
err = d.decodeDict(v)
default:
err = errors.New("Invalid input")
}
return
}
func (d *Decoder) decodeInt(v reflect.Value) error {
// we need to read an i, some digits, and an e.
ch, err := d.readByte()
if err != nil {
return err
}
if ch != 'i' {
panic("got not an i when peek returned an i")
}
line, err := d.readBytes('e')
if err != nil || d.raw {
return err
}
digits := string(line[:len(line)-1])
switch v.Kind() {
default:
return fmt.Errorf("Cannot store int64 into %s", v.Type())
case reflect.Interface:
n, err := strconv.ParseInt(digits, 10, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(n))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(digits, 10, 64)
if err != nil {
return err
}
v.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n, err := strconv.ParseUint(digits, 10, 64)
if err != nil {
return err
}
v.SetUint(n)
case reflect.Bool:
n, err := strconv.ParseUint(digits, 10, 64)
if err != nil {
return err
}
v.SetBool(n != 0)
}
return nil
}
func (d *Decoder) decodeString(v reflect.Value) error {
// read until a colon to get the number of digits to read after
line, err := d.readBytes(':')
if err != nil {
return err
}
// parse it into an int for making a slice
l32, err := strconv.ParseInt(string(line[:len(line)-1]), 10, 32)
l := int(l32)
if err != nil {
return err
}
if l < 0 {
return fmt.Errorf("invalid negative string length: %d", l)
}
// read exactly l bytes out and make our string
buf := make([]byte, l)
_, err = d.readFull(buf)
if err != nil || d.raw {
return err
}
switch v.Kind() {
default:
return fmt.Errorf("Cannot store string into %s", v.Type())
case reflect.Slice:
if v.Type() != reflectByteSliceType {
return fmt.Errorf("Cannot store string into %s", v.Type())
}
v.SetBytes(buf)
case reflect.String:
v.SetString(string(buf))
case reflect.Interface:
v.Set(reflect.ValueOf(string(buf)))
}
return nil
}
func (d *Decoder) decodeList(v reflect.Value) error {
if !d.raw {
// if we have an interface, just put a []interface{} in it!
if v.Kind() == reflect.Interface {
var x []interface{}
defer func(p reflect.Value) { p.Set(v) }(v)
v = reflect.ValueOf(&x).Elem()
}
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
return fmt.Errorf("Cant store a []interface{} into %s", v.Type())
}
}
// read out the l that prefixes the list
ch, err := d.readByte()
if err != nil {
return err
}
if ch != 'l' {
panic("got something other than a list head after a peek")
}
// if we're decoding in raw mode,
// we only want to read into the buffer,
// without actually parsing any values
if d.raw {
var ch byte
for {
// peek for the end token and read it out
ch, err = d.peekByte()
if err != nil {
return err
}
if ch == 'e' {
_, err = d.readByte() // consume the end
return err
}
// decode the next value
err = d.decodeInto(v)
if err != nil {
return err
}
}
}
for i := 0; ; i++ {
// peek for the end token and read it out
ch, err := d.peekByte()
if err != nil {
return err
}
switch ch {
case 'e':
_, err := d.readByte() // consume the end
return err
}
// grow it if required
if i >= v.Cap() && v.IsValid() {
newcap := v.Cap() + v.Cap()/2
if newcap < 4 {
newcap = 4
}
newv := reflect.MakeSlice(v.Type(), v.Len(), newcap)
reflect.Copy(newv, v)
v.Set(newv)
}
// reslice into cap (its a slice now since it had to have grown)
if i >= v.Len() && v.IsValid() {
v.SetLen(i + 1)
}
// decode a value into the index
if err := d.decodeInto(v.Index(i)); err != nil {
return err
}
}
}
func (d *Decoder) decodeDict(v reflect.Value) error {
// if we have an interface{}, just put a map[string]interface{} in it!
if !d.raw && v.Kind() == reflect.Interface {
var x map[string]interface{}
defer func(p reflect.Value) { p.Set(v) }(v)
v = reflect.ValueOf(&x).Elem()
}
// consume the head token
ch, err := d.readByte()
if err != nil {
return err
}
if ch != 'd' {
panic("got an incorrect token when it was checked already")
}
if d.raw {
// if we're decoding in raw mode,
// we only want to read into the buffer,
// without actually parsing any values
for {
// peek the next value type
ch, err := d.peekByte()
if err != nil {
return err
}
if ch == 'e' {
_, err = d.readByte() // consume the end token
return err
}
err = d.decodeString(v)
if err != nil {
return err
}
err = d.decodeInto(v)
if err != nil {
return err
}
}
}
// check for correct type
var (
mapElem reflect.Value
isMap bool
vals map[string]reflect.Value
)
switch v.Kind() {
case reflect.Map:
t := v.Type()
if t.Key() != reflectStringType {
return fmt.Errorf("Can't store a map[string]interface{} into %s", v.Type())
}
if v.IsNil() {
v.Set(reflect.MakeMap(t))
}
isMap = true
mapElem = reflect.New(t.Elem()).Elem()
case reflect.Struct:
vals = make(map[string]reflect.Value)
setStructValues(vals, v)
default:
return fmt.Errorf("Can't store a map[string]interface{} into %s", v.Type())
}
var (
lastKey string
first bool = true
)
for {
var subv reflect.Value
// peek the next value type
ch, err := d.peekByte()
if err != nil {
return err
}
if ch == 'e' {
_, err = d.readByte() // consume the end token
return err
}
// peek the next value we're suppsed to read
var key string
if err := d.decodeString(reflect.ValueOf(&key).Elem()); err != nil {
return err
}
// check for unordered keys
if !first && d.failUnordered && lastKey > key {
return fmt.Errorf("unordered dictionary: %q appears before %q",
lastKey, key)
}
lastKey, first = key, false
if isMap {
mapElem.Set(reflect.Zero(v.Type().Elem()))
subv = mapElem
} else {
subv = vals[key]
}
if !subv.IsValid() {
// if it's invalid, grab but ignore the next value
var x interface{}
err := d.decodeInto(reflect.ValueOf(&x).Elem())
if err != nil {
return err
}
continue
}
// subv now contains what we load into
if err := d.decodeInto(subv); err != nil {
return err
}
if isMap {
v.SetMapIndex(reflect.ValueOf(key), subv)
}
}
}
// indirect walks down v allocating pointers as needed,
// until it gets to a non-pointer.
// if it encounters an (Text)Unmarshaler, indirect stops and returns that.
func (d *Decoder) indirect(v reflect.Value) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}
for {
// Load value from interface, but only if the result will be
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() {
v = e
continue
}
}
if v.Kind() != reflect.Ptr || v.IsNil() {
break
}
vi := v.Interface()
if u, ok := vi.(Unmarshaler); ok {
return u, nil, reflect.Value{}
}
if u, ok := vi.(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
}
v = v.Elem()
}
return nil, nil, indirect(v, true)
}
func setStructValues(m map[string]reflect.Value, v reflect.Value) {
t := v.Type()
if t.Kind() != reflect.Struct {
return
}
// do embedded fields first
for i := 0; i < v.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue
}
v := v.FieldByIndex(f.Index)
if f.Anonymous && f.Tag == "" {
setStructValues(m, v)
}
}
// overwrite embedded struct tags and names
for i := 0; i < v.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue
}
v := v.FieldByIndex(f.Index)
name, _ := parseTag(f.Tag.Get("bencode"))
if name == "" {
if f.Anonymous {
// it's a struct and its fields have already been added to the map
continue
}
name = f.Name
}
if isValidTag(name) {
m[name] = v
}
}
}

400
bencode/decode_test.go Normal file
View File

@ -0,0 +1,400 @@
package bencode
import (
"bytes"
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
)
func TestDecode(t *testing.T) {
type testCase struct {
in string
val interface{}
expect interface{}
err bool
unorderedFail bool
}
type dT struct {
X string
Y int
Z string `bencode:"zff"`
}
type Embedded struct {
B string
}
type issue22 struct {
X string `bencode:"x"`
Time myTimeType `bencode:"t"`
Foo myBoolType `bencode:"f"`
Bar myStringType `bencode:"b"`
Slice mySliceType `bencode:"s"`
Y string `bencode:"y"`
}
type issue26 struct {
X string `bencode:"x"`
Foo myBoolTextType `bencode:"f"`
Bar myTextStringType `bencode:"b"`
Slice myTextSliceType `bencode:"s"`
Y string `bencode:"y"`
}
type issue22WithErrorChild struct {
Name string `bencode:"n"`
Error errorMarshalType `bencode:"e"`
}
type issue26WithErrorChild struct {
Name string `bencode:"n"`
Error errorTextMarshalType `bencode:"e"`
}
type discardNonFieldDef struct {
B string
D string
}
type twoDefsForSameKey struct {
A string
A2 string `bencode:"A"`
A3 string `bencode:"A"`
}
now := time.Now()
var decodeCases = []testCase{
// integers
{`i5e`, new(int), int(5), false, false},
{`i-10e`, new(int), int(-10), false, false},
{`i8e`, new(uint), uint(8), false, false},
{`i8e`, new(uint8), uint8(8), false, false},
{`i8e`, new(uint16), uint16(8), false, false},
{`i8e`, new(uint32), uint32(8), false, false},
{`i8e`, new(uint64), uint64(8), false, false},
{`i8e`, new(int), int(8), false, false},
{`i8e`, new(int8), int8(8), false, false},
{`i8e`, new(int16), int16(8), false, false},
{`i8e`, new(int32), int32(8), false, false},
{`i8e`, new(int64), int64(8), false, false},
{`i0e`, new(*int), new(int), false, false},
{`i-2e`, new(uint), nil, true, false},
// bools
{`i1e`, new(bool), true, false, false},
{`i0e`, new(bool), false, false, false},
{`i0e`, new(*bool), new(bool), false, false},
{`i8e`, new(bool), true, false, false},
// strings
{`3:foo`, new(string), "foo", false, false},
{`4:foob`, new(string), "foob", false, false},
{`0:`, new(*string), new(string), false, false},
{`6:short`, new(string), nil, true, false},
// lists
{`l3:foo3:bare`, new([]string), []string{"foo", "bar"}, false, false},
{`li15ei20ee`, new([]int), []int{15, 20}, false, false},
{`ld3:fooi0eed3:bari1eee`, new([]map[string]int), []map[string]int{
{"foo": 0},
{"bar": 1},
}, false, false},
// dicts
{`d3:foo3:bar4:foob3:fooe`, new(map[string]string), map[string]string{
"foo": "bar",
"foob": "foo",
}, false, false},
{`d1:X3:foo1:Yi10e3:zff3:bare`, new(dT), dT{"foo", 10, "bar"}, false, false},
// encoding/json takes, if set, the tag as name and doesn't falls back to the
// struct field's name.
{`d1:X3:foo1:Yi10e1:Z3:bare`, new(dT), dT{"foo", 10, ""}, false, false},
{`d1:X3:foo1:Yi10e1:h3:bare`, new(dT), dT{"foo", 10, ""}, false, false},
{`d3:fooli0ei1ee3:barli2ei3eee`, new(map[string][]int), map[string][]int{
"foo": []int{0, 1},
"bar": []int{2, 3},
}, false, false},
{`de`, new(map[string]string), map[string]string{}, false, false},
// into interfaces
{`i5e`, new(interface{}), int64(5), false, false},
{`li5ee`, new(interface{}), []interface{}{int64(5)}, false, false},
{`5:hello`, new(interface{}), "hello", false, false},
{`d5:helloi5ee`, new(interface{}), map[string]interface{}{"hello": int64(5)}, false, false},
// into values whose type support the Unmarshaler interface
{`1:y`, new(myTimeType), nil, true, false},
{fmt.Sprintf("i%de", now.Unix()), new(myTimeType), myTimeType{time.Unix(now.Unix(), 0)}, false, false},
{`1:y`, new(myBoolType), myBoolType(true), false, false},
{`i42e`, new(myBoolType), nil, true, false},
{`1:n`, new(myBoolType), myBoolType(false), false, false},
{`1:n`, new(errorMarshalType), nil, true, false},
{`li102ei111ei111ee`, new(myStringType), myStringType("foo"), false, false},
{`i42e`, new(myStringType), nil, true, false},
{`d1:ai1e3:foo3:bare`, new(mySliceType), mySliceType{"a", int64(1), "foo", "bar"}, false, false},
{`i42e`, new(mySliceType), nil, true, false},
// into values who have a child which type supports the Unmarshaler interface
{
fmt.Sprintf(`d1:b3:foo1:f1:y1:sd1:f3:foo1:ai42ee1:ti%de1:x1:x1:y1:ye`, now.Unix()),
new(issue22),
issue22{
X: "x",
Time: myTimeType{time.Unix(now.Unix(), 0)},
Foo: myBoolType(true),
Bar: myStringType("foo"),
Slice: mySliceType{"a", int64(42), "f", "foo"},
Y: "y",
},
false,
false,
},
{
`d1:ei42e1:n3:fooe`,
new(issue22WithErrorChild),
nil,
true,
false,
},
// into values whose type support the TextUnmarshaler interface
{`1:y`, new(myBoolTextType), myBoolTextType(true), false, false},
{`1:n`, new(myBoolTextType), myBoolTextType(false), false, false},
{`i42e`, new(myBoolTextType), nil, true, false},
{`1:n`, new(errorTextMarshalType), nil, true, false},
{`7:foo_bar`, new(myTextStringType), myTextStringType("bar"), false, false},
{`i42e`, new(myTextStringType), nil, true, false},
{`7:a,b,c,d`, new(myTextSliceType), myTextSliceType{"a", "b", "c", "d"}, false, false},
{`i42e`, new(myTextSliceType), nil, true, false},
// into values who have a child which type supports the TextUnmarshaler interface
{
`d1:b7:foo_bar1:f1:y1:s5:1,2,31:x1:x1:y1:ye`,
new(issue26),
issue26{
X: "x",
Foo: myBoolTextType(true),
Bar: myTextStringType("bar"),
Slice: myTextSliceType{"1", "2", "3"},
Y: "y",
},
false,
false,
},
{
`d1:ei42e1:n3:fooe`,
new(issue26WithErrorChild),
nil,
true,
false,
},
// malformed
{`i53:foo`, new(interface{}), nil, true, false},
{`6:foo`, new(interface{}), nil, true, false},
{`di5ei2ee`, new(interface{}), nil, true, false},
{`d3:fooe`, new(interface{}), nil, true, false},
{`l3:foo3:bar`, new(interface{}), nil, true, false},
{`d-1:`, new(interface{}), nil, true, false},
// embedded structs
{`d1:A3:foo1:B3:bare`, new(struct {
A string
Embedded
}), struct {
A string
Embedded
}{"foo", Embedded{"bar"}}, false, false},
// Embedded structs with a valid tag are encoded as a definition
{`d1:B3:bar6:nestedd1:B3:fooee`, new(struct {
Embedded `bencode:"nested"`
}), struct {
Embedded `bencode:"nested"`
}{Embedded{"foo"}}, false, false},
// Don't fail when reading keys missing from the struct
{"d1:A7:discard1:B4:take1:C7:discard1:D4:takee",
new(discardNonFieldDef),
discardNonFieldDef{"take", "take"},
false,
false,
},
// Don't fail when reading the same key twice
{"d1:A1:a1:A1:b1:A1:c1:A1:de", new(twoDefsForSameKey),
twoDefsForSameKey{"", "", "d"}, false, false},
// Empty struct
{"de", new(struct{}), struct{}{}, false, false},
// Fail on unordered dictionaries
{"d1:Yi10e1:X1:a3:zff1:ce", new(dT), dT{}, true, true},
{"d3:zff1:c1:Yi10e1:X1:ae", new(dT), dT{}, true, true},
}
for i, tt := range decodeCases {
dec := NewDecoder(strings.NewReader(tt.in))
dec.SetFailOnUnorderedKeys(tt.unorderedFail)
err := dec.Decode(tt.val)
if !tt.err && err != nil {
t.Errorf("#%d (%v): Unexpected err: %v", i, tt.in, err)
continue
}
if tt.err && err == nil {
t.Errorf("#%d (%v): Expected err is nil", i, tt.in)
continue
}
v := reflect.ValueOf(tt.val).Elem().Interface()
if !reflect.DeepEqual(v, tt.expect) && !tt.err {
t.Errorf("#%d (%v): Val: %#v != %#v", i, tt.in, v, tt.expect)
}
}
}
func TestRawDecode(t *testing.T) {
type testCase struct {
in string
expect []byte
err bool
}
var rawDecodeCases = []testCase{
{`i5e`, []byte(`i5e`), false},
{`5:hello`, []byte(`5:hello`), false},
{`li5ei10e5:helloe`, []byte(`li5ei10e5:helloe`), false},
{`llleee`, []byte(`llleee`), false},
{`li5eli5eli5eeee`, []byte(`li5eli5eli5eeee`), false},
{`d5:helloi5ee`, []byte(`d5:helloi5ee`), false},
}
for i, tt := range rawDecodeCases {
var x RawMessage
err := DecodeString(tt.in, &x)
if !tt.err && err != nil {
t.Errorf("#%d: Unexpected err: %v", i, err)
continue
}
if tt.err && err == nil {
t.Errorf("#%d: Expected err is nil", i)
continue
}
if !reflect.DeepEqual(x, RawMessage(tt.expect)) && !tt.err {
t.Errorf("#%d: Val: %#v != %#v", i, x, tt.expect)
}
}
}
type myStringType string
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
func (mst *myStringType) UnmarshalBencode(b []byte) error {
var raw []byte
err := DecodeBytes(b, &raw)
if err != nil {
return err
}
*mst = myStringType(raw)
return nil
}
type mySliceType []interface{}
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
func (mst *mySliceType) UnmarshalBencode(b []byte) error {
m := make(map[string]interface{})
err := DecodeBytes(b, &m)
if err != nil {
return err
}
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
raw := make([]interface{}, 0, len(m)*2)
for _, key := range keys {
raw = append(raw, key, m[key])
}
*mst = mySliceType(raw)
return nil
}
type myTextStringType string
// UnmarshalText implements TextUnmarshaler.UnmarshalText
func (mst *myTextStringType) UnmarshalText(b []byte) error {
*mst = myTextStringType(bytes.TrimPrefix(b, []byte("foo_")))
return nil
}
type myTextSliceType []string
// UnmarshalText implements TextUnmarshaler.UnmarshalText
func (mst *myTextSliceType) UnmarshalText(b []byte) error {
raw := string(b)
*mst = strings.Split(raw, ",")
return nil
}
func TestNestedRawDecode(t *testing.T) {
type testCase struct {
in string
val interface{}
expect interface{}
err bool
}
type message struct {
Key string
Val int
Raw RawMessage
}
var cases = []testCase{
{`li5e5:hellod1:a1:beli5eee`, new([]RawMessage), []RawMessage{
RawMessage(`i5e`),
RawMessage(`5:hello`),
RawMessage(`d1:a1:be`),
RawMessage(`li5ee`),
}, false},
{`d1:a1:b1:c1:de`, new(map[string]RawMessage), map[string]RawMessage{
"a": RawMessage(`1:b`),
"c": RawMessage(`1:d`),
}, false},
{`d3:Key5:hello3:Rawldedei5e1:ae3:Vali10ee`, new(message), message{
Key: "hello",
Val: 10,
Raw: RawMessage(`ldedei5e1:ae`),
}, false},
}
for i, tt := range cases {
err := DecodeString(tt.in, tt.val)
if !tt.err && err != nil {
t.Errorf("#%d: Unexpected err: %v", i, err)
continue
}
if tt.err && err == nil {
t.Errorf("#%d: Expected err is nil", i)
continue
}
v := reflect.ValueOf(tt.val).Elem().Interface()
if !reflect.DeepEqual(v, tt.expect) && !tt.err {
t.Errorf("#%d: Val:\n%#v !=\n%#v", i, v, tt.expect)
}
}
}

6
bencode/doc.go Normal file
View File

@ -0,0 +1,6 @@
// Package bencode implements encoding and decoding of bencoded objects,
// which has a similar API to the encoding/json package and many other
// serialization formats.
//
// Notice: the package is moved and modified from github.com/zeebo/bencode@v1.0.0
package bencode

322
bencode/encode.go Normal file
View File

@ -0,0 +1,322 @@
package bencode
import (
"bytes"
"encoding"
"fmt"
"io"
"reflect"
"sort"
)
type sortValues []reflect.Value
func (p sortValues) Len() int { return len(p) }
func (p sortValues) Less(i, j int) bool { return p[i].String() < p[j].String() }
func (p sortValues) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Marshaler is the interface implemented by types
// that can marshal themselves into valid bencode.
type Marshaler interface {
MarshalBencode() ([]byte, error)
}
// An Encoder writes bencoded objects to an output stream.
type Encoder struct {
w io.Writer
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w}
}
// Encode writes the bencoded data of val to its output stream.
// If an encountered value implements the Marshaler interface,
// its MarshalBencode method is called to produce the bencode output for this value.
// If no MarshalBencode method is present but the value implements encoding.TextMarshaler instead,
// its MarshalText method is called, which encodes the result as a bencode string.
// See the documentation for Decode about the conversion of Go values to
// bencoded data.
func (e *Encoder) Encode(val interface{}) error {
return encodeValue(e.w, reflect.ValueOf(val))
}
// EncodeString returns the bencoded data of val as a string.
func EncodeString(val interface{}) (string, error) {
buf := new(bytes.Buffer)
e := NewEncoder(buf)
if err := e.Encode(val); err != nil {
return "", err
}
return buf.String(), nil
}
// EncodeBytes returns the bencoded data of val as a slice of bytes.
func EncodeBytes(val interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
e := NewEncoder(buf)
if err := e.Encode(val); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func isNilValue(v reflect.Value) bool {
return (v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr) &&
v.IsNil()
}
func encodeValue(w io.Writer, val reflect.Value) error {
marshaler, textMarshaler, v := indirectEncodeValue(val)
// marshal a type using the Marshaler type
// if it implements that interface.
if marshaler != nil {
bytes, err := marshaler.MarshalBencode()
if err != nil {
return err
}
_, err = w.Write(bytes)
return err
}
// marshal a type using the TextMarshaler type
// if it implements that interface.
if textMarshaler != nil {
bytes, err := textMarshaler.MarshalText()
if err != nil {
return err
}
_, err = fmt.Fprintf(w, "%d:%s", len(bytes), bytes)
return err
}
// if indirection returns us an invalid value that means there was a nil
// pointer in the path somewhere.
if !v.IsValid() {
return nil
}
// send in a raw message if we have that type
if rm, ok := v.Interface().(RawMessage); ok {
_, err := io.Copy(w, bytes.NewReader(rm))
return err
}
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
_, err := fmt.Fprintf(w, "i%de", v.Int())
return err
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
_, err := fmt.Fprintf(w, "i%de", v.Uint())
return err
case reflect.Bool:
i := 0
if v.Bool() {
i = 1
}
_, err := fmt.Fprintf(w, "i%de", i)
return err
case reflect.String:
_, err := fmt.Fprintf(w, "%d:%s", len(v.String()), v.String())
return err
case reflect.Slice, reflect.Array:
// handle byte slices like strings
if byteSlice, ok := val.Interface().([]byte); ok {
_, err := fmt.Fprintf(w, "%d:", len(byteSlice))
if err == nil {
_, err = w.Write(byteSlice)
}
return err
}
if _, err := fmt.Fprint(w, "l"); err != nil {
return err
}
for i := 0; i < v.Len(); i++ {
if err := encodeValue(w, v.Index(i)); err != nil {
return err
}
}
_, err := fmt.Fprint(w, "e")
return err
case reflect.Map:
if _, err := fmt.Fprint(w, "d"); err != nil {
return err
}
var (
keys sortValues = v.MapKeys()
mval reflect.Value
)
sort.Sort(keys)
for i := range keys {
mval = v.MapIndex(keys[i])
if isNilValue(mval) {
continue
}
if err := encodeValue(w, keys[i]); err != nil {
return err
}
if err := encodeValue(w, mval); err != nil {
return err
}
}
_, err := fmt.Fprint(w, "e")
return err
case reflect.Struct:
if _, err := fmt.Fprint(w, "d"); err != nil {
return err
}
// add embedded structs to the dictionary
dict := make(dictionary, 0, v.NumField())
dict, err := readStruct(dict, v)
if err != nil {
return err
}
// sort the dictionary by keys
sort.Sort(dict)
// encode the dictionary in order
for _, def := range dict {
// encode the key
err := encodeValue(w, reflect.ValueOf(def.key))
if err != nil {
return err
}
// encode the value
err = encodeValue(w, def.value)
if err != nil {
return err
}
}
_, err = fmt.Fprint(w, "e")
return err
}
return fmt.Errorf("Can't encode type: %s", v.Type())
}
// indirectEncodeValue walks down v allocating pointers as needed,
// until it gets to a non-pointer.
// if it encounters an (Text)Marshaler, indirect stops and returns that.
func indirectEncodeValue(v reflect.Value) (Marshaler, encoding.TextMarshaler, reflect.Value) {
// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}
for {
if v.Kind() == reflect.Ptr && v.IsNil() {
break
}
vi := v.Interface()
if m, ok := vi.(Marshaler); ok {
return m, nil, reflect.Value{}
}
if m, ok := vi.(encoding.TextMarshaler); ok {
return nil, m, reflect.Value{}
}
if v.Kind() != reflect.Ptr {
break
}
v = v.Elem()
}
return nil, nil, indirect(v, false)
}
type definition struct {
key string
value reflect.Value
}
type dictionary []definition
func (d dictionary) Len() int { return len(d) }
func (d dictionary) Less(i, j int) bool { return d[i].key < d[j].key }
func (d dictionary) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
func readStruct(dict dictionary, v reflect.Value) (dictionary, error) {
t := v.Type()
var (
fieldValue reflect.Value
rkey string
)
for i := 0; i < t.NumField(); i++ {
key := t.Field(i)
rkey = key.Name
fieldValue = v.FieldByIndex(key.Index)
// filter out unexported values etc.
if !fieldValue.CanInterface() {
continue
}
// filter out nil pointer values
if isNilValue(fieldValue) {
continue
}
// * Near identical to usage in JSON except with key 'bencode'
//
// * Struct values encode as BEncode dictionaries. Each exported
// struct field becomes a set in the dictionary unless
// - the field's tag is "-", or
// - the field is empty and its tag specifies the "omitempty"
// option.
//
// * The default key string is the struct field name but can be
// specified in the struct field's tag value. The "bencode"
// key in struct field's tag value is the key name, followed
// by an optional comma and options.
tagValue := key.Tag.Get("bencode")
if tagValue != "" {
// Keys with '-' are omit from output
if tagValue == "-" {
continue
}
name, options := parseTag(tagValue)
// Keys with 'omitempty' are omitted if the field is empty
if options.Contains("omitempty") && isEmptyValue(fieldValue) {
continue
}
// All other values are treated as the key string
if isValidTag(name) {
rkey = name
}
}
if key.Anonymous && key.Type.Kind() == reflect.Struct && tagValue == "" {
var err error
dict, err = readStruct(dict, fieldValue)
if err != nil {
return nil, err
}
} else {
dict = append(dict, definition{rkey, fieldValue})
}
}
return dict, nil
}

View File

@ -0,0 +1,109 @@
package bencode
import (
"errors"
"time"
)
type myBoolType bool
// MarshalBencode implements Marshaler.MarshalBencode
func (mbt myBoolType) MarshalBencode() ([]byte, error) {
var c string
if mbt {
c = "y"
} else {
c = "n"
}
return EncodeBytes(c)
}
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
func (mbt *myBoolType) UnmarshalBencode(b []byte) error {
var str string
err := DecodeBytes(b, &str)
if err != nil {
return err
}
switch str {
case "y":
*mbt = true
case "n":
*mbt = false
default:
err = errors.New("invalid myBoolType")
}
return err
}
type myBoolTextType bool
// MarshalText implements TextMarshaler.MarshalText
func (mbt myBoolTextType) MarshalText() ([]byte, error) {
if mbt {
return []byte("y"), nil
}
return []byte("n"), nil
}
// UnmarshalText implements TextUnmarshaler.UnmarshalText
func (mbt *myBoolTextType) UnmarshalText(b []byte) error {
switch string(b) {
case "y":
*mbt = true
case "n":
*mbt = false
default:
return errors.New("invalid myBoolType")
}
return nil
}
type myTimeType struct {
time.Time
}
// MarshalBencode implements Marshaler.MarshalBencode
func (mtt myTimeType) MarshalBencode() ([]byte, error) {
return EncodeBytes(mtt.Time.Unix())
}
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
func (mtt *myTimeType) UnmarshalBencode(b []byte) error {
var epoch int64
err := DecodeBytes(b, &epoch)
if err != nil {
return err
}
mtt.Time = time.Unix(epoch, 0)
return nil
}
type errorMarshalType struct{}
// MarshalBencode implements Marshaler.MarshalBencode
func (emt errorMarshalType) MarshalBencode() ([]byte, error) {
return nil, errors.New("oops")
}
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
func (emt errorMarshalType) UnmarshalBencode([]byte) error {
return errors.New("oops")
}
type errorTextMarshalType struct{}
// MarshalText implements TextMarshaler.MarshalText
func (emt errorTextMarshalType) MarshalText() ([]byte, error) {
return nil, errors.New("oops")
}
// UnmarshalText implements TextUnmarshaler.UnmarshalText
func (emt errorTextMarshalType) UnmarshalText([]byte) error {
return errors.New("oops")
}

354
bencode/encode_test.go Normal file
View File

@ -0,0 +1,354 @@
package bencode
import (
"errors"
"fmt"
"testing"
"time"
)
func TestEncode(t *testing.T) {
type encodeTestCase struct {
in interface{}
out string
err bool
}
type eT struct {
A string
X string `bencode:"D"`
Y string `bencode:"B"`
Z string `bencode:"C"`
}
type sortProblem struct {
A string
B string `bencode:","`
}
type issue18Sub struct {
Name string
}
type issue18 struct {
T *issue18Sub
}
type Embedded struct {
B string
}
type issue22 struct {
Time myTimeType `bencode:"t"`
Foo myBoolType `bencode:"f"`
}
type issue22WithErrorChild struct {
Name string `bencode:"n"`
Error errorMarshalType `bencode:"e"`
}
type issue26 struct {
Answer int64 `bencode:"a"`
Foo myBoolTextType `bencode:"f"`
Name string `bencode:"n"`
}
type issue26WithErrorChild struct {
Name string `bencode:"n"`
Error errorTextMarshalType `bencode:"e"`
}
type issue28 struct {
X string `bencode:"x"`
Time *myTimeType `bencode:"t"`
Foo myBoolPtrType `bencode:"f"`
Y string `bencode:"y"`
}
now := time.Now()
var encodeCases = []encodeTestCase{
// integers
{10, `i10e`, false},
{-10, `i-10e`, false},
{0, `i0e`, false},
{int(10), `i10e`, false},
{int8(10), `i10e`, false},
{int16(10), `i10e`, false},
{int32(10), `i10e`, false},
{int64(10), `i10e`, false},
{uint(10), `i10e`, false},
{uint8(10), `i10e`, false},
{uint16(10), `i10e`, false},
{uint32(10), `i10e`, false},
{uint64(10), `i10e`, false},
{(*int)(nil), ``, false},
// ptr-to-integer
{func() *int {
i := 42
return &i
}(), `i42e`, false},
// strings
{"foo", `3:foo`, false},
{"barbb", `5:barbb`, false},
{"", `0:`, false},
{(*string)(nil), ``, false},
// ptr-to-string
{func() *string {
str := "foo"
return &str
}(), `3:foo`, false},
// lists
{[]interface{}{"foo", 20}, `l3:fooi20ee`, false},
{[]interface{}{90, 20}, `li90ei20ee`, false},
{[]interface{}{[]interface{}{"foo", "bar"}, 20}, `ll3:foo3:barei20ee`, false},
{[]map[string]int{
{"a": 0, "b": 1},
{"c": 2, "d": 3},
}, `ld1:ai0e1:bi1eed1:ci2e1:di3eee`, false},
{[][]byte{
[]byte{'0', '2', '4', '6', '8'},
[]byte{'a', 'c', 'e'},
}, `l5:024683:acee`, false},
{(*[]interface{})(nil), ``, false},
// boolean
{true, "i1e", false},
{false, "i0e", false},
{(*bool)(nil), ``, false},
// dicts
{map[string]interface{}{
"a": "foo",
"c": "bar",
"b": "tes",
}, `d1:a3:foo1:b3:tes1:c3:bare`, false},
{eT{"foo", "bar", "far", "boo"}, `d1:A3:foo1:B3:far1:C3:boo1:D3:bare`, false},
{map[string][]int{
"a": {0, 1},
"b": {2, 3},
}, `d1:ali0ei1ee1:bli2ei3eee`, false},
{struct{ A, b int }{1, 2}, "d1:Ai1ee", false},
{(*struct{ A int })(nil), ``, false},
// raw
{RawMessage(`i5e`), `i5e`, false},
{[]RawMessage{
RawMessage(`i5e`),
RawMessage(`5:hello`),
RawMessage(`ldededee`),
}, `li5e5:helloldededeee`, false},
{map[string]RawMessage{
"a": RawMessage(`i5e`),
"b": RawMessage(`5:hello`),
"c": RawMessage(`ldededee`),
}, `d1:ai5e1:b5:hello1:cldededeee`, false},
// problem sorting
{sortProblem{A: "foo", B: "bar"}, `d1:A3:foo1:B3:bare`, false},
// nil values dropped from maps and structs
{map[string]*int{"a": nil}, `de`, false},
{struct{ A *int }{nil}, `de`, false},
{issue18{}, `de`, false},
{map[string]interface{}{"a": nil}, `de`, false},
{struct{ A interface{} }{nil}, `de`, false},
// embedded structs
{struct {
A string
Embedded
}{"foo", Embedded{"bar"}}, `d1:A3:foo1:B3:bare`, false},
{struct {
A string
Embedded `bencode:"C"`
}{"foo", Embedded{"bar"}}, `d1:A3:foo1:Cd1:B3:baree`, false},
// embedded structs order issue #20
{struct {
Embedded
A string
}{Embedded{"bar"}, "foo"}, `d1:A3:foo1:B3:bare`, false},
// types which implement the Marshal interface will
// be marshalled using this interface
{myBoolType(true), `1:y`, false},
{myBoolType(false), `1:n`, false},
{myTimeType{now}, fmt.Sprintf("i%de", now.Unix()), false},
{errorMarshalType{}, "", true},
// pointers to types which implement the Marshal interface will
// be marshalled using this interface
{func() *myBoolType {
b := myBoolType(true)
return &b
}(), `1:y`, false},
{func() *myTimeType {
t := myTimeType{now}
return &t
}(), fmt.Sprintf("i%de", now.Unix()), false},
{func() *errorMarshalType {
e := errorMarshalType{}
return &e
}(), "", true},
// nil-pointers to types which implement the Marshal interface will be ignored
{(*myBoolType)(nil), "", false},
{(*myTimeType)(nil), "", false},
{(*errorMarshalType)(nil), "", false},
// ptr-types which implements the Marshal interface will
// be marshalled using this interface
{func() *myBoolPtrType {
b := myBoolPtrType(true)
return &b
}(), `1:y`, false},
{func() *myBoolPtrType {
b := myBoolPtrType(false)
return &b
}(), `1:n`, false},
{(*myBoolPtrType)(nil), ``, false},
// structures can also have children which support
// the Marshal interface
{
issue22{Time: myTimeType{now}, Foo: myBoolType(true)},
fmt.Sprintf("d1:f1:y1:ti%dee", now.Unix()),
false,
},
{ // an error will be returned if a child can't be marshalled
issue22WithErrorChild{Name: "Foo", Error: errorMarshalType{}},
"", true,
},
// structures passed by reference which have children that support
// the (Text)Marshal interface (by value or by reference),
// will be marshaled using that interface
{
&issue22{Time: myTimeType{now}, Foo: myBoolType(true)},
fmt.Sprintf("d1:f1:y1:ti%dee", now.Unix()),
false,
},
{ // an error will be returned if a child can't be marshalled
&issue22WithErrorChild{Name: "Foo", Error: errorMarshalType{}},
"", true,
},
// types which implement the TextMarshal interface will
// be marshalled into a bencode string value using this interface
{myBoolTextType(true), `1:y`, false},
{myBoolTextType(false), `1:n`, false},
{errorTextMarshalType{}, "", true},
// structures can also have children which support
// the TextMarshal interface
{
issue26{Answer: 42, Foo: myBoolTextType(true), Name: "Nova"},
`d1:ai42e1:f1:y1:n4:Novae`,
false,
},
{ // an error will be returned if a child TextMarshaler returns an error
issue26WithErrorChild{Name: "Foo", Error: errorTextMarshalType{}},
"", true,
},
// ptr types which are used as value types,
// but which ptr version implement the Marshaler/TextMarshaler interface,
// will still get marshalling using this interface, when possible
{
&issue28{X: "x", Time: &myTimeType{now}, Foo: myBoolPtrType(true), Y: "y"},
fmt.Sprintf(`d1:f1:y1:ti%de1:x1:x1:y1:ye`, now.Unix()),
false,
},
}
for i, tt := range encodeCases {
t.Logf("%d: %#v", i, tt.in)
data, err := EncodeString(tt.in)
if !tt.err && err != nil {
t.Errorf("#%d: Unexpected err: %v", i, err)
continue
}
if tt.err && err == nil {
t.Errorf("#%d: Expected err is nil", i)
continue
}
if tt.out != data {
t.Errorf("#%d: Val: %q != %q", i, data, tt.out)
}
}
}
type myBoolPtrType bool
// MarshalBencode implements Marshaler.MarshalBencode
func (mbt *myBoolPtrType) MarshalBencode() ([]byte, error) {
var c string
if *mbt {
c = "y"
} else {
c = "n"
}
return EncodeBytes(c)
}
// UnmarshalBencode implements Unmarshaler.UnmarshalBencode
func (mbt *myBoolPtrType) UnmarshalBencode(b []byte) error {
var str string
err := DecodeBytes(b, &str)
if err != nil {
return err
}
switch str {
case "y":
*mbt = true
case "n":
*mbt = false
default:
err = errors.New("invalid myBoolType")
}
return err
}
func TestEncodeOmit(t *testing.T) {
type encodeTestCase struct {
in interface{}
out string
err bool
}
type eT struct {
A string `bencode:",omitempty"`
B int `bencode:",omitempty"`
C *int `bencode:",omitempty"`
}
var encodeCases = []encodeTestCase{
{eT{}, `de`, false},
{eT{A: "a"}, `d1:A1:ae`, false},
{eT{B: 5}, `d1:Bi5ee`, false},
{eT{C: new(int)}, `d1:Ci0ee`, false},
}
for i, tt := range encodeCases {
data, err := EncodeString(tt.in)
if !tt.err && err != nil {
t.Errorf("#%d: Unexpected err: %v", i, err)
continue
}
if tt.err && err == nil {
t.Errorf("#%d: Expected err is nil", i)
continue
}
if tt.out != data {
t.Errorf("#%d: Val: %q != %q", i, data, tt.out)
}
}
}

67
bencode/example_test.go Normal file
View File

@ -0,0 +1,67 @@
package bencode
import (
"fmt"
"io"
)
var (
data string
r io.Reader
w io.Writer
)
func ExampleDecodeString() {
var torrent interface{}
if err := DecodeString(data, &torrent); err != nil {
panic(err)
}
}
func ExampleEncodeString() {
var torrent interface{}
data, err := EncodeString(torrent)
if err != nil {
panic(err)
}
fmt.Println(data)
}
func ExampleDecodeBytes() {
var torrent interface{}
if err := DecodeBytes([]byte(data), &torrent); err != nil {
panic(err)
}
}
func ExampleEncodeBytes() {
var torrent interface{}
data, err := EncodeBytes(torrent)
if err != nil {
panic(err)
}
fmt.Println(data)
}
func ExampleEncoder_Encode() {
var x struct {
Foo string
Bar []string `bencode:"name"`
}
enc := NewEncoder(w)
if err := enc.Encode(x); err != nil {
panic(err)
}
}
func ExampleDecoder_Decode() {
dec := NewDecoder(r)
var torrent struct {
Announce string
List [][]string `bencode:"announce-list"`
}
if err := dec.Decode(&torrent); err != nil {
panic(err)
}
}

5
bencode/raw.go Normal file
View File

@ -0,0 +1,5 @@
package bencode
// RawMessage is a special type that will store the raw bencode data when
// encoding or decoding.
type RawMessage []byte

76
bencode/tag.go Normal file
View File

@ -0,0 +1,76 @@
package bencode
import (
"reflect"
"strings"
"unicode"
)
// tagOptions is the string following a comma in a struct field's "bencode"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// Contains returns whether checks that a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (options tagOptions) Contains(optionName string) bool {
s := string(options)
for s != "" {
var next string
i := strings.Index(s, ",")
if i >= 0 {
s, next = s[:i], s[i+1:]
}
if s == optionName {
return true
}
s = next
}
return false
}
func isValidTag(key string) bool {
if key == "" {
return false
}
for _, c := range key {
if c != ' ' && c != '$' && c != '-' && c != '_' && c != '.' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
return false
}
}
return true
}
func matchName(key string) func(string) bool {
return func(s string) bool {
return strings.ToLower(key) == strings.ToLower(s)
}
}
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
default:
return v.IsZero()
}
}

181
dht/blacklist.go Normal file
View File

@ -0,0 +1,181 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"sync"
"time"
)
// Blacklist is used to manage the ip blacklist.
//
// Notice: The implementation should clear the address existed for long time.
type Blacklist interface {
// In reports whether the address, ip and port, is in the blacklist.
In(ip string, port int) bool
// If port is equal to 0, it should ignore port and only use ip when matching.
Add(ip string, port int)
// If port is equal to 0, it should delete the address by only the ip.
Del(ip string, port int)
// Close is used to notice the implementation to release the underlying
// resource.
Close()
}
type noopBlacklist struct{}
func (nbl noopBlacklist) In(ip string, port int) bool { return false }
func (nbl noopBlacklist) Add(ip string, port int) {}
func (nbl noopBlacklist) Del(ip string, port int) {}
func (nbl noopBlacklist) Close() {}
// NewNoopBlacklist returns a no-op Blacklist.
func NewNoopBlacklist() Blacklist { return noopBlacklist{} }
// DebugBlacklist returns a new Blacklist to log the information as debug.
func DebugBlacklist(bl Blacklist, logf func(string, ...interface{})) Blacklist {
return logBlacklist{Blacklist: bl, logf: logf}
}
type logBlacklist struct {
Blacklist
logf func(string, ...interface{})
}
func (dbl logBlacklist) Add(ip string, port int) {
dbl.logf("add the blacklist: ip=%s, port=%d", ip, port)
dbl.Blacklist.Add(ip, port)
}
func (dbl logBlacklist) Del(ip string, port int) {
dbl.logf("delete the blacklist: ip=%s, port=%d", ip, port)
dbl.Blacklist.Del(ip, port)
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// NewMemoryBlacklist returns a blacklst implementation based on memory.
//
// if maxnum is equal to 0, no limit.
func NewMemoryBlacklist(maxnum int, duration time.Duration) Blacklist {
bl := &blacklist{
num: maxnum,
ips: make(map[string]*wrappedPort, 128),
exit: make(chan struct{}),
}
go bl.loop(duration)
return bl
}
type wrappedPort struct {
Time time.Time
Enable bool
Ports map[int]struct{}
}
type blacklist struct {
exit chan struct{}
lock sync.RWMutex
ips map[string]*wrappedPort
num int
}
func (bl *blacklist) loop(interval time.Duration) {
tick := time.NewTicker(interval)
defer tick.Stop()
for {
select {
case <-bl.exit:
return
case now := <-tick.C:
bl.lock.Lock()
for ip, wp := range bl.ips {
if now.Sub(wp.Time) > interval {
delete(bl.ips, ip)
}
}
bl.lock.Unlock()
}
}
}
func (bl *blacklist) Close() {
select {
case <-bl.exit:
default:
close(bl.exit)
}
}
// In reports whether the address, ip and port, is in the blacklist.
func (bl *blacklist) In(ip string, port int) (yes bool) {
bl.lock.RLock()
if wp, ok := bl.ips[ip]; ok {
if wp.Enable {
_, yes = wp.Ports[port]
} else {
yes = true
}
}
bl.lock.RUnlock()
return
}
func (bl *blacklist) Add(ip string, port int) {
bl.lock.Lock()
wp, ok := bl.ips[ip]
if !ok {
if bl.num > 0 && len(bl.ips) >= bl.num {
bl.lock.Unlock()
return
}
wp = &wrappedPort{Enable: true}
bl.ips[ip] = wp
}
if port < 1 {
wp.Enable = false
wp.Ports = nil
} else if wp.Ports == nil {
wp.Ports = map[int]struct{}{port: struct{}{}}
} else {
wp.Ports[port] = struct{}{}
}
wp.Time = time.Now()
bl.lock.Unlock()
}
func (bl *blacklist) Del(ip string, port int) {
bl.lock.Lock()
if wp, ok := bl.ips[ip]; ok {
if port < 1 {
delete(bl.ips, ip)
} else if wp.Enable {
switch len(wp.Ports) {
case 0, 1:
delete(bl.ips, ip)
default:
delete(wp.Ports, port)
wp.Time = time.Now()
}
}
}
bl.lock.Unlock()
}

80
dht/blacklist_test.go Normal file
View File

@ -0,0 +1,80 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"testing"
"time"
)
func (bl *blacklist) portsLen() (n int) {
bl.lock.RLock()
for _, wp := range bl.ips {
n += len(wp.Ports)
}
bl.lock.RUnlock()
return
}
func (bl *blacklist) getIPs() []string {
bl.lock.RLock()
ips := make([]string, 0, len(bl.ips))
for ip := range bl.ips {
ips = append(ips, ip)
}
bl.lock.RUnlock()
return ips
}
func TestMemoryBlacklist(t *testing.T) {
bl := NewMemoryBlacklist(3, time.Second).(*blacklist)
defer bl.Close()
bl.Add("1.1.1.1", 123)
bl.Add("1.1.1.1", 456)
bl.Add("1.1.1.1", 789)
bl.Add("2.2.2.2", 111)
bl.Add("3.3.3.3", 0)
bl.Add("4.4.4.4", 222)
ips := bl.getIPs()
if len(ips) != 3 {
t.Error(ips)
} else {
for _, ip := range ips {
switch ip {
case "1.1.1.1", "2.2.2.2", "3.3.3.3":
default:
t.Error(ip)
}
}
}
if n := bl.portsLen(); n != 4 {
t.Errorf("expect port num 4, but got %d", n)
}
if bl.In("1.1.1.1", 111) || !bl.In("1.1.1.1", 123) {
t.Fail()
}
if !bl.In("3.3.3.3", 111) || bl.In("4.4.4.4", 222) {
t.Fail()
}
bl.Del("3.3.3.3", 0)
if bl.In("3.3.3.3", 111) {
t.Fail()
}
}

864
dht/dht_server.go Normal file
View File

@ -0,0 +1,864 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package dht implements the DHT Protocol. And you can use it to build or join
// the DHT swarm network.
package dht
import (
"bytes"
"fmt"
"io"
"log"
"net"
"sync"
"time"
"github.com/xgfone/bt/bencode"
"github.com/xgfone/bt/krpc"
"github.com/xgfone/bt/metainfo"
)
const (
queryMethodPing = "ping"
queryMethodFindNode = "find_node"
queryMethodGetPeers = "get_peers"
queryMethodAnnouncePeer = "announce_peer"
)
var errUnsupportedIPProtocol = fmt.Errorf("unsupported ip protocol")
// Predefine some ip protocol stacks.
const (
IPv4Protocol IPProtocolStack = 4
IPv6Protocol IPProtocolStack = 6
)
// IPProtocolStack represents the ip protocol stack, such as IPv4 or IPv6
type IPProtocolStack uint8
// Result is used to pass the response result to the callback function.
type Result struct {
// Addr is the address of the peer where the request is sent to.
//
// Notice: it may be nil for "get_peers" request.
Addr *net.UDPAddr
// For Error
Code int // 0 represents the success.
Reason string // Reason indicates the reason why the request failed when Code > 0.
Timeout bool // Timeout indicates whether the response is timeout.
// The list of the address of the peers returned by GetPeers.
Peers []metainfo.Address
}
// Config is used to configure the DHT server.
type Config struct {
// K is the size of the bucket of the routing table.
//
// The default is 8.
K int
// ID is the id of the current DHT server node.
//
// The default is generated randomly
ID metainfo.Hash
// IPProtocols is used to specify the supported IP Protocol Stack.
//
// If not given, it will detect the address family of the server listens
// automatically and set the supported ip protocol stack by the address
// family. If fails, the default is []IPProtocolStack{IPv4Protocol}.
// So, if the server is listening on the empty address, such as ":6881",
// and only supports IPv6, you should specify it to "IPv6Protocol".
IPProtocols []IPProtocolStack
// ReadOnly indicates whether the current node is read-only.
//
// If true, the DHT server will enter in the read-only mode.
//
// The default is false.
//
// BEP 43
ReadOnly bool
// MsgSize is the maximum size of the DHT message.
//
// The default is 4096.
MsgSize int
// SearchDepth is used to control the depth to send the "get_peers" or
// "find_node" query recursively to get the peers storing the torrent
// infohash.
//
// The default depth is 8.
SearchDepth int
// RespTimeout is the response timeout, that's, the response is valid
// only before the timeout reaches.
//
// The default is "10s".
RespTimeout time.Duration
// RoutingTableStorage is used to store the nodes in the routing table.
//
// The default is nil.
RoutingTableStorage RoutingTableStorage
// PeerManager is used to manage the peers on torrent infohash,
// which is called when receiving the "get_peers" query.
//
// The default uses the inner token-peer manager.
PeerManager PeerManager
// Blacklist is used to manage the ip blacklist.
//
// The default is NewMemoryBlacklist(1024, time.Hour*24*7).
Blacklist Blacklist
// ErrorLog is used to log the error.
//
// The default is log.Printf.
ErrorLog func(format string, args ...interface{})
// OnSearch is called when someone searches the torrent infohash,
// that's, the "get_peers" query.
//
// The default callback does noting.
OnSearch func(infohash string, ip net.IP, port uint16)
// OnTorrent is called when someone has the torrent infohash
// or someone has just downloaded the torrent infohash,
// that's, the "get_peers" response or "announce_peer" query.
//
// The default callback does noting.
OnTorrent func(infohash string, ip net.IP, port uint16)
// HandleInMessage is used to intercept the incoming DHT message.
// For example, you can debug the message as the log.
//
// Return true if going on handling by the default. Or return false.
//
// The default is nil.
HandleInMessage func(*net.UDPAddr, *krpc.Message) bool
// HandleOutMessage is used to intercept the outgoing DHT message.
// For example, you can debug the message as the log.
//
// Return (false, nil) if going on handling by the default.
//
// The default is nil.
HandleOutMessage func(*net.UDPAddr, *krpc.Message) (wrote bool, err error)
}
func (c Config) in(*net.UDPAddr, *krpc.Message) bool { return true }
func (c Config) out(*net.UDPAddr, *krpc.Message) (bool, error) { return false, nil }
func (c *Config) set(conf ...Config) {
if len(conf) > 0 {
*c = conf[0]
}
if c.K <= 0 {
c.K = 8
}
if c.ID.IsZero() {
c.ID = metainfo.NewRandomHash()
}
if c.MsgSize <= 0 {
c.MsgSize = 4096
}
if c.ErrorLog == nil {
c.ErrorLog = log.Printf
}
if c.SearchDepth < 1 {
c.SearchDepth = 8
}
if c.RoutingTableStorage == nil {
c.RoutingTableStorage = noopStorage{}
}
if c.Blacklist == nil {
c.Blacklist = NewMemoryBlacklist(1024, time.Hour*24*7)
}
if c.RespTimeout == 0 {
c.RespTimeout = time.Second * 10
}
if c.OnSearch == nil {
c.OnSearch = func(string, net.IP, uint16) {}
}
if c.OnTorrent == nil {
c.OnTorrent = func(string, net.IP, uint16) {}
}
if c.HandleInMessage == nil {
c.HandleInMessage = c.in
}
if c.HandleOutMessage == nil {
c.HandleOutMessage = c.out
}
}
// Server is a DHT server.
type Server struct {
conf Config
exit chan struct{}
conn net.PacketConn
once sync.Once
ipv4 bool
ipv6 bool
want []krpc.Want
peerManager PeerManager
// routingTable *routingTable
routingTable4 *routingTable
routingTable6 *routingTable
tokenManager *tokenManager
tokenPeerManager *tokenPeerManager
transactionManager *transactionManager
}
// NewServer returns a new DHT server.
func NewServer(conn net.PacketConn, config ...Config) *Server {
var conf Config
conf.set(config...)
if len(conf.IPProtocols) == 0 {
host, _, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
panic(err)
} else if ip := net.ParseIP(host); ipIsZero(ip) || ip.To4() != nil {
conf.IPProtocols = []IPProtocolStack{IPv4Protocol}
} else {
conf.IPProtocols = []IPProtocolStack{IPv6Protocol}
}
}
var ipv4, ipv6 bool
var want []krpc.Want
for _, ip := range conf.IPProtocols {
switch ip {
case IPv4Protocol:
ipv4 = true
want = append(want, krpc.WantNodes)
case IPv6Protocol:
ipv6 = true
want = append(want, krpc.WantNodes6)
}
}
s := &Server{
ipv4: ipv4,
ipv6: ipv6,
want: want,
conn: conn,
conf: conf,
exit: make(chan struct{}),
peerManager: conf.PeerManager,
tokenManager: newTokenManager(),
tokenPeerManager: newTokenPeerManager(),
transactionManager: newTransactionManager(),
}
s.routingTable4 = newRoutingTable(s, false)
s.routingTable6 = newRoutingTable(s, true)
if s.peerManager == nil {
s.peerManager = s.tokenPeerManager
}
return s
}
// ID returns the ID of the DHT server node.
func (s *Server) ID() metainfo.Hash { return s.conf.ID }
// Bootstrap initializes the routing table at first.
//
// Notice: If the routing table has had some nodes, it does noting.
func (s *Server) Bootstrap(addrs []string) {
if (s.ipv4 && s.routingTable4.Len() == 0) ||
(s.ipv6 && s.routingTable6.Len() == 0) {
for _, addr := range addrs {
as, err := metainfo.NewAddressesFromString(addr)
if err != nil {
s.conf.ErrorLog(err.Error())
continue
}
for _, a := range as {
if err = s.FindNode(a.UDPAddr(), s.conf.ID); err != nil {
s.conf.ErrorLog(`fail to bootstrap '%s': %s`, a.String(), err)
}
}
}
}
}
// Node4Num returns the number of the ipv4 nodes in the routing table.
func (s *Server) Node4Num() int { return s.routingTable4.Len() }
// Node6Num returns the number of the ipv6 nodes in the routing table.
func (s *Server) Node6Num() int { return s.routingTable6.Len() }
// AddNode adds the node into the routing table.
//
// The returned value:
// NodeAdded: The node is added successfully.
// NodeNotAdded: The node is not added and is discarded.
// NodeExistAndUpdated: The node has existed, and its status has been updated.
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
// The current node will be discarded.
//
func (s *Server) AddNode(node krpc.Node) int {
// For IPv6
if isIPv6(node.Addr.IP) {
if s.ipv6 {
return s.routingTable6.AddNode(node)
}
return NodeNotAdded
}
// For IPv4
if s.ipv4 {
return s.routingTable4.AddNode(node)
}
return NodeNotAdded
}
func (s *Server) addNode(a *net.UDPAddr, id metainfo.Hash, ro bool) (r int) {
if ro { // BEP 43
return NodeNotAdded
}
if r = s.AddNode(krpc.NewNodeByUDPAddr(id, a)); r == NodeExistAndChanged {
s.conf.Blacklist.Add(a.IP.String(), a.Port)
}
return
}
func (s *Server) addNode2(node krpc.Node, ro bool) int {
if ro { // BEP 43
return NodeNotAdded
}
return s.AddNode(node)
}
func (s *Server) stop() {
close(s.exit)
s.routingTable4.Stop()
s.routingTable6.Stop()
s.tokenManager.Stop()
s.tokenPeerManager.Stop()
s.transactionManager.Stop()
s.conf.Blacklist.Close()
}
// Close stops the DHT server.
func (s *Server) Close() { s.once.Do(s.stop) }
// Sync is used to synchronize the routing table to the underlying storage.
func (s *Server) Sync() {
s.routingTable4.Sync()
s.routingTable6.Sync()
}
// Run starts the DHT server.
func (s *Server) Run() {
go s.routingTable4.Start(time.Minute * 5)
go s.routingTable6.Start(time.Minute * 5)
go s.tokenManager.Start(time.Minute * 10)
go s.tokenPeerManager.Start(time.Hour * 24)
go s.transactionManager.Start(s, s.conf.RespTimeout)
buf := make([]byte, s.conf.MsgSize)
for {
n, raddr, err := s.conn.ReadFrom(buf)
if err != nil {
s.conf.ErrorLog("fail to read the dht message: %s", err)
return
}
s.handlePacket(raddr.(*net.UDPAddr), buf[:n])
}
}
func (s *Server) isDisabled(raddr *net.UDPAddr) bool {
if isIPv6(raddr.IP) {
if !s.ipv6 {
return true
}
} else if !s.ipv4 {
return true
}
return false
}
// HandlePacket handles the incoming DHT message.
func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) {
if s.isDisabled(raddr) {
return
}
// Check whether the raddr is in the ip blacklist. If yes, discard it.
if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) {
return
}
var msg krpc.Message
if err := bencode.DecodeBytes(data, &msg); err != nil {
s.conf.ErrorLog("decode krpc message error: %s", err)
return
} else if msg.T == "" {
s.conf.ErrorLog("no transaction id from '%s'", raddr)
return
}
// TODO: Should we use a task pool??
go s.handleMessage(raddr, msg)
}
func (s *Server) handleMessage(raddr *net.UDPAddr, m krpc.Message) {
if !s.conf.HandleInMessage(raddr, &m) {
return
}
switch m.Y {
case "q":
if !m.A.ID.IsZero() {
r := s.addNode(raddr, m.A.ID, m.RO)
if r != NodeExistAndChanged && !s.conf.ReadOnly { // BEP 43
s.handleQuery(raddr, m)
}
}
case "r":
if !m.R.ID.IsZero() {
if s.addNode(raddr, m.R.ID, m.RO) == NodeExistAndChanged {
return
}
if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil {
t.OnResponse(t, raddr, m)
}
}
case "e":
if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil {
t.OnError(t, m.E.Code, m.E.Reason)
}
default:
s.conf.ErrorLog("unknown dht message type '%s'", m.Y)
}
}
func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) {
switch m.Q {
case queryMethodPing:
s.reply(raddr, m.T, krpc.ResponseResult{})
case queryMethodFindNode: // See BEP 32
var r krpc.ResponseResult
n4 := m.A.ContainsWant(krpc.WantNodes)
n6 := m.A.ContainsWant(krpc.WantNodes6)
if !n4 && !n6 {
if isIPv6(raddr.IP) {
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
} else {
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
}
} else {
if n4 {
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
}
if n6 {
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
}
}
s.reply(raddr, m.T, r)
case queryMethodGetPeers: // See BEP 32
n4 := m.A.ContainsWant(krpc.WantNodes)
n6 := m.A.ContainsWant(krpc.WantNodes6)
// Get the ipv4/ipv6 peers storing the torrent infohash.
var r krpc.ResponseResult
if !n4 && !n6 {
r.Values = s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, isIPv6(raddr.IP))
} else {
if n4 {
r.Values = s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, false)
}
if n6 {
values := s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, true)
if len(r.Values) == 0 {
r.Values = values
} else {
r.Values = append(r.Values, values...)
}
}
}
// No Peers, and return the closest other nodes.
if len(r.Values) == 0 {
if !n4 && !n6 {
if isIPv6(raddr.IP) {
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
} else {
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
}
} else {
if n4 {
r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K)
}
if n6 {
r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K)
}
}
}
r.Token = s.tokenManager.Token(raddr)
s.reply(raddr, m.T, r)
s.conf.OnSearch(m.A.InfoHash.HexString(), raddr.IP, uint16(raddr.Port))
case queryMethodAnnouncePeer:
if s.tokenManager.Check(raddr, m.A.Token) {
return
}
s.reply(raddr, m.T, krpc.ResponseResult{})
s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.IP, m.A.GetPort(raddr.Port))
default:
s.sendError(raddr, m.T, "unknown query method", krpc.ErrorCodeMethodUnknown)
}
}
func (s *Server) send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) {
// // TODO: Should we check the ip blacklist??
// if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) {
// return
// }
m.RO = s.conf.ReadOnly // BEP 43
if wrote, err = s.conf.HandleOutMessage(raddr, &m); !wrote && err == nil {
wrote, err = s._send(raddr, m)
}
return
}
func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) {
if m.T == "" || m.Y == "" {
panic(`DHT message "t" or "y" must not be empty`)
}
buf := bytes.NewBuffer(nil)
buf.Grow(128)
if err = bencode.NewEncoder(buf).Encode(m); err != nil {
panic(err)
}
n, err := s.conn.WriteTo(buf.Bytes(), raddr)
if err != nil {
err = fmt.Errorf("error writing %d bytes to %s: %s", buf.Len(), raddr, err)
s.conf.Blacklist.Add(raddr.IP.String(), 0)
return
}
wrote = true
if n != buf.Len() {
err = io.ErrShortWrite
}
return
}
func (s *Server) sendError(raddr *net.UDPAddr, tid, reason string, code int) {
if _, err := s.send(raddr, krpc.NewErrorMsg(tid, code, reason)); err != nil {
s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error())
}
}
func (s *Server) reply(raddr *net.UDPAddr, tid string, r krpc.ResponseResult) {
r.ID = s.conf.ID
if _, err := s.send(raddr, krpc.NewResponseMsg(tid, r)); err != nil {
s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error())
}
}
func (s *Server) request(t *transaction) (err error) {
if s.isDisabled(t.Addr) {
return errUnsupportedIPProtocol
}
t.Arg.ID = s.conf.ID
if t.ID == "" {
t.ID = s.transactionManager.GetTransactionID()
}
if _, err = s.send(t.Addr, krpc.NewQueryMsg(t.ID, t.Query, t.Arg)); err != nil {
s.conf.ErrorLog("error replying to %s: %s", t.Addr.String(), err.Error())
} else {
s.transactionManager.AddTransaction(t)
}
return
}
func (s *Server) onError(t *transaction, code int, reason string) {
s.conf.ErrorLog("got an error to ping '%s': code=%d, reason=%s",
t.Addr.String(), code, reason)
t.Done(Result{Code: code, Reason: reason})
}
func (s *Server) onTimeout(t *transaction) {
// TODO: Should we use a task pool??
t.Done(Result{Timeout: true})
s.conf.ErrorLog("transaction '%s' timeout: query=%s, raddr=%s",
t.ID, t.Query, t.Addr.String())
}
func (s *Server) onPingResp(t *transaction, a *net.UDPAddr, m krpc.Message) {
t.Done(Result{})
}
func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) {
// Store the response node with the token.
if m.R.Token != "" {
s.tokenPeerManager.Set(m.R.ID, a, m.R.Token)
}
// Get the peers.
if len(m.R.Values) > 0 {
t.Done(Result{Peers: m.R.Values})
for _, addr := range m.R.Values {
s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr.IP, addr.Port)
}
return
}
// Search the torrent infohash recursively.
t.Depth--
if t.Depth < 1 {
t.Done(Result{})
return
}
var found bool
ids := t.Visited
nodes := make([]krpc.Node, 0, len(m.R.Nodes)+len(m.R.Nodes6))
for _, node := range m.R.Nodes {
if node.ID == t.Arg.InfoHash {
found = true
}
if ids.Contains(node.ID) {
continue
}
if s.addNode2(node, m.RO) == NodeAdded {
nodes = append(nodes, node)
ids = append(ids, node.ID)
}
}
for _, node := range m.R.Nodes6 {
if node.ID == t.Arg.InfoHash {
found = true
}
if ids.Contains(node.ID) {
continue
}
if s.addNode2(node, m.RO) == NodeAdded {
nodes = append(nodes, node)
ids = append(ids, node.ID)
}
}
if found || len(nodes) == 0 {
t.Done(Result{})
return
}
for _, node := range nodes {
s.getPeers(t.Arg.InfoHash, node.Addr, t.Depth, ids, t.Callback)
}
}
func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int,
ids metainfo.Hashes, cb ...func(Result)) {
arg := krpc.QueryArg{InfoHash: info, Wants: s.want}
t := newTransaction(s, addr.UDPAddr(), queryMethodGetPeers, arg, cb...)
t.OnResponse = s.onGetPeersResp
t.Depth = depth
t.Visited = ids
if err := s.request(t); err != nil {
s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err)
}
}
// Ping sends a PING query to addr, and the callback function cb will be called
// when the response or error is returned, or it's timeout.
func (s *Server) Ping(addr *net.UDPAddr, cb ...func(Result)) (err error) {
t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...)
t.OnResponse = s.onPingResp
return s.request(t)
}
// GetPeers searches the peer storing the torrent by the infohash of the torrent,
// which will search it recursively until some peers are returned or it reaches
// the maximun depth, that's, ServerConfig.SearchDepth.
//
// If cb is given, it will be called when some peers are returned.
// Notice: it may be called for many times.
func (s *Server) GetPeers(infohash metainfo.Hash, cb ...func(Result)) {
if infohash.IsZero() {
panic("the infohash of the torrent is ZERO")
}
var nodes []krpc.Node
if s.ipv4 {
nodes = s.routingTable4.Closest(infohash, s.conf.K)
}
if s.ipv6 {
nodes = append(nodes, s.routingTable6.Closest(infohash, s.conf.K)...)
}
if len(nodes) == 0 {
if len(cb) != 0 && cb[0] != nil {
cb[0](Result{})
}
return
}
ids := make(metainfo.Hashes, len(nodes))
for i, node := range nodes {
ids[i] = node.ID
}
for _, node := range nodes {
s.getPeers(infohash, node.Addr, s.conf.SearchDepth, ids, cb...)
}
}
// AnnouncePeer announces the torrent infohash to the K closest nodes,
// and returns the nodes to which it sends the announce_peer query.
func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort bool) []krpc.Node {
if infohash.IsZero() {
panic("the infohash of the torrent is ZERO")
}
var nodes []krpc.Node
if s.ipv4 {
nodes = s.routingTable4.Closest(infohash, s.conf.K)
}
if s.ipv6 {
nodes = append(nodes, s.routingTable6.Closest(infohash, s.conf.K)...)
}
sentNodes := make([]krpc.Node, 0, len(nodes))
for _, node := range nodes {
addr := node.Addr.UDPAddr()
token := s.tokenPeerManager.Get(infohash, addr)
if token == "" {
continue
}
arg := krpc.QueryArg{ImpliedPort: impliedPort, InfoHash: infohash, Port: port, Token: token}
t := newTransaction(s, addr, queryMethodAnnouncePeer, arg)
if err := s.request(t); err != nil {
s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err)
} else {
sentNodes = append(sentNodes, node)
}
}
return sentNodes
}
// FindNode sends the "find_node" query to the addr to find the target node.
//
// Notice: In general, it's used to bootstrap the routing table.
func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error {
if target.IsZero() {
panic("the target is ZERO")
}
return s.findNode(target, addr, s.conf.SearchDepth, nil)
}
func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int,
ids metainfo.Hashes) error {
arg := krpc.QueryArg{Target: target, Wants: s.want}
t := newTransaction(s, addr, queryMethodFindNode, arg)
t.OnResponse = s.onFindNodeResp
t.Visited = ids
return s.request(t)
}
func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message) {
// Search the target node recursively.
t.Depth--
if t.Depth < 1 {
return
}
var found bool
ids := t.Visited
nodes := make([]krpc.Node, 0, len(m.R.Nodes)+len(m.R.Nodes6))
for _, node := range m.R.Nodes {
if node.ID == t.Arg.Target {
found = true
}
if ids.Contains(node.ID) {
continue
}
if s.addNode2(node, m.RO) == NodeAdded {
nodes = append(nodes, node)
ids = append(ids, node.ID)
}
}
for _, node := range m.R.Nodes6 {
if node.ID == t.Arg.Target {
found = true
}
if ids.Contains(node.ID) {
continue
}
if s.addNode2(node, m.RO) == NodeAdded {
nodes = append(nodes, node)
ids = append(ids, node.ID)
}
}
if found || len(nodes) == 0 {
return
}
for _, node := range nodes {
err := s.findNode(t.Arg.Target, node.Addr.UDPAddr(), t.Depth, ids)
if err != nil {
s.conf.ErrorLog(`fail to send "find_node" query to '%s': %s`,
node.Addr.String(), err)
}
}
}
func isIPv6(ip net.IP) bool {
if ip.To4() == nil {
return true
}
return false
}
func ipIsZero(ip net.IP) bool {
for _, b := range ip {
if b != 0 {
return false
}
}
return true
}

182
dht/dht_server_test.go Normal file
View File

@ -0,0 +1,182 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/xgfone/bt/metainfo"
)
type testPeerManager struct {
lock sync.RWMutex
peers map[metainfo.Hash][]metainfo.Address
}
func newTestPeerManager() *testPeerManager {
return &testPeerManager{peers: make(map[metainfo.Hash][]metainfo.Address)}
}
func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr metainfo.Address) {
pm.lock.Lock()
var exist bool
for _, orig := range pm.peers[infohash] {
if orig.Equal(addr) {
exist = true
break
}
}
if !exist {
pm.peers[infohash] = append(pm.peers[infohash], addr)
}
pm.lock.Unlock()
}
func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int,
ipv6 bool) (addrs []metainfo.Address) {
// We only supports IPv4, so ignore the ipv6 argument.
pm.lock.RLock()
_addrs := pm.peers[infohash]
if _len := len(_addrs); _len > 0 {
if _len > maxnum {
_len = maxnum
}
addrs = _addrs[:_len]
}
pm.lock.RUnlock()
return
}
func onSearch(infohash string, ip net.IP, port uint16) {
addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10))
fmt.Printf("%s is searching %s\n", addr, infohash)
}
func onTorrent(infohash string, ip net.IP, port uint16) {
addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10))
fmt.Printf("%s has downloaded %s\n", addr, infohash)
}
func newDHTServer(id metainfo.Hash, addr string, pm PeerManager) (s *Server, err error) {
conn, err := net.ListenPacket("udp", addr)
if err == nil {
c := Config{ID: id, PeerManager: pm, OnSearch: onSearch, OnTorrent: onTorrent}
s = NewServer(conn, c)
}
return
}
func ExampleServer() {
// For test, we predefine some node ids and infohash.
id1 := metainfo.NewRandomHash()
id2 := metainfo.NewRandomHash()
id3 := metainfo.NewRandomHash()
infohash := metainfo.Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
// Create first DHT server
pm := newTestPeerManager()
server1, err := newDHTServer(id1, ":9001", pm)
if err != nil {
fmt.Println(err)
return
}
defer server1.Close()
// Create second DHT server
server2, err := newDHTServer(id2, ":9002", nil)
if err != nil {
fmt.Println(err)
return
}
defer server2.Close()
// Create third DHT server
server3, err := newDHTServer(id3, ":9003", nil)
if err != nil {
fmt.Println(err)
return
}
defer server3.Close()
// Start the three DHT servers
go server1.Run()
go server2.Run()
go server3.Run()
// Wait that the DHT servers start.
time.Sleep(time.Second)
// Bootstrap the routing table to create a DHT network with the three servers
server1.Bootstrap([]string{"127.0.0.1:9002", "127.0.0.1:9003"})
server2.Bootstrap([]string{"127.0.0.1:9001", "127.0.0.1:9003"})
server3.Bootstrap([]string{"127.0.0.1:9001", "127.0.0.1:9002"})
// Wait that the DHT servers learn the routing table
time.Sleep(time.Second)
fmt.Println("Server1:", server1.Node4Num())
fmt.Println("Server2:", server2.Node4Num())
fmt.Println("Server3:", server3.Node4Num())
server1.GetPeers(infohash, func(r Result) {
if len(r.Peers) == 0 {
fmt.Printf("no peers for %s\n", infohash)
} else {
for _, peer := range r.Peers {
fmt.Printf("%s: %s\n", infohash, peer.String())
}
}
})
// Wait that the last get_peers ends.
time.Sleep(time.Second)
// Add the peer to let the DHT server1 has the peer.
pm.AddPeer(infohash, metainfo.NewAddress(net.ParseIP("127.0.0.1"), 9001))
// Search the torrent infohash again, but from DHT server2,
// which will search the DHT server1 recursively.
server2.GetPeers(infohash, func(r Result) {
if len(r.Peers) == 0 {
fmt.Printf("no peers for %s\n", infohash)
} else {
for _, peer := range r.Peers {
fmt.Printf("%s: %s\n", infohash, peer.String())
}
}
})
// Wait that the recursive call ends.
time.Sleep(time.Second)
// Unordered output:
// Server1: 2
// Server2: 2
// Server3: 2
// 127.0.0.1:9001 is searching 0102030405060708090a0b0c0d0e0f1011121314
// 127.0.0.1:9001 is searching 0102030405060708090a0b0c0d0e0f1011121314
// no peers for 0102030405060708090a0b0c0d0e0f1011121314
// no peers for 0102030405060708090a0b0c0d0e0f1011121314
// 127.0.0.1:9002 is searching 0102030405060708090a0b0c0d0e0f1011121314
// 127.0.0.1:9002 is searching 0102030405060708090a0b0c0d0e0f1011121314
// no peers for 0102030405060708090a0b0c0d0e0f1011121314
// 0102030405060708090a0b0c0d0e0f1011121314: 127.0.0.1:9001
// 127.0.0.1:9001 has downloaded 0102030405060708090a0b0c0d0e0f1011121314
}

139
dht/peer_manager.go Normal file
View File

@ -0,0 +1,139 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"net"
"sync"
"time"
"github.com/xgfone/bt/metainfo"
)
// PeerManager is used to manage the peers.
type PeerManager interface {
// If ipv6 is true, only return ipv6 addresses. Or return ipv4 addresses.
GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []metainfo.Address
}
type peer struct {
ID metainfo.Hash
IP net.IP
Port uint16
Token string
Time time.Time
}
type tokenPeerManager struct {
lock sync.RWMutex
exit chan struct{}
peers map[metainfo.Hash]map[string]peer
}
func newTokenPeerManager() *tokenPeerManager {
return &tokenPeerManager{
exit: make(chan struct{}),
peers: make(map[metainfo.Hash]map[string]peer, 128),
}
}
// Start starts the token-peer manager.
func (tpm *tokenPeerManager) Start(interval time.Duration) {
tick := time.NewTicker(interval)
defer tick.Stop()
for {
select {
case <-tpm.exit:
return
case now := <-tick.C:
tpm.lock.Lock()
for id, peers := range tpm.peers {
for addr, peer := range peers {
if now.Sub(peer.Time) > interval {
delete(peers, addr)
}
}
if len(peers) == 0 {
delete(tpm.peers, id)
}
}
tpm.lock.Unlock()
}
}
}
func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token string) {
addrkey := addr.String()
tpm.lock.Lock()
peers, ok := tpm.peers[id]
if !ok {
peers = make(map[string]peer, 4)
tpm.peers[id] = peers
}
peers[addrkey] = peer{
ID: id,
IP: addr.IP,
Port: uint16(addr.Port),
Token: token,
Time: time.Now(),
}
tpm.lock.Unlock()
}
func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr *net.UDPAddr) (token string) {
addrkey := addr.String()
tpm.lock.RLock()
if peers, ok := tpm.peers[id]; ok {
if peer, ok := peers[addrkey]; ok {
token = peer.Token
}
}
tpm.lock.RUnlock()
return
}
func (tpm *tokenPeerManager) Stop() {
select {
case <-tpm.exit:
default:
close(tpm.exit)
}
}
func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int,
ipv6 bool) (addrs []metainfo.Address) {
addrs = make([]metainfo.Address, 0, maxnum)
tpm.lock.RLock()
if peers, ok := tpm.peers[infohash]; ok {
for _, peer := range peers {
if maxnum < 1 {
break
}
if ipv6 { // For IPv6
if isIPv6(peer.IP) {
maxnum--
addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port))
}
} else if !isIPv6(peer.IP) { // For IPv4
maxnum--
addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port))
}
}
}
tpm.lock.RUnlock()
return
}

439
dht/routing_table.go Normal file
View File

@ -0,0 +1,439 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"sort"
"sync"
"time"
"github.com/xgfone/bt/krpc"
"github.com/xgfone/bt/metainfo"
)
const bktlen = 160
const (
badTimeout = time.Minute * 20
dubiousTimeout = time.Minute * 15
)
type routingTable struct {
k int
s *Server
ipv6 bool
sync chan struct{}
exit chan struct{}
root metainfo.Hash
lock sync.RWMutex
bkts []*bucket
}
func newRoutingTable(s *Server, ipv6 bool) *routingTable {
rt := &routingTable{
k: s.conf.K,
s: s,
ipv6: ipv6,
root: s.conf.ID,
sync: make(chan struct{}),
exit: make(chan struct{}),
bkts: make([]*bucket, bktlen),
}
for i := range rt.bkts {
rt.bkts[i] = &bucket{table: rt}
}
// Load all the nodes from the storage.
nodes, err := s.conf.RoutingTableStorage.Load(s.conf.ID, ipv6)
if err != nil {
s.conf.ErrorLog("fail to load routing table(ipv6=%v): %s", ipv6, err)
} else {
now := time.Now()
for _, node := range nodes {
if now.Sub(node.LastChanged) < dubiousTimeout {
rt.addNode(node.Node, node.LastChanged)
}
}
}
return rt
}
// Sync dumps the information of the routing table to the underlying storage.
func (rt *routingTable) Sync() {
select {
case rt.sync <- struct{}{}:
case <-rt.exit:
}
}
// Start starts the routing table.
func (rt *routingTable) Start(interval time.Duration) {
tick := time.NewTicker(interval)
defer tick.Stop()
for {
select {
case <-rt.exit:
rt.dump()
return
case <-rt.sync:
rt.dump()
case now := <-tick.C:
rt.lock.Lock()
rt.checkAllBuckets(now)
rt.lock.Unlock()
}
}
}
// Dump stores all the nodes into the underlying storage.
func (rt *routingTable) dump() {
nodes := make([]RoutingTableNode, 0, 128)
rt.lock.RLock()
for _, bkt := range rt.bkts {
for _, n := range bkt.Nodes {
nodes = append(nodes, RoutingTableNode{
Node: krpc.Node{ID: n.ID, Addr: n.Addr},
LastChanged: n.LastChanged,
})
}
}
rt.lock.RUnlock()
if len(nodes) > 0 {
err := rt.s.conf.RoutingTableStorage.Dump(rt.root, nodes, rt.ipv6)
if err != nil {
rt.s.conf.ErrorLog("fail to dump nodes in routing table(ipv6=%v): %s", rt.ipv6, err)
}
}
}
func (rt *routingTable) checkAllBuckets(now time.Time) {
defer func() {
if err := recover(); err != nil {
rt.s.conf.ErrorLog("panic: %v", err)
}
}()
for _, bkt := range rt.bkts {
bkt.CheckAllNodes(now)
}
}
func (rt *routingTable) Len() (n int) {
rt.lock.RLock()
for _, bkt := range rt.bkts {
n += len(bkt.Nodes)
}
rt.lock.RUnlock()
return
}
// Stop stops the routing table.
func (rt *routingTable) Stop() {
select {
case <-rt.exit:
default:
close(rt.exit)
}
}
// AddNode adds the node into the routing table.
//
// The returned value:
// NodeAdded: The node is added successfully.
// NodeNotAdded: The node is not added and is discarded.
// NodeExistAndUpdated: The node has existed, and its status has been updated.
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
// The current node will be discarded.
//
func (rt *routingTable) AddNode(n krpc.Node) (r int) {
if n.ID == rt.root { // Don't add itself.
return NodeNotAdded
}
return rt.addNode(n, time.Now())
}
func (rt *routingTable) addNode(n krpc.Node, now time.Time) (r int) {
bktid := bucketid(rt.root, n.ID)
rt.lock.Lock()
r = rt.bkts[bktid].AddNode(n, now)
rt.lock.Unlock()
return
}
// Closest returns the maxnum number of the nodes which are the closest to target.
func (rt *routingTable) Closest(target metainfo.Hash, maxnum int) (nodes []krpc.Node) {
close := nodesByDistance{target: target, maxnum: maxnum}
rt.lock.RLock()
for _, b := range rt.bkts {
for _, n := range b.Nodes {
close.push(n.Node)
}
}
rt.lock.RUnlock()
return close.nodes
}
type nodesByDistance struct {
maxnum int
target metainfo.Hash
nodes []krpc.Node
}
func (ns *nodesByDistance) push(node krpc.Node) {
ix := sort.Search(len(ns.nodes), func(i int) bool {
return distcmp(ns.target, ns.nodes[i].ID, node.ID) > 0
})
if len(ns.nodes) < ns.maxnum {
ns.nodes = append(ns.nodes, node)
}
// slide existing nodes down to make room.
// this will overwrite the entry we just appended.
if ix != len(ns.nodes) {
copy(ns.nodes[ix+1:], ns.nodes[ix:])
ns.nodes[ix] = node
}
}
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b metainfo.Hash) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
if da > db {
return 1
} else if da < db {
return -1
}
}
return 0
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
/// Bucket
// Predefine some values returned by the method AddNode of routing table.
const (
NodeAdded = iota
NodeNotAdded
NodeExistAndUpdated
NodeExistAndChanged
)
type bucket struct {
table *routingTable
Nodes []*wrappedNode
LastChanged time.Time
}
func (b *bucket) emit(f func(metainfo.Hash, RoutingTableNode) error, n *wrappedNode) {
node := RoutingTableNode{Node: n.Node, LastChanged: n.LastChanged}
f(b.table.root, node)
}
func (b *bucket) UpdateLastChangedTime(now time.Time) {
b.LastChanged = now
}
func (b *bucket) AddNode(n krpc.Node, now time.Time) (status int) {
// Update the old one.
for _, orig := range b.Nodes {
if orig.ID == n.ID {
if orig.Addr.Equal(n.Addr) {
orig.UpdateLastChangedTime(now)
status = NodeExistAndUpdated
return
}
// // TODO: Should we replace the old one??
// b.UpdateLastChangedTime(now)
// copy(b.Nodes[i:], b.Nodes[i+1:])
// b.Nodes[len(b.Nodes)-1] = newWrappedNode(b, n, now)
// status = nodeExistAndChanged
return NodeExistAndChanged
}
}
// The bucket is not full and append it.
if _len := len(b.Nodes); _len < b.table.k {
b.UpdateLastChangedTime(now)
b.Nodes = append(b.Nodes, newWrappedNode(b, n, now))
status = NodeAdded
return
}
// The bucket is full and remove the bad node, or discard the current node.
for i, node := range b.Nodes {
if node.IsBad(now) {
b.UpdateLastChangedTime(now)
copy(b.Nodes[i:], b.Nodes[i+1:])
b.Nodes = append(b.Nodes, newWrappedNode(b, n, now))
status = NodeAdded
return
}
}
// It will discard the current node and return false.
status = NodeNotAdded
return
}
func (b *bucket) delNodeByIndex(index int) {
b.Nodes = append(b.Nodes[:index], b.Nodes[index+1:]...)
}
func (b *bucket) CheckAllNodes(now time.Time) {
_len := len(b.Nodes)
if _len == 0 || now.Sub(b.LastChanged) < dubiousTimeout {
return
}
// Check all the dubious nodes.
indexes := make([]int, 0, len(b.Nodes))
for i, node := range b.Nodes {
switch status := node.Status(now); status {
case nodeStatusGood:
case nodeStatusDubious:
// Try to send the PING query to the dubious node to check whether it is alive.
if err := b.table.s.Ping(node.Node.Addr.UDPAddr()); err != nil {
b.table.s.conf.ErrorLog("fail to ping '%s': %s", node.Node.String(), err)
}
case nodeStatusBad:
// Remove the bad node
indexes = append(indexes, i)
default:
panic(status)
}
}
// Remove all the bad nodes.
for _, index := range indexes {
b.delNodeByIndex(index)
}
}
func bucketid(ownerid, nid metainfo.Hash) int {
if ownerid == nid {
return 0
}
var i int
var bite byte
var bitDiff int
var v byte
for i, bite = range ownerid {
v = bite ^ nid[i]
switch {
case v&0x80 > 0:
bitDiff = 7
goto calc
case v&0x40 > 0:
bitDiff = 6
goto calc
case v&0x20 > 0:
bitDiff = 5
goto calc
case v&0x10 > 0:
bitDiff = 4
goto calc
case v&0x08 > 0:
bitDiff = 3
goto calc
case v&0x04 > 0:
bitDiff = 2
goto calc
case v&0x02 > 0:
bitDiff = 1
goto calc
case v&0x01 > 0:
bitDiff = 0
goto calc
}
}
calc:
return i*8 + (8 - bitDiff)
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
/// Node
const (
nodeStatusGood = iota
nodeStatusDubious
nodeStatusBad
)
type wrappedNode struct {
krpc.Node
LastChanged time.Time
bkt *bucket
}
func newWrappedNode(bkt *bucket, n krpc.Node, now time.Time) *wrappedNode {
return &wrappedNode{bkt: bkt, Node: n, LastChanged: now}
}
func (n *wrappedNode) UpdateLastChangedTime(now ...time.Time) {
var _now time.Time
if len(now) > 0 {
_now = now[0]
} else {
_now = time.Now()
}
n.LastChanged = _now
n.bkt.UpdateLastChangedTime(_now)
return
}
func (n *wrappedNode) IsBad(now ...time.Time) bool {
var _now time.Time
if len(now) > 0 {
_now = now[0]
} else {
_now = time.Now()
}
return _now.Sub(n.LastChanged) > badTimeout
}
func (n *wrappedNode) IsDubious(now ...time.Time) bool {
var _now time.Time
if len(now) > 0 {
_now = now[0]
} else {
_now = time.Now()
}
return _now.Sub(n.LastChanged) > dubiousTimeout
}
func (n *wrappedNode) Status(now time.Time) int {
duration := now.Sub(n.LastChanged)
switch {
case duration > badTimeout:
return nodeStatusBad
case duration > dubiousTimeout:
return nodeStatusDubious
default:
return nodeStatusGood
}
}

View File

@ -0,0 +1,42 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"time"
"github.com/xgfone/bt/krpc"
"github.com/xgfone/bt/metainfo"
)
// RoutingTableNode represents the node with last changed time in the routing table.
type RoutingTableNode struct {
Node krpc.Node
LastChanged time.Time
}
// RoutingTableStorage is used to store the nodes in the routing table.
type RoutingTableStorage interface {
Load(ownid metainfo.Hash, ipv6 bool) (nodes []RoutingTableNode, err error)
Dump(ownid metainfo.Hash, nodes []RoutingTableNode, ipv6 bool) (err error)
}
// NewNoopRoutingTableStorage returns a no-op RoutingTableStorage.
func NewNoopRoutingTableStorage() RoutingTableStorage { return noopStorage{} }
type noopStorage struct{}
func (s noopStorage) Load(metainfo.Hash, bool) (nodes []RoutingTableNode, err error) { return }
func (s noopStorage) Dump(metainfo.Hash, []RoutingTableNode, bool) (err error) { return }

107
dht/token_manager.go Normal file
View File

@ -0,0 +1,107 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"net"
"sync"
"time"
"github.com/xgfone/bt/utils"
)
// TokenManager is used to manage and validate the token.
//
// TODO: Should we allocate the different token for each node??
type tokenManager struct {
lock sync.RWMutex
last string
new string
exit chan struct{}
addrs sync.Map
}
func newTokenManager() *tokenManager {
token := utils.RandomString(8)
return &tokenManager{last: token, new: token, exit: make(chan struct{})}
}
func (tm *tokenManager) updateToken() {
token := utils.RandomString(8)
tm.lock.Lock()
tm.last, tm.new = tm.new, token
tm.lock.Unlock()
}
func (tm *tokenManager) clear(now time.Time, expired time.Duration) {
tm.addrs.Range(func(k interface{}, v interface{}) bool {
if now.Sub(v.(time.Time)) >= expired {
tm.addrs.Delete(k)
}
return true
})
}
// Start starts the token manager.
func (tm *tokenManager) Start(expired time.Duration) {
tick1 := time.NewTicker(expired)
tick2 := time.NewTicker(expired / 2)
defer tick1.Stop()
defer tick2.Stop()
for {
select {
case <-tm.exit:
return
case <-tick2.C:
tm.updateToken()
case now := <-tick1.C:
tm.clear(now, expired)
}
}
}
// Stop stops the token manager.
func (tm *tokenManager) Stop() {
select {
case <-tm.exit:
default:
close(tm.exit)
}
}
// Token allocates a token for a node addr and returns the token.
func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) {
addrs := addr.String()
tm.lock.RLock()
token = tm.new
tm.lock.RUnlock()
tm.addrs.Store(addrs, time.Now())
return
}
// Check checks whether the token associated with the node addr is valid,
// that's, it's not expired.
func (tm *tokenManager) Check(addr *net.UDPAddr, token string) (ok bool) {
tm.lock.RLock()
last, new := tm.last, tm.new
tm.lock.RUnlock()
if last == token || new == token {
_, ok = tm.addrs.Load(addr.String())
}
return
}

152
dht/transaction_manager.go Normal file
View File

@ -0,0 +1,152 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dht
import (
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/xgfone/bt/krpc"
"github.com/xgfone/bt/metainfo"
)
type transaction struct {
ID string
Query string
Arg krpc.QueryArg
Addr *net.UDPAddr
Time time.Time
Depth int
Visited metainfo.Hashes
Callback func(Result)
OnError func(t *transaction, code int, reason string)
OnTimeout func(t *transaction)
OnResponse func(t *transaction, radd *net.UDPAddr, msg krpc.Message)
}
func (t *transaction) Done(r Result) {
if t.Callback != nil {
r.Addr = t.Addr
t.Callback(r)
}
}
func noopResponse(*transaction, *net.UDPAddr, krpc.Message) {}
func newTransaction(s *Server, a *net.UDPAddr, q string, qa krpc.QueryArg,
callback ...func(Result)) *transaction {
var cb func(Result)
if len(callback) > 0 {
cb = callback[0]
}
return &transaction{
Addr: a,
Query: q,
Arg: qa,
Callback: cb,
OnError: s.onError,
OnTimeout: s.onTimeout,
OnResponse: noopResponse,
Time: time.Now(),
}
}
type transactionkey struct {
id string
addr string
}
type transactionManager struct {
lock sync.Mutex
exit chan struct{}
trans map[transactionkey]*transaction
tid uint32
}
func newTransactionManager() *transactionManager {
return &transactionManager{
exit: make(chan struct{}),
trans: make(map[transactionkey]*transaction, 128),
}
}
// Start starts the transaction manager.
func (tm *transactionManager) Start(s *Server, interval time.Duration) {
tick := time.NewTicker(interval)
defer tick.Stop()
for {
select {
case <-tm.exit:
return
case now := <-tick.C:
tm.lock.Lock()
for k, t := range tm.trans {
if now.Sub(t.Time) > interval {
delete(tm.trans, k)
t.OnTimeout(t)
}
}
tm.lock.Unlock()
}
}
}
// Stop stops the transaction manager.
func (tm *transactionManager) Stop() {
select {
case <-tm.exit:
default:
close(tm.exit)
}
}
// GetTransactionID returns a new transaction id.
func (tm *transactionManager) GetTransactionID() string {
return strconv.FormatUint(uint64(atomic.AddUint32(&tm.tid, 1)), 36)
}
// AddTransaction adds the new transaction.
func (tm *transactionManager) AddTransaction(t *transaction) {
key := transactionkey{id: t.ID, addr: t.Addr.String()}
tm.lock.Lock()
tm.trans[key] = t
tm.lock.Unlock()
}
// DeleteTransaction deletes the transaction.
func (tm *transactionManager) DeleteTransaction(t *transaction) {
key := transactionkey{id: t.ID, addr: t.Addr.String()}
tm.lock.Lock()
delete(tm.trans, key)
tm.lock.Unlock()
}
// PopTransaction deletes and returns the transaction by the transaction id
// and the peer address.
//
// Return nil if there is no the transaction.
func (tm *transactionManager) PopTransaction(tid string, addr *net.UDPAddr) (t *transaction) {
key := transactionkey{id: tid, addr: addr.String()}
tm.lock.Lock()
if t = tm.trans[key]; t != nil {
delete(tm.trans, key)
}
tm.lock.Unlock()
return
}

315
downloader/torrent_info.go Normal file
View File

@ -0,0 +1,315 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package downloader is used to download the torrent or the real file
// from the peer node by the peer wire protocol.
package downloader
import (
"bytes"
"crypto/sha1"
"errors"
"fmt"
"log"
"net"
"strconv"
"github.com/xgfone/bt/metainfo"
"github.com/xgfone/bt/peerprotocol"
)
const peerBlockSize = 16384 // 16KiB.
// Request is used to send a download request.
type request struct {
Host string
Port uint16
PeerID metainfo.Hash
InfoHash metainfo.Hash
}
// TorrentResponse represents a torrent info response.
type TorrentResponse struct {
Host string // Which host the torrent is downloaded from
Port uint16 // Which port the torrent is downloaded from
PeerID metainfo.Hash // ID of the peer where torrent is downloaded from
InfoHash metainfo.Hash // The SHA-1 hash of the torrent to be downloaded
InfoBytes []byte // The content of the info part in the torrent
}
// TorrentDownloaderConfig is used to configure the TorrentDownloader.
type TorrentDownloaderConfig struct {
// ID is the id of the downloader peer node.
//
// The default is a random id.
ID metainfo.Hash
// WorkerNum is the number of the worker downloading the torrent concurrently.
//
// The default is 128.
WorkerNum int
// ErrorLog is used to log the error.
//
// The default is log.Printf.
ErrorLog func(format string, args ...interface{})
}
func (c *TorrentDownloaderConfig) set(conf ...TorrentDownloaderConfig) {
if len(conf) > 0 {
*c = conf[0]
}
if c.WorkerNum <= 0 {
c.WorkerNum = 128
}
if c.ID.IsZero() {
c.ID = metainfo.NewRandomHash()
}
if c.ErrorLog == nil {
c.ErrorLog = log.Printf
}
}
// TorrentDownloader is used to download the torrent file from the peer.
type TorrentDownloader struct {
conf TorrentDownloaderConfig
exit chan struct{}
requests chan request
responses chan TorrentResponse
ondht func(string, uint16)
ebits peerprotocol.ExtensionBits
ehmsg peerprotocol.ExtendedHandshakeMsg
}
// NewTorrentDownloader returns a new TorrentDownloader.
//
// If id is ZERO, it is reset to a random id. workerNum is 128 by default.
func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader {
var conf TorrentDownloaderConfig
conf.set(c...)
d := &TorrentDownloader{
conf: conf,
exit: make(chan struct{}),
requests: make(chan request, conf.WorkerNum),
responses: make(chan TorrentResponse, 1024),
ehmsg: peerprotocol.ExtendedHandshakeMsg{
M: map[string]uint8{peerprotocol.ExtendedMessageNameMetadata: 1},
},
}
for i := 0; i < conf.WorkerNum; i++ {
go d.worker()
}
d.ebits.Set(peerprotocol.ExtensionBitExtended)
return d
}
// Request submits a download request.
func (d *TorrentDownloader) Request(host string, port uint16, infohash metainfo.Hash) {
d.requests <- request{Host: host, Port: port, InfoHash: infohash}
}
// Response returns a response channel to get the downloaded torrent info.
func (d *TorrentDownloader) Response() <-chan TorrentResponse {
return d.responses
}
// Close closes the downloader and releases the underlying resources.
func (d *TorrentDownloader) Close() {
select {
case <-d.exit:
default:
close(d.exit)
}
}
// OnDHTNode sets the DHT node callback, which will enable DHT extenstion bit.
//
// In the callback function, you maybe ping it like DHT by UDP.
// If the node responds, you can add the node in DHT routing table.
//
// BEP 5
func (d *TorrentDownloader) OnDHTNode(cb func(host string, port uint16)) {
d.ondht = cb
if cb == nil {
d.ebits.Unset(peerprotocol.ExtensionBitDHT)
} else {
d.ebits.Set(peerprotocol.ExtensionBitDHT)
}
}
func (d *TorrentDownloader) worker() {
for {
select {
case <-d.exit:
return
case r := <-d.requests:
if err := d.download(r.Host, r.Port, r.PeerID, r.InfoHash); err != nil {
d.conf.ErrorLog("fail to download the torrent '%s': %s",
r.InfoHash.HexString(), err)
}
}
}
}
func (d *TorrentDownloader) download(host string, port uint16,
peerID, infohash metainfo.Hash) (err error) {
addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10))
conn, err := peerprotocol.NewPeerConnByDial(d.conf.ID, addr)
if err != nil {
return fmt.Errorf("fail to dial to '%s': %s", addr, err)
}
defer conn.Close()
conn.ExtensionBits = d.ebits
rmsg, err := conn.Handshake(infohash)
if err != nil || rmsg.InfoHash != infohash || !rmsg.IsSupportExtended() {
return
} else if !peerID.IsZero() && peerID != rmsg.PeerID {
return fmt.Errorf("inconsistent peer id '%s'", rmsg.PeerID.HexString())
}
if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil {
return
}
var pieces [][]byte
var piecesNum int
var metadataSize int
var utmetadataID uint8
var msg peerprotocol.Message
for {
if msg, err = conn.ReadMsg(); err != nil {
return err
}
select {
case <-d.exit:
return
default:
}
if msg.Keepalive {
continue
}
switch msg.Type {
case peerprotocol.Extended:
case peerprotocol.Port:
if d.ondht != nil {
d.ondht(host, msg.Port)
}
continue
default:
continue
}
switch msg.ExtendedID {
case peerprotocol.ExtendedIDHandshake:
if utmetadataID > 0 {
return fmt.Errorf("rehandshake from the peer '%s'", conn.RemoteAddr().String())
}
var ehmsg peerprotocol.ExtendedHandshakeMsg
if err = ehmsg.Decode(msg.ExtendedPayload); err != nil {
return
}
utmetadataID = ehmsg.M[peerprotocol.ExtendedMessageNameMetadata]
if utmetadataID == 0 {
return errors.New(`the peer does not support "ut_metadata"`)
}
metadataSize = ehmsg.MetadataSize
piecesNum = metadataSize / peerBlockSize
if metadataSize%peerBlockSize != 0 {
piecesNum++
}
pieces = make([][]byte, piecesNum)
go d.requestPieces(conn, utmetadataID, piecesNum)
case 1:
if pieces == nil {
return
}
var utmsg peerprotocol.UtMetadataExtendedMsg
if err = utmsg.DecodeFromPayload(msg.ExtendedPayload); err != nil {
return
}
if utmsg.MsgType != peerprotocol.UtMetadataExtendedMsgTypeData {
continue
}
pieceLen := len(utmsg.Data)
if (utmsg.Piece != piecesNum-1 && pieceLen != peerBlockSize) ||
(utmsg.Piece == piecesNum-1 && pieceLen != metadataSize%peerBlockSize) {
return
}
pieces[utmsg.Piece] = utmsg.Data
finish := true
for _, piece := range pieces {
if len(piece) == 0 {
finish = false
break
}
}
if finish {
metadataInfo := bytes.Join(pieces, nil)
if infohash == metainfo.Hash(sha1.Sum(metadataInfo)) {
d.responses <- TorrentResponse{
Host: host,
Port: port,
PeerID: rmsg.PeerID,
InfoHash: infohash,
InfoBytes: metadataInfo,
}
}
return
}
}
}
}
func (d *TorrentDownloader) requestPieces(conn *peerprotocol.PeerConn, utMetadataID uint8, piecesNum int) {
for i := 0; i < piecesNum; i++ {
payload, err := peerprotocol.UtMetadataExtendedMsg{
MsgType: peerprotocol.UtMetadataExtendedMsgTypeRequest,
Piece: i,
}.EncodeToBytes()
if err != nil {
panic(err)
}
msg := peerprotocol.Message{
Type: peerprotocol.Extended,
ExtendedID: utMetadataID,
ExtendedPayload: payload,
}
if err = conn.WriteMsg(msg); err != nil {
d.conf.ErrorLog("fail to send the ut_metadata request: %s", err)
return
}
}
}

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module github.com/xgfone/bt
go 1.11

16
krpc/doc.go Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package krpc supplies the KRPC message used by DHT.
package krpc

440
krpc/message.go Normal file
View File

@ -0,0 +1,440 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package krpc
import (
"bytes"
"fmt"
"github.com/xgfone/bt/bencode"
"github.com/xgfone/bt/metainfo"
)
// Predefine some error code.
const (
// BEP 5
ErrorCodeGeneric = 201
ErrorCodeServer = 202
ErrorCodeProtocol = 203
ErrorCodeMethodUnknown = 204
// BEP 44
ErrorCodeMessageValueFieldTooBig = 205
ErrorCodeInvalidSignature = 206
ErrorCodeSaltFieldTooBig = 207
ErrorCodeCasHashMismatched = 301
ErrorCodeSequenceNumberLessThanCurrent = 302
)
// Message represents messages that nodes in the network send to each other
// as specified by the KRPC protocol, which are also referred to as the KRPC
// messages.
//
// There are three types of messages: QUERY, RESPONSE, ERROR
// The message is a dictonary that is then "bencoded"
// (serialization & compression format adopted by the BitTorrent)
// and sent via the UDP connection to peers.
//
// A KRPC message is a single dictionary with two keys common to every message
// and additional keys depending on the type of message. Every message has a key
// "t" with a string value representing a transaction ID. This transaction ID
// is generated by the querying node and is echoed in the response, so responses
// may be correlated with multiple queries to the same node. The transaction ID
// should be encoded as a short string of binary numbers, typically 2 characters
// are enough as they cover 2^16 outstanding queries. The other key contained
// in every KRPC message is "y" with a single character value describing
// the type of message. The value of the "y" key is one of "q" for query,
// "r" for response, or "e" for error.
type Message struct {
T string `bencode:"t"` // required: transaction ID
Y string `bencode:"y"` // required: type of the message: q for QUERY, r for RESPONSE, e for ERROR
Q string `bencode:"q,omitempty"` // Query method (one of 4: "ping", "find_node", "get_peers", "announce_peer")
A QueryArg `bencode:"a,omitempty"` // named arguments sent with a query
R ResponseResult `bencode:"r,omitempty"` // RESPONSE type only
E Error `bencode:"e,omitempty"` // ERROR type only
RO bool `bencode:"ro,omitempty"` // BEP 43: ReadOnly
}
// NewQueryMsg return a QUERY message.
func NewQueryMsg(tid, method string, arg QueryArg) Message {
return Message{T: tid, Y: "q", Q: method, A: arg}
}
// NewResponseMsg return a RESPONSE message.
func NewResponseMsg(tid string, data ResponseResult) Message {
return Message{T: tid, Y: "r", R: data}
}
// NewErrorMsg return a ERROR message.
func NewErrorMsg(tid string, code int, reason string) Message {
return Message{T: tid, Y: "e", E: Error{Code: code, Reason: reason}}
}
// IsQuery reports whether the message is an QUERY.
func (m Message) IsQuery() bool {
return m.Y == "q"
}
// IsResponse reports whether the message is an RESPONSE.
func (m Message) IsResponse() bool {
return m.Y == "r"
}
// IsError reports whether the message is an ERROR.
func (m Message) IsError() bool {
return m.Y == "e"
}
// RID returns the value named "id" in "r".
//
// Return "" instead if no "id".
func (m Message) RID() metainfo.Hash {
return m.R.ID
}
// QID returns the value named "id" in "a", that's, the query arguments.
//
// Return "" instead if no "id".
func (m Message) QID() metainfo.Hash {
return m.A.ID
}
// ID returns the QID or RID.
func (m Message) ID() metainfo.Hash {
switch m.Y {
case "q":
return m.QID()
case "r":
return m.RID()
default:
panic(fmt.Errorf("unknown message type '%s'", m.Y))
}
}
// Error represents a response error.
type Error struct {
Code int
Reason string
}
// NewError returns a new Error.
func NewError(code int, reason string) Error {
return Error{Code: code, Reason: reason}
}
func (e *Error) decode(vs []interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("unpacking %#v: %v", vs, r)
}
}()
e.Code = int(vs[0].(int64))
e.Reason = vs[1].(string)
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (e *Error) UnmarshalBencode(b []byte) (err error) {
var iface interface{}
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
return
}
switch v := iface.(type) {
case []interface{}:
err = e.decode(v)
case string:
e.Reason = v
default:
err = fmt.Errorf(`KRPC error bencode value has unexpected type: %T`, v)
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (e Error) MarshalBencode() (ret []byte, err error) {
if e.Code == 0 && e.Reason == "" {
return nil, nil
}
buf := bytes.NewBuffer(nil)
buf.Grow(32)
err = bencode.NewEncoder(buf).Encode([]interface{}{e.Code, e.Reason})
if err != nil {
ret = buf.Bytes()
}
return
}
func (e Error) Error() string {
return fmt.Sprintf("KRPC error %d: %s", e.Code, e.Reason)
}
// Want represents the type of nodes that the request wants.
//
// BEP 32
type Want string
// Predefine some wants.
//
// BEP 32
const (
WantNodes Want = "n4"
WantNodes6 Want = "n6"
)
// QueryArg represents the arguments used by the QUERY message.
type QueryArg struct {
// ID is used to identify a querying node.
ID metainfo.Hash `bencode:"id"` // BEP 5
// Target is used to identify the node sought by the queryer.
//
// find_node
Target metainfo.Hash `bencode:"target,omitempty"` // BEP 5
// InfoHash is the infohash of the torrent file.
//
// get_peers, announce_peer
InfoHash metainfo.Hash `bencode:"info_hash,omitempty"` // BEP 5
// Port is the port on where sender is listening to allow others
// to download the torrent.
//
// announce_peer
Port uint16 `bencode:"port,omitempty"` // BEP 5
// Token is the received one from an earlier "get_peers" query.
//
// announce_peer
Token string `bencode:"token,omitempty"` // BEP 5
// ImpliedPort is used by the senders to apparent DHT port
// to improve NAT support.
//
// announce_peer
ImpliedPort bool `bencode:"implied_port,omitempty"` // BEP 5
// Wants is only used to represent to expect "nodes" for "n4"
// or "nodes6" for "n6".
//
// Notice: It only governs the presence of the "nodes" and "nodes6"
// parameters, not the interpretation of "values".
//
// find_node, get_peers
Wants []Want `bencode:"want,omitempty"` // BEP 32
}
// ContainsWant reports whether the request contains the given Want.
func (a QueryArg) ContainsWant(w Want) bool {
for _, want := range a.Wants {
if want == w {
return true
}
}
return false
}
// GetPort returns the real port of the peer.
func (a QueryArg) GetPort(port int) uint16 {
if a.ImpliedPort {
return uint16(port)
}
return a.Port
}
// ResponseResult represents the results used by the RESPONSE message.
type ResponseResult struct {
// ID is used to indentify the queried node, that's, the response node.
ID metainfo.Hash `bencode:"id"` // BEP 5
// Nodes is a string containing the compact node information for the list
// of the ipv4 target node, or the K(8) closest good nodes in routing table
// of the requested ipv4 target.
//
// find_node
Nodes CompactIPv4Node `bencode:"nodes,omitempty"` // BEP 5
// Nodes6 is a string containing the compact node information for the list
// of the ipv6 target node, or the K(8) closest good nodes in routing table
// of the requested ipv6 target.
//
// find_node
Nodes6 CompactIPv6Node `bencode:"nodes6,omitempty"` // BEP 32
// Token is used for future "announce_peer".
//
// get_peers
Token string `bencode:"token,omitempty"` // BEP 5
// Values is a list of the torrent peers.
//
// get_peers
Values CompactAddresses `bencode:"values,omitempty"` // BEP 5
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// CompactAddresses represents a group of the compact addresses.
type CompactAddresses []metainfo.Address
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (cas CompactAddresses) MarshalBinary() ([]byte, error) {
ss := make([]string, len(cas))
for i, addr := range cas {
data, err := addr.MarshalBinary()
if err != nil {
return nil, err
}
ss[i] = string(data)
}
return bencode.EncodeBytes(ss)
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (cas *CompactAddresses) UnmarshalBinary(b []byte) (err error) {
var ss []string
if err = bencode.DecodeBytes(b, &ss); err == nil {
addrs := make(CompactAddresses, len(ss))
for i, s := range ss {
var addr metainfo.Address
if err = addr.UnmarshalBinary([]byte(s)); err != nil {
return
}
addrs[i] = addr
}
}
return
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// CompactIPv4Node is a set of IPv4 Nodes.
type CompactIPv4Node []Node
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (cn CompactIPv4Node) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(26 * len(cn))
for _, ni := range cn {
if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 {
continue
}
if n, err := ni.WriteBinary(buf); err != nil {
return nil, err
} else if n != 26 {
panic(fmt.Errorf("CompactIPv4NodeInfo: the invalid NodeInfo length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (cn *CompactIPv4Node) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%26 != 0 {
return fmt.Errorf("CompactIPv4NodeInfo: invalid bytes length '%d'", _len)
}
nis := make([]Node, 0, _len/26)
for i := 0; i < _len; i += 26 {
var ni Node
if err = ni.UnmarshalBinary(b[i : i+26]); err != nil {
return
}
nis = append(nis, ni)
}
*cn = nis
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cn CompactIPv4Node) MarshalBencode() (b []byte, err error) {
if b, err = cn.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cn *CompactIPv4Node) UnmarshalBencode(b []byte) (err error) {
var s string
if err = bencode.DecodeBytes(b, &s); err == nil {
err = cn.UnmarshalBinary([]byte(s))
}
return
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// CompactIPv6Node is a set of IPv6 Nodes.
type CompactIPv6Node []Node
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (cn CompactIPv6Node) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(38 * len(cn))
for _, ni := range cn {
ni.Addr.IP = ni.Addr.IP.To16()
if n, err := ni.WriteBinary(buf); err != nil {
return nil, err
} else if n != 38 {
panic(fmt.Errorf("CompactIPv6NodeInfo: the invalid NodeInfo length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (cn *CompactIPv6Node) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%38 != 0 {
return fmt.Errorf("CompactIPv6NodeInfo: invalid bytes length '%d'", _len)
}
nis := make([]Node, 0, _len/38)
for i := 0; i < _len; i += 38 {
var ni Node
if err = ni.UnmarshalBinary(b[i : i+38]); err != nil {
return
}
nis = append(nis, ni)
}
*cn = nis
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cn CompactIPv6Node) MarshalBencode() (b []byte, err error) {
if b, err = cn.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cn *CompactIPv6Node) UnmarshalBencode(b []byte) (err error) {
var s string
if err = bencode.DecodeBytes(b, &s); err == nil {
err = cn.UnmarshalBinary([]byte(s))
}
return
}

41
krpc/message_test.go Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package krpc
import (
"testing"
"github.com/xgfone/bt/bencode"
)
func TestMessage(t *testing.T) {
s, err := bencode.EncodeString(Message{RO: true})
if err != nil {
t.Fatal(err)
}
var ms map[string]interface{}
if err := bencode.DecodeString(s, &ms); err != nil {
t.Fatal(err)
} else if len(ms) != 3 {
t.Fatal()
} else if v, ok := ms["t"]; !ok || v != "" {
t.Fatal()
} else if v, ok := ms["y"]; !ok || v != "" {
t.Fatal()
} else if v, ok := ms["ro"].(int64); !ok || v != 1 {
t.Fatal()
}
}

84
krpc/node.go Normal file
View File

@ -0,0 +1,84 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package krpc
import (
"bytes"
"fmt"
"io"
"net"
"github.com/xgfone/bt/metainfo"
)
// Node represents a node information.
type Node struct {
ID metainfo.Hash
Addr metainfo.Address
}
// NewNode returns a new Node.
func NewNode(id metainfo.Hash, ip net.IP, port int) Node {
return Node{ID: id, Addr: metainfo.NewAddress(ip, uint16(port))}
}
// NewNodeByUDPAddr returns a new Node with the id and the UDP address.
func NewNodeByUDPAddr(id metainfo.Hash, addr *net.UDPAddr) (n Node) {
n.ID = id
n.Addr.FromUDPAddr(addr)
return
}
func (n Node) String() string {
return fmt.Sprintf("Node<%x@%s>", n.ID, n.Addr)
}
// Equal reports whether n is equal to o.
func (n Node) Equal(o Node) bool {
return n.ID == o.ID && n.Addr.Equal(o.Addr)
}
// WriteBinary is the same as MarshalBinary, but writes the result into w
// instead of returning.
func (n Node) WriteBinary(w io.Writer) (m int, err error) {
var n1, n2 int
if n1, err = w.Write(n.ID[:]); err == nil {
m = n1
if n2, err = n.Addr.WriteBinary(w); err == nil {
m += n2
}
}
return
}
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (n Node) MarshalBinary() (data []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(48)
if _, err = n.WriteBinary(buf); err == nil {
data = buf.Bytes()
}
return
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (n *Node) UnmarshalBinary(b []byte) error {
if len(b) < 26 {
return io.ErrShortBuffer
}
copy(n.ID[:], b[:20])
return n.Addr.UnmarshalBinary(b[20:])
}

342
metainfo/address.go Normal file
View File

@ -0,0 +1,342 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"github.com/xgfone/bt/bencode"
)
// ErrInvalidAddr is returned when the compact address is invalid.
var ErrInvalidAddr = fmt.Errorf("invalid compact information of ip and port")
// Address represents a client/server listening on a UDP port implementing
// the DHT protocol.
type Address struct {
IP net.IP // For IPv4, its length must be 4.
Port uint16
}
// NewAddress returns a new Address.
func NewAddress(ip net.IP, port uint16) Address {
if ipv4 := ip.To4(); len(ipv4) > 0 {
ip = ipv4
}
return Address{IP: ip, Port: port}
}
// NewAddressFromString returns a new Address by the address string.
func NewAddressFromString(s string) (addr Address, err error) {
err = addr.FromString(s)
return
}
// NewAddressesFromString returns a list of Addresses by the address string.
func NewAddressesFromString(s string) (addrs []Address, err error) {
shost, sport, err := net.SplitHostPort(s)
if err != nil {
return nil, fmt.Errorf("invalid address '%s': %s", s, err)
}
var port uint16
if sport != "" {
v, err := strconv.ParseUint(sport, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid address '%s': %s", s, err)
}
port = uint16(v)
}
ips, err := net.LookupIP(shost)
if err != nil {
return nil, fmt.Errorf("fail to lookup the domain '%s': %s", shost, err)
}
addrs = make([]Address, len(ips))
for i, ip := range ips {
addrs[i] = Address{IP: ip, Port: port}
}
return
}
// FromString parses and sets the ip from the string addr.
func (a *Address) FromString(addr string) (err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
}
if port != "" {
v, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
}
a.Port = uint16(v)
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("fail to lookup the domain '%s': %s", host, err)
} else if len(ips) == 0 {
return fmt.Errorf("the domain '%s' has no ips", host)
}
a.IP = ips[0]
if ip := a.IP.To4(); len(ip) > 0 {
a.IP = ip
}
return
}
// FromUDPAddr sets the ip from net.UDPAddr.
func (a *Address) FromUDPAddr(ua *net.UDPAddr) {
a.Port = uint16(ua.Port)
a.IP = ua.IP
if ipv4 := a.IP.To4(); len(ipv4) > 0 {
a.IP = ipv4
}
}
// UDPAddr creates a new net.UDPAddr.
func (a Address) UDPAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: a.IP,
Port: int(a.Port),
}
}
func (a Address) String() string {
if a.Port == 0 {
return a.IP.String()
}
return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10))
}
// Equal reports whether n is equal to o, which is equal to
// n.HasIPAndPort(o.IP, o.Port)
func (a Address) Equal(o Address) bool {
return a.Port == o.Port && a.IP.Equal(o.IP)
}
// HasIPAndPort reports whether the current node has the ip and the port.
func (a Address) HasIPAndPort(ip net.IP, port uint16) bool {
return port == a.Port && a.IP.Equal(ip)
}
// WriteBinary is the same as MarshalBinary, but writes the result into w
// instead of returning.
func (a Address) WriteBinary(w io.Writer) (m int, err error) {
if m, err = w.Write(a.IP); err == nil {
if err = binary.Write(w, binary.BigEndian, a.Port); err == nil {
m += 2
}
}
return
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (a *Address) UnmarshalBinary(b []byte) (err error) {
_len := len(b) - 2
switch _len {
case net.IPv4len, net.IPv6len:
default:
return ErrInvalidAddr
}
a.IP = make(net.IP, _len)
copy(a.IP, b[:_len])
a.Port = binary.BigEndian.Uint16(b[_len:])
return
}
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (a Address) MarshalBinary() (data []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(20)
if _, err = a.WriteBinary(buf); err == nil {
data = buf.Bytes()
}
return
}
func (a *Address) decode(vs []interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = e.(error)
}
}()
host := vs[0].(string)
if a.IP = net.ParseIP(host); len(a.IP) == 0 {
return ErrInvalidAddr
} else if ip := a.IP.To4(); len(ip) > 0 {
a.IP = ip
}
a.Port = uint16(vs[1].(int64))
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *Address) UnmarshalBencode(b []byte) (err error) {
var iface interface{}
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
return
}
switch v := iface.(type) {
case string:
err = a.FromString(v)
case []interface{}:
err = a.decode(v)
default:
err = fmt.Errorf("unsupported type: %T", iface)
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (a Address) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(32)
err = bencode.NewEncoder(buf).Encode([]interface{}{a.IP.String(), a.Port})
if err == nil {
b = buf.Bytes()
}
return
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// HostAddress is the same as the Address, but the host part may be
// either a domain or a ip.
type HostAddress struct {
Host string
Port uint16
}
// NewHostAddress returns a new host addrress.
func NewHostAddress(host string, port uint16) HostAddress {
return HostAddress{Host: host, Port: port}
}
// NewHostAddressFromString returns a new host address by the string.
func NewHostAddressFromString(s string) (addr HostAddress, err error) {
err = addr.FromString(s)
return
}
// FromString parses and sets the host from the string addr.
func (a *HostAddress) FromString(addr string) (err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
} else if host == "" {
return fmt.Errorf("invalid address '%s': missing host", addr)
}
if port != "" {
v, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
}
a.Port = uint16(v)
}
a.Host = host
return
}
func (a HostAddress) String() string {
if a.Port == 0 {
return a.Host
}
return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10))
}
// Addresses parses the host address to a list of Addresses.
func (a HostAddress) Addresses() (addrs []Address, err error) {
if ip := net.ParseIP(a.Host); len(ip) > 0 {
return []Address{NewAddress(ip, a.Port)}, nil
}
ips, err := net.LookupIP(a.Host)
if err != nil {
err = fmt.Errorf("fail to lookup the domain '%s': %s", a.Host, err)
} else {
addrs = make([]Address, len(ips))
for i, ip := range ips {
addrs[i] = NewAddress(ip, a.Port)
}
}
return
}
// Equal reports whether a is equal to o.
func (a HostAddress) Equal(o HostAddress) bool {
return a.Port == o.Port && a.Host == o.Host
}
func (a *HostAddress) decode(vs []interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = e.(error)
}
}()
a.Host = vs[0].(string)
a.Port = uint16(vs[1].(int64))
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *HostAddress) UnmarshalBencode(b []byte) (err error) {
var iface interface{}
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
return
}
switch v := iface.(type) {
case string:
err = a.FromString(v)
case []interface{}:
err = a.decode(v)
default:
err = fmt.Errorf("unsupported type: %T", iface)
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (a HostAddress) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(32)
err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port})
if err == nil {
b = buf.Bytes()
}
return
}

48
metainfo/address_test.go Normal file
View File

@ -0,0 +1,48 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"testing"
)
func TestAddress(t *testing.T) {
var addr1 Address
if err := addr1.FromString("1.2.3.4:1234"); err != nil {
t.Error(err)
return
}
data, err := addr1.MarshalBencode()
if err != nil {
t.Error(err)
return
} else if s := string(data); s != `l7:1.2.3.4i1234ee` {
t.Errorf(`expected 'l7:1.2.3.4i1234ee', but got '%s'`, s)
}
var addr2 Address
if err = addr2.UnmarshalBencode(data); err != nil {
t.Error(err)
} else if addr2.String() != `1.2.3.4:1234` {
t.Errorf("expected '1.2.3.4:1234', but got '%s'", addr2)
}
if data, err = addr2.MarshalBinary(); err != nil {
t.Error(err)
} else if s := string(data); s != "\x01\x02\x03\x04\x04\xd2" {
t.Errorf(`expected '\x01\x02\x03\x04\x04\xd2', but got '%#x'`, s)
}
}

17
metainfo/doc.go Normal file
View File

@ -0,0 +1,17 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package metainfo is used to encode or decode the metainfo of the torrent file
// and the MagNet link.
package metainfo

65
metainfo/file.go Normal file
View File

@ -0,0 +1,65 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"path/filepath"
)
// File represents a file in the multi-file case.
type File struct {
// Length is the length of the file in bytes.
Length int64 `json:"length" bencode:"length"` // BEP 3
// Paths is a list containing one or more string elements that together
// represent the path and filename. Each element in the list corresponds
// to either a directory name or (in the case of the final element) the
// filename.
//
// For example, a the file "dir1/dir2/file.ext" would consist of three
// string elements: "dir1", "dir2", and "file.ext". This is encoded as
// a bencoded list of strings such as l4:dir14:dir28:file.exte.
Paths []string `json:"path" bencode:"path"` // BEP 3
}
func (f File) String() string {
return filepath.Join(f.Paths...)
}
// Path returns the path of the current.
func (f File) Path(info Info) string {
if info.IsDir() {
return f.String()
}
return info.Name
}
// Offset returns the offset of the current file from the start.
func (f File) Offset(info Info) (ret int64) {
path := f.Path(info)
for _, file := range info.AllFiles() {
if path == file.Path(info) {
return
}
ret += file.Length
}
panic("not found")
}
type files []File
func (fs files) Len() int { return len(fs) }
func (fs files) Less(i, j int) bool { return fs[i].String() < fs[j].String() }
func (fs files) Swap(i, j int) { f := fs[i]; fs[i] = fs[j]; fs[j] = f }

138
metainfo/info.go Normal file
View File

@ -0,0 +1,138 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
)
// Info is the file inforatino.
type Info struct {
// Name is the name of the file in the single file case.
// Or, it is the name of the directory in the muliple file case.
Name string `json:"name" bencode:"name"` // BEP 3
// PieceLength is the number of bytes in each piece, which is usually
// a power of 2.
PieceLength int64 `json:"piece length" bencode:"piece length"` // BEP 3
// Pieces is the concatenation of all 20-byte SHA1 hash values,
// one per piece (byte string, i.e. not urlencoded).
Pieces Hashes `json:"pieces" bencode:"pieces"` // BEP 3
// Length is the length of the file in bytes in the single file case.
//
// It's mutually exclusive with Files.
Length int64 `json:"length,omitempty" bencode:"length,omitempty"` // BEP 3
// Files is the list of all the files in the multi-file case.
//
// For the purposes of the other keys, the multi-file case is treated
// as only having a single file by concatenating the files in the order
// they appear in the files list.
//
// It's mutually exclusive with Length.
Files []File `json:"files,omitempty" bencode:"files,omitempty"` // BEP 3
}
// NewInfoFromFilePath returns a new Info from a file or directory.
func NewInfoFromFilePath(root string, pieceLength int64) (info Info, err error) {
info.Name = filepath.Base(root)
info.PieceLength = pieceLength
err = filepath.Walk(root, func(path string, fi os.FileInfo, err error) error {
if err != nil {
return err
}
if path == root && !fi.IsDir() { // The root is a file.
info.Length = fi.Size()
return nil
}
relPath, err := filepath.Rel(root, path)
if err != nil {
return fmt.Errorf("error getting relative path: %s", err)
}
info.Files = append(info.Files, File{
Paths: strings.Split(relPath, string(filepath.Separator)),
Length: fi.Size(),
})
return nil
})
if err == nil {
sort.Sort(files(info.Files))
info.Pieces, err = GeneratePiecesFromFiles(info.AllFiles(), info.PieceLength,
func(file File) (io.ReadCloser, error) {
if _len := len(file.Paths); _len > 0 {
paths := make([]string, 0, _len+1)
paths = append(paths, root)
paths = append(paths, file.Paths...)
return os.Open(filepath.Join(paths...))
}
return os.Open(root)
})
if err != nil {
err = fmt.Errorf("error generating pieces: %s", err)
}
}
return
}
// IsDir reports whether the name is a directory, that's, the file is not
// a single file.
func (info Info) IsDir() bool { return len(info.Files) != 0 }
// CountPieces returns the number of the pieces.
func (info Info) CountPieces() int { return len(info.Pieces) }
// TotalLength returns the total length of the torrent file.
func (info Info) TotalLength() (ret int64) {
if info.IsDir() {
for _, fi := range info.Files {
ret += fi.Length
}
} else {
ret = info.Length
}
return
}
// Piece returns the Piece by the index starting with 0.
func (info Info) Piece(index int) Piece {
if n := len(info.Pieces); index >= n {
panic(fmt.Errorf("Info.Piece: index '%d' exceeds maximum '%d'", index, n))
}
return Piece{info: info, index: index}
}
// AllFiles returns all the files.
//
// Notice: for the single file, the Path is nil.
func (info Info) AllFiles() []File {
if info.IsDir() {
return info.Files
}
return []File{{Length: info.Length}}
}

207
metainfo/infohash.go Normal file
View File

@ -0,0 +1,207 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base32"
"encoding/hex"
"fmt"
"github.com/xgfone/bt/bencode"
)
var zeroHash Hash
// HashSize is the size of the InfoHash.
const HashSize = 20
// Hash is the 20-byte SHA1 hash used for info and pieces.
type Hash [HashSize]byte
// NewRandomHash returns a random hash.
func NewRandomHash() (h Hash) {
rand.Read(h[:])
return
}
// NewHash converts the 20-bytes to Hash.
func NewHash(b []byte) (h Hash) {
copy(h[:], b[:HashSize])
return
}
// NewHashFromString returns a new Hash from a string.
func NewHashFromString(s string) (h Hash) {
err := h.FromString(s)
if err != nil {
panic(err)
}
return
}
// NewHashFromHexString returns a new Hash from a hex string.
func NewHashFromHexString(s string) (h Hash) {
err := h.FromHexString(s)
if err != nil {
panic(err)
}
return
}
// NewHashFromBytes returns a new Hash from a byte slice.
func NewHashFromBytes(b []byte) (ret Hash) {
hasher := sha1.New()
hasher.Write(b)
copy(ret[:], hasher.Sum(nil))
return
}
// Bytes returns the byte slice type.
func (h Hash) Bytes() []byte {
return h[:]
}
// String is equal to HexString.
func (h Hash) String() string {
return h.HexString()
}
// BytesString returns the bytes string, that's, string(h[:]).
func (h Hash) BytesString() string {
return string(h[:])
}
// HexString returns the hex string format.
func (h Hash) HexString() string {
return hex.EncodeToString(h[:])
}
// IsZero reports whether the whole hash is zero.
func (h Hash) IsZero() bool {
return h == zeroHash
}
// MarshalBencode implements the interface bencode.Marshaler.
func (h Hash) MarshalBencode() (b []byte, err error) {
return bencode.EncodeBytes(h[:])
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (h *Hash) UnmarshalBencode(b []byte) (err error) {
var s string
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&s); err == nil {
err = h.FromString(s)
}
return
}
// FromString resets the info hash from the string.
func (h *Hash) FromString(s string) (err error) {
switch len(s) {
case HashSize:
copy(h[:], s)
case 2 * HashSize:
err = h.FromHexString(s)
case 32:
var bs []byte
if bs, err = base32.StdEncoding.DecodeString(s); err == nil {
copy(h[:], bs)
}
default:
err = fmt.Errorf("hash string has bad length: %d", len(s))
}
return nil
}
// FromHexString resets the info hash from the hex string.
func (h *Hash) FromHexString(s string) (err error) {
if len(s) != 2*HashSize {
err = fmt.Errorf("hash hex string has bad length: %d", len(s))
return
}
n, err := hex.Decode(h[:], []byte(s))
if err != nil {
return
}
if n != HashSize {
panic(n)
}
return
}
// Xor returns the hash of h XOR o.
func (h Hash) Xor(o Hash) (ret Hash) {
for i := range o {
ret[i] = h[i] ^ o[i]
}
return
}
// Compare returns 0 if h == o, -1 if h < o, or +1 if h > o.
func (h Hash) Compare(o Hash) int { return bytes.Compare(h[:], o[:]) }
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// Hashes is a set of Hashes.
type Hashes []Hash
// Contains reports whether hs contains h.
func (hs Hashes) Contains(h Hash) bool {
for _, _h := range hs {
if h == _h {
return true
}
}
return false
}
// MarshalBencode implements the interface bencode.Marshaler.
func (hs Hashes) MarshalBencode() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(HashSize * len(hs))
for _, h := range hs {
buf.Write(h[:])
}
return bencode.EncodeBytes(buf.Bytes())
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (hs *Hashes) UnmarshalBencode(b []byte) (err error) {
var bs []byte
if err = bencode.DecodeBytes(b, &bs); err != nil {
return
}
_len := len(bs)
if _len%HashSize != 0 {
return fmt.Errorf("Hashes: invalid bytes length '%d'", _len)
}
hashes := make(Hashes, 0, _len/HashSize)
for i := 0; i < _len; i += HashSize {
var h Hash
copy(h[:], bs[i:i+HashSize])
hashes = append(hashes, h)
}
*hs = hashes
return
}

138
metainfo/magnet.go Normal file
View File

@ -0,0 +1,138 @@
// Mozilla Public License Version 2.0
// Modify from github.com/anacrolix/torrent/metainfo.
package metainfo
import (
"encoding/base32"
"encoding/hex"
"errors"
"fmt"
"net/url"
"strings"
)
// Magnet link components.
type Magnet struct {
InfoHash Hash // From "xt"
Trackers []string // From "tr"
DisplayName string // From "dn" if not empty
Params url.Values // All other values, such as "as", "xs", etc
}
const xtPrefix = "urn:btih:"
// Peers returns the list of the addresses of the peers.
//
// See BEP 9
func (m Magnet) Peers() (peers []HostAddress, err error) {
vs := m.Params["x.pe"]
peers = make([]HostAddress, 0, len(vs))
for _, v := range vs {
if v != "" {
var addr HostAddress
if err = addr.FromString(v); err != nil {
return
}
peers = append(peers, addr)
}
}
return
}
func (m Magnet) String() string {
vs := make(url.Values, len(m.Params)+len(m.Trackers)+2)
for k, v := range m.Params {
vs[k] = append([]string(nil), v...)
}
for _, tr := range m.Trackers {
vs.Add("tr", tr)
}
if m.DisplayName != "" {
vs.Add("dn", m.DisplayName)
}
// Transmission and Deluge both expect "urn:btih:" to be unescaped.
// Deluge wants it to be at the start of the magnet link.
// The InfoHash field is expected to be BitTorrent in this implementation.
u := url.URL{
Scheme: "magnet",
RawQuery: "xt=" + xtPrefix + m.InfoHash.HexString(),
}
if len(vs) != 0 {
u.RawQuery += "&" + vs.Encode()
}
return u.String()
}
// ParseMagnetURI parses Magnet-formatted URIs into a Magnet instance.
func ParseMagnetURI(uri string) (m Magnet, err error) {
u, err := url.Parse(uri)
if err != nil {
err = fmt.Errorf("error parsing uri: %w", err)
return
} else if u.Scheme != "magnet" {
err = fmt.Errorf("unexpected scheme %q", u.Scheme)
return
}
q := u.Query()
xt := q.Get("xt")
if m.InfoHash, err = parseInfohash(q.Get("xt")); err != nil {
err = fmt.Errorf("error parsing infohash %q: %w", xt, err)
return
}
dropFirst(q, "xt")
m.DisplayName = q.Get("dn")
dropFirst(q, "dn")
m.Trackers = q["tr"]
delete(q, "tr")
if len(q) == 0 {
q = nil
}
m.Params = q
return
}
func parseInfohash(xt string) (ih Hash, err error) {
if !strings.HasPrefix(xt, xtPrefix) {
err = errors.New("bad xt parameter prefix")
return
}
var n int
encoded := xt[len(xtPrefix):]
switch len(encoded) {
case 40:
n, err = hex.Decode(ih[:], []byte(encoded))
case 32:
n, err = base32.StdEncoding.Decode(ih[:], []byte(encoded))
default:
err = fmt.Errorf("unhandled xt parameter encoding (encoded length %d)", len(encoded))
return
}
if err != nil {
err = fmt.Errorf("error decoding xt: %w", err)
} else if n != 20 {
panic(fmt.Errorf("invalid length '%d' of the decoded bytes", n))
}
return
}
func dropFirst(vs url.Values, key string) {
sl := vs[key]
switch len(sl) {
case 0, 1:
vs.Del(key)
default:
vs[key] = sl[1:]
}
}

170
metainfo/metainfo.go Normal file
View File

@ -0,0 +1,170 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"errors"
"io"
"os"
"strings"
"github.com/xgfone/bt/bencode"
"github.com/xgfone/bt/utils"
)
// Bytes is the []byte type.
type Bytes = bencode.RawMessage
// AnnounceList is a list of the announces.
type AnnounceList [][]string
// Unique returns the list of the unique announces.
func (al AnnounceList) Unique() (announces []string) {
announces = make([]string, 0, len(al))
for _, tier := range al {
for _, v := range tier {
if v != "" && utils.InStringSlice(announces, v) {
announces = append(announces, v)
}
}
}
return nil
}
// URLList represents a list of the url.
//
// BEP 19
type URLList []string
// FullURL returns the index-th full url.
//
// For the single-file case, name is the "name" of "info".
// For the multi-file case, name is the path "name/path/file"
// from "info" and "files".
//
// See http://bittorrent.org/beps/bep_0019.html
func (us URLList) FullURL(index int, name string) (url string) {
if url = us[index]; strings.HasSuffix(url, "/") {
url += name
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (us URLList) MarshalBencode() (b []byte, err error) {
return bencode.EncodeBytes([]string(us))
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (us *URLList) UnmarshalBencode(b []byte) (err error) {
var v interface{}
if err = bencode.DecodeBytes(b, &v); err == nil {
switch vs := v.(type) {
case string:
*us = URLList{vs}
case []interface{}:
urls := make(URLList, len(vs))
for i, u := range vs {
s, ok := u.(string)
if !ok {
return errors.New("the element of 'url-list' is not string")
}
urls[i] = s
}
*us = urls
default:
return errors.New("invalid 'url-lsit'")
}
}
return
}
// MetaInfo represents the .torrent file.
type MetaInfo struct {
InfoBytes Bytes `bencode:"info"` // BEP 3
Announce string `bencode:"announce,omitempty"` // BEP 3
AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12
Nodes []HostAddress `bencode:"nodes,omitempty"` // BEP 5
URLList URLList `bencode:"url-list,omitempty"` // BEP 19
// Where's this specified?
// Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification.
// All of them are optional.
// CreationDate is the creation time of the torrent, in standard UNIX epoch
// format (seconds since 1-Jan-1970 00:00:00 UTC).
CreationDate int64 `bencode:"creation date,omitempty"`
// Comment is the free-form textual comments of the author.
Comment string `bencode:"comment,omitempty"`
// CreatedBy is name and version of the program used to create the .torrent.
CreatedBy string `bencode:"created by,omitempty"`
// Encoding is the string encoding format used to generate the pieces part
// of the info dictionary in the .torrent metafile.
Encoding string `bencode:"encoding,omitempty"`
}
// Load loads a MetaInfo from an io.Reader.
func Load(r io.Reader) (mi MetaInfo, err error) {
err = bencode.NewDecoder(r).Decode(&mi)
return
}
// LoadFromFile loads a MetaInfo from a file.
func LoadFromFile(filename string) (mi MetaInfo, err error) {
f, err := os.Open(filename)
if err == nil {
defer f.Close()
mi, err = Load(f)
}
return
}
// Announces returns all the announces.
func (mi MetaInfo) Announces() AnnounceList {
if len(mi.AnnounceList) > 0 {
return mi.AnnounceList
} else if mi.Announce != "" {
return [][]string{{mi.Announce}}
}
return nil
}
// Magnet creates a Magnet from a MetaInfo.
func (mi *MetaInfo) Magnet(displayName string, infoHash Hash) (m Magnet) {
for _, t := range mi.Announces().Unique() {
m.Trackers = append(m.Trackers, t)
}
m.DisplayName = displayName
m.InfoHash = infoHash
return
}
// Write encodes the metainfo to w.
func (mi MetaInfo) Write(w io.Writer) error {
return bencode.NewEncoder(w).Encode(mi)
}
// InfoHash returns the hash of the info.
func (mi MetaInfo) InfoHash() Hash {
return NewHashFromBytes(mi.InfoBytes)
}
// Info parses the InfoBytes to the Info.
func (mi MetaInfo) Info() (info Info, err error) {
err = bencode.DecodeBytes(mi.InfoBytes, &info)
return
}

99
metainfo/piece.go Normal file
View File

@ -0,0 +1,99 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"crypto/sha1"
"errors"
"fmt"
"io"
"github.com/xgfone/bt/utils"
)
// Piece represents a torrent file piece.
type Piece struct {
info Info
index int
}
// Index returns the index of the current piece.
func (p Piece) Index() int { return p.index }
// Offset returns the offset that the current piece is in all the files.
func (p Piece) Offset() int64 { return int64(p.index) * p.info.PieceLength }
// Hash returns the hash representation of the piece.
func (p Piece) Hash() (h Hash) { return p.info.Pieces[p.index] }
// Length returns the length of the current piece.
func (p Piece) Length() int64 {
if p.index == p.info.CountPieces()-1 {
return p.info.TotalLength() - int64(p.index)*p.info.PieceLength
}
return p.info.PieceLength
}
// GeneratePieces generates the pieces from the reader.
func GeneratePieces(r io.Reader, pieceLength int64) (hs Hashes, err error) {
buf := make([]byte, pieceLength)
for {
h := sha1.New()
written, err := utils.CopyNBuffer(h, r, pieceLength, buf)
if written > 0 {
hs = append(hs, NewHash(h.Sum(nil)))
}
if err == io.EOF {
return hs, nil
}
if err != nil {
return nil, err
}
}
}
func writeFiles(w io.Writer, files []File, open func(File) (io.ReadCloser, error)) error {
buf := make([]byte, 8192)
for _, file := range files {
r, err := open(file)
if err != nil {
return fmt.Errorf("error opening %s: %s", file, err)
}
n, err := utils.CopyNBuffer(w, r, file.Length, buf)
r.Close()
if n != file.Length {
return fmt.Errorf("error copying %s: %s", file, err)
}
}
return nil
}
// GeneratePiecesFromFiles generates the pieces from the files.
func GeneratePiecesFromFiles(files []File, pieceLength int64,
open func(File) (io.ReadCloser, error)) (Hashes, error) {
if pieceLength <= 0 {
return nil, errors.New("piece length must be a positive integer")
}
pr, pw := io.Pipe()
defer pr.Close()
go func() { pw.CloseWithError(writeFiles(pw, files, open)) }()
return GeneratePieces(pr, pieceLength)
}

19
peerprotocol/doc.go Normal file
View File

@ -0,0 +1,19 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package peerprotocol implements the core BT peer protocol. You can use it
// to implement a peer client, and upload/download the file to/from other peers.
//
// See TorrentDownloader to download the torrent file from other peer.
package peerprotocol

161
peerprotocol/extension.go Normal file
View File

@ -0,0 +1,161 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"bytes"
"errors"
"net"
"github.com/xgfone/bt/bencode"
)
var errInvalidIP = errors.New("invalid ipv4 or ipv6")
// Predefine some extended message identifiers.
const (
ExtendedIDHandshake = 0 // BEP 10
)
// Predefine some extended message names.
const (
ExtendedMessageNameMetadata = "ut_metadata" // BEP 9
ExtendedMessageNamePex = "ut_pex" // BEP 11
)
// Predefine some "ut_metadata" extended message types.
const (
UtMetadataExtendedMsgTypeRequest = 0 // BEP 9
UtMetadataExtendedMsgTypeData = 1 // BEP 9
UtMetadataExtendedMsgTypeReject = 2 // BEP 9
)
// CompactIP is used to handle the compact ipv4 or ipv6.
type CompactIP net.IP
func (ci CompactIP) String() string {
return net.IP(ci).String()
}
// MarshalBencode implements the interface bencode.Marshaler.
func (ci CompactIP) MarshalBencode() ([]byte, error) {
ip := net.IP(ci)
if ipv4 := ip.To4(); len(ipv4) != 0 {
ip = ipv4
}
return bencode.EncodeBytes(ip[:])
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (ci *CompactIP) UnmarshalBencode(b []byte) (err error) {
var ip net.IP
if err = bencode.DecodeBytes(b, &ip); err != nil {
return
}
switch len(ip) {
case net.IPv4len, net.IPv6len:
default:
return errInvalidIP
}
if ipv4 := ip.To4(); len(ipv4) != 0 {
ip = ipv4
}
*ci = CompactIP(ip)
return
}
// ExtendedHandshakeMsg represent the extended handshake message.
//
// BEP 10
type ExtendedHandshakeMsg struct {
// M is the type of map[ExtendedMessageName]ExtendedMessageID.
M map[string]uint8 `bencode:"m"` // BEP 10
V string `bencode:"v,omitempty"` // BEP 10
Reqq int `bencode:"reqq,omitempty"` // BEP 10
// Port is the local client port, which is redundant and no need
// for the receiving side of the connection to send this.
Port uint16 `bencode:"p,omitempty"` // BEP 10
IPv6 net.IP `bencode:"ipv6,omitempty"` // BEP 10
IPv4 CompactIP `bencode:"ipv4,omitempty"` // BEP 10
YourIP CompactIP `bencode:"yourip,omitempty"` // BEP 10
MetadataSize int `bencode:"metadata_size,omitempty"` // BEP 9
}
// Decode decodes the extended handshake message from b.
func (ehm *ExtendedHandshakeMsg) Decode(b []byte) (err error) {
return bencode.DecodeBytes(b, ehm)
}
// Encode encodes the extended handshake message to b.
func (ehm ExtendedHandshakeMsg) Encode() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 128))
if err = bencode.NewEncoder(buf).Encode(ehm); err != nil {
b = buf.Bytes()
}
return
}
// UtMetadataExtendedMsg represents the "ut_metadata" extended message.
type UtMetadataExtendedMsg struct {
MsgType uint8 `bencode:"msg_type"` // BEP 9
Piece int `bencode:"piece"` // BEP 9
// They are only used by "data" type
TotalSize int `bencode:"total_size,omitempty"` // BEP 9
Data []byte `bencode:"-"`
}
// EncodeToPayload encodes UtMetadataExtendedMsg to extended payload
// and write the result into buf.
func (um UtMetadataExtendedMsg) EncodeToPayload(buf *bytes.Buffer) (err error) {
if um.MsgType != UtMetadataExtendedMsgTypeData {
um.TotalSize = 0
um.Data = nil
}
buf.Grow(len(um.Data) + 50)
if err = bencode.NewEncoder(buf).Encode(um); err == nil {
_, err = buf.Write(um.Data)
}
return
}
// EncodeToBytes is equal to
//
// buf := new(bytes.Buffer)
// err = um.EncodeToPayload(buf)
// return buf.Bytes(), err
//
func (um UtMetadataExtendedMsg) EncodeToBytes() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 128))
if err = um.EncodeToPayload(buf); err == nil {
b = buf.Bytes()
}
return
}
// DecodeFromPayload decodes the extended payload to itself.
func (um *UtMetadataExtendedMsg) DecodeFromPayload(b []byte) (err error) {
dec := bencode.NewDecoder(bytes.NewReader(b))
if err = dec.Decode(&um); err == nil {
um.Data = b[dec.BytesParsed():]
}
return
}

View File

@ -0,0 +1,54 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"bytes"
"testing"
)
func TestCompactIP(t *testing.T) {
ipv4 := CompactIP([]byte{1, 2, 3, 4})
b, err := ipv4.MarshalBencode()
if err != nil {
t.Fatal(err)
}
var ip CompactIP
if err = ip.UnmarshalBencode(b); err != nil {
t.Error(err)
} else if ip.String() != "1.2.3.4" {
t.Error(ip)
}
}
func TestUtMetadataExtendedMsg(t *testing.T) {
buf := new(bytes.Buffer)
data := []byte{0x31, 0x32, 0x33, 0x34, 0x35}
m1 := UtMetadataExtendedMsg{MsgType: 1, Piece: 2, TotalSize: 1024, Data: data}
if err := m1.EncodeToPayload(buf); err != nil {
t.Fatal(err)
}
msg := Message{Type: Extended, ExtendedPayload: buf.Bytes()}
m2, err := msg.UtMetadataExtendedMsg()
if err != nil {
t.Fatal(err)
} else if m2.MsgType != 1 || m2.Piece != 2 || m2.TotalSize != 1024 {
t.Error(m2)
} else if !bytes.Equal(m2.Data, data) {
t.Fail()
}
}

69
peerprotocol/fastset.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"crypto/sha1"
"encoding/binary"
"net"
"github.com/xgfone/bt/metainfo"
)
// GenerateAllowedFastSet generates some allowed fast set of the torrent file.
//
// Argument:
// set: generated piece set, the length of which is the number to be generated.
// sz: the number of pieces in torrent.
// ip: the of the remote peer of the connection.
// infohash: infohash of torrent.
//
// BEP 6
func GenerateAllowedFastSet(set []uint32, sz uint32, ip net.IP, infohash metainfo.Hash) {
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
iplen := len(ip)
x := make([]byte, 20+iplen)
for i, j := 0, iplen-1; i < j; i++ { // (1) compatible with IPv4/IPv6
x[i] = ip[i] & 0xff // (1)
}
// x[iplen-1] = 0 // It is equal to 0 primitively.
copy(x[iplen:], infohash[:]) // (2)
for cur, k := 0, len(set); cur < k; {
sum := sha1.Sum(x) // (3)
x = sum[:] // (3)
for i := 0; i < 5 && cur < k; i++ { // (4)
j := i * 4 // (5)
y := binary.BigEndian.Uint32(x[j : j+4]) // (6)
index := y % sz // (7)
if !uint32Contains(set, index) { // (8)
set[cur] = index // (9)
cur++
}
}
}
}
func uint32Contains(ss []uint32, s uint32) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}

View File

@ -0,0 +1,109 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"net"
"testing"
"github.com/xgfone/bt/metainfo"
)
func TestGenerateAllowedFastSet(t *testing.T) {
hexs := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
infohash := metainfo.NewHashFromHexString(hexs)
sets := make([]uint32, 7)
GenerateAllowedFastSet(sets, 1313, net.ParseIP("80.4.4.200"), infohash)
for i, v := range sets {
switch i {
case 0:
if v != 1059 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 1:
if v != 431 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 2:
if v != 808 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 3:
if v != 1217 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 4:
if v != 287 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 5:
if v != 376 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 6:
if v != 1188 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
default:
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
}
sets = make([]uint32, 9)
GenerateAllowedFastSet(sets, 1313, net.ParseIP("80.4.4.200"), infohash)
for i, v := range sets {
switch i {
case 0:
if v != 1059 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 1:
if v != 431 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 2:
if v != 808 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 3:
if v != 1217 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 4:
if v != 287 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 5:
if v != 376 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 6:
if v != 1188 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 7:
if v != 353 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
case 8:
if v != 508 {
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
default:
t.Errorf("unknown '%d' at the index '%d'", v, i)
}
}
}

140
peerprotocol/handshake.go Normal file
View File

@ -0,0 +1,140 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"encoding/hex"
"fmt"
"io"
"github.com/xgfone/bt/metainfo"
)
var errInvalidProtocolHeader = fmt.Errorf("unexpected peer protocol header string")
// Predefine some known extension bits.
const (
ExtensionBitDHT = 0 // BEP 5
ExtensionBitFast = 2 // BEP 6
ExtensionBitExtended = 20 // BEP 10
ExtensionBitMax = 64
)
// ExtensionBits is the reserved bytes to be used by all extensions.
//
// BEP 10: The bit is counted starting at 0 from right to left.
type ExtensionBits [8]byte
// String returns the hex string format.
func (eb ExtensionBits) String() string {
return hex.EncodeToString(eb[:])
}
// Set sets the bit to 1, that's, to set it to be on.
func (eb ExtensionBits) Set(bit uint) {
eb[7-bit/8] |= 1 << (bit % 8)
}
// Unset sets the bit to 0, that's, to set it to be off.
func (eb ExtensionBits) Unset(bit uint) {
eb[7-bit/8] &^= 1 << (bit % 8)
}
// IsSet reports whether the bit is on.
func (eb ExtensionBits) IsSet(bit uint) (yes bool) {
return eb[7-bit/8]&(1<<(bit%8)) != 0
}
// IsSupportDHT reports whether ExtensionBitDHT is set.
func (eb ExtensionBits) IsSupportDHT() (yes bool) {
return eb.IsSet(ExtensionBitDHT)
}
// IsSupportFast reports whether ExtensionBitFast is set.
func (eb ExtensionBits) IsSupportFast() (yes bool) {
return eb.IsSet(ExtensionBitFast)
}
// IsSupportExtended reports whether ExtensionBitExtended is set.
func (eb ExtensionBits) IsSupportExtended() (yes bool) {
return eb.IsSet(ExtensionBitExtended)
}
// HandshakeMsg is the message used by the handshake
type HandshakeMsg struct {
ExtensionBits
PeerID metainfo.Hash
InfoHash metainfo.Hash
}
// NewHandshakeMsg returns a new HandshakeMsg.
func NewHandshakeMsg(peerID, infoHash metainfo.Hash, es ...ExtensionBits) HandshakeMsg {
var e ExtensionBits
if len(es) > 0 {
e = es[0]
}
return HandshakeMsg{ExtensionBits: e, PeerID: peerID, InfoHash: infoHash}
}
// Handshake finishes the handshake with the peer.
//
// InfoHash may be ZERO, and it will read it from the peer then send it back
// to the peer.
//
// BEP 3
func Handshake(sock io.ReadWriter, msg HandshakeMsg) (ret HandshakeMsg, err error) {
var read bool
if msg.InfoHash.IsZero() {
if err = getPeerHandshakeMsg(sock, &ret); err != nil {
return
}
read = true
msg.InfoHash = ret.InfoHash
}
if _, err = io.WriteString(sock, ProtocolHeader); err != nil {
return
}
if _, err = sock.Write(msg.ExtensionBits[:]); err != nil {
return
}
if _, err = sock.Write(msg.InfoHash[:]); err != nil {
return
}
if _, err = sock.Write(msg.PeerID[:]); err != nil {
return
}
if !read {
err = getPeerHandshakeMsg(sock, &ret)
}
return
}
func getPeerHandshakeMsg(sock io.Reader, ret *HandshakeMsg) (err error) {
var b [68]byte
if _, err = io.ReadFull(sock, b[:]); err != nil {
return
} else if string(b[:20]) != ProtocolHeader {
return errInvalidProtocolHeader
}
copy(ret.ExtensionBits[:], b[20:28])
copy(ret.InfoHash[:], b[28:48])
copy(ret.PeerID[:], b[48:68])
return
}

277
peerprotocol/message.go Normal file
View File

@ -0,0 +1,277 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
)
var errMessageTooLong = fmt.Errorf("the peer message is too long")
// Message is the message used by the peer protocol, which contains
// all the fields specified by the standard message types.
type Message struct {
Keepalive bool
Type MessageType
// Index is used by these message types:
//
// BEP 3: Cancel, Request, Have, Piece
// BEP 6: Reject, Suggest, AllowedFast
//
Index uint32
// Begin is used by these message types:
//
// BEP 3: Request, Cancel, Piece
// BEP 6: Reject
//
Begin uint32
// Length is used by these message types:
//
// BEP 3: Request, Cancel
// BEP 6: Reject
//
Length uint32
// Piece is used by these message types:
//
// BEP 3: Piece
Piece []byte
// Bitfield is used by these message types:
//
// BEP 3: Bitfield
Bitfield []bool
// ExtendedID and ExtendedPayload are used by these message types:
//
// BEP 10: Extended
//
ExtendedID uint8
ExtendedPayload []byte
// Port is used by these message types:
//
// BEP 5: Port
//
Port uint16
}
// DecodeToMessage is equal to msg.Decode(r, maxLength).
func DecodeToMessage(r io.Reader, maxLength uint32) (msg Message, err error) {
err = msg.Decode(r, maxLength)
return
}
// UtMetadataExtendedMsg decodes the extended payload as UtMetadataExtendedMsg.
//
// Notice: the message type must be Extended.
func (m Message) UtMetadataExtendedMsg() (um UtMetadataExtendedMsg, err error) {
if m.Type != Extended {
panic("the message type is Extended")
}
err = um.DecodeFromPayload(m.ExtendedPayload)
return
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler,
// which is equal to m.Decode(bytes.NewBuffer(data), 0).
func (m *Message) UnmarshalBinary(data []byte) (err error) {
return m.Decode(bytes.NewBuffer(data), 0)
}
func readByte(r io.Reader) (b byte, err error) {
var bs [1]byte
if _, err = r.Read(bs[:]); err == nil {
b = bs[0]
}
return
}
// Decode reads the data from r and decodes it to Message.
//
// if maxLength is equal to 0, it is unlimited. Or, it will read maxLength bytes
// at most.
func (m *Message) Decode(r io.Reader, maxLength uint32) (err error) {
var length uint32
if err = binary.Read(r, binary.BigEndian, &length); err != nil {
if err != io.EOF {
err = fmt.Errorf("error reading peer message message length: %s", err)
}
return
}
if length == 0 {
m.Keepalive = true
return
} else if maxLength > 0 && length > maxLength {
return errMessageTooLong
}
m.Keepalive = false
lr := &io.LimitedReader{R: r, N: int64(length)}
// Check that all of r was utilized.
defer func() {
if err == nil && lr.N != 0 {
err = fmt.Errorf("%d bytes unused in message type %d", lr.N, m.Type)
}
}()
_type, err := readByte(lr)
if err != nil {
return
}
switch m.Type = MessageType(_type); m.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have, AllowedFast, Suggest:
err = binary.Read(lr, binary.BigEndian, &m.Index)
case Request, Cancel, Reject:
if err = binary.Read(lr, binary.BigEndian, &m.Index); err != nil {
return
}
if err = binary.Read(lr, binary.BigEndian, &m.Begin); err != nil {
return
}
if err = binary.Read(lr, binary.BigEndian, &m.Length); err != nil {
return
}
case Bitfield:
_len := length - 1
bs := make([]byte, _len)
if _, err = io.ReadFull(lr, bs); err == nil {
m.Bitfield = make([]bool, 0, _len*8)
for _, b := range bs {
for i := byte(7); i >= 0; i-- {
m.Bitfield = append(m.Bitfield, (b>>i)&1 == 1)
}
}
}
case Piece:
if err = binary.Read(lr, binary.BigEndian, &m.Index); err != nil {
return
}
if err = binary.Read(lr, binary.BigEndian, &m.Begin); err != nil {
return
}
// TODO: Should we use a []byte pool?
m.Piece = make([]byte, lr.N)
if _, err = io.ReadFull(lr, m.Piece); err != nil {
return fmt.Errorf("reading piece data error: %s", err)
}
case Extended:
if m.ExtendedID, err = readByte(lr); err == nil {
m.ExtendedPayload, err = ioutil.ReadAll(lr)
}
case Port:
err = binary.Read(lr, binary.BigEndian, &m.Port)
default:
err = fmt.Errorf("unknown message type %v", m.Type)
}
return
}
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (m Message) MarshalBinary() (data []byte, err error) {
// TODO: Should we use a buffer pool?
buf := bytes.NewBuffer(make([]byte, 0, 4))
if err = m.Encode(buf); err == nil {
data = buf.Bytes()
}
return
}
// Encode encodes the message to buf.
func (m Message) Encode(buf *bytes.Buffer) (err error) {
// The 4-bytes is the placeholder of the length.
buf.Reset()
buf.Write([]byte{0, 0, 0, 0})
// Write the non-keepalive message.
if !m.Keepalive {
if err = buf.WriteByte(byte(m.Type)); err != nil {
return
} else if err = m.marshalBinaryType(buf); err != nil {
return
}
// Calculate and reset the length of the message body.
data := buf.Bytes()
if payloadLen := len(data) - 4; payloadLen > 0 {
binary.BigEndian.PutUint32(data[:4], uint32(payloadLen))
}
}
return
}
func (m Message) marshalBinaryType(buf *bytes.Buffer) (err error) {
switch m.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have:
err = binary.Write(buf, binary.BigEndian, m.Index)
case Request, Cancel, Reject:
if err = binary.Write(buf, binary.BigEndian, m.Index); err != nil {
return
}
if err = binary.Write(buf, binary.BigEndian, m.Begin); err != nil {
return
}
if err = binary.Write(buf, binary.BigEndian, m.Length); err != nil {
return
}
case Bitfield:
_len := (len(m.Bitfield) + 7) / 8
bs := make([]byte, _len)
for i, has := range m.Bitfield {
if has {
bs[i/8] |= (1 << byte(7-i%8))
}
}
buf.Write(bs)
case Piece:
if err = binary.Write(buf, binary.BigEndian, m.Index); err != nil {
return
}
if err = binary.Write(buf, binary.BigEndian, m.Begin); err != nil {
return
}
if n, err := buf.Write(m.Piece); err != nil {
return err
} else if _len := len(m.Piece); n != _len {
return fmt.Errorf("expect writing %d bytes, but wrote %d", _len, n)
}
case Extended:
if err = buf.WriteByte(byte(m.ExtendedID)); err != nil {
_, err = buf.Write(m.ExtendedPayload)
}
case Port:
err = binary.Write(buf, binary.BigEndian, m.Port)
default:
err = fmt.Errorf("unknown message type: %v", m.Type)
}
return
}

522
peerprotocol/peerconn.go Normal file
View File

@ -0,0 +1,522 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"bytes"
"fmt"
"io"
"net"
"time"
"github.com/xgfone/bt/bencode"
"github.com/xgfone/bt/metainfo"
)
// Predefine some errors about extension support.
var (
ErrNotSupportDHT = fmt.Errorf("not support DHT extension")
ErrNotSupportFast = fmt.Errorf("not support Fast extension")
ErrNotSupportExtended = fmt.Errorf("not support Extended extension")
)
// Bep3Handler is used to handle the BEP 3 type message if Handler has also
// implemented the interface.
type Bep3Handler interface {
Choke(pc *PeerConn) error
Unchoke(pc *PeerConn) error
Interested(pc *PeerConn) error
NotInterested(pc *PeerConn) error
Have(pc *PeerConn, index uint32) error
Bitfield(pc *PeerConn, bits []bool) error
Request(pc *PeerConn, index uint32, begin uint32, length uint32) error
Piece(pc *PeerConn, index uint32, begin uint32, piece []byte) error
Cancel(pc *PeerConn, index uint32, begin uint32, length uint32) error
}
// Bep5Handler is used to handle the BEP 5 type message if Handler has also
// implemented the interface.
//
// Notice: the server must enable the DHT extension bit.
type Bep5Handler interface {
Port(pc *PeerConn, port uint16) error
}
// Bep6Handler is used to handle the BEP 6 type message if Handler has also
// implemented the interface.
//
// Notice: the server must enable the Fast extension bit.
type Bep6Handler interface {
HaveAll(pc *PeerConn) error
HaveNone(pc *PeerConn) error
Suggest(pc *PeerConn, index uint32) error
AllowedFast(pc *PeerConn, index uint32) error
Reject(pc *PeerConn, index uint32, begin uint32, length uint32) error
}
// Bep10Handler is used to handle the BEP 10 extended peer message
// if Handler has also implemented the interface.
//
// Notice: the server must enable the Extended extension bit.
type Bep10Handler interface {
OnHandShake(conn *PeerConn, exthmsg ExtendedHandshakeMsg) error
OnPayload(conn *PeerConn, extid uint8, payload []byte) error
}
// PeerConn is used to manage the connection to the peer.
type PeerConn struct {
net.Conn
ExtensionBits
ID metainfo.Hash
// Timeout is used to control the timeout of reading/writing the message.
//
// The default is 0, which represents no timeout.
Timeout time.Duration
// MaxLength is used to limit the maximum number of the message body.
//
// The default is 0, which represents no limit.
MaxLength uint32
// These two states is controlled by the local client peer.
//
// Choked is used to indicate whether or not the local client has choked
// the remote peer, that's, if Choked is true, the local client will
// discard all the pending requests from the remote peer and not answer
// any requests until the local client is unchoked.
//
// Interested is used to indicate whether or not the local client is
// interested in something the remote peer has to offer, that's,
// if Interested is true, the local client will begin requesting blocks
// when the remote client unchokes them.
//
Choked bool // The default should be true.
Interested bool
// These two states is controlled by the remote peer.
//
// PeerChoked is used to indicate whether or not the remote peer has choked
// the local client, that's, if PeerChoked is true, the remote peer will
// discard all the pending requests from the local client and not answer
// any requests until the remote peer is unchoked.
//
// PeerInterested is used to indicate whether or not the remote peer is
// interested in something the local client has to offer, that's,
// if PeerInterested is true, the remote peer will begin requesting blocks
// when the local client unchokes them.
//
PeerChoked bool // The default should be true.
PeerInterested bool
// Data is used to store the context data associated with the connection.
Data interface{}
}
// NewPeerConn returns a new PeerConn.
func NewPeerConn(id metainfo.Hash, conn net.Conn) *PeerConn {
return &PeerConn{ID: id, Conn: conn, Choked: true, PeerChoked: true}
}
// NewPeerConnByDial returns a new PeerConn by dialing to addr with the "tcp" network.
func NewPeerConnByDial(id metainfo.Hash, addr string) (pc *PeerConn, err error) {
conn, err := net.Dial("tcp", addr)
if err == nil {
pc = NewPeerConn(id, conn)
}
return
}
func (pc *PeerConn) setReadTimeout() {
if pc.Timeout > 0 {
pc.Conn.SetReadDeadline(time.Now().Add(pc.Timeout))
}
}
func (pc *PeerConn) setWriteTimeout() {
if pc.Timeout > 0 {
pc.Conn.SetWriteDeadline(time.Now().Add(pc.Timeout))
}
}
// SetChoked sets the Choked state of the local client peer.
//
// Notice: if the current state is not Choked, it will send a Choked message
// to the remote peer.
func (pc *PeerConn) SetChoked() (err error) {
if !pc.Choked {
if err = pc.SendChoke(); err == nil {
pc.Choked = true
}
}
return
}
// SetUnchoked sets the Unchoked state of the local client peer.
//
// Notice: if the current state is not Unchoked, it will send a Unchoked message
// to the remote peer.
func (pc *PeerConn) SetUnchoked() (err error) {
if pc.Choked {
if err = pc.SendUnchoke(); err == nil {
pc.Choked = false
}
}
return
}
// SetInterested sets the Interested state of the local client peer.
//
// Notice: if the current state is not Interested, it will send a Interested
// message to the remote peer.
func (pc *PeerConn) SetInterested() (err error) {
if !pc.Interested {
if err = pc.SendInterested(); err == nil {
pc.Interested = true
}
}
return
}
// SetNotInterested sets the NotInterested state of the local client peer.
//
// Notice: if the current state is not NotInterested, it will send
// a NotInterested message to the remote peer.
func (pc *PeerConn) SetNotInterested() (err error) {
if pc.Interested {
if err = pc.SendNotInterested(); err == nil {
pc.Interested = false
}
}
return
}
// Handshake does a handshake with the peer.
//
// BEP 3
func (pc *PeerConn) Handshake(infoHash metainfo.Hash) (HandshakeMsg, error) {
m := HandshakeMsg{ExtensionBits: pc.ExtensionBits, PeerID: pc.ID, InfoHash: infoHash}
pc.setReadTimeout()
return Handshake(pc.Conn, m)
}
// ReadMsg reads the message.
//
// BEP 3
func (pc *PeerConn) ReadMsg() (m Message, err error) {
pc.setReadTimeout()
err = m.Decode(pc.Conn, pc.MaxLength)
return
}
// WriteMsg writes the message to the peer.
//
// BEP 3
func (pc *PeerConn) WriteMsg(m Message) (err error) {
buf := bytes.NewBuffer(make([]byte, 0, 128))
if err = m.Encode(buf); err == nil {
pc.setWriteTimeout()
var n int
if n, err = pc.Conn.Write(buf.Bytes()); err == nil && n < buf.Len() {
err = io.ErrShortWrite
}
}
return
}
// SendKeepalive sends a Keepalive message to the peer.
//
// BEP 3
func (pc *PeerConn) SendKeepalive() error {
return pc.WriteMsg(Message{Keepalive: true})
}
// SendChoke sends a Choke message to the peer.
//
// BEP 3
func (pc *PeerConn) SendChoke() error {
return pc.WriteMsg(Message{Type: Choke})
}
// SendUnchoke sends a Unchoke message to the peer.
//
// BEP 3
func (pc *PeerConn) SendUnchoke() error {
return pc.WriteMsg(Message{Type: Unchoke})
}
// SendInterested sends a Interested message to the peer.
//
// BEP 3
func (pc *PeerConn) SendInterested() error {
return pc.WriteMsg(Message{Type: Interested})
}
// SendNotInterested sends a NotInterested message to the peer.
//
// BEP 3
func (pc *PeerConn) SendNotInterested() error {
return pc.WriteMsg(Message{Type: NotInterested})
}
// SendBitfield sends a Bitfield message to the peer.
//
// BEP 3
func (pc *PeerConn) SendBitfield(bits []bool) error {
return pc.WriteMsg(Message{Type: Bitfield, Bitfield: bits})
}
// SendHave sends a Have message to the peer.
//
// BEP 3
func (pc *PeerConn) SendHave(index uint32) error {
return pc.WriteMsg(Message{Type: Have, Index: index})
}
// SendRequest sends a Request message to the peer.
//
// BEP 3
func (pc *PeerConn) SendRequest(index, begin, length uint32) error {
return pc.WriteMsg(Message{Type: Request, Index: index, Begin: begin, Length: length})
}
// SendCancel sends a Cancel message to the peer.
//
// BEP 3
func (pc *PeerConn) SendCancel(index, begin, length uint32) error {
return pc.WriteMsg(Message{Type: Cancel, Index: index, Begin: begin, Length: length})
}
// SendPiece sends a Piece message to the peer.
//
// BEP 3
func (pc *PeerConn) SendPiece(index, begin uint32, piece []byte) error {
return pc.WriteMsg(Message{Type: Piece, Index: index, Begin: begin, Piece: piece})
}
// SendPort sends a Port message to the peer.
//
// BEP 5
func (pc *PeerConn) SendPort(port uint16) error {
return pc.WriteMsg(Message{Type: Port, Port: port})
}
// SendHaveAll sends a HaveAll message to the peer.
//
// BEP 6
func (pc *PeerConn) SendHaveAll() error {
return pc.WriteMsg(Message{Type: HaveAll})
}
// SendHaveNone sends a HaveNone message to the peer.
//
// BEP 6
func (pc *PeerConn) SendHaveNone() error {
return pc.WriteMsg(Message{Type: HaveNone})
}
// SendSuggest sends a Suggest message to the peer.
//
// BEP 6
func (pc *PeerConn) SendSuggest(index uint32) error {
return pc.WriteMsg(Message{Type: Suggest, Index: index})
}
// SendReject sends a Reject message to the peer.
//
// BEP 6
func (pc *PeerConn) SendReject(index, begin, length uint32) error {
return pc.WriteMsg(Message{Type: Reject, Index: index, Begin: begin, Length: length})
}
// SendAllowedFast sends a AllowedFast message to the peer.
//
// BEP 6
func (pc *PeerConn) SendAllowedFast(index uint32) error {
return pc.WriteMsg(Message{Type: AllowedFast, Index: index})
}
// SendExtHandshakeMsg sends the Extended Handshake message to the peer.
//
// BEP 10
func (pc *PeerConn) SendExtHandshakeMsg(m ExtendedHandshakeMsg) (err error) {
buf := bytes.NewBuffer(make([]byte, 0, 128))
if err = bencode.NewEncoder(buf).Encode(m); err == nil {
err = pc.SendExtMsg(ExtendedIDHandshake, buf.Bytes())
}
return
}
// SendExtMsg sends the Extended message with the extended id and the payload.
//
// BEP 10
func (pc *PeerConn) SendExtMsg(extID uint8, payload []byte) error {
return pc.WriteMsg(Message{Type: Extended, ExtendedID: extID, ExtendedPayload: payload})
}
// HandleMessage calls the method of the handler to handle the message.
//
// If handler has also implemented the interfaces Bep3Handler, Bep5Handler,
// Bep6Handler or Bep10Handler, their methods will be called instead of
// Handler.OnMessage for the corresponding type message.
func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
if msg.Keepalive {
return
}
switch msg.Type {
// BEP 3 - The BitTorrent Protocol Specification
case Choke:
pc.PeerChoked = true
if h, ok := handler.(Bep3Handler); ok {
err = h.Choke(pc)
} else {
err = handler.OnMessage(pc, msg)
}
case Unchoke:
pc.PeerChoked = false
if h, ok := handler.(Bep3Handler); ok {
err = h.Unchoke(pc)
} else {
err = handler.OnMessage(pc, msg)
}
case Interested:
pc.PeerInterested = true
if h, ok := handler.(Bep3Handler); ok {
err = h.Interested(pc)
} else {
err = handler.OnMessage(pc, msg)
}
case NotInterested:
pc.PeerInterested = false
if h, ok := handler.(Bep3Handler); ok {
err = h.NotInterested(pc)
} else {
err = handler.OnMessage(pc, msg)
}
case Have:
if h, ok := handler.(Bep3Handler); ok {
err = h.Have(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
}
case Bitfield:
if h, ok := handler.(Bep3Handler); ok {
err = h.Bitfield(pc, msg.Bitfield)
} else {
err = handler.OnMessage(pc, msg)
}
case Request:
if h, ok := handler.(Bep3Handler); ok {
err = h.Request(pc, msg.Index, msg.Begin, msg.Length)
} else {
err = handler.OnMessage(pc, msg)
}
case Piece:
if h, ok := handler.(Bep3Handler); ok {
err = h.Piece(pc, msg.Index, msg.Begin, msg.Piece)
} else {
err = handler.OnMessage(pc, msg)
}
case Cancel:
if h, ok := handler.(Bep3Handler); ok {
err = h.Cancel(pc, msg.Index, msg.Begin, msg.Length)
} else {
err = handler.OnMessage(pc, msg)
}
// BEP 5 - DHT Protocol
case Port:
if !pc.IsSupportDHT() {
return ErrNotSupportDHT
} else if h, ok := handler.(Bep5Handler); ok {
err = h.Port(pc, msg.Port)
} else {
err = handler.OnMessage(pc, msg)
}
// BEP 6 - Fast Extension
case Suggest:
if !pc.IsSupportFast() {
return ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.Suggest(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
}
case HaveAll:
if !pc.IsSupportFast() {
return ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.HaveAll(pc)
} else {
err = handler.OnMessage(pc, msg)
}
case HaveNone:
if !pc.IsSupportFast() {
return ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.HaveNone(pc)
} else {
err = handler.OnMessage(pc, msg)
}
case Reject:
if !pc.IsSupportFast() {
return ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.Reject(pc, msg.Index, msg.Begin, msg.Length)
} else {
err = handler.OnMessage(pc, msg)
}
case AllowedFast:
if !pc.IsSupportFast() {
return ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.AllowedFast(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
}
// BEP 10 - Extension Protocol
case Extended:
if !pc.IsSupportExtended() {
return ErrNotSupportExtended
} else if h, ok := handler.(Bep10Handler); ok {
err = pc.handleExtMsg(h, msg)
} else {
err = handler.OnMessage(pc, msg)
}
// Other
default:
err = handler.OnMessage(pc, msg)
}
return
}
func (pc *PeerConn) handleExtMsg(h Bep10Handler, m Message) (err error) {
if m.ExtendedID == ExtendedIDHandshake {
var ehmsg ExtendedHandshakeMsg
if err = bencode.DecodeBytes(m.ExtendedPayload, &ehmsg); err == nil {
err = h.OnHandShake(pc, ehmsg)
}
} else {
err = h.OnPayload(pc, m.ExtendedID, m.ExtendedPayload)
}
return
}

97
peerprotocol/protocol.go Normal file
View File

@ -0,0 +1,97 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"fmt"
)
// ProtocolHeader is the BT protocal prefix.
//
// BEP 3
const ProtocolHeader = "\x13BitTorrent protocol"
// Predefine some message types.
const (
// BEP 3
Choke MessageType = 0
Unchoke MessageType = 1
Interested MessageType = 2
NotInterested MessageType = 3
Have MessageType = 4
Bitfield MessageType = 5
Request MessageType = 6
Piece MessageType = 7
Cancel MessageType = 8
// BEP 5
Port MessageType = 9
// BEP 6 - Fast extension
Suggest MessageType = 0x0d // 13
HaveAll MessageType = 0x0e // 14
HaveNone MessageType = 0x0f // 15
Reject MessageType = 0x10 // 16
AllowedFast MessageType = 0x11 // 17
// BEP 10
Extended MessageType = 20
)
// MessageType is used to represent the message type.
type MessageType byte
func (mt MessageType) String() string {
switch mt {
case Choke:
return "Choke"
case Unchoke:
return "Unchoke"
case Interested:
return "Interested"
case NotInterested:
return "NotInterested"
case Have:
return "Have"
case Bitfield:
return "Bitfield"
case Request:
return "Request"
case Piece:
return "Piece"
case Cancel:
return "Cancel"
case Port:
return "Port"
case Suggest:
return "Suggest"
case HaveAll:
return "HaveAll"
case HaveNone:
return "HaveNone"
case Reject:
return "Reject"
case AllowedFast:
return "AllowedFast"
case Extended:
return "Extended"
}
return fmt.Sprintf("MessageType(%d)", mt)
}
// FastExtension reports whether the message type is fast extension.
func (mt MessageType) FastExtension() bool {
return mt >= Suggest && mt <= AllowedFast
}

169
peerprotocol/server.go Normal file
View File

@ -0,0 +1,169 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package peerprotocol
import (
"fmt"
"io"
"log"
"net"
"strings"
"time"
"github.com/xgfone/bt/metainfo"
)
// Handler is used to handle the incoming peer connection.
type Handler interface {
// OnHandShake is used to check whether the handshake extension is acceptable.
OnHandShake(conn *PeerConn, hmsg HandshakeMsg) error
// OnMessage is used to handle the incoming peer message.
//
// If requires, it should write the response to the peer.
OnMessage(conn *PeerConn, msg Message) error
}
// Config is used to configure the server.
type Config struct {
// ExtBits is used to handshake with the client.
ExtBits ExtensionBits
// MaxLength is used to limit the maximum number of the message body.
//
// The default is 0, which represents no limit.
MaxLength uint32
// Timeout is used to control the timeout of the read/write the message.
//
// The default is 0, which represents no timeout.
Timeout time.Duration
// ErrorLog is used to log the error.
ErrorLog func(format string, args ...interface{}) // Default: log.Printf
// HandleMessage is used to handle the incoming message. So you can
// customize it to add the request queue.
//
// The default handler is to forward to pc.HandleMessage(msg, handler).
HandleMessage func(pc *PeerConn, msg Message, handler Handler) error
}
func (c *Config) set(conf ...Config) {
if len(conf) > 0 {
*c = conf[0]
}
if c.ErrorLog == nil {
c.ErrorLog = log.Printf
}
if c.HandleMessage == nil {
c.HandleMessage = func(pc *PeerConn, m Message, h Handler) error {
return pc.HandleMessage(m, h)
}
}
}
// Server is used to implement the peer protocol server.
type Server struct {
ln net.Listener
id metainfo.Hash
h Handler
c Config
}
// NewServerByListen returns a new Server by listening on the address.
func NewServerByListen(network, address string, id metainfo.Hash, h Handler,
c ...Config) (*Server, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return NewServer(ln, id, h, c...), nil
}
// NewServer returns a new Server.
func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c ...Config) *Server {
if id.IsZero() {
panic("the peer node id must not be empty")
}
var conf Config
conf.set(c...)
return &Server{ln: ln, id: id, h: h, c: conf}
}
// Run starts the peer protocol server.
func (s *Server) Run() {
for {
conn, err := s.ln.Accept()
if err != nil {
s.c.ErrorLog("fail to accept new connection: %s", err)
}
go s.handleConn(conn)
}
}
func (s *Server) handleConn(conn net.Conn) {
pc := &PeerConn{
ID: s.id,
Conn: conn,
ExtensionBits: s.c.ExtBits,
Timeout: s.c.Timeout,
MaxLength: s.c.MaxLength,
Choked: true,
PeerChoked: true,
}
if err := s.handlePeerMessage(pc); err != nil {
s.c.ErrorLog(err.Error())
}
}
func (s *Server) handlePeerMessage(pc *PeerConn) (err error) {
defer pc.Close()
m, err := pc.Handshake(metainfo.Hash{})
if err != nil {
return fmt.Errorf("fail to handshake with '%s': %s", pc.RemoteAddr().String(), err)
} else if err = s.h.OnHandShake(pc, m); err != nil {
return fmt.Errorf("handshake error with '%s': %s", pc.RemoteAddr().String(), err)
}
return s.loopRun(pc, s.h)
}
// LoopRun loops running Read-Handle message.
func (s *Server) loopRun(pc *PeerConn, handler Handler) error {
for {
msg, err := pc.ReadMsg()
switch err {
case nil:
case io.EOF:
return nil
default:
s := err.Error()
if strings.Contains(s, "closed") {
return nil
}
return fmt.Errorf("fail to decode the message from '%s': %s",
pc.RemoteAddr().String(), s)
}
if err = s.c.HandleMessage(pc, msg, handler); err != nil {
return fmt.Errorf("fail to handle peer message from '%s': %s",
pc.RemoteAddr().String(), err)
}
}
}

312
tracker/httptracker/http.go Normal file
View File

@ -0,0 +1,312 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package httptracker implements the tracker protocol based on HTTP/HTTPS.
//
// You can use the package to implement a HTTP tracker server to track the
// information that other peers upload or download the file, or to create
// a HTTP tracker client to communicate with the HTTP tracker server.
package httptracker
import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/xgfone/bt/bencode"
"github.com/xgfone/bt/metainfo"
)
// AnnounceRequest is the tracker announce requests.
//
// BEP 3
type AnnounceRequest struct {
// InfoHash is the sha1 hash of the bencoded form of the info value from the metainfo file.
InfoHash metainfo.Hash `bencode:"info_hash"` // BEP 3
// PeerID is the id of the downloader.
//
// Each downloader generates its own id at random at the start of a new download.
PeerID metainfo.Hash `bencode:"peer_id"` // BEP 3
// Uploaded is the total amount uploaded so far, encoded in base ten ascii.
Uploaded int64 `bencode:"uploaded"` // BEP 3
// Downloaded is the total amount downloaded so far, encoded in base ten ascii.
Downloaded int64 `bencode:"downloaded"` // BEP 3
// Left is the number of bytes this peer still has to download,
// encoded in base ten ascii.
//
// Note that this can't be computed from downloaded and the file length
// since it might be a resume, and there's a chance that some of the
// downloaded data failed an integrity check and had to be re-downloaded.
//
// If less than 0, math.MaxInt64 will be used for HTTP trackers instead.
Left int64 `bencode:"left"` // BEP 3
// Port is the port that this peer is listening on.
//
// Common behavior is for a downloader to try to listen on port 6881,
// and if that port is taken try 6882, then 6883, etc. and give up after 6889.
Port uint16 `bencode:"port"` // BEP 3
// IP is the ip or DNS name which this peer is at, which generally used
// for the origin if it's on the same machine as the tracker.
//
// Optional.
IP string `bencode:"ip,omitempty"` // BEP 3
// If not present, this is one of the announcements done at regular intervals.
// An announcement using started is sent when a download first begins,
// and one using completed is sent when the download is complete.
// No completed is sent if the file was complete when started.
// Downloaders send an announcement using stopped when they cease downloading.
//
// Optional
Event uint32 `bencode:"event,omitempty"` // BEP 3
// Compact indicates whether it hopes the tracker to return the compact
// peer lists.
//
// Optional
Compact bool `bencode:"compact,omitempty"` // BEP 23
// NumWant is the number of peers that the client would like to receive
// from the tracker. This value is permitted to be zero. If omitted,
// typically defaults to 50 peers.
//
// See https://wiki.theory.org/index.php/BitTorrentSpecification
//
// Optional.
NumWant int32 `bencode:"numwant,omitempty"`
Key int32 `bencode:"key,omitempty"`
}
// ToQuery converts the Request to URL Query.
func (r AnnounceRequest) ToQuery() (vs url.Values) {
vs = make(url.Values, 9)
vs.Set("info_hash", r.InfoHash.BytesString())
vs.Set("peer_id", r.PeerID.BytesString())
vs.Set("uploaded", strconv.FormatInt(r.Uploaded, 10))
vs.Set("downloaded", strconv.FormatInt(r.Downloaded, 10))
vs.Set("left", strconv.FormatInt(r.Left, 10))
if r.IP != "" {
vs.Set("ip", r.IP)
}
if r.Event > 0 {
vs.Set("event", strconv.FormatInt(int64(r.Event), 10))
}
if r.Port > 0 {
vs.Set("port", strconv.FormatUint(uint64(r.Port), 10))
}
if r.NumWant > 0 {
vs.Set("numwant", strconv.FormatUint(uint64(r.NumWant), 10))
}
if r.Key != 0 {
vs.Set("key", strconv.FormatInt(int64(r.Key), 10))
}
// BEP 23
if r.Compact {
vs.Set("compact", "1")
} else {
vs.Set("compact", "0")
}
return
}
// FromQuery converts URL Query to itself.
func (r *AnnounceRequest) FromQuery(vs url.Values) (err error) {
if err = r.InfoHash.FromString(vs.Get("info_hash")); err != nil {
return
}
if err = r.PeerID.FromString(vs.Get("peer_id")); err != nil {
return
}
v, err := strconv.ParseInt(vs.Get("uploaded"), 10, 64)
if err != nil {
return
}
r.Uploaded = v
v, err = strconv.ParseInt(vs.Get("downloaded"), 10, 64)
if err != nil {
return
}
r.Downloaded = v
v, err = strconv.ParseInt(vs.Get("left"), 10, 64)
if err != nil {
return
}
r.Left = v
if s := vs.Get("event"); s != "" {
v, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return err
}
r.Event = uint32(v)
}
if s := vs.Get("port"); s != "" {
v, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return err
}
r.Port = uint16(v)
}
if s := vs.Get("numwant"); s != "" {
v, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return err
}
r.NumWant = int32(v)
}
if s := vs.Get("key"); s != "" {
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return err
}
r.Key = int32(v)
}
r.IP = vs.Get("ip")
switch vs.Get("compact") {
case "1":
r.Compact = true
case "0":
r.Compact = false
}
return
}
// AnnounceResponse is a announce response.
type AnnounceResponse struct {
FailureReason string `bencode:"failure reason,omitempty"`
// Interval is the seconds the downloader should wait before next rerequest.
Interval uint32 `bencode:"interval,omitempty"` // BEP 3
// Peers is the list of the peers.
Peers Peers `bencode:"peers,omitempty"` // BEP 3, BEP 23
// Peers6 is only used for ipv6 in the compact case.
Peers6 Peers6 `bencode:"peers6,omitempty"` // BEP 7
// Where's this specified?
// Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification.
// Complete is the number of peers with the entire file.
Complete uint32 `bencode:"complete,omitempty"`
// Incomplete is the number of non-seeder peers.
Incomplete uint32 `bencode:"incomplete,omitempty"`
// TrackerID is that the client should send back on its next announcements.
// If absent and a previous announce sent a tracker id,
// do not discard the old value; keep using it.
TrackerID string `bencode:"tracker id,omitempty"`
}
// ScrapeResponseResult is the result of the scraped file.
type ScrapeResponseResult struct {
// Complete is the number of active peers that have completed downloading.
Complete uint32 `bencode:"complete"` // BEP 48
// Incomplete is the number of active peers that have not completed downloading.
Incomplete uint32 `bencode:"incomplete"` // BEP 48
// The number of peers that have ever completed downloading.
Downloaded uint32 `bencode:"downloaded"` // BEP 48
}
// ScrapeResponse represents a Scrape response.
//
// BEP 48
type ScrapeResponse struct {
FailureReason string `bencode:"failure_reason,omitempty"`
Files map[metainfo.Hash]ScrapeResponseResult `bencode:"files,omitempty"`
}
// DecodeFrom reads the []byte data from r and decodes them to sr by bencode.
//
// r may be the body of the request from the http client.
func (sr *ScrapeResponse) DecodeFrom(r io.Reader) (err error) {
return bencode.NewDecoder(r).Decode(sr)
}
// EncodeTo encodes the response to []byte by bencode and write the result into w.
//
// w may be http.ResponseWriter.
func (sr ScrapeResponse) EncodeTo(w io.Writer) (err error) {
return bencode.NewEncoder(w).Encode(sr)
}
// TrackerClient represents a tracker client based on HTTP/HTTPS.
type TrackerClient struct {
AnnounceURL string
ScrapeURL string
}
// NewTrackerClient returns a new HTTPTrackerClient.
//
// scrapeURL may be empty, which will replace the "announce" in announceURL
// with "scrape" to generate the scrapeURL.
func NewTrackerClient(announceURL, scrapeURL string) *TrackerClient {
if scrapeURL == "" {
scrapeURL = strings.Replace(announceURL, "announce", "scrape", -1)
}
return &TrackerClient{AnnounceURL: announceURL, ScrapeURL: scrapeURL}
}
func (t *TrackerClient) send(u string, vs url.Values, r interface{}) (err error) {
sym := "?"
if strings.IndexByte(u, '?') > 0 {
sym = "&"
}
resp, err := http.Get(u + sym + vs.Encode())
if err != nil {
return
}
defer resp.Body.Close()
return bencode.NewDecoder(resp.Body).Decode(r)
}
// Announce sends a Announce request to the tracker.
func (t *TrackerClient) Announce(req AnnounceRequest) (resp AnnounceResponse, err error) {
err = t.send(t.AnnounceURL, req.ToQuery(), &resp)
return
}
// Scrape sends a Scrape request to the tracker.
func (t *TrackerClient) Scrape(infohashes []metainfo.Hash) (resp ScrapeResponse, err error) {
hs := make([]string, len(infohashes))
for i, h := range infohashes {
hs[i] = h.BytesString()
}
err = t.send(t.ScrapeURL, url.Values{"info_hash": hs}, &resp)
return
}

View File

@ -0,0 +1,192 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httptracker
import (
"bytes"
"encoding/binary"
"errors"
"net"
"github.com/xgfone/bt/bencode"
"github.com/xgfone/bt/metainfo"
)
var errInvalidPeer = errors.New("invalid peer information format")
// Peer is a tracker peer.
type Peer struct {
// ID is the peer's self-selected ID.
ID string `bencode:"peer id"` // BEP 3
// IP is the IP address or dns name.
IP string `bencode:"ip"` // BEP 3
Port uint16 `bencode:"port"` // BEP 3
}
// Addresses returns the list of the addresses that the peer listens on.
func (p Peer) Addresses() (addrs []metainfo.Address, err error) {
if ip := net.ParseIP(p.IP); len(ip) != 0 {
return []metainfo.Address{{IP: ip, Port: p.Port}}, nil
}
ips, err := net.LookupIP(p.IP)
if _len := len(ips); err == nil && len(ips) != 0 {
addrs = make([]metainfo.Address, _len)
for i, ip := range ips {
addrs[i] = metainfo.Address{IP: ip, Port: p.Port}
}
}
return
}
// Peers is a set of the peers.
type Peers []Peer
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (ps *Peers) UnmarshalBencode(b []byte) (err error) {
var v interface{}
if err = bencode.DecodeBytes(b, &v); err != nil {
return
}
switch vs := v.(type) {
case string: // BEP 23
_len := len(vs)
if _len%6 != 0 {
return metainfo.ErrInvalidAddr
}
peers := make(Peers, 0, _len/6)
for i := 0; i < _len; i += 6 {
var addr metainfo.Address
if err = addr.UnmarshalBinary([]byte(vs[i : i+6])); err != nil {
return
}
peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port})
}
*ps = peers
case []interface{}: // BEP 3
peers := make(Peers, len(vs))
for i, p := range vs {
m, ok := p.(map[string]interface{})
if !ok {
return errInvalidPeer
}
pid, ok := m["peer id"].(string)
if !ok {
return errInvalidPeer
}
ip, ok := m["ip"].(string)
if !ok {
return errInvalidPeer
}
port, ok := m["port"].(int64)
if !ok {
return errInvalidPeer
}
peers[i] = Peer{ID: pid, IP: ip, Port: uint16(port)}
}
*ps = peers
default:
return errInvalidPeer
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (ps Peers) MarshalBencode() (b []byte, err error) {
for _, p := range ps {
if p.ID == "" {
return ps.marshalCompactBencode() // BEP 23
}
}
// BEP 3
buf := bytes.NewBuffer(make([]byte, 0, 64*len(ps)))
buf.WriteByte('l')
for _, p := range ps {
if err = bencode.NewEncoder(buf).Encode(p); err != nil {
return
}
}
buf.WriteByte('e')
b = buf.Bytes()
return
}
func (ps Peers) marshalCompactBencode() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 6*len(ps)))
for _, peer := range ps {
ip := net.ParseIP(peer.IP).To4()
if len(ip) == 0 {
return nil, errInvalidPeer
}
buf.Write(ip[:])
binary.Write(buf, binary.BigEndian, peer.Port)
}
return bencode.EncodeBytes(buf.Bytes())
}
// Peers6 is a set of the peers for IPv6 in the compact case.
//
// BEP 7
type Peers6 []Peer
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (ps *Peers6) UnmarshalBencode(b []byte) (err error) {
var s string
if err = bencode.DecodeBytes(b, &s); err != nil {
return
}
_len := len(s)
if _len%18 != 0 {
return metainfo.ErrInvalidAddr
}
peers := make(Peers6, 0, _len/18)
for i := 0; i < _len; i += 18 {
var addr metainfo.Address
if err = addr.UnmarshalBinary([]byte(s[i : i+18])); err != nil {
return
}
peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port})
}
*ps = peers
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (ps Peers6) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 18*len(ps)))
for _, peer := range ps {
ip := net.ParseIP(peer.IP).To16()
if len(ip) == 0 {
return nil, errInvalidPeer
}
buf.Write(ip[:])
binary.Write(buf, binary.BigEndian, peer.Port)
}
return bencode.EncodeBytes(buf.Bytes())
}

View File

@ -0,0 +1,75 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httptracker
import (
"reflect"
"testing"
)
func TestPeers(t *testing.T) {
peers := Peers{
{ID: "123", IP: "1.1.1.1", Port: 80},
{ID: "456", IP: "2.2.2.2", Port: 81},
}
b, err := peers.MarshalBencode()
if err != nil {
t.Fatal(err)
}
var ps Peers
if err = ps.UnmarshalBencode(b); err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(ps, peers) {
t.Errorf("%v != %v", ps, peers)
}
/// For BEP 23
peers = Peers{
{IP: "1.1.1.1", Port: 80},
{IP: "2.2.2.2", Port: 81},
}
b, err = peers.MarshalBencode()
if err != nil {
t.Fatal(err)
}
if err = ps.UnmarshalBencode(b); err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(ps, peers) {
t.Errorf("%v != %v", ps, peers)
}
}
func TestPeers6(t *testing.T) {
peers := Peers6{
{IP: "fe80::5054:ff:fef0:1ab", Port: 80},
{IP: "fe80::5054:ff:fe29:205d", Port: 81},
}
b, err := peers.MarshalBencode()
if err != nil {
t.Fatal(err)
}
var ps Peers6
if err = ps.UnmarshalBencode(b); err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(ps, peers) {
t.Errorf("%v != %v", ps, peers)
}
}

View File

@ -0,0 +1,72 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httptracker
import (
"reflect"
"testing"
"github.com/xgfone/bt/metainfo"
)
func TestHTTPAnnounceRequest(t *testing.T) {
infohash := metainfo.NewRandomHash()
peerid := metainfo.NewRandomHash()
v1 := AnnounceRequest{
InfoHash: infohash,
PeerID: peerid,
Uploaded: 789,
Downloaded: 456,
Left: 123,
Port: 80,
Event: 123,
Compact: true,
}
vs := v1.ToQuery()
var v2 AnnounceRequest
if err := v2.FromQuery(vs); err != nil {
t.Fatal(err)
}
if v2.InfoHash != infohash {
t.Error(v2.InfoHash)
}
if v2.PeerID != peerid {
t.Error(v2.PeerID)
}
if v2.Uploaded != 789 {
t.Error(v2.Uploaded)
}
if v2.Downloaded != 456 {
t.Error(v2.Downloaded)
}
if v2.Left != 123 {
t.Error(v2.Left)
}
if v2.Port != 80 {
t.Error(v2.Port)
}
if v2.Event != 123 {
t.Error(v2.Event)
}
if !v2.Compact {
t.Error(v2.Compact)
}
if !reflect.DeepEqual(v1, v2) {
t.Errorf("%v != %v", v1, v2)
}
}

266
tracker/tracker.go Normal file
View File

@ -0,0 +1,266 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tracker supplies some common type interfaces of the BT tracker
// protocol.
package tracker
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"net/url"
"github.com/xgfone/bt/metainfo"
"github.com/xgfone/bt/tracker/httptracker"
"github.com/xgfone/bt/tracker/udptracker"
)
// Predefine some announce events.
//
// BEP 3
const (
None uint32 = iota
Completed // The local peer just completed the torrent.
Started // The local peer has just resumed this torrent.
Stopped // The local peer is leaving the swarm.
)
// AnnounceRequest is the common Announce request.
//
// BEP 3, 15
type AnnounceRequest struct {
InfoHash metainfo.Hash
PeerID metainfo.Hash
Uploaded int64
Downloaded int64
Left int64
Event uint32
IP net.IP
Key int32
NumWant int32 // -1 for default
Port uint16
}
// ToHTTPAnnounceRequest creates a new httptracker.AnnounceRequest from itself.
func (ar AnnounceRequest) ToHTTPAnnounceRequest() httptracker.AnnounceRequest {
var ip string
if len(ar.IP) != 0 {
ip = ar.IP.String()
}
return httptracker.AnnounceRequest{
InfoHash: ar.InfoHash,
PeerID: ar.PeerID,
Uploaded: ar.Uploaded,
Downloaded: ar.Downloaded,
Left: ar.Left,
Port: ar.Port,
IP: ip,
Event: ar.Event,
NumWant: ar.NumWant,
Key: ar.Key,
}
}
// ToUDPAnnounceRequest creates a new udptracker.AnnounceRequest from itself.
func (ar AnnounceRequest) ToUDPAnnounceRequest() udptracker.AnnounceRequest {
return udptracker.AnnounceRequest{
InfoHash: ar.InfoHash,
PeerID: ar.PeerID,
Downloaded: ar.Downloaded,
Left: ar.Left,
Uploaded: ar.Uploaded,
Event: ar.Event,
IP: ar.IP,
Key: ar.Key,
NumWant: ar.NumWant,
Port: ar.Port,
}
}
// AnnounceResponse is a common Announce response.
//
// BEP 3, 15
type AnnounceResponse struct {
Interval uint32
Leechers uint32
Seeders uint32
Addresses []metainfo.Address
}
// FromHTTPAnnounceResponse sets itself from r.
func (ar *AnnounceResponse) FromHTTPAnnounceResponse(r httptracker.AnnounceResponse) {
ar.Interval = r.Interval
ar.Leechers = r.Incomplete
ar.Seeders = r.Complete
ar.Addresses = make([]metainfo.Address, 0, len(r.Peers)+len(r.Peers6))
for _, peer := range r.Peers {
addrs, _ := peer.Addresses()
ar.Addresses = append(ar.Addresses, addrs...)
}
for _, peer := range r.Peers6 {
addrs, _ := peer.Addresses()
ar.Addresses = append(ar.Addresses, addrs...)
}
}
// FromUDPAnnounceResponse sets itself from r.
func (ar *AnnounceResponse) FromUDPAnnounceResponse(r udptracker.AnnounceResponse) {
ar.Interval = r.Interval
ar.Leechers = r.Leechers
ar.Seeders = r.Seeders
ar.Addresses = r.Addresses
}
// ScrapeResponseResult is a commont Scrape response result.
type ScrapeResponseResult struct {
// Seeders is the number of active peers that have completed downloading.
Seeders uint32 `bencode:"complete"` // BEP 15, 48
// Leechers is the number of active peers that have not completed downloading.
Leechers uint32 `bencode:"incomplete"` // BEP 15, 48
// Completed is the total number of peers that have ever completed downloading.
Completed uint32 `bencode:"downloaded"` // BEP 15, 48
}
// EncodeTo encodes the response to buf.
func (r ScrapeResponseResult) EncodeTo(buf *bytes.Buffer) {
binary.Write(buf, binary.BigEndian, r.Seeders)
binary.Write(buf, binary.BigEndian, r.Completed)
binary.Write(buf, binary.BigEndian, r.Leechers)
}
// DecodeFrom decodes the response from b.
func (r *ScrapeResponseResult) DecodeFrom(b []byte) {
r.Seeders = binary.BigEndian.Uint32(b[:4])
r.Completed = binary.BigEndian.Uint32(b[4:8])
r.Leechers = binary.BigEndian.Uint32(b[8:12])
}
// ScrapeResponse is a commont Scrape response.
type ScrapeResponse map[metainfo.Hash]ScrapeResponseResult
// FromHTTPScrapeResponse sets itself from r.
func (sr ScrapeResponse) FromHTTPScrapeResponse(r httptracker.ScrapeResponse) {
for k, v := range r.Files {
sr[k] = ScrapeResponseResult{
Seeders: v.Complete,
Leechers: v.Incomplete,
Completed: v.Downloaded,
}
}
}
// FromUDPScrapeResponse sets itself from hs and r.
func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash,
r []udptracker.ScrapeResponse) {
klen := len(hs)
if _len := len(r); _len < klen {
klen = _len
}
for i := 0; i < klen; i++ {
sr[hs[i]] = ScrapeResponseResult{
Seeders: r[i].Seeders,
Leechers: r[i].Leechers,
Completed: r[i].Completed,
}
}
}
// Client is the interface of BT tracker client.
type Client interface {
Announce(AnnounceRequest) (AnnounceResponse, error)
Scrape([]metainfo.Hash) (ScrapeResponse, error)
}
// NewClient returns a new Client.
func NewClient(connURL string) (c Client, err error) {
u, err := url.Parse(connURL)
if err == nil {
switch u.Scheme {
case "http", "https":
c = &tclient{http: httptracker.NewTrackerClient(connURL, "")}
case "udp", "udp4", "udp6":
var utc *udptracker.TrackerClient
utc, err = udptracker.NewTrackerClientByDial(u.Scheme, u.Host)
if err == nil {
var e []udptracker.Extension
if p := u.RequestURI(); p != "" {
e = []udptracker.Extension{udptracker.NewURLData([]byte(p))}
}
c = &tclient{exts: e, udp: utc}
}
default:
err = fmt.Errorf("unknown url scheme '%s'", u.Scheme)
}
}
return
}
type tclient struct {
http *httptracker.TrackerClient // BEP 3
udp *udptracker.TrackerClient // BEP 15
exts []udptracker.Extension // BEP 41
}
func (c *tclient) Announce(req AnnounceRequest) (resp AnnounceResponse, err error) {
if c.http != nil {
var r httptracker.AnnounceResponse
if r, err = c.http.Announce(req.ToHTTPAnnounceRequest()); err != nil {
return
} else if r.FailureReason != "" {
err = errors.New(r.FailureReason)
return
}
resp.FromHTTPAnnounceResponse(r)
return
}
r := req.ToUDPAnnounceRequest()
r.Exts = c.exts
rs, err := c.udp.Announce(r)
if err == nil {
resp.FromUDPAnnounceResponse(rs)
}
return
}
func (c *tclient) Scrape(hs []metainfo.Hash) (resp ScrapeResponse, err error) {
if c.http != nil {
var r httptracker.ScrapeResponse
if r, err = c.http.Scrape(hs); err != nil {
return
} else if r.FailureReason != "" {
err = errors.New(r.FailureReason)
return
}
resp = make(ScrapeResponse, len(r.Files))
resp.FromHTTPScrapeResponse(r)
return
}
r, err := c.udp.Scrape(hs)
if err == nil {
resp = make(ScrapeResponse, len(r))
resp.FromUDPScrapeResponse(hs, r)
}
return
}

121
tracker/tracker_test.go Normal file
View File

@ -0,0 +1,121 @@
package tracker
import (
"errors"
"fmt"
"log"
"net"
"time"
"github.com/xgfone/bt/metainfo"
"github.com/xgfone/bt/tracker/udptracker"
)
type testHandler struct{}
func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return }
func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) (
r udptracker.AnnounceResponse, err error) {
if req.Port != 80 {
err = errors.New("port is not 80")
return
}
if len(req.Exts) > 0 {
for i, ext := range req.Exts {
switch ext.Type {
case udptracker.URLData:
fmt.Printf("Extensions[%d]: URLData(%s)\n", i, string(ext.Data))
default:
fmt.Printf("Extensions[%d]: %s\n", i, ext.Type.String())
}
}
}
r = udptracker.AnnounceResponse{
Interval: 1,
Leechers: 2,
Seeders: 3,
Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}},
}
return
}
func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (
rs []udptracker.ScrapeResponse, err error) {
rs = make([]udptracker.ScrapeResponse, len(infohashes))
for i := range infohashes {
rs[i] = udptracker.ScrapeResponse{
Seeders: uint32(i)*10 + 1,
Leechers: uint32(i)*10 + 2,
Completed: uint32(i)*10 + 3,
}
}
return
}
func ExampleClient() {
// Start the UDP tracker server
sconn, err := net.ListenPacket("udp4", "127.0.0.1:8001")
if err != nil {
log.Fatal(err)
}
server := udptracker.NewTrackerServer(sconn, testHandler{})
defer server.Close()
go server.Run()
// Wait for the server to be started
time.Sleep(time.Second)
// Create a client and dial to the UDP tracker server.
client, err := NewClient("udp://127.0.0.1:8001/path?a=1&b=2")
if err != nil {
log.Fatal(err)
}
// Send the ANNOUNCE request to the UDP tracker server,
// and get the ANNOUNCE response.
req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80}
resp, err := client.Announce(req)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Interval: %d\n", resp.Interval)
fmt.Printf("Leechers: %d\n", resp.Leechers)
fmt.Printf("Seeders: %d\n", resp.Seeders)
for i, addr := range resp.Addresses {
fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String())
fmt.Printf("Address[%d].Port: %d\n", i, addr.Port)
}
// Send the SCRAPE request to the UDP tracker server,
// and get the SCRAPE respsone.
h1 := metainfo.Hash{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
h2 := metainfo.Hash{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
rs, err := client.Scrape([]metainfo.Hash{h1, h2})
if err != nil {
log.Fatal(err)
} else if len(rs) != 2 {
log.Fatalf("%+v", rs)
}
for i, r := range rs {
fmt.Printf("%s.Seeders: %d\n", i.HexString(), r.Seeders)
fmt.Printf("%s.Leechers: %d\n", i.HexString(), r.Leechers)
fmt.Printf("%s.Completed: %d\n", i.HexString(), r.Completed)
}
// Output:
// Extensions[0]: URLData(/path?a=1&b=2)
// Interval: 1
// Leechers: 2
// Seeders: 3
// Address[0].IP: 127.0.0.1
// Address[0].Port: 8001
// 0101010101010101010101010101010101010101.Seeders: 1
// 0101010101010101010101010101010101010101.Leechers: 2
// 0101010101010101010101010101010101010101.Completed: 3
// 0202020202020202020202020202020202020202.Seeders: 11
// 0202020202020202020202020202020202020202.Leechers: 12
// 0202020202020202020202020202020202020202.Completed: 13
}

271
tracker/udptracker/udp.go Normal file
View File

@ -0,0 +1,271 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package udptracker implements the tracker protocol based on UDP.
//
// You can use the package to implement a UDP tracker server to track the
// information that other peers upload or download the file, or to create
// a UDP tracker client to communicate with the UDP tracker server.
package udptracker
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"github.com/xgfone/bt/metainfo"
)
// ProtocolID is magic constant for the udp tracker connection.
//
// BEP 15
const ProtocolID = uint64(0x41727101980)
// Predefine some actions.
//
// BEP 15
const (
ActionConnect = uint32(0)
ActionAnnounce = uint32(1)
ActionScrape = uint32(2)
ActionError = uint32(3)
)
// AnnounceRequest represents the announce request used by UDP tracker.
//
// BEP 15
type AnnounceRequest struct {
InfoHash metainfo.Hash
PeerID metainfo.Hash
Downloaded int64
Left int64
Uploaded int64
Event uint32
IP net.IP
Key int32
NumWant int32 // -1 for default
Port uint16
Exts []Extension // BEP 41
}
// DecodeFrom decodes the request from b.
func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) {
r.InfoHash = metainfo.NewHash(b[0:20])
r.PeerID = metainfo.NewHash(b[20:40])
r.Downloaded = int64(binary.BigEndian.Uint64(b[40:48]))
r.Left = int64(binary.BigEndian.Uint64(b[48:56]))
r.Uploaded = int64(binary.BigEndian.Uint64(b[56:64]))
r.Event = binary.BigEndian.Uint32(b[64:68])
if ipv4 {
r.IP = make(net.IP, net.IPv4len)
copy(r.IP, b[68:72])
b = b[72:]
} else {
r.IP = make(net.IP, net.IPv6len)
copy(r.IP, b[68:84])
b = b[84:]
}
r.Key = int32(binary.BigEndian.Uint32(b[0:4]))
r.NumWant = int32(binary.BigEndian.Uint32(b[4:8]))
r.Port = binary.BigEndian.Uint16(b[8:10])
b = b[10:]
for len(b) > 0 {
var ext Extension
parsed := ext.DecodeFrom(b)
r.Exts = append(r.Exts, ext)
b = b[parsed:]
}
}
// EncodeTo encodes the request to buf.
func (r AnnounceRequest) EncodeTo(buf *bytes.Buffer) {
buf.Grow(82)
buf.Write(r.InfoHash[:])
buf.Write(r.PeerID[:])
binary.Write(buf, binary.BigEndian, r.Downloaded)
binary.Write(buf, binary.BigEndian, r.Left)
binary.Write(buf, binary.BigEndian, r.Uploaded)
binary.Write(buf, binary.BigEndian, r.Event)
if ip := r.IP.To4(); ip != nil {
buf.Write(ip[:])
} else {
buf.Write(r.IP[:])
}
binary.Write(buf, binary.BigEndian, r.Key)
binary.Write(buf, binary.BigEndian, r.NumWant)
binary.Write(buf, binary.BigEndian, r.Port)
for _, ext := range r.Exts {
ext.EncodeTo(buf)
}
}
// AnnounceResponse represents the announce response used by UDP tracker.
//
// BEP 15
type AnnounceResponse struct {
Interval uint32
Leechers uint32
Seeders uint32
Addresses []metainfo.Address
}
// EncodeTo encodes the response to buf.
func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer) {
buf.Grow(12 + len(r.Addresses)*18)
binary.Write(buf, binary.BigEndian, r.Interval)
binary.Write(buf, binary.BigEndian, r.Leechers)
binary.Write(buf, binary.BigEndian, r.Seeders)
for _, addr := range r.Addresses {
if ip := addr.IP.To4(); ip != nil {
buf.Write(ip[:])
} else {
buf.Write(addr.IP[:])
}
binary.Write(buf, binary.BigEndian, addr.Port)
}
}
// DecodeFrom decodes the response from b.
func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) {
r.Interval = binary.BigEndian.Uint32(b[:4])
r.Leechers = binary.BigEndian.Uint32(b[4:8])
r.Seeders = binary.BigEndian.Uint32(b[8:12])
b = b[12:]
iplen := net.IPv6len
if ipv4 {
iplen = net.IPv4len
}
_len := len(b)
step := iplen + 2
r.Addresses = make([]metainfo.Address, 0, _len/step)
for i := step; i <= _len; i += step {
ip := make(net.IP, iplen)
copy(ip, b[i-step:i-2])
port := binary.BigEndian.Uint16(b[i-2 : i])
r.Addresses = append(r.Addresses, metainfo.Address{IP: ip, Port: port})
}
}
// ScrapeResponse represents the UDP SCRAPE response.
//
// BEP 15
type ScrapeResponse struct {
Seeders uint32
Leechers uint32
Completed uint32
}
// EncodeTo encodes the response to buf.
func (r ScrapeResponse) EncodeTo(buf *bytes.Buffer) {
binary.Write(buf, binary.BigEndian, r.Seeders)
binary.Write(buf, binary.BigEndian, r.Completed)
binary.Write(buf, binary.BigEndian, r.Leechers)
}
// DecodeFrom decodes the response from b.
func (r *ScrapeResponse) DecodeFrom(b []byte) {
r.Seeders = binary.BigEndian.Uint32(b[:4])
r.Completed = binary.BigEndian.Uint32(b[4:8])
r.Leechers = binary.BigEndian.Uint32(b[8:12])
}
// Predefine some UDP extension types.
//
// BEP 41
const (
EndOfOptions ExtensionType = iota
Nop
URLData
)
// NewEndOfOptions returns a new EndOfOptions UDP extension.
func NewEndOfOptions() Extension { return Extension{Type: EndOfOptions} }
// NewNop returns a new Nop UDP extension.
func NewNop() Extension { return Extension{Type: Nop} }
// NewURLData returns a new URLData UDP extension.
func NewURLData(data []byte) Extension {
return Extension{Type: URLData, Length: uint8(len(data)), Data: data}
}
// ExtensionType represents the type of UDP extension.
type ExtensionType uint8
func (et ExtensionType) String() string {
switch et {
case EndOfOptions:
return "EndOfOptions"
case Nop:
return "Nop"
case URLData:
return "URLData"
default:
return fmt.Sprintf("ExtensionType(%d)", et)
}
}
// Extension represent the extension used by the UDP ANNOUNCE request.
//
// BEP 41
type Extension struct {
Type ExtensionType
Length uint8
Data []byte
}
// EncodeTo encodes the response to buf.
func (e Extension) EncodeTo(buf *bytes.Buffer) {
if _len := uint8(len(e.Data)); e.Length == 0 && _len != 0 {
e.Length = _len
} else if _len != e.Length {
panic("the length of data is inconsistent")
}
buf.WriteByte(byte(e.Type))
buf.WriteByte(e.Length)
buf.Write(e.Data)
}
// DecodeFrom decodes the response from b.
func (e *Extension) DecodeFrom(b []byte) (parsed int) {
switch len(b) {
case 0:
case 1:
e.Type = ExtensionType(b[0])
parsed = 1
default:
e.Type = ExtensionType(b[0])
e.Length = b[1]
parsed = 2
if e.Length > 0 {
parsed += int(e.Length)
e.Data = b[2:parsed]
}
}
return
}

View File

@ -0,0 +1,289 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package udptracker
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"strings"
"sync/atomic"
"time"
"github.com/xgfone/bt/metainfo"
)
// NewTrackerClientByDial returns a new TrackerClient by dialing.
func NewTrackerClientByDial(network, address string, c ...TrackerClientConfig) (
*TrackerClient, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewTrackerClient(conn.(*net.UDPConn), c...), nil
}
// NewTrackerClient returns a new TrackerClient.
func NewTrackerClient(conn *net.UDPConn, c ...TrackerClientConfig) *TrackerClient {
var conf TrackerClientConfig
conf.set(c...)
ipv4 := strings.Contains(conn.LocalAddr().String(), ".")
return &TrackerClient{conn: conn, conf: conf, ipv4: ipv4}
}
// TrackerClientConfig is used to configure the TrackerClient.
type TrackerClientConfig struct {
// ReadTimeout is used to receive the response.
ReadTimeout time.Duration // Default: 5s
MaxBufSize int // Default: 2048
}
func (c *TrackerClientConfig) set(conf ...TrackerClientConfig) {
if len(conf) > 0 {
*c = conf[0]
}
if c.MaxBufSize <= 0 {
c.MaxBufSize = 2048
}
if c.ReadTimeout <= 0 {
c.ReadTimeout = time.Second * 5
}
}
// TrackerClient is a tracker client based on UDP.
//
// Notice: the request is synchronized, that's, the last request is not returned,
// the next request must not be sent.
//
// BEP 15
type TrackerClient struct {
ipv4 bool
conf TrackerClientConfig
conn *net.UDPConn
last time.Time
cid uint64
tid uint32
}
// Close closes the UDP tracker client.
func (utc *TrackerClient) Close() { utc.conn.Close() }
func (utc *TrackerClient) readResp(b []byte) (int, error) {
utc.conn.SetReadDeadline(time.Now().Add(utc.conf.ReadTimeout))
return utc.conn.Read(b)
}
func (utc *TrackerClient) getTranID() uint32 {
return atomic.AddUint32(&utc.tid, 1)
}
func (utc *TrackerClient) parseError(b []byte) (tid uint32, reason string) {
tid = binary.BigEndian.Uint32(b[:4])
reason = string(b[4:])
return
}
func (utc *TrackerClient) send(b []byte) (err error) {
n, err := utc.conn.Write(b)
if err == nil && n < len(b) {
err = io.ErrShortWrite
}
return
}
func (utc *TrackerClient) connect() (err error) {
tid := utc.getTranID()
buf := bytes.NewBuffer(make([]byte, 0, 16))
binary.Write(buf, binary.BigEndian, ProtocolID)
binary.Write(buf, binary.BigEndian, ActionConnect)
binary.Write(buf, binary.BigEndian, tid)
if err = utc.send(buf.Bytes()); err != nil {
return
}
data := make([]byte, 32)
n, err := utc.readResp(data)
if err != nil {
return
} else if n < 8 {
return io.ErrShortBuffer
}
data = data[:n]
switch binary.BigEndian.Uint32(data[:4]) {
case ActionConnect:
case ActionError:
_, reason := utc.parseError(data[4:])
return errors.New(reason)
default:
return errors.New("tracker response not connect action")
}
if n < 16 {
return io.ErrShortBuffer
}
if binary.BigEndian.Uint32(data[4:8]) != tid {
return errors.New("invalid transaction id")
}
utc.cid = binary.BigEndian.Uint64(data[8:16])
utc.last = time.Now()
return
}
func (utc *TrackerClient) getConnectionID() (cid uint64, err error) {
cid = utc.cid
if time.Now().Sub(utc.last) > time.Minute {
if err = utc.connect(); err == nil {
cid = utc.cid
}
}
return
}
func (utc *TrackerClient) announce(req AnnounceRequest) (r AnnounceResponse, err error) {
cid, err := utc.getConnectionID()
if err != nil {
return
}
tid := utc.getTranID()
buf := bytes.NewBuffer(make([]byte, 0, 110))
binary.Write(buf, binary.BigEndian, cid)
binary.Write(buf, binary.BigEndian, ActionAnnounce)
binary.Write(buf, binary.BigEndian, tid)
req.EncodeTo(buf)
b := buf.Bytes()
if err = utc.send(b); err != nil {
return
}
data := make([]byte, utc.conf.MaxBufSize)
n, err := utc.readResp(data)
if err != nil {
return
} else if n < 8 {
err = io.ErrShortBuffer
return
}
data = data[:n]
switch binary.BigEndian.Uint32(data[:4]) {
case ActionAnnounce:
case ActionError:
_, reason := utc.parseError(data[4:])
err = errors.New(reason)
return
default:
err = errors.New("tracker response not connect action")
return
}
if n < 16 {
err = io.ErrShortBuffer
return
}
if binary.BigEndian.Uint32(data[4:8]) != tid {
err = errors.New("invalid transaction id")
return
}
r.DecodeFrom(data[8:], utc.ipv4)
return
}
// Announce sends a Announce request to the tracker.
//
// Notice:
// 1. if it does not connect to the UDP tracker server, it will connect to it,
// then send the ANNOUNCE request.
// 2. If returning an error, you should retry it.
// See http://www.bittorrent.org/beps/bep_0015.html#time-outs
func (utc *TrackerClient) Announce(r AnnounceRequest) (AnnounceResponse, error) {
return utc.announce(r)
}
func (utc *TrackerClient) scrape(infohashes []metainfo.Hash) (rs []ScrapeResponse, err error) {
cid, err := utc.getConnectionID()
if err != nil {
return
}
tid := utc.getTranID()
buf := bytes.NewBuffer(make([]byte, 0, 16+len(infohashes)*20))
binary.Write(buf, binary.BigEndian, cid)
binary.Write(buf, binary.BigEndian, ActionScrape)
binary.Write(buf, binary.BigEndian, tid)
for _, h := range infohashes {
buf.Write(h[:])
}
if err = utc.send(buf.Bytes()); err != nil {
return
}
data := make([]byte, utc.conf.MaxBufSize)
n, err := utc.readResp(data)
if err != nil {
return
} else if n < 8 {
err = io.ErrShortBuffer
return
}
data = data[:n]
switch binary.BigEndian.Uint32(data[:4]) {
case ActionScrape:
case ActionError:
_, reason := utc.parseError(data[4:])
err = errors.New(reason)
return
default:
err = errors.New("tracker response not connect action")
return
}
if binary.BigEndian.Uint32(data[4:8]) != tid {
err = errors.New("invalid transaction id")
return
}
data = data[8:]
_len := len(data)
rs = make([]ScrapeResponse, 0, _len/12)
for i := 12; i <= _len; i += 12 {
var r ScrapeResponse
r.DecodeFrom(data[i-12 : i])
rs = append(rs, r)
}
return
}
// Scrape sends a Scrape request to the tracker.
//
// Notice:
// 1. if it does not connect to the UDP tracker server, it will connect to it,
// then send the ANNOUNCE request.
// 2. If returning an error, you should retry it.
// See http://www.bittorrent.org/beps/bep_0015.html#time-outs
func (utc *TrackerClient) Scrape(hs []metainfo.Hash) ([]ScrapeResponse, error) {
return utc.scrape(hs)
}

View File

@ -0,0 +1,281 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package udptracker
import (
"bytes"
"encoding/binary"
"log"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/xgfone/bt/metainfo"
)
// TrackerServerHandler is used to handle the request from the client.
type TrackerServerHandler interface {
// OnConnect is used to check whether to make the connection or not.
OnConnect(raddr *net.UDPAddr) (err error)
OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (AnnounceResponse, error)
OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ([]ScrapeResponse, error)
}
func encodeResponseHeader(buf *bytes.Buffer, action, tid uint32) {
binary.Write(buf, binary.BigEndian, action)
binary.Write(buf, binary.BigEndian, tid)
}
type wrappedPeerAddr struct {
Addr *net.UDPAddr
Time time.Time
}
// TrackerServerConfig is used to configure the TrackerServer.
type TrackerServerConfig struct {
MaxBufSize int // Default: 2048
ErrorLog func(format string, args ...interface{}) // Default: log.Printf
}
func (c *TrackerServerConfig) setDefault() {
if c.MaxBufSize <= 0 {
c.MaxBufSize = 2048
}
if c.ErrorLog == nil {
c.ErrorLog = log.Printf
}
}
// TrackerServer is a tracker server based on UDP.
type TrackerServer struct {
conn net.PacketConn
conf TrackerServerConfig
handler TrackerServerHandler
bufpool sync.Pool
cid uint64
exit chan struct{}
lock sync.RWMutex
conns map[uint64]wrappedPeerAddr
}
// NewTrackerServer returns a new TrackerServer.
func NewTrackerServer(c net.PacketConn, h TrackerServerHandler,
config ...TrackerServerConfig) *TrackerServer {
var conf TrackerServerConfig
if len(config) > 0 {
conf = config[0]
}
conf.setDefault()
s := &TrackerServer{
conf: conf,
conn: c,
handler: h,
exit: make(chan struct{}),
conns: make(map[uint64]wrappedPeerAddr, 128),
}
s.bufpool.New = func() interface{} { return make([]byte, conf.MaxBufSize) }
return s
}
// Close closes the tracker server.
func (uts *TrackerServer) Close() {
select {
case <-uts.exit:
default:
close(uts.exit)
uts.conn.Close()
}
}
func (uts *TrackerServer) cleanConnectionID(interval time.Duration) {
tick := time.NewTicker(interval)
defer tick.Stop()
for {
select {
case <-uts.exit:
return
case now := <-tick.C:
uts.lock.RLock()
for cid, wa := range uts.conns {
if now.Sub(wa.Time) > interval {
delete(uts.conns, cid)
}
}
uts.lock.RUnlock()
}
}
}
// Run starts the tracker server.
func (uts *TrackerServer) Run() {
go uts.cleanConnectionID(time.Minute * 2)
for {
buf := uts.bufpool.Get().([]byte)
n, raddr, err := uts.conn.ReadFrom(buf)
if err != nil {
if !strings.Contains(err.Error(), "closed") {
uts.conf.ErrorLog("failed to read udp tracker request: %s", err)
}
return
} else if n < 16 {
continue
}
go uts.handleRequest(raddr.(*net.UDPAddr), buf, n)
}
}
func (uts *TrackerServer) handleRequest(raddr *net.UDPAddr, buf []byte, n int) {
defer uts.bufpool.Put(buf)
uts.handlePacket(raddr, buf[:n])
}
func (uts *TrackerServer) send(raddr *net.UDPAddr, b []byte) {
n, err := uts.conn.WriteTo(b, raddr)
if err != nil {
uts.conf.ErrorLog("fail to send the udp tracker response to '%s': %s",
raddr.String(), err)
} else if n < len(b) {
uts.conf.ErrorLog("too short udp tracker response sent to '%s'", raddr.String())
}
}
func (uts *TrackerServer) getConnectionID() uint64 {
return atomic.AddUint64(&uts.cid, 1)
}
func (uts *TrackerServer) addConnection(cid uint64, raddr *net.UDPAddr) {
now := time.Now()
uts.lock.Lock()
uts.conns[cid] = wrappedPeerAddr{Addr: raddr, Time: now}
uts.lock.Unlock()
}
func (uts *TrackerServer) checkConnection(cid uint64, raddr *net.UDPAddr) (ok bool) {
uts.lock.RLock()
if w, _ok := uts.conns[cid]; _ok && w.Addr.Port == raddr.Port &&
bytes.Equal(w.Addr.IP, raddr.IP) {
ok = true
}
uts.lock.RUnlock()
return
}
func (uts *TrackerServer) sendError(raddr *net.UDPAddr, tid uint32, reason string) {
buf := bytes.NewBuffer(make([]byte, 0, 8+len(reason)))
encodeResponseHeader(buf, ActionError, tid)
buf.WriteString(reason)
uts.send(raddr, buf.Bytes())
}
func (uts *TrackerServer) sendConnResp(raddr *net.UDPAddr, tid uint32, cid uint64) {
buf := bytes.NewBuffer(make([]byte, 0, 16))
encodeResponseHeader(buf, ActionConnect, tid)
binary.Write(buf, binary.BigEndian, cid)
uts.send(raddr, buf.Bytes())
}
func (uts *TrackerServer) sendAnnounceResp(raddr *net.UDPAddr, tid uint32,
resp AnnounceResponse) {
buf := bytes.NewBuffer(make([]byte, 0, 8+12+len(resp.Addresses)*18))
encodeResponseHeader(buf, ActionAnnounce, tid)
resp.EncodeTo(buf)
uts.send(raddr, buf.Bytes())
}
func (uts *TrackerServer) sendScrapResp(raddr *net.UDPAddr, tid uint32,
rs []ScrapeResponse) {
buf := bytes.NewBuffer(make([]byte, 0, 8+len(rs)*12))
encodeResponseHeader(buf, ActionScrape, tid)
for _, r := range rs {
r.EncodeTo(buf)
}
uts.send(raddr, buf.Bytes())
}
func (uts *TrackerServer) handlePacket(raddr *net.UDPAddr, b []byte) {
cid := binary.BigEndian.Uint64(b[:8])
action := binary.BigEndian.Uint32(b[8:12])
tid := binary.BigEndian.Uint32(b[12:16])
b = b[16:]
// Handle the connection request.
if cid == ProtocolID && action == ActionConnect {
if err := uts.handler.OnConnect(raddr); err != nil {
uts.sendError(raddr, tid, err.Error())
return
}
cid := uts.getConnectionID()
uts.addConnection(cid, raddr)
uts.sendConnResp(raddr, tid, cid)
return
}
// Check whether the request is connected.
if !uts.checkConnection(cid, raddr) {
uts.sendError(raddr, tid, "connection is expired")
return
}
switch action {
case ActionAnnounce:
var req AnnounceRequest
if raddr.IP.To4() != nil { // For ipv4
if len(b) < 82 {
uts.sendError(raddr, tid, "invalid announce request")
return
}
req.DecodeFrom(b, true)
} else { // For ipv6
if len(b) < 94 {
uts.sendError(raddr, tid, "invalid announce request")
return
}
req.DecodeFrom(b, false)
}
resp, err := uts.handler.OnAnnounce(raddr, req)
if err != nil {
uts.sendError(raddr, tid, err.Error())
} else {
uts.sendAnnounceResp(raddr, tid, resp)
}
case ActionScrape:
_len := len(b)
infohashes := make([]metainfo.Hash, 0, _len/20)
for i, _len := 20, len(b); i <= _len; i += 20 {
infohashes = append(infohashes, metainfo.NewHash(b[i-20:i]))
}
if len(infohashes) == 0 {
uts.sendError(raddr, tid, "no infohash")
return
}
resps, err := uts.handler.OnScrap(raddr, infohashes)
if err != nil {
uts.sendError(raddr, tid, err.Error())
} else {
uts.sendScrapResp(raddr, tid, resps)
}
default:
uts.sendError(raddr, tid, "unkwnown action")
}
}

View File

@ -0,0 +1,135 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package udptracker
import (
"errors"
"fmt"
"log"
"net"
"time"
"github.com/xgfone/bt/metainfo"
)
type testHandler struct{}
func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return }
func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (
r AnnounceResponse, err error) {
if req.Port != 80 {
err = errors.New("port is not 80")
return
}
if len(req.Exts) > 0 {
for i, ext := range req.Exts {
switch ext.Type {
case URLData:
fmt.Printf("Extensions[%d]: URLData(%s)\n", i, string(ext.Data))
default:
fmt.Printf("Extensions[%d]: %s\n", i, ext.Type.String())
}
}
}
r = AnnounceResponse{
Interval: 1,
Leechers: 2,
Seeders: 3,
Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}},
}
return
}
func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (
rs []ScrapeResponse, err error) {
rs = make([]ScrapeResponse, len(infohashes))
for i := range infohashes {
rs[i] = ScrapeResponse{
Seeders: uint32(i)*10 + 1,
Leechers: uint32(i)*10 + 2,
Completed: uint32(i)*10 + 3,
}
}
return
}
func ExampleTrackerClient() {
// Start the UDP tracker server
sconn, err := net.ListenPacket("udp4", "127.0.0.1:8001")
if err != nil {
log.Fatal(err)
}
server := NewTrackerServer(sconn, testHandler{})
defer server.Close()
go server.Run()
// Wait for the server to be started
time.Sleep(time.Second)
// Create a client and dial to the UDP tracker server.
client, err := NewTrackerClientByDial("udp4", "127.0.0.1:8001")
if err != nil {
log.Fatal(err)
}
// Send the ANNOUNCE request to the UDP tracker server,
// and get the ANNOUNCE response.
exts := []Extension{NewURLData([]byte("data")), NewNop()}
req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80, Exts: exts}
resp, err := client.Announce(req)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Interval: %d\n", resp.Interval)
fmt.Printf("Leechers: %d\n", resp.Leechers)
fmt.Printf("Seeders: %d\n", resp.Seeders)
for i, addr := range resp.Addresses {
fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String())
fmt.Printf("Address[%d].Port: %d\n", i, addr.Port)
}
// Send the SCRAPE request to the UDP tracker server,
// and get the SCRAPE respsone.
hs := []metainfo.Hash{metainfo.NewRandomHash(), metainfo.NewRandomHash()}
rs, err := client.Scrape(hs)
if err != nil {
log.Fatal(err)
} else if len(rs) != 2 {
log.Fatalf("%+v", rs)
}
for i, r := range rs {
fmt.Printf("%d.Seeders: %d\n", i, r.Seeders)
fmt.Printf("%d.Leechers: %d\n", i, r.Leechers)
fmt.Printf("%d.Completed: %d\n", i, r.Completed)
}
// Output:
// Extensions[0]: URLData(data)
// Extensions[1]: Nop
// Interval: 1
// Leechers: 2
// Seeders: 3
// Address[0].IP: 127.0.0.1
// Address[0].Port: 8001
// 0.Seeders: 1
// 0.Leechers: 2
// 0.Completed: 3
// 1.Seeders: 11
// 1.Leechers: 12
// 1.Completed: 13
}

50
utils/bool.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"sync/atomic"
)
// Bool is used to implement a atomic bool.
type Bool struct {
v uint32
}
// NewBool returns a new Bool with the initialized value.
func NewBool(t bool) (b Bool) {
if t {
b.v = 1
}
return
}
// Get returns the bool value.
func (b *Bool) Get() bool { return atomic.LoadUint32(&b.v) == 1 }
// SetTrue sets the bool value to true.
func (b *Bool) SetTrue() { atomic.StoreUint32(&b.v, 1) }
// SetFalse sets the bool value to false.
func (b *Bool) SetFalse() { atomic.StoreUint32(&b.v, 0) }
// Set sets the bool value to t.
func (b *Bool) Set(t bool) {
if t {
b.SetTrue()
} else {
b.SetFalse()
}
}

16
utils/doc.go Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package utils supplies some convenient functions.
package utils

30
utils/io.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import "io"
// CopyNBuffer is the same as io.CopyN, but uses the given buf as the buffer.
func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) {
written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf)
if written == n {
return n, nil
} else if written < n && err == nil {
// src stopped early; must have been EOF.
err = io.EOF
}
return
}

26
utils/slice.go Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
// InStringSlice reports whether s is in ss.
func InStringSlice(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}

29
utils/slice_test.go Normal file
View File

@ -0,0 +1,29 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"testing"
)
func TestInStringSlice(t *testing.T) {
if !InStringSlice([]string{"a", "b"}, "a") {
t.Fail()
}
if InStringSlice([]string{"a", "b"}, "z") {
t.Fail()
}
}

24
utils/string.go Normal file
View File

@ -0,0 +1,24 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import "crypto/rand"
// RandomString generates a size-length string randomly.
func RandomString(size int) string {
bs := make([]byte, size)
rand.Read(bs)
return string(bs)
}