From 6bfa2f6700e2bb73788a6adc1827b202587167c6 Mon Sep 17 00:00:00 2001 From: xgfone Date: Sun, 7 Jun 2020 13:43:15 +0800 Subject: [PATCH] first commit --- .gitignore | 40 ++ .travis.yml | 10 + LICENSE | 202 ++++++ README.md | 41 ++ bencode/AUTHORS | 8 + bencode/LICENSE | 19 + bencode/decode.go | 614 ++++++++++++++++++ bencode/decode_test.go | 400 ++++++++++++ bencode/doc.go | 6 + bencode/encode.go | 322 ++++++++++ bencode/encode_decode_test.go | 109 ++++ bencode/encode_test.go | 354 +++++++++++ bencode/example_test.go | 67 ++ bencode/raw.go | 5 + bencode/tag.go | 76 +++ dht/blacklist.go | 181 ++++++ dht/blacklist_test.go | 80 +++ dht/dht_server.go | 864 ++++++++++++++++++++++++++ dht/dht_server_test.go | 182 ++++++ dht/peer_manager.go | 139 +++++ dht/routing_table.go | 439 +++++++++++++ dht/routing_table_storage.go | 42 ++ dht/token_manager.go | 107 ++++ dht/transaction_manager.go | 152 +++++ downloader/torrent_info.go | 315 ++++++++++ go.mod | 3 + krpc/doc.go | 16 + krpc/message.go | 440 +++++++++++++ krpc/message_test.go | 41 ++ krpc/node.go | 84 +++ metainfo/address.go | 342 ++++++++++ metainfo/address_test.go | 48 ++ metainfo/doc.go | 17 + metainfo/file.go | 65 ++ metainfo/info.go | 138 ++++ metainfo/infohash.go | 207 ++++++ metainfo/magnet.go | 138 ++++ metainfo/metainfo.go | 170 +++++ metainfo/piece.go | 99 +++ peerprotocol/doc.go | 19 + peerprotocol/extension.go | 161 +++++ peerprotocol/extension_test.go | 54 ++ peerprotocol/fastset.go | 69 ++ peerprotocol/fastset_test.go | 109 ++++ peerprotocol/handshake.go | 140 +++++ peerprotocol/message.go | 277 +++++++++ peerprotocol/peerconn.go | 522 ++++++++++++++++ peerprotocol/protocol.go | 97 +++ peerprotocol/server.go | 169 +++++ tracker/httptracker/http.go | 312 ++++++++++ tracker/httptracker/http_peer.go | 192 ++++++ tracker/httptracker/http_peer_test.go | 75 +++ tracker/httptracker/http_test.go | 72 +++ tracker/tracker.go | 266 ++++++++ tracker/tracker_test.go | 121 ++++ tracker/udptracker/udp.go | 271 ++++++++ tracker/udptracker/udp_client.go | 289 +++++++++ tracker/udptracker/udp_server.go | 281 +++++++++ tracker/udptracker/udp_test.go | 135 ++++ utils/bool.go | 50 ++ utils/doc.go | 16 + utils/io.go | 30 + utils/slice.go | 26 + utils/slice_test.go | 29 + utils/string.go | 24 + 65 files changed, 10388 insertions(+) create mode 100644 .gitignore create mode 100644 .travis.yml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 bencode/AUTHORS create mode 100644 bencode/LICENSE create mode 100644 bencode/decode.go create mode 100644 bencode/decode_test.go create mode 100644 bencode/doc.go create mode 100644 bencode/encode.go create mode 100644 bencode/encode_decode_test.go create mode 100644 bencode/encode_test.go create mode 100644 bencode/example_test.go create mode 100644 bencode/raw.go create mode 100644 bencode/tag.go create mode 100644 dht/blacklist.go create mode 100644 dht/blacklist_test.go create mode 100644 dht/dht_server.go create mode 100644 dht/dht_server_test.go create mode 100644 dht/peer_manager.go create mode 100644 dht/routing_table.go create mode 100644 dht/routing_table_storage.go create mode 100644 dht/token_manager.go create mode 100644 dht/transaction_manager.go create mode 100644 downloader/torrent_info.go create mode 100644 go.mod create mode 100644 krpc/doc.go create mode 100644 krpc/message.go create mode 100644 krpc/message_test.go create mode 100644 krpc/node.go create mode 100644 metainfo/address.go create mode 100644 metainfo/address_test.go create mode 100644 metainfo/doc.go create mode 100644 metainfo/file.go create mode 100644 metainfo/info.go create mode 100644 metainfo/infohash.go create mode 100644 metainfo/magnet.go create mode 100644 metainfo/metainfo.go create mode 100644 metainfo/piece.go create mode 100644 peerprotocol/doc.go create mode 100644 peerprotocol/extension.go create mode 100644 peerprotocol/extension_test.go create mode 100644 peerprotocol/fastset.go create mode 100644 peerprotocol/fastset_test.go create mode 100644 peerprotocol/handshake.go create mode 100644 peerprotocol/message.go create mode 100644 peerprotocol/peerconn.go create mode 100644 peerprotocol/protocol.go create mode 100644 peerprotocol/server.go create mode 100644 tracker/httptracker/http.go create mode 100644 tracker/httptracker/http_peer.go create mode 100644 tracker/httptracker/http_peer_test.go create mode 100644 tracker/httptracker/http_test.go create mode 100644 tracker/tracker.go create mode 100644 tracker/tracker_test.go create mode 100644 tracker/udptracker/udp.go create mode 100644 tracker/udptracker/udp_client.go create mode 100644 tracker/udptracker/udp_server.go create mode 100644 tracker/udptracker/udp_test.go create mode 100644 utils/bool.go create mode 100644 utils/doc.go create mode 100644 utils/io.go create mode 100644 utils/slice.go create mode 100644 utils/slice_test.go create mode 100644 utils/string.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..afef657 --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +# Intellij IDE +.idea/ +*.iml + +# test +test_* + +# log +*.log + +# Mac +.DS_Store + +# VS Code +.vscode/ diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..fdbd0e5 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,10 @@ +language: go +go: + - 1.11.x + - 1.12.x + - 1.13.x + - 1.14.x +env: + - GO111MODULE=on +script: + - go test -race ./... diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8f71f43 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..a76bf42 --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +# BT - Another Implementation Based On Golang [![Build Status](https://travis-ci.org/xgfone/bt.svg?branch=master)](https://travis-ci.org/xgfone/bt) [![GoDoc](https://godoc.org/github.com/xgfone/bt?status.svg)](https://pkg.go.dev/github.com/xgfone/bt) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://raw.githubusercontent.com/xgfone/bt/master/LICENSE) + +A pure golang implementation of [BitTorrent](http://bittorrent.org/beps/bep_0000.html) library, which is inspired by [dht](https://github.com/shiyanhui/dht) and [torrent](https://github.com/anacrolix/torrent). + + +## Features + +- Support IPv4/IPv6. +- Pure Go implementation. +- Multi-BEPs implementation. [See below](#the-implemented-specifications) +- Only library without any denpendencies. For the command tools, see [bttools](https://github.com/xgfone/bttools). + + +## Install +```shell +$ go get github.com/xgfone/bt +``` + + +## Example +See [godoc](https://pkg.go.dev/github.com/xgfone/bt) or [bttools](https://github.com/xgfone/bttools). + + +## The Implemented Specifications +- [x] [**BEP 03:** The BitTorrent Protocol Specification](http://bittorrent.org/beps/bep_0003.html) +- [x] [**BEP 05:** DHT Protocol](http://bittorrent.org/beps/bep_0005.html) +- [x] [**BEP 06:** Fast Extension](http://bittorrent.org/beps/bep_0006.html) +- [x] [**BEP 07:** IPv6 Tracker Extension](http://bittorrent.org/beps/bep_0007.html) +- [x] [**BEP 09:** Extension for Peers to Send Metadata Files](http://bittorrent.org/beps/bep_0009.html) +- [x] [**BEP 10:** Extension Protocol](http://bittorrent.org/beps/bep_0010.html) +- [ ] [**BEP 11:** Peer Exchange (PEX)](http://bittorrent.org/beps/bep_0011.html) +- [x] [**BEP 12:** Multitracker Metadata Extension](http://bittorrent.org/beps/bep_0012.html) +- [x] [**BEP 15:** UDP Tracker Protocol for BitTorrent](http://bittorrent.org/beps/bep_0015.html) +- [x] [**BEP 19:** WebSeed - HTTP/FTP Seeding (GetRight style)](http://bittorrent.org/beps/bep_0019.html) (Only `url-list` in metainfo) +- [x] [**BEP 23:** Tracker Returns Compact Peer Lists](http://bittorrent.org/beps/bep_0023.html) +- [x] [**BEP 32:** IPv6 extension for DHT](http://bittorrent.org/beps/bep_0032.html) +- [ ] [**BEP 33:** DHT scrape](http://bittorrent.org/beps/bep_0033.html) +- [x] [**BEP 41:** UDP Tracker Protocol Extensions](http://bittorrent.org/beps/bep_0041.html) +- [x] [**BEP 43:** Read-only DHT Nodes](http://bittorrent.org/beps/bep_0043.html) +- [ ] [**BEP 44:** Storing arbitrary data in the DHT](http://bittorrent.org/beps/bep_0044.html) +- [x] [**BEP 48:** Tracker Protocol Extension: Scrape](http://bittorrent.org/beps/bep_0048.html) diff --git a/bencode/AUTHORS b/bencode/AUTHORS new file mode 100644 index 0000000..ff4a705 --- /dev/null +++ b/bencode/AUTHORS @@ -0,0 +1,8 @@ +Jeff Wendling +Liam Edwards-Playne +Casey Bodley +Conrad Pankoff +Cenk Alti +Jan Winkelmann +Patrick Mézard +Glen De Cauwsemaecker diff --git a/bencode/LICENSE b/bencode/LICENSE new file mode 100644 index 0000000..67a8689 --- /dev/null +++ b/bencode/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2013 The Authors (see AUTHORS) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/bencode/decode.go b/bencode/decode.go new file mode 100644 index 0000000..168e867 --- /dev/null +++ b/bencode/decode.go @@ -0,0 +1,614 @@ +package bencode + +import ( + "bufio" + "bytes" + "encoding" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" +) + +var ( + reflectByteSliceType = reflect.TypeOf([]byte(nil)) + reflectStringType = reflect.TypeOf("") +) + +// Unmarshaler is the interface implemented by types that can unmarshal +// a bencode description of themselves. +// The input can be assumed to be a valid encoding of a bencode value. +// UnmarshalBencode must copy the bencode data if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalBencode([]byte) error +} + +// A Decoder reads and decodes bencoded data from an input stream. +type Decoder struct { + r *bufio.Reader + raw bool + buf []byte + n int + failUnordered bool +} + +// SetFailOnUnorderedKeys will cause the decoder to fail when encountering +// unordered keys. The default is to not fail. +func (d *Decoder) SetFailOnUnorderedKeys(fail bool) { + d.failUnordered = fail +} + +// BytesParsed returns the number of bytes that have actually been parsed +func (d *Decoder) BytesParsed() int { + return d.n +} + +// read also writes into the buffer when d.raw is set. +func (d *Decoder) read(p []byte) (n int, err error) { + n, err = d.r.Read(p) + if d.raw { + d.buf = append(d.buf, p[:n]...) + } + d.n += n + return +} + +// readBytes also writes into the buffer when d.raw is set. +func (d *Decoder) readBytes(delim byte) (line []byte, err error) { + line, err = d.r.ReadBytes(delim) + if d.raw { + d.buf = append(d.buf, line...) + } + d.n += len(line) + return +} + +// readByte also writes into the buffer when d.raw is set. +func (d *Decoder) readByte() (b byte, err error) { + b, err = d.r.ReadByte() + if d.raw { + d.buf = append(d.buf, b) + } + d.n++ + return +} + +// readFull also writes into the buffer when d.raw is set. +func (d *Decoder) readFull(p []byte) (n int, err error) { + n, err = io.ReadFull(d.r, p) + if d.raw { + d.buf = append(d.buf, p[:n]...) + } + d.n += n + return +} + +func (d *Decoder) peekByte() (b byte, err error) { + ch, err := d.r.Peek(1) + if err != nil { + return + } + b = ch[0] + return +} + +// NewDecoder returns a new decoder that reads from r +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: bufio.NewReader(r)} +} + +// Decode reads the bencoded value from its input and stores it in the value pointed to by val. +// Decode allocates maps/slices as necessary with the following additional rules: +// To decode a bencoded value into a nil interface value, the type stored in the interface value is one of: +// int64 for bencoded integers +// string for bencoded strings +// []interface{} for bencoded lists +// map[string]interface{} for bencoded dicts +// To unmarshal bencode into a value implementing the Unmarshaler interface, +// Unmarshal calls that value's UnmarshalBencode method. +// Otherwise, if the value implements encoding.TextUnmarshaler +// and the input is a bencode string, Unmarshal calls that value's +// UnmarshalText method with the decoded form of the string. +func (d *Decoder) Decode(val interface{}) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("Unwritable type passed into decode") + } + + return d.decodeInto(rv) +} + +// DecodeString reads the data in the string and stores it into the value pointed to by val. +// Read the docs for Decode for more information. +func DecodeString(in string, val interface{}) error { + buf := strings.NewReader(in) + d := NewDecoder(buf) + return d.Decode(val) +} + +// DecodeBytes reads the data in b and stores it into the value pointed to by val. +// Read the docs for Decode for more information. +func DecodeBytes(b []byte, val interface{}) error { + r := bytes.NewReader(b) + d := NewDecoder(r) + return d.Decode(val) +} + +func indirect(v reflect.Value, alloc bool) reflect.Value { + for { + switch v.Kind() { + case reflect.Interface: + if v.IsNil() { + if !alloc { + return reflect.Value{} + } + return v + } + + case reflect.Ptr: + if v.IsNil() { + if !alloc { + return reflect.Value{} + } + v.Set(reflect.New(v.Type().Elem())) + } + + default: + return v + } + + v = v.Elem() + } +} + +func (d *Decoder) decodeInto(val reflect.Value) (err error) { + var v reflect.Value + if d.raw { + v = val + } else { + var ( + unmarshaler Unmarshaler + textUnmarshaler encoding.TextUnmarshaler + ) + unmarshaler, textUnmarshaler, v = d.indirect(val) + + // if we're decoding into an Unmarshaler, + // we pass on the next bencode value to this value instead, + // so it can decide what to do with it. + if unmarshaler != nil { + var x RawMessage + if err := d.decodeInto(reflect.ValueOf(&x)); err != nil { + return err + } + return unmarshaler.UnmarshalBencode([]byte(x)) + } + + // if we're decoding into an TextUnmarshaler, + // we'll assume that the bencode value is a string, + // we decode it as such and pass the result onto the unmarshaler. + if textUnmarshaler != nil { + var b []byte + ref := reflect.ValueOf(&b) + if err := d.decodeString(reflect.Indirect(ref)); err != nil { + return err + } + return textUnmarshaler.UnmarshalText(b) + } + + // if we're decoding into a RawMessage set raw to true for the rest of + // the call stack, and switch out the value with an interface{}. + if _, ok := v.Interface().(RawMessage); ok { + v = reflect.Value{} // explicitly make v invalid + + // set d.raw for the lifetime of this function call, and set the raw + // message when the function is exiting. + d.buf = d.buf[:0] + d.raw = true + defer func() { + d.raw = false + v := indirect(val, true) + v.SetBytes(append([]byte(nil), d.buf...)) + }() + } + } + + next, err := d.peekByte() + if err != nil { + return + } + + switch next { + case 'i': + err = d.decodeInt(v) + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + err = d.decodeString(v) + case 'l': + err = d.decodeList(v) + case 'd': + err = d.decodeDict(v) + default: + err = errors.New("Invalid input") + } + + return +} + +func (d *Decoder) decodeInt(v reflect.Value) error { + // we need to read an i, some digits, and an e. + ch, err := d.readByte() + if err != nil { + return err + } + if ch != 'i' { + panic("got not an i when peek returned an i") + } + + line, err := d.readBytes('e') + if err != nil || d.raw { + return err + } + + digits := string(line[:len(line)-1]) + + switch v.Kind() { + default: + return fmt.Errorf("Cannot store int64 into %s", v.Type()) + case reflect.Interface: + n, err := strconv.ParseInt(digits, 10, 64) + if err != nil { + return err + } + v.Set(reflect.ValueOf(n)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(digits, 10, 64) + if err != nil { + return err + } + v.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n, err := strconv.ParseUint(digits, 10, 64) + if err != nil { + return err + } + v.SetUint(n) + case reflect.Bool: + n, err := strconv.ParseUint(digits, 10, 64) + if err != nil { + return err + } + v.SetBool(n != 0) + } + + return nil +} + +func (d *Decoder) decodeString(v reflect.Value) error { + // read until a colon to get the number of digits to read after + line, err := d.readBytes(':') + if err != nil { + return err + } + + // parse it into an int for making a slice + l32, err := strconv.ParseInt(string(line[:len(line)-1]), 10, 32) + l := int(l32) + if err != nil { + return err + } + if l < 0 { + return fmt.Errorf("invalid negative string length: %d", l) + } + + // read exactly l bytes out and make our string + buf := make([]byte, l) + _, err = d.readFull(buf) + if err != nil || d.raw { + return err + } + + switch v.Kind() { + default: + return fmt.Errorf("Cannot store string into %s", v.Type()) + case reflect.Slice: + if v.Type() != reflectByteSliceType { + return fmt.Errorf("Cannot store string into %s", v.Type()) + } + v.SetBytes(buf) + case reflect.String: + v.SetString(string(buf)) + case reflect.Interface: + v.Set(reflect.ValueOf(string(buf))) + } + return nil +} + +func (d *Decoder) decodeList(v reflect.Value) error { + if !d.raw { + // if we have an interface, just put a []interface{} in it! + if v.Kind() == reflect.Interface { + var x []interface{} + defer func(p reflect.Value) { p.Set(v) }(v) + v = reflect.ValueOf(&x).Elem() + } + + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + return fmt.Errorf("Cant store a []interface{} into %s", v.Type()) + } + } + + // read out the l that prefixes the list + ch, err := d.readByte() + if err != nil { + return err + } + if ch != 'l' { + panic("got something other than a list head after a peek") + } + + // if we're decoding in raw mode, + // we only want to read into the buffer, + // without actually parsing any values + if d.raw { + var ch byte + for { + // peek for the end token and read it out + ch, err = d.peekByte() + if err != nil { + return err + } + if ch == 'e' { + _, err = d.readByte() // consume the end + return err + } + + // decode the next value + err = d.decodeInto(v) + if err != nil { + return err + } + } + } + + for i := 0; ; i++ { + // peek for the end token and read it out + ch, err := d.peekByte() + if err != nil { + return err + } + switch ch { + case 'e': + _, err := d.readByte() // consume the end + return err + } + + // grow it if required + if i >= v.Cap() && v.IsValid() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + + // reslice into cap (its a slice now since it had to have grown) + if i >= v.Len() && v.IsValid() { + v.SetLen(i + 1) + } + + // decode a value into the index + if err := d.decodeInto(v.Index(i)); err != nil { + return err + } + } +} + +func (d *Decoder) decodeDict(v reflect.Value) error { + // if we have an interface{}, just put a map[string]interface{} in it! + if !d.raw && v.Kind() == reflect.Interface { + var x map[string]interface{} + defer func(p reflect.Value) { p.Set(v) }(v) + v = reflect.ValueOf(&x).Elem() + } + + // consume the head token + ch, err := d.readByte() + if err != nil { + return err + } + if ch != 'd' { + panic("got an incorrect token when it was checked already") + } + + if d.raw { + // if we're decoding in raw mode, + // we only want to read into the buffer, + // without actually parsing any values + for { + // peek the next value type + ch, err := d.peekByte() + if err != nil { + return err + } + if ch == 'e' { + _, err = d.readByte() // consume the end token + return err + } + + err = d.decodeString(v) + if err != nil { + return err + } + + err = d.decodeInto(v) + if err != nil { + return err + } + } + } + + // check for correct type + var ( + mapElem reflect.Value + isMap bool + vals map[string]reflect.Value + ) + + switch v.Kind() { + case reflect.Map: + t := v.Type() + if t.Key() != reflectStringType { + return fmt.Errorf("Can't store a map[string]interface{} into %s", v.Type()) + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + + isMap = true + mapElem = reflect.New(t.Elem()).Elem() + case reflect.Struct: + vals = make(map[string]reflect.Value) + setStructValues(vals, v) + default: + return fmt.Errorf("Can't store a map[string]interface{} into %s", v.Type()) + } + + var ( + lastKey string + first bool = true + ) + + for { + var subv reflect.Value + + // peek the next value type + ch, err := d.peekByte() + if err != nil { + return err + } + if ch == 'e' { + _, err = d.readByte() // consume the end token + return err + } + + // peek the next value we're suppsed to read + var key string + if err := d.decodeString(reflect.ValueOf(&key).Elem()); err != nil { + return err + } + + // check for unordered keys + if !first && d.failUnordered && lastKey > key { + return fmt.Errorf("unordered dictionary: %q appears before %q", + lastKey, key) + } + lastKey, first = key, false + + if isMap { + mapElem.Set(reflect.Zero(v.Type().Elem())) + subv = mapElem + } else { + subv = vals[key] + } + + if !subv.IsValid() { + // if it's invalid, grab but ignore the next value + var x interface{} + err := d.decodeInto(reflect.ValueOf(&x).Elem()) + if err != nil { + return err + } + + continue + } + + // subv now contains what we load into + if err := d.decodeInto(subv); err != nil { + return err + } + + if isMap { + v.SetMapIndex(reflect.ValueOf(key), subv) + } + } +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an (Text)Unmarshaler, indirect stops and returns that. +func (d *Decoder) indirect(v reflect.Value) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr || v.IsNil() { + break + } + + vi := v.Interface() + if u, ok := vi.(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if u, ok := vi.(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + + v = v.Elem() + } + return nil, nil, indirect(v, true) +} + +func setStructValues(m map[string]reflect.Value, v reflect.Value) { + t := v.Type() + if t.Kind() != reflect.Struct { + return + } + + // do embedded fields first + for i := 0; i < v.NumField(); i++ { + f := t.Field(i) + if f.PkgPath != "" { + continue + } + v := v.FieldByIndex(f.Index) + if f.Anonymous && f.Tag == "" { + setStructValues(m, v) + } + } + + // overwrite embedded struct tags and names + for i := 0; i < v.NumField(); i++ { + f := t.Field(i) + if f.PkgPath != "" { + continue + } + v := v.FieldByIndex(f.Index) + name, _ := parseTag(f.Tag.Get("bencode")) + if name == "" { + if f.Anonymous { + // it's a struct and its fields have already been added to the map + continue + } + name = f.Name + } + if isValidTag(name) { + m[name] = v + } + } +} diff --git a/bencode/decode_test.go b/bencode/decode_test.go new file mode 100644 index 0000000..18a0a40 --- /dev/null +++ b/bencode/decode_test.go @@ -0,0 +1,400 @@ +package bencode + +import ( + "bytes" + "fmt" + "reflect" + "sort" + "strings" + "testing" + "time" +) + +func TestDecode(t *testing.T) { + type testCase struct { + in string + val interface{} + expect interface{} + err bool + unorderedFail bool + } + + type dT struct { + X string + Y int + Z string `bencode:"zff"` + } + + type Embedded struct { + B string + } + + type issue22 struct { + X string `bencode:"x"` + Time myTimeType `bencode:"t"` + Foo myBoolType `bencode:"f"` + Bar myStringType `bencode:"b"` + Slice mySliceType `bencode:"s"` + Y string `bencode:"y"` + } + + type issue26 struct { + X string `bencode:"x"` + Foo myBoolTextType `bencode:"f"` + Bar myTextStringType `bencode:"b"` + Slice myTextSliceType `bencode:"s"` + Y string `bencode:"y"` + } + + type issue22WithErrorChild struct { + Name string `bencode:"n"` + Error errorMarshalType `bencode:"e"` + } + + type issue26WithErrorChild struct { + Name string `bencode:"n"` + Error errorTextMarshalType `bencode:"e"` + } + + type discardNonFieldDef struct { + B string + D string + } + + type twoDefsForSameKey struct { + A string + A2 string `bencode:"A"` + A3 string `bencode:"A"` + } + + now := time.Now() + + var decodeCases = []testCase{ + // integers + {`i5e`, new(int), int(5), false, false}, + {`i-10e`, new(int), int(-10), false, false}, + {`i8e`, new(uint), uint(8), false, false}, + {`i8e`, new(uint8), uint8(8), false, false}, + {`i8e`, new(uint16), uint16(8), false, false}, + {`i8e`, new(uint32), uint32(8), false, false}, + {`i8e`, new(uint64), uint64(8), false, false}, + {`i8e`, new(int), int(8), false, false}, + {`i8e`, new(int8), int8(8), false, false}, + {`i8e`, new(int16), int16(8), false, false}, + {`i8e`, new(int32), int32(8), false, false}, + {`i8e`, new(int64), int64(8), false, false}, + {`i0e`, new(*int), new(int), false, false}, + {`i-2e`, new(uint), nil, true, false}, + + // bools + {`i1e`, new(bool), true, false, false}, + {`i0e`, new(bool), false, false, false}, + {`i0e`, new(*bool), new(bool), false, false}, + {`i8e`, new(bool), true, false, false}, + + // strings + {`3:foo`, new(string), "foo", false, false}, + {`4:foob`, new(string), "foob", false, false}, + {`0:`, new(*string), new(string), false, false}, + {`6:short`, new(string), nil, true, false}, + + // lists + {`l3:foo3:bare`, new([]string), []string{"foo", "bar"}, false, false}, + {`li15ei20ee`, new([]int), []int{15, 20}, false, false}, + {`ld3:fooi0eed3:bari1eee`, new([]map[string]int), []map[string]int{ + {"foo": 0}, + {"bar": 1}, + }, false, false}, + + // dicts + + {`d3:foo3:bar4:foob3:fooe`, new(map[string]string), map[string]string{ + "foo": "bar", + "foob": "foo", + }, false, false}, + {`d1:X3:foo1:Yi10e3:zff3:bare`, new(dT), dT{"foo", 10, "bar"}, false, false}, + + // encoding/json takes, if set, the tag as name and doesn't falls back to the + // struct field's name. + {`d1:X3:foo1:Yi10e1:Z3:bare`, new(dT), dT{"foo", 10, ""}, false, false}, + + {`d1:X3:foo1:Yi10e1:h3:bare`, new(dT), dT{"foo", 10, ""}, false, false}, + {`d3:fooli0ei1ee3:barli2ei3eee`, new(map[string][]int), map[string][]int{ + "foo": []int{0, 1}, + "bar": []int{2, 3}, + }, false, false}, + {`de`, new(map[string]string), map[string]string{}, false, false}, + + // into interfaces + {`i5e`, new(interface{}), int64(5), false, false}, + {`li5ee`, new(interface{}), []interface{}{int64(5)}, false, false}, + {`5:hello`, new(interface{}), "hello", false, false}, + {`d5:helloi5ee`, new(interface{}), map[string]interface{}{"hello": int64(5)}, false, false}, + + // into values whose type support the Unmarshaler interface + {`1:y`, new(myTimeType), nil, true, false}, + {fmt.Sprintf("i%de", now.Unix()), new(myTimeType), myTimeType{time.Unix(now.Unix(), 0)}, false, false}, + {`1:y`, new(myBoolType), myBoolType(true), false, false}, + {`i42e`, new(myBoolType), nil, true, false}, + {`1:n`, new(myBoolType), myBoolType(false), false, false}, + {`1:n`, new(errorMarshalType), nil, true, false}, + {`li102ei111ei111ee`, new(myStringType), myStringType("foo"), false, false}, + {`i42e`, new(myStringType), nil, true, false}, + {`d1:ai1e3:foo3:bare`, new(mySliceType), mySliceType{"a", int64(1), "foo", "bar"}, false, false}, + {`i42e`, new(mySliceType), nil, true, false}, + + // into values who have a child which type supports the Unmarshaler interface + { + fmt.Sprintf(`d1:b3:foo1:f1:y1:sd1:f3:foo1:ai42ee1:ti%de1:x1:x1:y1:ye`, now.Unix()), + new(issue22), + issue22{ + X: "x", + Time: myTimeType{time.Unix(now.Unix(), 0)}, + Foo: myBoolType(true), + Bar: myStringType("foo"), + Slice: mySliceType{"a", int64(42), "f", "foo"}, + Y: "y", + }, + false, + false, + }, + { + `d1:ei42e1:n3:fooe`, + new(issue22WithErrorChild), + nil, + true, + false, + }, + + // into values whose type support the TextUnmarshaler interface + {`1:y`, new(myBoolTextType), myBoolTextType(true), false, false}, + {`1:n`, new(myBoolTextType), myBoolTextType(false), false, false}, + {`i42e`, new(myBoolTextType), nil, true, false}, + {`1:n`, new(errorTextMarshalType), nil, true, false}, + {`7:foo_bar`, new(myTextStringType), myTextStringType("bar"), false, false}, + {`i42e`, new(myTextStringType), nil, true, false}, + {`7:a,b,c,d`, new(myTextSliceType), myTextSliceType{"a", "b", "c", "d"}, false, false}, + {`i42e`, new(myTextSliceType), nil, true, false}, + + // into values who have a child which type supports the TextUnmarshaler interface + { + `d1:b7:foo_bar1:f1:y1:s5:1,2,31:x1:x1:y1:ye`, + new(issue26), + issue26{ + X: "x", + Foo: myBoolTextType(true), + Bar: myTextStringType("bar"), + Slice: myTextSliceType{"1", "2", "3"}, + Y: "y", + }, + false, + false, + }, + { + `d1:ei42e1:n3:fooe`, + new(issue26WithErrorChild), + nil, + true, + false, + }, + + // malformed + {`i53:foo`, new(interface{}), nil, true, false}, + {`6:foo`, new(interface{}), nil, true, false}, + {`di5ei2ee`, new(interface{}), nil, true, false}, + {`d3:fooe`, new(interface{}), nil, true, false}, + {`l3:foo3:bar`, new(interface{}), nil, true, false}, + {`d-1:`, new(interface{}), nil, true, false}, + + // embedded structs + {`d1:A3:foo1:B3:bare`, new(struct { + A string + Embedded + }), struct { + A string + Embedded + }{"foo", Embedded{"bar"}}, false, false}, + + // Embedded structs with a valid tag are encoded as a definition + {`d1:B3:bar6:nestedd1:B3:fooee`, new(struct { + Embedded `bencode:"nested"` + }), struct { + Embedded `bencode:"nested"` + }{Embedded{"foo"}}, false, false}, + + // Don't fail when reading keys missing from the struct + {"d1:A7:discard1:B4:take1:C7:discard1:D4:takee", + new(discardNonFieldDef), + discardNonFieldDef{"take", "take"}, + false, + false, + }, + + // Don't fail when reading the same key twice + {"d1:A1:a1:A1:b1:A1:c1:A1:de", new(twoDefsForSameKey), + twoDefsForSameKey{"", "", "d"}, false, false}, + + // Empty struct + {"de", new(struct{}), struct{}{}, false, false}, + + // Fail on unordered dictionaries + {"d1:Yi10e1:X1:a3:zff1:ce", new(dT), dT{}, true, true}, + {"d3:zff1:c1:Yi10e1:X1:ae", new(dT), dT{}, true, true}, + } + + for i, tt := range decodeCases { + dec := NewDecoder(strings.NewReader(tt.in)) + dec.SetFailOnUnorderedKeys(tt.unorderedFail) + err := dec.Decode(tt.val) + if !tt.err && err != nil { + t.Errorf("#%d (%v): Unexpected err: %v", i, tt.in, err) + continue + } + if tt.err && err == nil { + t.Errorf("#%d (%v): Expected err is nil", i, tt.in) + continue + } + v := reflect.ValueOf(tt.val).Elem().Interface() + if !reflect.DeepEqual(v, tt.expect) && !tt.err { + t.Errorf("#%d (%v): Val: %#v != %#v", i, tt.in, v, tt.expect) + } + } +} + +func TestRawDecode(t *testing.T) { + type testCase struct { + in string + expect []byte + err bool + } + + var rawDecodeCases = []testCase{ + {`i5e`, []byte(`i5e`), false}, + {`5:hello`, []byte(`5:hello`), false}, + {`li5ei10e5:helloe`, []byte(`li5ei10e5:helloe`), false}, + {`llleee`, []byte(`llleee`), false}, + {`li5eli5eli5eeee`, []byte(`li5eli5eli5eeee`), false}, + {`d5:helloi5ee`, []byte(`d5:helloi5ee`), false}, + } + + for i, tt := range rawDecodeCases { + var x RawMessage + err := DecodeString(tt.in, &x) + if !tt.err && err != nil { + t.Errorf("#%d: Unexpected err: %v", i, err) + continue + } + if tt.err && err == nil { + t.Errorf("#%d: Expected err is nil", i) + continue + } + if !reflect.DeepEqual(x, RawMessage(tt.expect)) && !tt.err { + t.Errorf("#%d: Val: %#v != %#v", i, x, tt.expect) + } + } +} + +type myStringType string + +// UnmarshalBencode implements Unmarshaler.UnmarshalBencode +func (mst *myStringType) UnmarshalBencode(b []byte) error { + var raw []byte + err := DecodeBytes(b, &raw) + if err != nil { + return err + } + + *mst = myStringType(raw) + return nil +} + +type mySliceType []interface{} + +// UnmarshalBencode implements Unmarshaler.UnmarshalBencode +func (mst *mySliceType) UnmarshalBencode(b []byte) error { + m := make(map[string]interface{}) + err := DecodeBytes(b, &m) + if err != nil { + return err + } + + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + raw := make([]interface{}, 0, len(m)*2) + for _, key := range keys { + raw = append(raw, key, m[key]) + } + + *mst = mySliceType(raw) + return nil +} + +type myTextStringType string + +// UnmarshalText implements TextUnmarshaler.UnmarshalText +func (mst *myTextStringType) UnmarshalText(b []byte) error { + *mst = myTextStringType(bytes.TrimPrefix(b, []byte("foo_"))) + return nil +} + +type myTextSliceType []string + +// UnmarshalText implements TextUnmarshaler.UnmarshalText +func (mst *myTextSliceType) UnmarshalText(b []byte) error { + raw := string(b) + *mst = strings.Split(raw, ",") + return nil +} + +func TestNestedRawDecode(t *testing.T) { + type testCase struct { + in string + val interface{} + expect interface{} + err bool + } + + type message struct { + Key string + Val int + Raw RawMessage + } + + var cases = []testCase{ + {`li5e5:hellod1:a1:beli5eee`, new([]RawMessage), []RawMessage{ + RawMessage(`i5e`), + RawMessage(`5:hello`), + RawMessage(`d1:a1:be`), + RawMessage(`li5ee`), + }, false}, + {`d1:a1:b1:c1:de`, new(map[string]RawMessage), map[string]RawMessage{ + "a": RawMessage(`1:b`), + "c": RawMessage(`1:d`), + }, false}, + {`d3:Key5:hello3:Rawldedei5e1:ae3:Vali10ee`, new(message), message{ + Key: "hello", + Val: 10, + Raw: RawMessage(`ldedei5e1:ae`), + }, false}, + } + + for i, tt := range cases { + err := DecodeString(tt.in, tt.val) + if !tt.err && err != nil { + t.Errorf("#%d: Unexpected err: %v", i, err) + continue + } + if tt.err && err == nil { + t.Errorf("#%d: Expected err is nil", i) + continue + } + v := reflect.ValueOf(tt.val).Elem().Interface() + if !reflect.DeepEqual(v, tt.expect) && !tt.err { + t.Errorf("#%d: Val:\n%#v !=\n%#v", i, v, tt.expect) + } + } +} diff --git a/bencode/doc.go b/bencode/doc.go new file mode 100644 index 0000000..fdc39ac --- /dev/null +++ b/bencode/doc.go @@ -0,0 +1,6 @@ +// Package bencode implements encoding and decoding of bencoded objects, +// which has a similar API to the encoding/json package and many other +// serialization formats. +// +// Notice: the package is moved and modified from github.com/zeebo/bencode@v1.0.0 +package bencode diff --git a/bencode/encode.go b/bencode/encode.go new file mode 100644 index 0000000..042a99a --- /dev/null +++ b/bencode/encode.go @@ -0,0 +1,322 @@ +package bencode + +import ( + "bytes" + "encoding" + "fmt" + "io" + "reflect" + "sort" +) + +type sortValues []reflect.Value + +func (p sortValues) Len() int { return len(p) } +func (p sortValues) Less(i, j int) bool { return p[i].String() < p[j].String() } +func (p sortValues) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Marshaler is the interface implemented by types +// that can marshal themselves into valid bencode. +type Marshaler interface { + MarshalBencode() ([]byte, error) +} + +// An Encoder writes bencoded objects to an output stream. +type Encoder struct { + w io.Writer +} + +// NewEncoder returns a new encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{w} +} + +// Encode writes the bencoded data of val to its output stream. +// If an encountered value implements the Marshaler interface, +// its MarshalBencode method is called to produce the bencode output for this value. +// If no MarshalBencode method is present but the value implements encoding.TextMarshaler instead, +// its MarshalText method is called, which encodes the result as a bencode string. +// See the documentation for Decode about the conversion of Go values to +// bencoded data. +func (e *Encoder) Encode(val interface{}) error { + return encodeValue(e.w, reflect.ValueOf(val)) +} + +// EncodeString returns the bencoded data of val as a string. +func EncodeString(val interface{}) (string, error) { + buf := new(bytes.Buffer) + e := NewEncoder(buf) + if err := e.Encode(val); err != nil { + return "", err + } + return buf.String(), nil +} + +// EncodeBytes returns the bencoded data of val as a slice of bytes. +func EncodeBytes(val interface{}) ([]byte, error) { + buf := new(bytes.Buffer) + e := NewEncoder(buf) + if err := e.Encode(val); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func isNilValue(v reflect.Value) bool { + return (v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr) && + v.IsNil() +} + +func encodeValue(w io.Writer, val reflect.Value) error { + marshaler, textMarshaler, v := indirectEncodeValue(val) + + // marshal a type using the Marshaler type + // if it implements that interface. + if marshaler != nil { + bytes, err := marshaler.MarshalBencode() + if err != nil { + return err + } + + _, err = w.Write(bytes) + return err + } + + // marshal a type using the TextMarshaler type + // if it implements that interface. + if textMarshaler != nil { + bytes, err := textMarshaler.MarshalText() + if err != nil { + return err + } + + _, err = fmt.Fprintf(w, "%d:%s", len(bytes), bytes) + return err + } + + // if indirection returns us an invalid value that means there was a nil + // pointer in the path somewhere. + if !v.IsValid() { + return nil + } + + // send in a raw message if we have that type + if rm, ok := v.Interface().(RawMessage); ok { + _, err := io.Copy(w, bytes.NewReader(rm)) + return err + } + + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + _, err := fmt.Fprintf(w, "i%de", v.Int()) + return err + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + _, err := fmt.Fprintf(w, "i%de", v.Uint()) + return err + + case reflect.Bool: + i := 0 + if v.Bool() { + i = 1 + } + _, err := fmt.Fprintf(w, "i%de", i) + return err + + case reflect.String: + _, err := fmt.Fprintf(w, "%d:%s", len(v.String()), v.String()) + return err + + case reflect.Slice, reflect.Array: + // handle byte slices like strings + if byteSlice, ok := val.Interface().([]byte); ok { + _, err := fmt.Fprintf(w, "%d:", len(byteSlice)) + + if err == nil { + _, err = w.Write(byteSlice) + } + + return err + } + + if _, err := fmt.Fprint(w, "l"); err != nil { + return err + } + + for i := 0; i < v.Len(); i++ { + if err := encodeValue(w, v.Index(i)); err != nil { + return err + } + } + + _, err := fmt.Fprint(w, "e") + return err + + case reflect.Map: + if _, err := fmt.Fprint(w, "d"); err != nil { + return err + } + var ( + keys sortValues = v.MapKeys() + mval reflect.Value + ) + sort.Sort(keys) + for i := range keys { + mval = v.MapIndex(keys[i]) + if isNilValue(mval) { + continue + } + if err := encodeValue(w, keys[i]); err != nil { + return err + } + if err := encodeValue(w, mval); err != nil { + return err + } + } + _, err := fmt.Fprint(w, "e") + return err + + case reflect.Struct: + if _, err := fmt.Fprint(w, "d"); err != nil { + return err + } + + // add embedded structs to the dictionary + dict := make(dictionary, 0, v.NumField()) + dict, err := readStruct(dict, v) + if err != nil { + return err + } + + // sort the dictionary by keys + sort.Sort(dict) + + // encode the dictionary in order + for _, def := range dict { + // encode the key + err := encodeValue(w, reflect.ValueOf(def.key)) + if err != nil { + return err + } + + // encode the value + err = encodeValue(w, def.value) + if err != nil { + return err + } + } + + _, err = fmt.Fprint(w, "e") + return err + } + + return fmt.Errorf("Can't encode type: %s", v.Type()) +} + +// indirectEncodeValue walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an (Text)Marshaler, indirect stops and returns that. +func indirectEncodeValue(v reflect.Value) (Marshaler, encoding.TextMarshaler, reflect.Value) { + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + for { + if v.Kind() == reflect.Ptr && v.IsNil() { + break + } + + vi := v.Interface() + if m, ok := vi.(Marshaler); ok { + return m, nil, reflect.Value{} + } + if m, ok := vi.(encoding.TextMarshaler); ok { + return nil, m, reflect.Value{} + } + + if v.Kind() != reflect.Ptr { + break + } + + v = v.Elem() + } + return nil, nil, indirect(v, false) +} + +type definition struct { + key string + value reflect.Value +} + +type dictionary []definition + +func (d dictionary) Len() int { return len(d) } +func (d dictionary) Less(i, j int) bool { return d[i].key < d[j].key } +func (d dictionary) Swap(i, j int) { d[i], d[j] = d[j], d[i] } + +func readStruct(dict dictionary, v reflect.Value) (dictionary, error) { + t := v.Type() + var ( + fieldValue reflect.Value + rkey string + ) + for i := 0; i < t.NumField(); i++ { + key := t.Field(i) + rkey = key.Name + fieldValue = v.FieldByIndex(key.Index) + + // filter out unexported values etc. + if !fieldValue.CanInterface() { + continue + } + + // filter out nil pointer values + if isNilValue(fieldValue) { + continue + } + + // * Near identical to usage in JSON except with key 'bencode' + // + // * Struct values encode as BEncode dictionaries. Each exported + // struct field becomes a set in the dictionary unless + // - the field's tag is "-", or + // - the field is empty and its tag specifies the "omitempty" + // option. + // + // * The default key string is the struct field name but can be + // specified in the struct field's tag value. The "bencode" + // key in struct field's tag value is the key name, followed + // by an optional comma and options. + tagValue := key.Tag.Get("bencode") + if tagValue != "" { + // Keys with '-' are omit from output + if tagValue == "-" { + continue + } + + name, options := parseTag(tagValue) + // Keys with 'omitempty' are omitted if the field is empty + if options.Contains("omitempty") && isEmptyValue(fieldValue) { + continue + } + + // All other values are treated as the key string + if isValidTag(name) { + rkey = name + } + } + + if key.Anonymous && key.Type.Kind() == reflect.Struct && tagValue == "" { + var err error + dict, err = readStruct(dict, fieldValue) + if err != nil { + return nil, err + } + } else { + dict = append(dict, definition{rkey, fieldValue}) + } + } + return dict, nil +} diff --git a/bencode/encode_decode_test.go b/bencode/encode_decode_test.go new file mode 100644 index 0000000..48d5694 --- /dev/null +++ b/bencode/encode_decode_test.go @@ -0,0 +1,109 @@ +package bencode + +import ( + "errors" + "time" +) + +type myBoolType bool + +// MarshalBencode implements Marshaler.MarshalBencode +func (mbt myBoolType) MarshalBencode() ([]byte, error) { + var c string + if mbt { + c = "y" + } else { + c = "n" + } + + return EncodeBytes(c) +} + +// UnmarshalBencode implements Unmarshaler.UnmarshalBencode +func (mbt *myBoolType) UnmarshalBencode(b []byte) error { + var str string + err := DecodeBytes(b, &str) + if err != nil { + return err + } + + switch str { + case "y": + *mbt = true + case "n": + *mbt = false + default: + err = errors.New("invalid myBoolType") + } + + return err +} + +type myBoolTextType bool + +// MarshalText implements TextMarshaler.MarshalText +func (mbt myBoolTextType) MarshalText() ([]byte, error) { + if mbt { + return []byte("y"), nil + } + + return []byte("n"), nil +} + +// UnmarshalText implements TextUnmarshaler.UnmarshalText +func (mbt *myBoolTextType) UnmarshalText(b []byte) error { + switch string(b) { + case "y": + *mbt = true + case "n": + *mbt = false + default: + return errors.New("invalid myBoolType") + } + return nil +} + +type myTimeType struct { + time.Time +} + +// MarshalBencode implements Marshaler.MarshalBencode +func (mtt myTimeType) MarshalBencode() ([]byte, error) { + return EncodeBytes(mtt.Time.Unix()) +} + +// UnmarshalBencode implements Unmarshaler.UnmarshalBencode +func (mtt *myTimeType) UnmarshalBencode(b []byte) error { + var epoch int64 + err := DecodeBytes(b, &epoch) + if err != nil { + return err + } + + mtt.Time = time.Unix(epoch, 0) + return nil +} + +type errorMarshalType struct{} + +// MarshalBencode implements Marshaler.MarshalBencode +func (emt errorMarshalType) MarshalBencode() ([]byte, error) { + return nil, errors.New("oops") +} + +// UnmarshalBencode implements Unmarshaler.UnmarshalBencode +func (emt errorMarshalType) UnmarshalBencode([]byte) error { + return errors.New("oops") +} + +type errorTextMarshalType struct{} + +// MarshalText implements TextMarshaler.MarshalText +func (emt errorTextMarshalType) MarshalText() ([]byte, error) { + return nil, errors.New("oops") +} + +// UnmarshalText implements TextUnmarshaler.UnmarshalText +func (emt errorTextMarshalType) UnmarshalText([]byte) error { + return errors.New("oops") +} diff --git a/bencode/encode_test.go b/bencode/encode_test.go new file mode 100644 index 0000000..22327bb --- /dev/null +++ b/bencode/encode_test.go @@ -0,0 +1,354 @@ +package bencode + +import ( + "errors" + "fmt" + "testing" + "time" +) + +func TestEncode(t *testing.T) { + type encodeTestCase struct { + in interface{} + out string + err bool + } + + type eT struct { + A string + X string `bencode:"D"` + Y string `bencode:"B"` + Z string `bencode:"C"` + } + + type sortProblem struct { + A string + B string `bencode:","` + } + + type issue18Sub struct { + Name string + } + + type issue18 struct { + T *issue18Sub + } + + type Embedded struct { + B string + } + + type issue22 struct { + Time myTimeType `bencode:"t"` + Foo myBoolType `bencode:"f"` + } + + type issue22WithErrorChild struct { + Name string `bencode:"n"` + Error errorMarshalType `bencode:"e"` + } + + type issue26 struct { + Answer int64 `bencode:"a"` + Foo myBoolTextType `bencode:"f"` + Name string `bencode:"n"` + } + + type issue26WithErrorChild struct { + Name string `bencode:"n"` + Error errorTextMarshalType `bencode:"e"` + } + + type issue28 struct { + X string `bencode:"x"` + Time *myTimeType `bencode:"t"` + Foo myBoolPtrType `bencode:"f"` + Y string `bencode:"y"` + } + + now := time.Now() + + var encodeCases = []encodeTestCase{ + // integers + {10, `i10e`, false}, + {-10, `i-10e`, false}, + {0, `i0e`, false}, + {int(10), `i10e`, false}, + {int8(10), `i10e`, false}, + {int16(10), `i10e`, false}, + {int32(10), `i10e`, false}, + {int64(10), `i10e`, false}, + {uint(10), `i10e`, false}, + {uint8(10), `i10e`, false}, + {uint16(10), `i10e`, false}, + {uint32(10), `i10e`, false}, + {uint64(10), `i10e`, false}, + {(*int)(nil), ``, false}, + + // ptr-to-integer + {func() *int { + i := 42 + return &i + }(), `i42e`, false}, + + // strings + {"foo", `3:foo`, false}, + {"barbb", `5:barbb`, false}, + {"", `0:`, false}, + {(*string)(nil), ``, false}, + + // ptr-to-string + {func() *string { + str := "foo" + return &str + }(), `3:foo`, false}, + + // lists + {[]interface{}{"foo", 20}, `l3:fooi20ee`, false}, + {[]interface{}{90, 20}, `li90ei20ee`, false}, + {[]interface{}{[]interface{}{"foo", "bar"}, 20}, `ll3:foo3:barei20ee`, false}, + {[]map[string]int{ + {"a": 0, "b": 1}, + {"c": 2, "d": 3}, + }, `ld1:ai0e1:bi1eed1:ci2e1:di3eee`, false}, + {[][]byte{ + []byte{'0', '2', '4', '6', '8'}, + []byte{'a', 'c', 'e'}, + }, `l5:024683:acee`, false}, + {(*[]interface{})(nil), ``, false}, + + // boolean + {true, "i1e", false}, + {false, "i0e", false}, + {(*bool)(nil), ``, false}, + + // dicts + {map[string]interface{}{ + "a": "foo", + "c": "bar", + "b": "tes", + }, `d1:a3:foo1:b3:tes1:c3:bare`, false}, + {eT{"foo", "bar", "far", "boo"}, `d1:A3:foo1:B3:far1:C3:boo1:D3:bare`, false}, + {map[string][]int{ + "a": {0, 1}, + "b": {2, 3}, + }, `d1:ali0ei1ee1:bli2ei3eee`, false}, + {struct{ A, b int }{1, 2}, "d1:Ai1ee", false}, + {(*struct{ A int })(nil), ``, false}, + + // raw + {RawMessage(`i5e`), `i5e`, false}, + {[]RawMessage{ + RawMessage(`i5e`), + RawMessage(`5:hello`), + RawMessage(`ldededee`), + }, `li5e5:helloldededeee`, false}, + {map[string]RawMessage{ + "a": RawMessage(`i5e`), + "b": RawMessage(`5:hello`), + "c": RawMessage(`ldededee`), + }, `d1:ai5e1:b5:hello1:cldededeee`, false}, + + // problem sorting + {sortProblem{A: "foo", B: "bar"}, `d1:A3:foo1:B3:bare`, false}, + + // nil values dropped from maps and structs + {map[string]*int{"a": nil}, `de`, false}, + {struct{ A *int }{nil}, `de`, false}, + {issue18{}, `de`, false}, + {map[string]interface{}{"a": nil}, `de`, false}, + {struct{ A interface{} }{nil}, `de`, false}, + + // embedded structs + {struct { + A string + Embedded + }{"foo", Embedded{"bar"}}, `d1:A3:foo1:B3:bare`, false}, + {struct { + A string + Embedded `bencode:"C"` + }{"foo", Embedded{"bar"}}, `d1:A3:foo1:Cd1:B3:baree`, false}, + + // embedded structs order issue #20 + {struct { + Embedded + A string + }{Embedded{"bar"}, "foo"}, `d1:A3:foo1:B3:bare`, false}, + + // types which implement the Marshal interface will + // be marshalled using this interface + {myBoolType(true), `1:y`, false}, + {myBoolType(false), `1:n`, false}, + {myTimeType{now}, fmt.Sprintf("i%de", now.Unix()), false}, + {errorMarshalType{}, "", true}, + + // pointers to types which implement the Marshal interface will + // be marshalled using this interface + {func() *myBoolType { + b := myBoolType(true) + return &b + }(), `1:y`, false}, + {func() *myTimeType { + t := myTimeType{now} + return &t + }(), fmt.Sprintf("i%de", now.Unix()), false}, + {func() *errorMarshalType { + e := errorMarshalType{} + return &e + }(), "", true}, + + // nil-pointers to types which implement the Marshal interface will be ignored + {(*myBoolType)(nil), "", false}, + {(*myTimeType)(nil), "", false}, + {(*errorMarshalType)(nil), "", false}, + + // ptr-types which implements the Marshal interface will + // be marshalled using this interface + {func() *myBoolPtrType { + b := myBoolPtrType(true) + return &b + }(), `1:y`, false}, + {func() *myBoolPtrType { + b := myBoolPtrType(false) + return &b + }(), `1:n`, false}, + {(*myBoolPtrType)(nil), ``, false}, + + // structures can also have children which support + // the Marshal interface + { + issue22{Time: myTimeType{now}, Foo: myBoolType(true)}, + fmt.Sprintf("d1:f1:y1:ti%dee", now.Unix()), + false, + }, + { // an error will be returned if a child can't be marshalled + issue22WithErrorChild{Name: "Foo", Error: errorMarshalType{}}, + "", true, + }, + // structures passed by reference which have children that support + // the (Text)Marshal interface (by value or by reference), + // will be marshaled using that interface + { + &issue22{Time: myTimeType{now}, Foo: myBoolType(true)}, + fmt.Sprintf("d1:f1:y1:ti%dee", now.Unix()), + false, + }, + { // an error will be returned if a child can't be marshalled + &issue22WithErrorChild{Name: "Foo", Error: errorMarshalType{}}, + "", true, + }, + + // types which implement the TextMarshal interface will + // be marshalled into a bencode string value using this interface + {myBoolTextType(true), `1:y`, false}, + {myBoolTextType(false), `1:n`, false}, + {errorTextMarshalType{}, "", true}, + + // structures can also have children which support + // the TextMarshal interface + { + issue26{Answer: 42, Foo: myBoolTextType(true), Name: "Nova"}, + `d1:ai42e1:f1:y1:n4:Novae`, + false, + }, + { // an error will be returned if a child TextMarshaler returns an error + issue26WithErrorChild{Name: "Foo", Error: errorTextMarshalType{}}, + "", true, + }, + + // ptr types which are used as value types, + // but which ptr version implement the Marshaler/TextMarshaler interface, + // will still get marshalling using this interface, when possible + { + &issue28{X: "x", Time: &myTimeType{now}, Foo: myBoolPtrType(true), Y: "y"}, + fmt.Sprintf(`d1:f1:y1:ti%de1:x1:x1:y1:ye`, now.Unix()), + false, + }, + } + + for i, tt := range encodeCases { + t.Logf("%d: %#v", i, tt.in) + data, err := EncodeString(tt.in) + if !tt.err && err != nil { + t.Errorf("#%d: Unexpected err: %v", i, err) + continue + } + if tt.err && err == nil { + t.Errorf("#%d: Expected err is nil", i) + continue + } + if tt.out != data { + t.Errorf("#%d: Val: %q != %q", i, data, tt.out) + } + } +} + +type myBoolPtrType bool + +// MarshalBencode implements Marshaler.MarshalBencode +func (mbt *myBoolPtrType) MarshalBencode() ([]byte, error) { + var c string + if *mbt { + c = "y" + } else { + c = "n" + } + + return EncodeBytes(c) +} + +// UnmarshalBencode implements Unmarshaler.UnmarshalBencode +func (mbt *myBoolPtrType) UnmarshalBencode(b []byte) error { + var str string + err := DecodeBytes(b, &str) + if err != nil { + return err + } + + switch str { + case "y": + *mbt = true + case "n": + *mbt = false + default: + err = errors.New("invalid myBoolType") + } + + return err +} + +func TestEncodeOmit(t *testing.T) { + type encodeTestCase struct { + in interface{} + out string + err bool + } + + type eT struct { + A string `bencode:",omitempty"` + B int `bencode:",omitempty"` + C *int `bencode:",omitempty"` + } + + var encodeCases = []encodeTestCase{ + {eT{}, `de`, false}, + {eT{A: "a"}, `d1:A1:ae`, false}, + {eT{B: 5}, `d1:Bi5ee`, false}, + {eT{C: new(int)}, `d1:Ci0ee`, false}, + } + + for i, tt := range encodeCases { + data, err := EncodeString(tt.in) + if !tt.err && err != nil { + t.Errorf("#%d: Unexpected err: %v", i, err) + continue + } + if tt.err && err == nil { + t.Errorf("#%d: Expected err is nil", i) + continue + } + if tt.out != data { + t.Errorf("#%d: Val: %q != %q", i, data, tt.out) + } + } +} diff --git a/bencode/example_test.go b/bencode/example_test.go new file mode 100644 index 0000000..98d361a --- /dev/null +++ b/bencode/example_test.go @@ -0,0 +1,67 @@ +package bencode + +import ( + "fmt" + "io" +) + +var ( + data string + r io.Reader + w io.Writer +) + +func ExampleDecodeString() { + var torrent interface{} + if err := DecodeString(data, &torrent); err != nil { + panic(err) + } +} + +func ExampleEncodeString() { + var torrent interface{} + data, err := EncodeString(torrent) + if err != nil { + panic(err) + } + fmt.Println(data) +} + +func ExampleDecodeBytes() { + var torrent interface{} + if err := DecodeBytes([]byte(data), &torrent); err != nil { + panic(err) + } +} + +func ExampleEncodeBytes() { + var torrent interface{} + data, err := EncodeBytes(torrent) + if err != nil { + panic(err) + } + fmt.Println(data) +} + +func ExampleEncoder_Encode() { + var x struct { + Foo string + Bar []string `bencode:"name"` + } + + enc := NewEncoder(w) + if err := enc.Encode(x); err != nil { + panic(err) + } +} + +func ExampleDecoder_Decode() { + dec := NewDecoder(r) + var torrent struct { + Announce string + List [][]string `bencode:"announce-list"` + } + if err := dec.Decode(&torrent); err != nil { + panic(err) + } +} diff --git a/bencode/raw.go b/bencode/raw.go new file mode 100644 index 0000000..8d35a39 --- /dev/null +++ b/bencode/raw.go @@ -0,0 +1,5 @@ +package bencode + +// RawMessage is a special type that will store the raw bencode data when +// encoding or decoding. +type RawMessage []byte diff --git a/bencode/tag.go b/bencode/tag.go new file mode 100644 index 0000000..ebea05a --- /dev/null +++ b/bencode/tag.go @@ -0,0 +1,76 @@ +package bencode + +import ( + "reflect" + "strings" + "unicode" +) + +// tagOptions is the string following a comma in a struct field's "bencode" +// tag, or the empty string. It does not include the leading comma. +type tagOptions string + +// parseTag splits a struct field's tag into its name and +// comma-separated options. +func parseTag(tag string) (string, tagOptions) { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx], tagOptions(tag[idx+1:]) + } + return tag, tagOptions("") +} + +// Contains returns whether checks that a comma-separated list of options +// contains a particular substr flag. substr must be surrounded by a +// string boundary or commas. +func (options tagOptions) Contains(optionName string) bool { + s := string(options) + for s != "" { + var next string + i := strings.Index(s, ",") + if i >= 0 { + s, next = s[:i], s[i+1:] + } + if s == optionName { + return true + } + s = next + } + return false +} + +func isValidTag(key string) bool { + if key == "" { + return false + } + for _, c := range key { + if c != ' ' && c != '$' && c != '-' && c != '_' && c != '.' && !unicode.IsLetter(c) && !unicode.IsDigit(c) { + return false + } + } + return true +} + +func matchName(key string) func(string) bool { + return func(s string) bool { + return strings.ToLower(key) == strings.ToLower(s) + } +} + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + default: + return v.IsZero() + } +} diff --git a/dht/blacklist.go b/dht/blacklist.go new file mode 100644 index 0000000..d31056e --- /dev/null +++ b/dht/blacklist.go @@ -0,0 +1,181 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "sync" + "time" +) + +// Blacklist is used to manage the ip blacklist. +// +// Notice: The implementation should clear the address existed for long time. +type Blacklist interface { + // In reports whether the address, ip and port, is in the blacklist. + In(ip string, port int) bool + + // If port is equal to 0, it should ignore port and only use ip when matching. + Add(ip string, port int) + + // If port is equal to 0, it should delete the address by only the ip. + Del(ip string, port int) + + // Close is used to notice the implementation to release the underlying + // resource. + Close() +} + +type noopBlacklist struct{} + +func (nbl noopBlacklist) In(ip string, port int) bool { return false } +func (nbl noopBlacklist) Add(ip string, port int) {} +func (nbl noopBlacklist) Del(ip string, port int) {} +func (nbl noopBlacklist) Close() {} + +// NewNoopBlacklist returns a no-op Blacklist. +func NewNoopBlacklist() Blacklist { return noopBlacklist{} } + +// DebugBlacklist returns a new Blacklist to log the information as debug. +func DebugBlacklist(bl Blacklist, logf func(string, ...interface{})) Blacklist { + return logBlacklist{Blacklist: bl, logf: logf} +} + +type logBlacklist struct { + Blacklist + logf func(string, ...interface{}) +} + +func (dbl logBlacklist) Add(ip string, port int) { + dbl.logf("add the blacklist: ip=%s, port=%d", ip, port) + dbl.Blacklist.Add(ip, port) +} + +func (dbl logBlacklist) Del(ip string, port int) { + dbl.logf("delete the blacklist: ip=%s, port=%d", ip, port) + dbl.Blacklist.Del(ip, port) +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// NewMemoryBlacklist returns a blacklst implementation based on memory. +// +// if maxnum is equal to 0, no limit. +func NewMemoryBlacklist(maxnum int, duration time.Duration) Blacklist { + bl := &blacklist{ + num: maxnum, + ips: make(map[string]*wrappedPort, 128), + exit: make(chan struct{}), + } + go bl.loop(duration) + return bl +} + +type wrappedPort struct { + Time time.Time + Enable bool + Ports map[int]struct{} +} + +type blacklist struct { + exit chan struct{} + lock sync.RWMutex + ips map[string]*wrappedPort + num int +} + +func (bl *blacklist) loop(interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + for { + select { + case <-bl.exit: + return + case now := <-tick.C: + bl.lock.Lock() + for ip, wp := range bl.ips { + if now.Sub(wp.Time) > interval { + delete(bl.ips, ip) + } + } + bl.lock.Unlock() + } + } +} + +func (bl *blacklist) Close() { + select { + case <-bl.exit: + default: + close(bl.exit) + } +} + +// In reports whether the address, ip and port, is in the blacklist. +func (bl *blacklist) In(ip string, port int) (yes bool) { + bl.lock.RLock() + if wp, ok := bl.ips[ip]; ok { + if wp.Enable { + _, yes = wp.Ports[port] + } else { + yes = true + } + } + bl.lock.RUnlock() + return +} + +func (bl *blacklist) Add(ip string, port int) { + bl.lock.Lock() + wp, ok := bl.ips[ip] + if !ok { + if bl.num > 0 && len(bl.ips) >= bl.num { + bl.lock.Unlock() + return + } + + wp = &wrappedPort{Enable: true} + bl.ips[ip] = wp + } + + if port < 1 { + wp.Enable = false + wp.Ports = nil + } else if wp.Ports == nil { + wp.Ports = map[int]struct{}{port: struct{}{}} + } else { + wp.Ports[port] = struct{}{} + } + + wp.Time = time.Now() + bl.lock.Unlock() +} + +func (bl *blacklist) Del(ip string, port int) { + bl.lock.Lock() + if wp, ok := bl.ips[ip]; ok { + if port < 1 { + delete(bl.ips, ip) + } else if wp.Enable { + switch len(wp.Ports) { + case 0, 1: + delete(bl.ips, ip) + default: + delete(wp.Ports, port) + wp.Time = time.Now() + } + } + } + bl.lock.Unlock() +} diff --git a/dht/blacklist_test.go b/dht/blacklist_test.go new file mode 100644 index 0000000..ca1cb99 --- /dev/null +++ b/dht/blacklist_test.go @@ -0,0 +1,80 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "testing" + "time" +) + +func (bl *blacklist) portsLen() (n int) { + bl.lock.RLock() + for _, wp := range bl.ips { + n += len(wp.Ports) + } + bl.lock.RUnlock() + return +} + +func (bl *blacklist) getIPs() []string { + bl.lock.RLock() + ips := make([]string, 0, len(bl.ips)) + for ip := range bl.ips { + ips = append(ips, ip) + } + bl.lock.RUnlock() + return ips +} + +func TestMemoryBlacklist(t *testing.T) { + bl := NewMemoryBlacklist(3, time.Second).(*blacklist) + defer bl.Close() + + bl.Add("1.1.1.1", 123) + bl.Add("1.1.1.1", 456) + bl.Add("1.1.1.1", 789) + bl.Add("2.2.2.2", 111) + bl.Add("3.3.3.3", 0) + bl.Add("4.4.4.4", 222) + + ips := bl.getIPs() + if len(ips) != 3 { + t.Error(ips) + } else { + for _, ip := range ips { + switch ip { + case "1.1.1.1", "2.2.2.2", "3.3.3.3": + default: + t.Error(ip) + } + } + } + + if n := bl.portsLen(); n != 4 { + t.Errorf("expect port num 4, but got %d", n) + } + + if bl.In("1.1.1.1", 111) || !bl.In("1.1.1.1", 123) { + t.Fail() + } + if !bl.In("3.3.3.3", 111) || bl.In("4.4.4.4", 222) { + t.Fail() + } + + bl.Del("3.3.3.3", 0) + if bl.In("3.3.3.3", 111) { + t.Fail() + } +} diff --git a/dht/dht_server.go b/dht/dht_server.go new file mode 100644 index 0000000..c00d24b --- /dev/null +++ b/dht/dht_server.go @@ -0,0 +1,864 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package dht implements the DHT Protocol. And you can use it to build or join +// the DHT swarm network. +package dht + +import ( + "bytes" + "fmt" + "io" + "log" + "net" + "sync" + "time" + + "github.com/xgfone/bt/bencode" + "github.com/xgfone/bt/krpc" + "github.com/xgfone/bt/metainfo" +) + +const ( + queryMethodPing = "ping" + queryMethodFindNode = "find_node" + queryMethodGetPeers = "get_peers" + queryMethodAnnouncePeer = "announce_peer" +) + +var errUnsupportedIPProtocol = fmt.Errorf("unsupported ip protocol") + +// Predefine some ip protocol stacks. +const ( + IPv4Protocol IPProtocolStack = 4 + IPv6Protocol IPProtocolStack = 6 +) + +// IPProtocolStack represents the ip protocol stack, such as IPv4 or IPv6 +type IPProtocolStack uint8 + +// Result is used to pass the response result to the callback function. +type Result struct { + // Addr is the address of the peer where the request is sent to. + // + // Notice: it may be nil for "get_peers" request. + Addr *net.UDPAddr + + // For Error + Code int // 0 represents the success. + Reason string // Reason indicates the reason why the request failed when Code > 0. + Timeout bool // Timeout indicates whether the response is timeout. + + // The list of the address of the peers returned by GetPeers. + Peers []metainfo.Address +} + +// Config is used to configure the DHT server. +type Config struct { + // K is the size of the bucket of the routing table. + // + // The default is 8. + K int + + // ID is the id of the current DHT server node. + // + // The default is generated randomly + ID metainfo.Hash + + // IPProtocols is used to specify the supported IP Protocol Stack. + // + // If not given, it will detect the address family of the server listens + // automatically and set the supported ip protocol stack by the address + // family. If fails, the default is []IPProtocolStack{IPv4Protocol}. + // So, if the server is listening on the empty address, such as ":6881", + // and only supports IPv6, you should specify it to "IPv6Protocol". + IPProtocols []IPProtocolStack + + // ReadOnly indicates whether the current node is read-only. + // + // If true, the DHT server will enter in the read-only mode. + // + // The default is false. + // + // BEP 43 + ReadOnly bool + + // MsgSize is the maximum size of the DHT message. + // + // The default is 4096. + MsgSize int + + // SearchDepth is used to control the depth to send the "get_peers" or + // "find_node" query recursively to get the peers storing the torrent + // infohash. + // + // The default depth is 8. + SearchDepth int + + // RespTimeout is the response timeout, that's, the response is valid + // only before the timeout reaches. + // + // The default is "10s". + RespTimeout time.Duration + + // RoutingTableStorage is used to store the nodes in the routing table. + // + // The default is nil. + RoutingTableStorage RoutingTableStorage + + // PeerManager is used to manage the peers on torrent infohash, + // which is called when receiving the "get_peers" query. + // + // The default uses the inner token-peer manager. + PeerManager PeerManager + + // Blacklist is used to manage the ip blacklist. + // + // The default is NewMemoryBlacklist(1024, time.Hour*24*7). + Blacklist Blacklist + + // ErrorLog is used to log the error. + // + // The default is log.Printf. + ErrorLog func(format string, args ...interface{}) + + // OnSearch is called when someone searches the torrent infohash, + // that's, the "get_peers" query. + // + // The default callback does noting. + OnSearch func(infohash string, ip net.IP, port uint16) + + // OnTorrent is called when someone has the torrent infohash + // or someone has just downloaded the torrent infohash, + // that's, the "get_peers" response or "announce_peer" query. + // + // The default callback does noting. + OnTorrent func(infohash string, ip net.IP, port uint16) + + // HandleInMessage is used to intercept the incoming DHT message. + // For example, you can debug the message as the log. + // + // Return true if going on handling by the default. Or return false. + // + // The default is nil. + HandleInMessage func(*net.UDPAddr, *krpc.Message) bool + + // HandleOutMessage is used to intercept the outgoing DHT message. + // For example, you can debug the message as the log. + // + // Return (false, nil) if going on handling by the default. + // + // The default is nil. + HandleOutMessage func(*net.UDPAddr, *krpc.Message) (wrote bool, err error) +} + +func (c Config) in(*net.UDPAddr, *krpc.Message) bool { return true } +func (c Config) out(*net.UDPAddr, *krpc.Message) (bool, error) { return false, nil } + +func (c *Config) set(conf ...Config) { + if len(conf) > 0 { + *c = conf[0] + } + + if c.K <= 0 { + c.K = 8 + } + if c.ID.IsZero() { + c.ID = metainfo.NewRandomHash() + } + if c.MsgSize <= 0 { + c.MsgSize = 4096 + } + if c.ErrorLog == nil { + c.ErrorLog = log.Printf + } + if c.SearchDepth < 1 { + c.SearchDepth = 8 + } + if c.RoutingTableStorage == nil { + c.RoutingTableStorage = noopStorage{} + } + if c.Blacklist == nil { + c.Blacklist = NewMemoryBlacklist(1024, time.Hour*24*7) + } + if c.RespTimeout == 0 { + c.RespTimeout = time.Second * 10 + } + if c.OnSearch == nil { + c.OnSearch = func(string, net.IP, uint16) {} + } + if c.OnTorrent == nil { + c.OnTorrent = func(string, net.IP, uint16) {} + } + if c.HandleInMessage == nil { + c.HandleInMessage = c.in + } + if c.HandleOutMessage == nil { + c.HandleOutMessage = c.out + } +} + +// Server is a DHT server. +type Server struct { + conf Config + exit chan struct{} + conn net.PacketConn + once sync.Once + + ipv4 bool + ipv6 bool + want []krpc.Want + + peerManager PeerManager + // routingTable *routingTable + routingTable4 *routingTable + routingTable6 *routingTable + tokenManager *tokenManager + tokenPeerManager *tokenPeerManager + transactionManager *transactionManager +} + +// NewServer returns a new DHT server. +func NewServer(conn net.PacketConn, config ...Config) *Server { + var conf Config + conf.set(config...) + + if len(conf.IPProtocols) == 0 { + host, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + panic(err) + } else if ip := net.ParseIP(host); ipIsZero(ip) || ip.To4() != nil { + conf.IPProtocols = []IPProtocolStack{IPv4Protocol} + } else { + conf.IPProtocols = []IPProtocolStack{IPv6Protocol} + } + } + + var ipv4, ipv6 bool + var want []krpc.Want + for _, ip := range conf.IPProtocols { + switch ip { + case IPv4Protocol: + ipv4 = true + want = append(want, krpc.WantNodes) + case IPv6Protocol: + ipv6 = true + want = append(want, krpc.WantNodes6) + } + } + + s := &Server{ + ipv4: ipv4, + ipv6: ipv6, + want: want, + conn: conn, + conf: conf, + exit: make(chan struct{}), + peerManager: conf.PeerManager, + tokenManager: newTokenManager(), + tokenPeerManager: newTokenPeerManager(), + transactionManager: newTransactionManager(), + } + + s.routingTable4 = newRoutingTable(s, false) + s.routingTable6 = newRoutingTable(s, true) + if s.peerManager == nil { + s.peerManager = s.tokenPeerManager + } + + return s +} + +// ID returns the ID of the DHT server node. +func (s *Server) ID() metainfo.Hash { return s.conf.ID } + +// Bootstrap initializes the routing table at first. +// +// Notice: If the routing table has had some nodes, it does noting. +func (s *Server) Bootstrap(addrs []string) { + if (s.ipv4 && s.routingTable4.Len() == 0) || + (s.ipv6 && s.routingTable6.Len() == 0) { + for _, addr := range addrs { + as, err := metainfo.NewAddressesFromString(addr) + if err != nil { + s.conf.ErrorLog(err.Error()) + continue + } + + for _, a := range as { + if err = s.FindNode(a.UDPAddr(), s.conf.ID); err != nil { + s.conf.ErrorLog(`fail to bootstrap '%s': %s`, a.String(), err) + } + } + } + } +} + +// Node4Num returns the number of the ipv4 nodes in the routing table. +func (s *Server) Node4Num() int { return s.routingTable4.Len() } + +// Node6Num returns the number of the ipv6 nodes in the routing table. +func (s *Server) Node6Num() int { return s.routingTable6.Len() } + +// AddNode adds the node into the routing table. +// +// The returned value: +// NodeAdded: The node is added successfully. +// NodeNotAdded: The node is not added and is discarded. +// NodeExistAndUpdated: The node has existed, and its status has been updated. +// NodeExistAndChanged: The node has existed, but the address is inconsistent. +// The current node will be discarded. +// +func (s *Server) AddNode(node krpc.Node) int { + // For IPv6 + if isIPv6(node.Addr.IP) { + if s.ipv6 { + return s.routingTable6.AddNode(node) + } + return NodeNotAdded + } + + // For IPv4 + if s.ipv4 { + return s.routingTable4.AddNode(node) + } + + return NodeNotAdded +} + +func (s *Server) addNode(a *net.UDPAddr, id metainfo.Hash, ro bool) (r int) { + if ro { // BEP 43 + return NodeNotAdded + } + + if r = s.AddNode(krpc.NewNodeByUDPAddr(id, a)); r == NodeExistAndChanged { + s.conf.Blacklist.Add(a.IP.String(), a.Port) + } + + return +} + +func (s *Server) addNode2(node krpc.Node, ro bool) int { + if ro { // BEP 43 + return NodeNotAdded + } + return s.AddNode(node) +} + +func (s *Server) stop() { + close(s.exit) + s.routingTable4.Stop() + s.routingTable6.Stop() + s.tokenManager.Stop() + s.tokenPeerManager.Stop() + s.transactionManager.Stop() + s.conf.Blacklist.Close() +} + +// Close stops the DHT server. +func (s *Server) Close() { s.once.Do(s.stop) } + +// Sync is used to synchronize the routing table to the underlying storage. +func (s *Server) Sync() { + s.routingTable4.Sync() + s.routingTable6.Sync() +} + +// Run starts the DHT server. +func (s *Server) Run() { + go s.routingTable4.Start(time.Minute * 5) + go s.routingTable6.Start(time.Minute * 5) + go s.tokenManager.Start(time.Minute * 10) + go s.tokenPeerManager.Start(time.Hour * 24) + go s.transactionManager.Start(s, s.conf.RespTimeout) + + buf := make([]byte, s.conf.MsgSize) + for { + n, raddr, err := s.conn.ReadFrom(buf) + if err != nil { + s.conf.ErrorLog("fail to read the dht message: %s", err) + return + } + + s.handlePacket(raddr.(*net.UDPAddr), buf[:n]) + } +} + +func (s *Server) isDisabled(raddr *net.UDPAddr) bool { + if isIPv6(raddr.IP) { + if !s.ipv6 { + return true + } + } else if !s.ipv4 { + return true + } + return false +} + +// HandlePacket handles the incoming DHT message. +func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) { + if s.isDisabled(raddr) { + return + } + + // Check whether the raddr is in the ip blacklist. If yes, discard it. + if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) { + return + } + + var msg krpc.Message + if err := bencode.DecodeBytes(data, &msg); err != nil { + s.conf.ErrorLog("decode krpc message error: %s", err) + return + } else if msg.T == "" { + s.conf.ErrorLog("no transaction id from '%s'", raddr) + return + } + + // TODO: Should we use a task pool?? + go s.handleMessage(raddr, msg) +} + +func (s *Server) handleMessage(raddr *net.UDPAddr, m krpc.Message) { + if !s.conf.HandleInMessage(raddr, &m) { + return + } + + switch m.Y { + case "q": + if !m.A.ID.IsZero() { + r := s.addNode(raddr, m.A.ID, m.RO) + if r != NodeExistAndChanged && !s.conf.ReadOnly { // BEP 43 + s.handleQuery(raddr, m) + } + } + case "r": + if !m.R.ID.IsZero() { + if s.addNode(raddr, m.R.ID, m.RO) == NodeExistAndChanged { + return + } + + if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil { + t.OnResponse(t, raddr, m) + } + } + case "e": + if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil { + t.OnError(t, m.E.Code, m.E.Reason) + } + default: + s.conf.ErrorLog("unknown dht message type '%s'", m.Y) + } +} + +func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) { + switch m.Q { + case queryMethodPing: + s.reply(raddr, m.T, krpc.ResponseResult{}) + case queryMethodFindNode: // See BEP 32 + var r krpc.ResponseResult + n4 := m.A.ContainsWant(krpc.WantNodes) + n6 := m.A.ContainsWant(krpc.WantNodes6) + if !n4 && !n6 { + if isIPv6(raddr.IP) { + r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K) + } else { + r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K) + } + } else { + if n4 { + r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K) + } + if n6 { + r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K) + } + } + s.reply(raddr, m.T, r) + case queryMethodGetPeers: // See BEP 32 + n4 := m.A.ContainsWant(krpc.WantNodes) + n6 := m.A.ContainsWant(krpc.WantNodes6) + + // Get the ipv4/ipv6 peers storing the torrent infohash. + var r krpc.ResponseResult + if !n4 && !n6 { + r.Values = s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, isIPv6(raddr.IP)) + } else { + if n4 { + r.Values = s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, false) + } + + if n6 { + values := s.peerManager.GetPeers(m.A.InfoHash, s.conf.K, true) + if len(r.Values) == 0 { + r.Values = values + } else { + r.Values = append(r.Values, values...) + } + } + } + + // No Peers, and return the closest other nodes. + if len(r.Values) == 0 { + if !n4 && !n6 { + if isIPv6(raddr.IP) { + r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K) + } else { + r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K) + } + } else { + if n4 { + r.Nodes = s.routingTable4.Closest(m.A.InfoHash, s.conf.K) + } + if n6 { + r.Nodes6 = s.routingTable6.Closest(m.A.InfoHash, s.conf.K) + } + } + } + + r.Token = s.tokenManager.Token(raddr) + s.reply(raddr, m.T, r) + s.conf.OnSearch(m.A.InfoHash.HexString(), raddr.IP, uint16(raddr.Port)) + case queryMethodAnnouncePeer: + if s.tokenManager.Check(raddr, m.A.Token) { + return + } + s.reply(raddr, m.T, krpc.ResponseResult{}) + s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.IP, m.A.GetPort(raddr.Port)) + default: + s.sendError(raddr, m.T, "unknown query method", krpc.ErrorCodeMethodUnknown) + } +} + +func (s *Server) send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) { + // // TODO: Should we check the ip blacklist?? + // if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) { + // return + // } + + m.RO = s.conf.ReadOnly // BEP 43 + if wrote, err = s.conf.HandleOutMessage(raddr, &m); !wrote && err == nil { + wrote, err = s._send(raddr, m) + } + + return +} + +func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) { + if m.T == "" || m.Y == "" { + panic(`DHT message "t" or "y" must not be empty`) + } + + buf := bytes.NewBuffer(nil) + buf.Grow(128) + if err = bencode.NewEncoder(buf).Encode(m); err != nil { + panic(err) + } + + n, err := s.conn.WriteTo(buf.Bytes(), raddr) + if err != nil { + err = fmt.Errorf("error writing %d bytes to %s: %s", buf.Len(), raddr, err) + s.conf.Blacklist.Add(raddr.IP.String(), 0) + return + } + + wrote = true + if n != buf.Len() { + err = io.ErrShortWrite + } + + return +} + +func (s *Server) sendError(raddr *net.UDPAddr, tid, reason string, code int) { + if _, err := s.send(raddr, krpc.NewErrorMsg(tid, code, reason)); err != nil { + s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error()) + } +} + +func (s *Server) reply(raddr *net.UDPAddr, tid string, r krpc.ResponseResult) { + r.ID = s.conf.ID + if _, err := s.send(raddr, krpc.NewResponseMsg(tid, r)); err != nil { + s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error()) + } +} + +func (s *Server) request(t *transaction) (err error) { + if s.isDisabled(t.Addr) { + return errUnsupportedIPProtocol + } + + t.Arg.ID = s.conf.ID + if t.ID == "" { + t.ID = s.transactionManager.GetTransactionID() + } + if _, err = s.send(t.Addr, krpc.NewQueryMsg(t.ID, t.Query, t.Arg)); err != nil { + s.conf.ErrorLog("error replying to %s: %s", t.Addr.String(), err.Error()) + } else { + s.transactionManager.AddTransaction(t) + } + return +} + +func (s *Server) onError(t *transaction, code int, reason string) { + s.conf.ErrorLog("got an error to ping '%s': code=%d, reason=%s", + t.Addr.String(), code, reason) + t.Done(Result{Code: code, Reason: reason}) +} + +func (s *Server) onTimeout(t *transaction) { + // TODO: Should we use a task pool?? + t.Done(Result{Timeout: true}) + s.conf.ErrorLog("transaction '%s' timeout: query=%s, raddr=%s", + t.ID, t.Query, t.Addr.String()) +} + +func (s *Server) onPingResp(t *transaction, a *net.UDPAddr, m krpc.Message) { + t.Done(Result{}) +} + +func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) { + // Store the response node with the token. + if m.R.Token != "" { + s.tokenPeerManager.Set(m.R.ID, a, m.R.Token) + } + + // Get the peers. + if len(m.R.Values) > 0 { + t.Done(Result{Peers: m.R.Values}) + for _, addr := range m.R.Values { + s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr.IP, addr.Port) + } + return + } + + // Search the torrent infohash recursively. + t.Depth-- + if t.Depth < 1 { + t.Done(Result{}) + return + } + + var found bool + ids := t.Visited + nodes := make([]krpc.Node, 0, len(m.R.Nodes)+len(m.R.Nodes6)) + for _, node := range m.R.Nodes { + if node.ID == t.Arg.InfoHash { + found = true + } + if ids.Contains(node.ID) { + continue + } + if s.addNode2(node, m.RO) == NodeAdded { + nodes = append(nodes, node) + ids = append(ids, node.ID) + } + } + for _, node := range m.R.Nodes6 { + if node.ID == t.Arg.InfoHash { + found = true + } + if ids.Contains(node.ID) { + continue + } + if s.addNode2(node, m.RO) == NodeAdded { + nodes = append(nodes, node) + ids = append(ids, node.ID) + } + } + + if found || len(nodes) == 0 { + t.Done(Result{}) + return + } + + for _, node := range nodes { + s.getPeers(t.Arg.InfoHash, node.Addr, t.Depth, ids, t.Callback) + } +} + +func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int, + ids metainfo.Hashes, cb ...func(Result)) { + arg := krpc.QueryArg{InfoHash: info, Wants: s.want} + t := newTransaction(s, addr.UDPAddr(), queryMethodGetPeers, arg, cb...) + t.OnResponse = s.onGetPeersResp + t.Depth = depth + t.Visited = ids + if err := s.request(t); err != nil { + s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err) + } +} + +// Ping sends a PING query to addr, and the callback function cb will be called +// when the response or error is returned, or it's timeout. +func (s *Server) Ping(addr *net.UDPAddr, cb ...func(Result)) (err error) { + t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...) + t.OnResponse = s.onPingResp + return s.request(t) +} + +// GetPeers searches the peer storing the torrent by the infohash of the torrent, +// which will search it recursively until some peers are returned or it reaches +// the maximun depth, that's, ServerConfig.SearchDepth. +// +// If cb is given, it will be called when some peers are returned. +// Notice: it may be called for many times. +func (s *Server) GetPeers(infohash metainfo.Hash, cb ...func(Result)) { + if infohash.IsZero() { + panic("the infohash of the torrent is ZERO") + } + + var nodes []krpc.Node + if s.ipv4 { + nodes = s.routingTable4.Closest(infohash, s.conf.K) + } + if s.ipv6 { + nodes = append(nodes, s.routingTable6.Closest(infohash, s.conf.K)...) + } + + if len(nodes) == 0 { + if len(cb) != 0 && cb[0] != nil { + cb[0](Result{}) + } + return + } + + ids := make(metainfo.Hashes, len(nodes)) + for i, node := range nodes { + ids[i] = node.ID + } + + for _, node := range nodes { + s.getPeers(infohash, node.Addr, s.conf.SearchDepth, ids, cb...) + } + +} + +// AnnouncePeer announces the torrent infohash to the K closest nodes, +// and returns the nodes to which it sends the announce_peer query. +func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort bool) []krpc.Node { + if infohash.IsZero() { + panic("the infohash of the torrent is ZERO") + } + + var nodes []krpc.Node + if s.ipv4 { + nodes = s.routingTable4.Closest(infohash, s.conf.K) + } + if s.ipv6 { + nodes = append(nodes, s.routingTable6.Closest(infohash, s.conf.K)...) + } + + sentNodes := make([]krpc.Node, 0, len(nodes)) + for _, node := range nodes { + addr := node.Addr.UDPAddr() + token := s.tokenPeerManager.Get(infohash, addr) + if token == "" { + continue + } + + arg := krpc.QueryArg{ImpliedPort: impliedPort, InfoHash: infohash, Port: port, Token: token} + t := newTransaction(s, addr, queryMethodAnnouncePeer, arg) + if err := s.request(t); err != nil { + s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err) + } else { + sentNodes = append(sentNodes, node) + } + } + + return sentNodes +} + +// FindNode sends the "find_node" query to the addr to find the target node. +// +// Notice: In general, it's used to bootstrap the routing table. +func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error { + if target.IsZero() { + panic("the target is ZERO") + } + + return s.findNode(target, addr, s.conf.SearchDepth, nil) +} + +func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int, + ids metainfo.Hashes) error { + arg := krpc.QueryArg{Target: target, Wants: s.want} + t := newTransaction(s, addr, queryMethodFindNode, arg) + t.OnResponse = s.onFindNodeResp + t.Visited = ids + return s.request(t) +} + +func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message) { + // Search the target node recursively. + t.Depth-- + if t.Depth < 1 { + return + } + + var found bool + ids := t.Visited + nodes := make([]krpc.Node, 0, len(m.R.Nodes)+len(m.R.Nodes6)) + for _, node := range m.R.Nodes { + if node.ID == t.Arg.Target { + found = true + } + if ids.Contains(node.ID) { + continue + } + if s.addNode2(node, m.RO) == NodeAdded { + nodes = append(nodes, node) + ids = append(ids, node.ID) + } + } + for _, node := range m.R.Nodes6 { + if node.ID == t.Arg.Target { + found = true + } + if ids.Contains(node.ID) { + continue + } + if s.addNode2(node, m.RO) == NodeAdded { + nodes = append(nodes, node) + ids = append(ids, node.ID) + } + } + + if found || len(nodes) == 0 { + return + } + + for _, node := range nodes { + err := s.findNode(t.Arg.Target, node.Addr.UDPAddr(), t.Depth, ids) + if err != nil { + s.conf.ErrorLog(`fail to send "find_node" query to '%s': %s`, + node.Addr.String(), err) + } + } +} + +func isIPv6(ip net.IP) bool { + if ip.To4() == nil { + return true + } + return false +} + +func ipIsZero(ip net.IP) bool { + for _, b := range ip { + if b != 0 { + return false + } + } + return true +} diff --git a/dht/dht_server_test.go b/dht/dht_server_test.go new file mode 100644 index 0000000..e530e81 --- /dev/null +++ b/dht/dht_server_test.go @@ -0,0 +1,182 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/xgfone/bt/metainfo" +) + +type testPeerManager struct { + lock sync.RWMutex + peers map[metainfo.Hash][]metainfo.Address +} + +func newTestPeerManager() *testPeerManager { + return &testPeerManager{peers: make(map[metainfo.Hash][]metainfo.Address)} +} + +func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr metainfo.Address) { + pm.lock.Lock() + var exist bool + for _, orig := range pm.peers[infohash] { + if orig.Equal(addr) { + exist = true + break + } + } + if !exist { + pm.peers[infohash] = append(pm.peers[infohash], addr) + } + pm.lock.Unlock() +} + +func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, + ipv6 bool) (addrs []metainfo.Address) { + // We only supports IPv4, so ignore the ipv6 argument. + pm.lock.RLock() + _addrs := pm.peers[infohash] + if _len := len(_addrs); _len > 0 { + if _len > maxnum { + _len = maxnum + } + addrs = _addrs[:_len] + } + pm.lock.RUnlock() + return +} + +func onSearch(infohash string, ip net.IP, port uint16) { + addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10)) + fmt.Printf("%s is searching %s\n", addr, infohash) +} + +func onTorrent(infohash string, ip net.IP, port uint16) { + addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10)) + fmt.Printf("%s has downloaded %s\n", addr, infohash) +} + +func newDHTServer(id metainfo.Hash, addr string, pm PeerManager) (s *Server, err error) { + conn, err := net.ListenPacket("udp", addr) + if err == nil { + c := Config{ID: id, PeerManager: pm, OnSearch: onSearch, OnTorrent: onTorrent} + s = NewServer(conn, c) + } + return +} + +func ExampleServer() { + // For test, we predefine some node ids and infohash. + id1 := metainfo.NewRandomHash() + id2 := metainfo.NewRandomHash() + id3 := metainfo.NewRandomHash() + infohash := metainfo.Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} + + // Create first DHT server + pm := newTestPeerManager() + server1, err := newDHTServer(id1, ":9001", pm) + if err != nil { + fmt.Println(err) + return + } + defer server1.Close() + + // Create second DHT server + server2, err := newDHTServer(id2, ":9002", nil) + if err != nil { + fmt.Println(err) + return + } + defer server2.Close() + + // Create third DHT server + server3, err := newDHTServer(id3, ":9003", nil) + if err != nil { + fmt.Println(err) + return + } + defer server3.Close() + + // Start the three DHT servers + go server1.Run() + go server2.Run() + go server3.Run() + + // Wait that the DHT servers start. + time.Sleep(time.Second) + + // Bootstrap the routing table to create a DHT network with the three servers + server1.Bootstrap([]string{"127.0.0.1:9002", "127.0.0.1:9003"}) + server2.Bootstrap([]string{"127.0.0.1:9001", "127.0.0.1:9003"}) + server3.Bootstrap([]string{"127.0.0.1:9001", "127.0.0.1:9002"}) + + // Wait that the DHT servers learn the routing table + time.Sleep(time.Second) + + fmt.Println("Server1:", server1.Node4Num()) + fmt.Println("Server2:", server2.Node4Num()) + fmt.Println("Server3:", server3.Node4Num()) + + server1.GetPeers(infohash, func(r Result) { + if len(r.Peers) == 0 { + fmt.Printf("no peers for %s\n", infohash) + } else { + for _, peer := range r.Peers { + fmt.Printf("%s: %s\n", infohash, peer.String()) + } + } + }) + + // Wait that the last get_peers ends. + time.Sleep(time.Second) + + // Add the peer to let the DHT server1 has the peer. + pm.AddPeer(infohash, metainfo.NewAddress(net.ParseIP("127.0.0.1"), 9001)) + + // Search the torrent infohash again, but from DHT server2, + // which will search the DHT server1 recursively. + server2.GetPeers(infohash, func(r Result) { + if len(r.Peers) == 0 { + fmt.Printf("no peers for %s\n", infohash) + } else { + for _, peer := range r.Peers { + fmt.Printf("%s: %s\n", infohash, peer.String()) + } + } + }) + + // Wait that the recursive call ends. + time.Sleep(time.Second) + + // Unordered output: + // Server1: 2 + // Server2: 2 + // Server3: 2 + // 127.0.0.1:9001 is searching 0102030405060708090a0b0c0d0e0f1011121314 + // 127.0.0.1:9001 is searching 0102030405060708090a0b0c0d0e0f1011121314 + // no peers for 0102030405060708090a0b0c0d0e0f1011121314 + // no peers for 0102030405060708090a0b0c0d0e0f1011121314 + // 127.0.0.1:9002 is searching 0102030405060708090a0b0c0d0e0f1011121314 + // 127.0.0.1:9002 is searching 0102030405060708090a0b0c0d0e0f1011121314 + // no peers for 0102030405060708090a0b0c0d0e0f1011121314 + // 0102030405060708090a0b0c0d0e0f1011121314: 127.0.0.1:9001 + // 127.0.0.1:9001 has downloaded 0102030405060708090a0b0c0d0e0f1011121314 +} diff --git a/dht/peer_manager.go b/dht/peer_manager.go new file mode 100644 index 0000000..03d5006 --- /dev/null +++ b/dht/peer_manager.go @@ -0,0 +1,139 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "net" + "sync" + "time" + + "github.com/xgfone/bt/metainfo" +) + +// PeerManager is used to manage the peers. +type PeerManager interface { + // If ipv6 is true, only return ipv6 addresses. Or return ipv4 addresses. + GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []metainfo.Address +} + +type peer struct { + ID metainfo.Hash + IP net.IP + Port uint16 + Token string + Time time.Time +} + +type tokenPeerManager struct { + lock sync.RWMutex + exit chan struct{} + peers map[metainfo.Hash]map[string]peer +} + +func newTokenPeerManager() *tokenPeerManager { + return &tokenPeerManager{ + exit: make(chan struct{}), + peers: make(map[metainfo.Hash]map[string]peer, 128), + } +} + +// Start starts the token-peer manager. +func (tpm *tokenPeerManager) Start(interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + for { + select { + case <-tpm.exit: + return + case now := <-tick.C: + tpm.lock.Lock() + for id, peers := range tpm.peers { + for addr, peer := range peers { + if now.Sub(peer.Time) > interval { + delete(peers, addr) + } + } + + if len(peers) == 0 { + delete(tpm.peers, id) + } + } + tpm.lock.Unlock() + } + } +} + +func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token string) { + addrkey := addr.String() + tpm.lock.Lock() + peers, ok := tpm.peers[id] + if !ok { + peers = make(map[string]peer, 4) + tpm.peers[id] = peers + } + peers[addrkey] = peer{ + ID: id, + IP: addr.IP, + Port: uint16(addr.Port), + Token: token, + Time: time.Now(), + } + tpm.lock.Unlock() +} + +func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr *net.UDPAddr) (token string) { + addrkey := addr.String() + tpm.lock.RLock() + if peers, ok := tpm.peers[id]; ok { + if peer, ok := peers[addrkey]; ok { + token = peer.Token + } + } + tpm.lock.RUnlock() + return +} + +func (tpm *tokenPeerManager) Stop() { + select { + case <-tpm.exit: + default: + close(tpm.exit) + } +} + +func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, + ipv6 bool) (addrs []metainfo.Address) { + addrs = make([]metainfo.Address, 0, maxnum) + tpm.lock.RLock() + if peers, ok := tpm.peers[infohash]; ok { + for _, peer := range peers { + if maxnum < 1 { + break + } + + if ipv6 { // For IPv6 + if isIPv6(peer.IP) { + maxnum-- + addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port)) + } + } else if !isIPv6(peer.IP) { // For IPv4 + maxnum-- + addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port)) + } + } + } + tpm.lock.RUnlock() + return +} diff --git a/dht/routing_table.go b/dht/routing_table.go new file mode 100644 index 0000000..a4f9da9 --- /dev/null +++ b/dht/routing_table.go @@ -0,0 +1,439 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "sort" + "sync" + "time" + + "github.com/xgfone/bt/krpc" + "github.com/xgfone/bt/metainfo" +) + +const bktlen = 160 + +const ( + badTimeout = time.Minute * 20 + dubiousTimeout = time.Minute * 15 +) + +type routingTable struct { + k int + s *Server + ipv6 bool + sync chan struct{} + exit chan struct{} + root metainfo.Hash + lock sync.RWMutex + bkts []*bucket +} + +func newRoutingTable(s *Server, ipv6 bool) *routingTable { + rt := &routingTable{ + k: s.conf.K, + s: s, + ipv6: ipv6, + root: s.conf.ID, + sync: make(chan struct{}), + exit: make(chan struct{}), + bkts: make([]*bucket, bktlen), + } + + for i := range rt.bkts { + rt.bkts[i] = &bucket{table: rt} + } + + // Load all the nodes from the storage. + nodes, err := s.conf.RoutingTableStorage.Load(s.conf.ID, ipv6) + if err != nil { + s.conf.ErrorLog("fail to load routing table(ipv6=%v): %s", ipv6, err) + } else { + now := time.Now() + for _, node := range nodes { + if now.Sub(node.LastChanged) < dubiousTimeout { + rt.addNode(node.Node, node.LastChanged) + } + } + } + + return rt +} + +// Sync dumps the information of the routing table to the underlying storage. +func (rt *routingTable) Sync() { + select { + case rt.sync <- struct{}{}: + case <-rt.exit: + } +} + +// Start starts the routing table. +func (rt *routingTable) Start(interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + + for { + select { + case <-rt.exit: + rt.dump() + return + case <-rt.sync: + rt.dump() + case now := <-tick.C: + rt.lock.Lock() + rt.checkAllBuckets(now) + rt.lock.Unlock() + } + } +} + +// Dump stores all the nodes into the underlying storage. +func (rt *routingTable) dump() { + nodes := make([]RoutingTableNode, 0, 128) + rt.lock.RLock() + for _, bkt := range rt.bkts { + for _, n := range bkt.Nodes { + nodes = append(nodes, RoutingTableNode{ + Node: krpc.Node{ID: n.ID, Addr: n.Addr}, + LastChanged: n.LastChanged, + }) + } + } + rt.lock.RUnlock() + + if len(nodes) > 0 { + err := rt.s.conf.RoutingTableStorage.Dump(rt.root, nodes, rt.ipv6) + if err != nil { + rt.s.conf.ErrorLog("fail to dump nodes in routing table(ipv6=%v): %s", rt.ipv6, err) + } + } +} + +func (rt *routingTable) checkAllBuckets(now time.Time) { + defer func() { + if err := recover(); err != nil { + rt.s.conf.ErrorLog("panic: %v", err) + } + }() + + for _, bkt := range rt.bkts { + bkt.CheckAllNodes(now) + } +} + +func (rt *routingTable) Len() (n int) { + rt.lock.RLock() + for _, bkt := range rt.bkts { + n += len(bkt.Nodes) + } + rt.lock.RUnlock() + return +} + +// Stop stops the routing table. +func (rt *routingTable) Stop() { + select { + case <-rt.exit: + default: + close(rt.exit) + } +} + +// AddNode adds the node into the routing table. +// +// The returned value: +// NodeAdded: The node is added successfully. +// NodeNotAdded: The node is not added and is discarded. +// NodeExistAndUpdated: The node has existed, and its status has been updated. +// NodeExistAndChanged: The node has existed, but the address is inconsistent. +// The current node will be discarded. +// +func (rt *routingTable) AddNode(n krpc.Node) (r int) { + if n.ID == rt.root { // Don't add itself. + return NodeNotAdded + } + return rt.addNode(n, time.Now()) +} + +func (rt *routingTable) addNode(n krpc.Node, now time.Time) (r int) { + bktid := bucketid(rt.root, n.ID) + rt.lock.Lock() + r = rt.bkts[bktid].AddNode(n, now) + rt.lock.Unlock() + return +} + +// Closest returns the maxnum number of the nodes which are the closest to target. +func (rt *routingTable) Closest(target metainfo.Hash, maxnum int) (nodes []krpc.Node) { + close := nodesByDistance{target: target, maxnum: maxnum} + rt.lock.RLock() + for _, b := range rt.bkts { + for _, n := range b.Nodes { + close.push(n.Node) + } + } + rt.lock.RUnlock() + return close.nodes +} + +type nodesByDistance struct { + maxnum int + target metainfo.Hash + nodes []krpc.Node +} + +func (ns *nodesByDistance) push(node krpc.Node) { + ix := sort.Search(len(ns.nodes), func(i int) bool { + return distcmp(ns.target, ns.nodes[i].ID, node.ID) > 0 + }) + + if len(ns.nodes) < ns.maxnum { + ns.nodes = append(ns.nodes, node) + } + + // slide existing nodes down to make room. + // this will overwrite the entry we just appended. + if ix != len(ns.nodes) { + copy(ns.nodes[ix+1:], ns.nodes[ix:]) + ns.nodes[ix] = node + } +} + +// distcmp compares the distances a->target and b->target. +// Returns -1 if a is closer to target, 1 if b is closer to target +// and 0 if they are equal. +func distcmp(target, a, b metainfo.Hash) int { + for i := range target { + da := a[i] ^ target[i] + db := b[i] ^ target[i] + if da > db { + return 1 + } else if da < db { + return -1 + } + } + return 0 +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +/// Bucket + +// Predefine some values returned by the method AddNode of routing table. +const ( + NodeAdded = iota + NodeNotAdded + NodeExistAndUpdated + NodeExistAndChanged +) + +type bucket struct { + table *routingTable + Nodes []*wrappedNode + LastChanged time.Time +} + +func (b *bucket) emit(f func(metainfo.Hash, RoutingTableNode) error, n *wrappedNode) { + node := RoutingTableNode{Node: n.Node, LastChanged: n.LastChanged} + f(b.table.root, node) +} + +func (b *bucket) UpdateLastChangedTime(now time.Time) { + b.LastChanged = now +} + +func (b *bucket) AddNode(n krpc.Node, now time.Time) (status int) { + // Update the old one. + for _, orig := range b.Nodes { + if orig.ID == n.ID { + if orig.Addr.Equal(n.Addr) { + orig.UpdateLastChangedTime(now) + status = NodeExistAndUpdated + return + } + + // // TODO: Should we replace the old one?? + // b.UpdateLastChangedTime(now) + // copy(b.Nodes[i:], b.Nodes[i+1:]) + // b.Nodes[len(b.Nodes)-1] = newWrappedNode(b, n, now) + // status = nodeExistAndChanged + return NodeExistAndChanged + } + } + + // The bucket is not full and append it. + if _len := len(b.Nodes); _len < b.table.k { + b.UpdateLastChangedTime(now) + b.Nodes = append(b.Nodes, newWrappedNode(b, n, now)) + status = NodeAdded + return + } + + // The bucket is full and remove the bad node, or discard the current node. + for i, node := range b.Nodes { + if node.IsBad(now) { + b.UpdateLastChangedTime(now) + copy(b.Nodes[i:], b.Nodes[i+1:]) + b.Nodes = append(b.Nodes, newWrappedNode(b, n, now)) + status = NodeAdded + return + } + } + + // It will discard the current node and return false. + status = NodeNotAdded + return +} + +func (b *bucket) delNodeByIndex(index int) { + b.Nodes = append(b.Nodes[:index], b.Nodes[index+1:]...) +} + +func (b *bucket) CheckAllNodes(now time.Time) { + _len := len(b.Nodes) + if _len == 0 || now.Sub(b.LastChanged) < dubiousTimeout { + return + } + + // Check all the dubious nodes. + indexes := make([]int, 0, len(b.Nodes)) + for i, node := range b.Nodes { + switch status := node.Status(now); status { + case nodeStatusGood: + case nodeStatusDubious: + // Try to send the PING query to the dubious node to check whether it is alive. + if err := b.table.s.Ping(node.Node.Addr.UDPAddr()); err != nil { + b.table.s.conf.ErrorLog("fail to ping '%s': %s", node.Node.String(), err) + } + case nodeStatusBad: + // Remove the bad node + indexes = append(indexes, i) + default: + panic(status) + } + } + + // Remove all the bad nodes. + for _, index := range indexes { + b.delNodeByIndex(index) + } +} + +func bucketid(ownerid, nid metainfo.Hash) int { + if ownerid == nid { + return 0 + } + + var i int + var bite byte + var bitDiff int + var v byte + for i, bite = range ownerid { + v = bite ^ nid[i] + switch { + case v&0x80 > 0: + bitDiff = 7 + goto calc + case v&0x40 > 0: + bitDiff = 6 + goto calc + case v&0x20 > 0: + bitDiff = 5 + goto calc + case v&0x10 > 0: + bitDiff = 4 + goto calc + case v&0x08 > 0: + bitDiff = 3 + goto calc + case v&0x04 > 0: + bitDiff = 2 + goto calc + case v&0x02 > 0: + bitDiff = 1 + goto calc + case v&0x01 > 0: + bitDiff = 0 + goto calc + } + } + +calc: + return i*8 + (8 - bitDiff) +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +/// Node + +const ( + nodeStatusGood = iota + nodeStatusDubious + nodeStatusBad +) + +type wrappedNode struct { + krpc.Node + LastChanged time.Time + bkt *bucket +} + +func newWrappedNode(bkt *bucket, n krpc.Node, now time.Time) *wrappedNode { + return &wrappedNode{bkt: bkt, Node: n, LastChanged: now} +} + +func (n *wrappedNode) UpdateLastChangedTime(now ...time.Time) { + var _now time.Time + if len(now) > 0 { + _now = now[0] + } else { + _now = time.Now() + } + + n.LastChanged = _now + n.bkt.UpdateLastChangedTime(_now) + return +} + +func (n *wrappedNode) IsBad(now ...time.Time) bool { + var _now time.Time + if len(now) > 0 { + _now = now[0] + } else { + _now = time.Now() + } + return _now.Sub(n.LastChanged) > badTimeout +} + +func (n *wrappedNode) IsDubious(now ...time.Time) bool { + var _now time.Time + if len(now) > 0 { + _now = now[0] + } else { + _now = time.Now() + } + return _now.Sub(n.LastChanged) > dubiousTimeout +} + +func (n *wrappedNode) Status(now time.Time) int { + duration := now.Sub(n.LastChanged) + switch { + case duration > badTimeout: + return nodeStatusBad + case duration > dubiousTimeout: + return nodeStatusDubious + default: + return nodeStatusGood + } +} diff --git a/dht/routing_table_storage.go b/dht/routing_table_storage.go new file mode 100644 index 0000000..d2d774b --- /dev/null +++ b/dht/routing_table_storage.go @@ -0,0 +1,42 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "time" + + "github.com/xgfone/bt/krpc" + "github.com/xgfone/bt/metainfo" +) + +// RoutingTableNode represents the node with last changed time in the routing table. +type RoutingTableNode struct { + Node krpc.Node + LastChanged time.Time +} + +// RoutingTableStorage is used to store the nodes in the routing table. +type RoutingTableStorage interface { + Load(ownid metainfo.Hash, ipv6 bool) (nodes []RoutingTableNode, err error) + Dump(ownid metainfo.Hash, nodes []RoutingTableNode, ipv6 bool) (err error) +} + +// NewNoopRoutingTableStorage returns a no-op RoutingTableStorage. +func NewNoopRoutingTableStorage() RoutingTableStorage { return noopStorage{} } + +type noopStorage struct{} + +func (s noopStorage) Load(metainfo.Hash, bool) (nodes []RoutingTableNode, err error) { return } +func (s noopStorage) Dump(metainfo.Hash, []RoutingTableNode, bool) (err error) { return } diff --git a/dht/token_manager.go b/dht/token_manager.go new file mode 100644 index 0000000..8e6b6a9 --- /dev/null +++ b/dht/token_manager.go @@ -0,0 +1,107 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "net" + "sync" + "time" + + "github.com/xgfone/bt/utils" +) + +// TokenManager is used to manage and validate the token. +// +// TODO: Should we allocate the different token for each node?? +type tokenManager struct { + lock sync.RWMutex + last string + new string + exit chan struct{} + addrs sync.Map +} + +func newTokenManager() *tokenManager { + token := utils.RandomString(8) + return &tokenManager{last: token, new: token, exit: make(chan struct{})} +} + +func (tm *tokenManager) updateToken() { + token := utils.RandomString(8) + tm.lock.Lock() + tm.last, tm.new = tm.new, token + tm.lock.Unlock() +} + +func (tm *tokenManager) clear(now time.Time, expired time.Duration) { + tm.addrs.Range(func(k interface{}, v interface{}) bool { + if now.Sub(v.(time.Time)) >= expired { + tm.addrs.Delete(k) + } + return true + }) +} + +// Start starts the token manager. +func (tm *tokenManager) Start(expired time.Duration) { + tick1 := time.NewTicker(expired) + tick2 := time.NewTicker(expired / 2) + defer tick1.Stop() + defer tick2.Stop() + + for { + select { + case <-tm.exit: + return + case <-tick2.C: + tm.updateToken() + case now := <-tick1.C: + tm.clear(now, expired) + } + } +} + +// Stop stops the token manager. +func (tm *tokenManager) Stop() { + select { + case <-tm.exit: + default: + close(tm.exit) + } +} + +// Token allocates a token for a node addr and returns the token. +func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) { + addrs := addr.String() + tm.lock.RLock() + token = tm.new + tm.lock.RUnlock() + tm.addrs.Store(addrs, time.Now()) + return +} + +// Check checks whether the token associated with the node addr is valid, +// that's, it's not expired. +func (tm *tokenManager) Check(addr *net.UDPAddr, token string) (ok bool) { + tm.lock.RLock() + last, new := tm.last, tm.new + tm.lock.RUnlock() + + if last == token || new == token { + _, ok = tm.addrs.Load(addr.String()) + } + + return +} diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go new file mode 100644 index 0000000..5a05a44 --- /dev/null +++ b/dht/transaction_manager.go @@ -0,0 +1,152 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dht + +import ( + "net" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/xgfone/bt/krpc" + "github.com/xgfone/bt/metainfo" +) + +type transaction struct { + ID string + Query string + Arg krpc.QueryArg + Addr *net.UDPAddr + Time time.Time + Depth int + + Visited metainfo.Hashes + Callback func(Result) + OnError func(t *transaction, code int, reason string) + OnTimeout func(t *transaction) + OnResponse func(t *transaction, radd *net.UDPAddr, msg krpc.Message) +} + +func (t *transaction) Done(r Result) { + if t.Callback != nil { + r.Addr = t.Addr + t.Callback(r) + } +} + +func noopResponse(*transaction, *net.UDPAddr, krpc.Message) {} +func newTransaction(s *Server, a *net.UDPAddr, q string, qa krpc.QueryArg, + callback ...func(Result)) *transaction { + var cb func(Result) + if len(callback) > 0 { + cb = callback[0] + } + + return &transaction{ + Addr: a, + Query: q, + Arg: qa, + Callback: cb, + OnError: s.onError, + OnTimeout: s.onTimeout, + OnResponse: noopResponse, + Time: time.Now(), + } +} + +type transactionkey struct { + id string + addr string +} + +type transactionManager struct { + lock sync.Mutex + exit chan struct{} + trans map[transactionkey]*transaction + tid uint32 +} + +func newTransactionManager() *transactionManager { + return &transactionManager{ + exit: make(chan struct{}), + trans: make(map[transactionkey]*transaction, 128), + } +} + +// Start starts the transaction manager. +func (tm *transactionManager) Start(s *Server, interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + for { + select { + case <-tm.exit: + return + case now := <-tick.C: + tm.lock.Lock() + for k, t := range tm.trans { + if now.Sub(t.Time) > interval { + delete(tm.trans, k) + t.OnTimeout(t) + } + } + tm.lock.Unlock() + } + } +} + +// Stop stops the transaction manager. +func (tm *transactionManager) Stop() { + select { + case <-tm.exit: + default: + close(tm.exit) + } +} + +// GetTransactionID returns a new transaction id. +func (tm *transactionManager) GetTransactionID() string { + return strconv.FormatUint(uint64(atomic.AddUint32(&tm.tid, 1)), 36) +} + +// AddTransaction adds the new transaction. +func (tm *transactionManager) AddTransaction(t *transaction) { + key := transactionkey{id: t.ID, addr: t.Addr.String()} + tm.lock.Lock() + tm.trans[key] = t + tm.lock.Unlock() +} + +// DeleteTransaction deletes the transaction. +func (tm *transactionManager) DeleteTransaction(t *transaction) { + key := transactionkey{id: t.ID, addr: t.Addr.String()} + tm.lock.Lock() + delete(tm.trans, key) + tm.lock.Unlock() +} + +// PopTransaction deletes and returns the transaction by the transaction id +// and the peer address. +// +// Return nil if there is no the transaction. +func (tm *transactionManager) PopTransaction(tid string, addr *net.UDPAddr) (t *transaction) { + key := transactionkey{id: tid, addr: addr.String()} + tm.lock.Lock() + if t = tm.trans[key]; t != nil { + delete(tm.trans, key) + } + tm.lock.Unlock() + return +} diff --git a/downloader/torrent_info.go b/downloader/torrent_info.go new file mode 100644 index 0000000..53f69d1 --- /dev/null +++ b/downloader/torrent_info.go @@ -0,0 +1,315 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package downloader is used to download the torrent or the real file +// from the peer node by the peer wire protocol. +package downloader + +import ( + "bytes" + "crypto/sha1" + "errors" + "fmt" + "log" + "net" + "strconv" + + "github.com/xgfone/bt/metainfo" + "github.com/xgfone/bt/peerprotocol" +) + +const peerBlockSize = 16384 // 16KiB. + +// Request is used to send a download request. +type request struct { + Host string + Port uint16 + PeerID metainfo.Hash + InfoHash metainfo.Hash +} + +// TorrentResponse represents a torrent info response. +type TorrentResponse struct { + Host string // Which host the torrent is downloaded from + Port uint16 // Which port the torrent is downloaded from + PeerID metainfo.Hash // ID of the peer where torrent is downloaded from + InfoHash metainfo.Hash // The SHA-1 hash of the torrent to be downloaded + InfoBytes []byte // The content of the info part in the torrent +} + +// TorrentDownloaderConfig is used to configure the TorrentDownloader. +type TorrentDownloaderConfig struct { + // ID is the id of the downloader peer node. + // + // The default is a random id. + ID metainfo.Hash + + // WorkerNum is the number of the worker downloading the torrent concurrently. + // + // The default is 128. + WorkerNum int + + // ErrorLog is used to log the error. + // + // The default is log.Printf. + ErrorLog func(format string, args ...interface{}) +} + +func (c *TorrentDownloaderConfig) set(conf ...TorrentDownloaderConfig) { + if len(conf) > 0 { + *c = conf[0] + } + + if c.WorkerNum <= 0 { + c.WorkerNum = 128 + } + if c.ID.IsZero() { + c.ID = metainfo.NewRandomHash() + } + if c.ErrorLog == nil { + c.ErrorLog = log.Printf + } +} + +// TorrentDownloader is used to download the torrent file from the peer. +type TorrentDownloader struct { + conf TorrentDownloaderConfig + exit chan struct{} + requests chan request + responses chan TorrentResponse + + ondht func(string, uint16) + ebits peerprotocol.ExtensionBits + ehmsg peerprotocol.ExtendedHandshakeMsg +} + +// NewTorrentDownloader returns a new TorrentDownloader. +// +// If id is ZERO, it is reset to a random id. workerNum is 128 by default. +func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader { + var conf TorrentDownloaderConfig + conf.set(c...) + + d := &TorrentDownloader{ + conf: conf, + exit: make(chan struct{}), + requests: make(chan request, conf.WorkerNum), + responses: make(chan TorrentResponse, 1024), + + ehmsg: peerprotocol.ExtendedHandshakeMsg{ + M: map[string]uint8{peerprotocol.ExtendedMessageNameMetadata: 1}, + }, + } + + for i := 0; i < conf.WorkerNum; i++ { + go d.worker() + } + + d.ebits.Set(peerprotocol.ExtensionBitExtended) + return d +} + +// Request submits a download request. +func (d *TorrentDownloader) Request(host string, port uint16, infohash metainfo.Hash) { + d.requests <- request{Host: host, Port: port, InfoHash: infohash} +} + +// Response returns a response channel to get the downloaded torrent info. +func (d *TorrentDownloader) Response() <-chan TorrentResponse { + return d.responses +} + +// Close closes the downloader and releases the underlying resources. +func (d *TorrentDownloader) Close() { + select { + case <-d.exit: + default: + close(d.exit) + } +} + +// OnDHTNode sets the DHT node callback, which will enable DHT extenstion bit. +// +// In the callback function, you maybe ping it like DHT by UDP. +// If the node responds, you can add the node in DHT routing table. +// +// BEP 5 +func (d *TorrentDownloader) OnDHTNode(cb func(host string, port uint16)) { + d.ondht = cb + if cb == nil { + d.ebits.Unset(peerprotocol.ExtensionBitDHT) + } else { + d.ebits.Set(peerprotocol.ExtensionBitDHT) + } +} + +func (d *TorrentDownloader) worker() { + for { + select { + case <-d.exit: + return + case r := <-d.requests: + if err := d.download(r.Host, r.Port, r.PeerID, r.InfoHash); err != nil { + d.conf.ErrorLog("fail to download the torrent '%s': %s", + r.InfoHash.HexString(), err) + } + } + } +} + +func (d *TorrentDownloader) download(host string, port uint16, + peerID, infohash metainfo.Hash) (err error) { + addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)) + conn, err := peerprotocol.NewPeerConnByDial(d.conf.ID, addr) + if err != nil { + return fmt.Errorf("fail to dial to '%s': %s", addr, err) + } + defer conn.Close() + + conn.ExtensionBits = d.ebits + rmsg, err := conn.Handshake(infohash) + if err != nil || rmsg.InfoHash != infohash || !rmsg.IsSupportExtended() { + return + } else if !peerID.IsZero() && peerID != rmsg.PeerID { + return fmt.Errorf("inconsistent peer id '%s'", rmsg.PeerID.HexString()) + } + + if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil { + return + } + + var pieces [][]byte + var piecesNum int + var metadataSize int + var utmetadataID uint8 + var msg peerprotocol.Message + + for { + if msg, err = conn.ReadMsg(); err != nil { + return err + } + + select { + case <-d.exit: + return + default: + } + + if msg.Keepalive { + continue + } + + switch msg.Type { + case peerprotocol.Extended: + case peerprotocol.Port: + if d.ondht != nil { + d.ondht(host, msg.Port) + } + continue + default: + continue + } + + switch msg.ExtendedID { + case peerprotocol.ExtendedIDHandshake: + if utmetadataID > 0 { + return fmt.Errorf("rehandshake from the peer '%s'", conn.RemoteAddr().String()) + } + + var ehmsg peerprotocol.ExtendedHandshakeMsg + if err = ehmsg.Decode(msg.ExtendedPayload); err != nil { + return + } + + utmetadataID = ehmsg.M[peerprotocol.ExtendedMessageNameMetadata] + if utmetadataID == 0 { + return errors.New(`the peer does not support "ut_metadata"`) + } + + metadataSize = ehmsg.MetadataSize + piecesNum = metadataSize / peerBlockSize + if metadataSize%peerBlockSize != 0 { + piecesNum++ + } + + pieces = make([][]byte, piecesNum) + go d.requestPieces(conn, utmetadataID, piecesNum) + case 1: + if pieces == nil { + return + } + + var utmsg peerprotocol.UtMetadataExtendedMsg + if err = utmsg.DecodeFromPayload(msg.ExtendedPayload); err != nil { + return + } + + if utmsg.MsgType != peerprotocol.UtMetadataExtendedMsgTypeData { + continue + } + + pieceLen := len(utmsg.Data) + if (utmsg.Piece != piecesNum-1 && pieceLen != peerBlockSize) || + (utmsg.Piece == piecesNum-1 && pieceLen != metadataSize%peerBlockSize) { + return + } + pieces[utmsg.Piece] = utmsg.Data + + finish := true + for _, piece := range pieces { + if len(piece) == 0 { + finish = false + break + } + } + + if finish { + metadataInfo := bytes.Join(pieces, nil) + if infohash == metainfo.Hash(sha1.Sum(metadataInfo)) { + d.responses <- TorrentResponse{ + Host: host, + Port: port, + PeerID: rmsg.PeerID, + InfoHash: infohash, + InfoBytes: metadataInfo, + } + } + return + } + } + } +} + +func (d *TorrentDownloader) requestPieces(conn *peerprotocol.PeerConn, utMetadataID uint8, piecesNum int) { + for i := 0; i < piecesNum; i++ { + payload, err := peerprotocol.UtMetadataExtendedMsg{ + MsgType: peerprotocol.UtMetadataExtendedMsgTypeRequest, + Piece: i, + }.EncodeToBytes() + if err != nil { + panic(err) + } + + msg := peerprotocol.Message{ + Type: peerprotocol.Extended, + ExtendedID: utMetadataID, + ExtendedPayload: payload, + } + + if err = conn.WriteMsg(msg); err != nil { + d.conf.ErrorLog("fail to send the ut_metadata request: %s", err) + return + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9597f68 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/xgfone/bt + +go 1.11 diff --git a/krpc/doc.go b/krpc/doc.go new file mode 100644 index 0000000..20b8891 --- /dev/null +++ b/krpc/doc.go @@ -0,0 +1,16 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package krpc supplies the KRPC message used by DHT. +package krpc diff --git a/krpc/message.go b/krpc/message.go new file mode 100644 index 0000000..e6f24b6 --- /dev/null +++ b/krpc/message.go @@ -0,0 +1,440 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package krpc + +import ( + "bytes" + "fmt" + + "github.com/xgfone/bt/bencode" + "github.com/xgfone/bt/metainfo" +) + +// Predefine some error code. +const ( + // BEP 5 + ErrorCodeGeneric = 201 + ErrorCodeServer = 202 + ErrorCodeProtocol = 203 + ErrorCodeMethodUnknown = 204 + + // BEP 44 + ErrorCodeMessageValueFieldTooBig = 205 + ErrorCodeInvalidSignature = 206 + ErrorCodeSaltFieldTooBig = 207 + ErrorCodeCasHashMismatched = 301 + ErrorCodeSequenceNumberLessThanCurrent = 302 +) + +// Message represents messages that nodes in the network send to each other +// as specified by the KRPC protocol, which are also referred to as the KRPC +// messages. +// +// There are three types of messages: QUERY, RESPONSE, ERROR +// The message is a dictonary that is then "bencoded" +// (serialization & compression format adopted by the BitTorrent) +// and sent via the UDP connection to peers. +// +// A KRPC message is a single dictionary with two keys common to every message +// and additional keys depending on the type of message. Every message has a key +// "t" with a string value representing a transaction ID. This transaction ID +// is generated by the querying node and is echoed in the response, so responses +// may be correlated with multiple queries to the same node. The transaction ID +// should be encoded as a short string of binary numbers, typically 2 characters +// are enough as they cover 2^16 outstanding queries. The other key contained +// in every KRPC message is "y" with a single character value describing +// the type of message. The value of the "y" key is one of "q" for query, +// "r" for response, or "e" for error. +type Message struct { + T string `bencode:"t"` // required: transaction ID + Y string `bencode:"y"` // required: type of the message: q for QUERY, r for RESPONSE, e for ERROR + Q string `bencode:"q,omitempty"` // Query method (one of 4: "ping", "find_node", "get_peers", "announce_peer") + A QueryArg `bencode:"a,omitempty"` // named arguments sent with a query + R ResponseResult `bencode:"r,omitempty"` // RESPONSE type only + E Error `bencode:"e,omitempty"` // ERROR type only + + RO bool `bencode:"ro,omitempty"` // BEP 43: ReadOnly +} + +// NewQueryMsg return a QUERY message. +func NewQueryMsg(tid, method string, arg QueryArg) Message { + return Message{T: tid, Y: "q", Q: method, A: arg} +} + +// NewResponseMsg return a RESPONSE message. +func NewResponseMsg(tid string, data ResponseResult) Message { + return Message{T: tid, Y: "r", R: data} +} + +// NewErrorMsg return a ERROR message. +func NewErrorMsg(tid string, code int, reason string) Message { + return Message{T: tid, Y: "e", E: Error{Code: code, Reason: reason}} +} + +// IsQuery reports whether the message is an QUERY. +func (m Message) IsQuery() bool { + return m.Y == "q" +} + +// IsResponse reports whether the message is an RESPONSE. +func (m Message) IsResponse() bool { + return m.Y == "r" +} + +// IsError reports whether the message is an ERROR. +func (m Message) IsError() bool { + return m.Y == "e" +} + +// RID returns the value named "id" in "r". +// +// Return "" instead if no "id". +func (m Message) RID() metainfo.Hash { + return m.R.ID +} + +// QID returns the value named "id" in "a", that's, the query arguments. +// +// Return "" instead if no "id". +func (m Message) QID() metainfo.Hash { + return m.A.ID +} + +// ID returns the QID or RID. +func (m Message) ID() metainfo.Hash { + switch m.Y { + case "q": + return m.QID() + case "r": + return m.RID() + default: + panic(fmt.Errorf("unknown message type '%s'", m.Y)) + } +} + +// Error represents a response error. +type Error struct { + Code int + Reason string +} + +// NewError returns a new Error. +func NewError(code int, reason string) Error { + return Error{Code: code, Reason: reason} +} + +func (e *Error) decode(vs []interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("unpacking %#v: %v", vs, r) + } + }() + + e.Code = int(vs[0].(int64)) + e.Reason = vs[1].(string) + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (e *Error) UnmarshalBencode(b []byte) (err error) { + var iface interface{} + if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil { + return + } + + switch v := iface.(type) { + case []interface{}: + err = e.decode(v) + case string: + e.Reason = v + default: + err = fmt.Errorf(`KRPC error bencode value has unexpected type: %T`, v) + } + + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (e Error) MarshalBencode() (ret []byte, err error) { + if e.Code == 0 && e.Reason == "" { + return nil, nil + } + + buf := bytes.NewBuffer(nil) + buf.Grow(32) + err = bencode.NewEncoder(buf).Encode([]interface{}{e.Code, e.Reason}) + if err != nil { + ret = buf.Bytes() + } + return +} + +func (e Error) Error() string { + return fmt.Sprintf("KRPC error %d: %s", e.Code, e.Reason) +} + +// Want represents the type of nodes that the request wants. +// +// BEP 32 +type Want string + +// Predefine some wants. +// +// BEP 32 +const ( + WantNodes Want = "n4" + WantNodes6 Want = "n6" +) + +// QueryArg represents the arguments used by the QUERY message. +type QueryArg struct { + // ID is used to identify a querying node. + ID metainfo.Hash `bencode:"id"` // BEP 5 + + // Target is used to identify the node sought by the queryer. + // + // find_node + Target metainfo.Hash `bencode:"target,omitempty"` // BEP 5 + + // InfoHash is the infohash of the torrent file. + // + // get_peers, announce_peer + InfoHash metainfo.Hash `bencode:"info_hash,omitempty"` // BEP 5 + + // Port is the port on where sender is listening to allow others + // to download the torrent. + // + // announce_peer + Port uint16 `bencode:"port,omitempty"` // BEP 5 + + // Token is the received one from an earlier "get_peers" query. + // + // announce_peer + Token string `bencode:"token,omitempty"` // BEP 5 + + // ImpliedPort is used by the senders to apparent DHT port + // to improve NAT support. + // + // announce_peer + ImpliedPort bool `bencode:"implied_port,omitempty"` // BEP 5 + + // Wants is only used to represent to expect "nodes" for "n4" + // or "nodes6" for "n6". + // + // Notice: It only governs the presence of the "nodes" and "nodes6" + // parameters, not the interpretation of "values". + // + // find_node, get_peers + Wants []Want `bencode:"want,omitempty"` // BEP 32 +} + +// ContainsWant reports whether the request contains the given Want. +func (a QueryArg) ContainsWant(w Want) bool { + for _, want := range a.Wants { + if want == w { + return true + } + } + return false +} + +// GetPort returns the real port of the peer. +func (a QueryArg) GetPort(port int) uint16 { + if a.ImpliedPort { + return uint16(port) + } + return a.Port +} + +// ResponseResult represents the results used by the RESPONSE message. +type ResponseResult struct { + // ID is used to indentify the queried node, that's, the response node. + ID metainfo.Hash `bencode:"id"` // BEP 5 + + // Nodes is a string containing the compact node information for the list + // of the ipv4 target node, or the K(8) closest good nodes in routing table + // of the requested ipv4 target. + // + // find_node + Nodes CompactIPv4Node `bencode:"nodes,omitempty"` // BEP 5 + + // Nodes6 is a string containing the compact node information for the list + // of the ipv6 target node, or the K(8) closest good nodes in routing table + // of the requested ipv6 target. + // + // find_node + Nodes6 CompactIPv6Node `bencode:"nodes6,omitempty"` // BEP 32 + + // Token is used for future "announce_peer". + // + // get_peers + Token string `bencode:"token,omitempty"` // BEP 5 + + // Values is a list of the torrent peers. + // + // get_peers + Values CompactAddresses `bencode:"values,omitempty"` // BEP 5 +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// CompactAddresses represents a group of the compact addresses. +type CompactAddresses []metainfo.Address + +// MarshalBinary implements the interface binary.BinaryMarshaler. +func (cas CompactAddresses) MarshalBinary() ([]byte, error) { + ss := make([]string, len(cas)) + for i, addr := range cas { + data, err := addr.MarshalBinary() + if err != nil { + return nil, err + } + ss[i] = string(data) + } + return bencode.EncodeBytes(ss) +} + +// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. +func (cas *CompactAddresses) UnmarshalBinary(b []byte) (err error) { + var ss []string + if err = bencode.DecodeBytes(b, &ss); err == nil { + addrs := make(CompactAddresses, len(ss)) + for i, s := range ss { + var addr metainfo.Address + if err = addr.UnmarshalBinary([]byte(s)); err != nil { + return + } + addrs[i] = addr + } + } + + return +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// CompactIPv4Node is a set of IPv4 Nodes. +type CompactIPv4Node []Node + +// MarshalBinary implements the interface binary.BinaryMarshaler. +func (cn CompactIPv4Node) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(26 * len(cn)) + for _, ni := range cn { + if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 { + continue + } + if n, err := ni.WriteBinary(buf); err != nil { + return nil, err + } else if n != 26 { + panic(fmt.Errorf("CompactIPv4NodeInfo: the invalid NodeInfo length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. +func (cn *CompactIPv4Node) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%26 != 0 { + return fmt.Errorf("CompactIPv4NodeInfo: invalid bytes length '%d'", _len) + } + + nis := make([]Node, 0, _len/26) + for i := 0; i < _len; i += 26 { + var ni Node + if err = ni.UnmarshalBinary(b[i : i+26]); err != nil { + return + } + nis = append(nis, ni) + } + + *cn = nis + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cn CompactIPv4Node) MarshalBencode() (b []byte, err error) { + if b, err = cn.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cn *CompactIPv4Node) UnmarshalBencode(b []byte) (err error) { + var s string + if err = bencode.DecodeBytes(b, &s); err == nil { + err = cn.UnmarshalBinary([]byte(s)) + } + return +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// CompactIPv6Node is a set of IPv6 Nodes. +type CompactIPv6Node []Node + +// MarshalBinary implements the interface binary.BinaryMarshaler. +func (cn CompactIPv6Node) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(38 * len(cn)) + for _, ni := range cn { + ni.Addr.IP = ni.Addr.IP.To16() + if n, err := ni.WriteBinary(buf); err != nil { + return nil, err + } else if n != 38 { + panic(fmt.Errorf("CompactIPv6NodeInfo: the invalid NodeInfo length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. +func (cn *CompactIPv6Node) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%38 != 0 { + return fmt.Errorf("CompactIPv6NodeInfo: invalid bytes length '%d'", _len) + } + + nis := make([]Node, 0, _len/38) + for i := 0; i < _len; i += 38 { + var ni Node + if err = ni.UnmarshalBinary(b[i : i+38]); err != nil { + return + } + nis = append(nis, ni) + } + + *cn = nis + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cn CompactIPv6Node) MarshalBencode() (b []byte, err error) { + if b, err = cn.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cn *CompactIPv6Node) UnmarshalBencode(b []byte) (err error) { + var s string + if err = bencode.DecodeBytes(b, &s); err == nil { + err = cn.UnmarshalBinary([]byte(s)) + } + return +} diff --git a/krpc/message_test.go b/krpc/message_test.go new file mode 100644 index 0000000..95bf086 --- /dev/null +++ b/krpc/message_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package krpc + +import ( + "testing" + + "github.com/xgfone/bt/bencode" +) + +func TestMessage(t *testing.T) { + s, err := bencode.EncodeString(Message{RO: true}) + if err != nil { + t.Fatal(err) + } + + var ms map[string]interface{} + if err := bencode.DecodeString(s, &ms); err != nil { + t.Fatal(err) + } else if len(ms) != 3 { + t.Fatal() + } else if v, ok := ms["t"]; !ok || v != "" { + t.Fatal() + } else if v, ok := ms["y"]; !ok || v != "" { + t.Fatal() + } else if v, ok := ms["ro"].(int64); !ok || v != 1 { + t.Fatal() + } +} diff --git a/krpc/node.go b/krpc/node.go new file mode 100644 index 0000000..ce83ddc --- /dev/null +++ b/krpc/node.go @@ -0,0 +1,84 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package krpc + +import ( + "bytes" + "fmt" + "io" + "net" + + "github.com/xgfone/bt/metainfo" +) + +// Node represents a node information. +type Node struct { + ID metainfo.Hash + Addr metainfo.Address +} + +// NewNode returns a new Node. +func NewNode(id metainfo.Hash, ip net.IP, port int) Node { + return Node{ID: id, Addr: metainfo.NewAddress(ip, uint16(port))} +} + +// NewNodeByUDPAddr returns a new Node with the id and the UDP address. +func NewNodeByUDPAddr(id metainfo.Hash, addr *net.UDPAddr) (n Node) { + n.ID = id + n.Addr.FromUDPAddr(addr) + return +} + +func (n Node) String() string { + return fmt.Sprintf("Node<%x@%s>", n.ID, n.Addr) +} + +// Equal reports whether n is equal to o. +func (n Node) Equal(o Node) bool { + return n.ID == o.ID && n.Addr.Equal(o.Addr) +} + +// WriteBinary is the same as MarshalBinary, but writes the result into w +// instead of returning. +func (n Node) WriteBinary(w io.Writer) (m int, err error) { + var n1, n2 int + if n1, err = w.Write(n.ID[:]); err == nil { + m = n1 + if n2, err = n.Addr.WriteBinary(w); err == nil { + m += n2 + } + } + return +} + +// MarshalBinary implements the interface binary.BinaryMarshaler. +func (n Node) MarshalBinary() (data []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(48) + if _, err = n.WriteBinary(buf); err == nil { + data = buf.Bytes() + } + return +} + +// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. +func (n *Node) UnmarshalBinary(b []byte) error { + if len(b) < 26 { + return io.ErrShortBuffer + } + + copy(n.ID[:], b[:20]) + return n.Addr.UnmarshalBinary(b[20:]) +} diff --git a/metainfo/address.go b/metainfo/address.go new file mode 100644 index 0000000..b31bac3 --- /dev/null +++ b/metainfo/address.go @@ -0,0 +1,342 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + + "github.com/xgfone/bt/bencode" +) + +// ErrInvalidAddr is returned when the compact address is invalid. +var ErrInvalidAddr = fmt.Errorf("invalid compact information of ip and port") + +// Address represents a client/server listening on a UDP port implementing +// the DHT protocol. +type Address struct { + IP net.IP // For IPv4, its length must be 4. + Port uint16 +} + +// NewAddress returns a new Address. +func NewAddress(ip net.IP, port uint16) Address { + if ipv4 := ip.To4(); len(ipv4) > 0 { + ip = ipv4 + } + return Address{IP: ip, Port: port} +} + +// NewAddressFromString returns a new Address by the address string. +func NewAddressFromString(s string) (addr Address, err error) { + err = addr.FromString(s) + return +} + +// NewAddressesFromString returns a list of Addresses by the address string. +func NewAddressesFromString(s string) (addrs []Address, err error) { + shost, sport, err := net.SplitHostPort(s) + if err != nil { + return nil, fmt.Errorf("invalid address '%s': %s", s, err) + } + + var port uint16 + if sport != "" { + v, err := strconv.ParseUint(sport, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid address '%s': %s", s, err) + } + port = uint16(v) + } + + ips, err := net.LookupIP(shost) + if err != nil { + return nil, fmt.Errorf("fail to lookup the domain '%s': %s", shost, err) + } + + addrs = make([]Address, len(ips)) + for i, ip := range ips { + addrs[i] = Address{IP: ip, Port: port} + } + + return +} + +// FromString parses and sets the ip from the string addr. +func (a *Address) FromString(addr string) (err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("invalid address '%s': %s", addr, err) + } + + if port != "" { + v, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return fmt.Errorf("invalid address '%s': %s", addr, err) + } + a.Port = uint16(v) + } + + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("fail to lookup the domain '%s': %s", host, err) + } else if len(ips) == 0 { + return fmt.Errorf("the domain '%s' has no ips", host) + } + + a.IP = ips[0] + if ip := a.IP.To4(); len(ip) > 0 { + a.IP = ip + } + + return +} + +// FromUDPAddr sets the ip from net.UDPAddr. +func (a *Address) FromUDPAddr(ua *net.UDPAddr) { + a.Port = uint16(ua.Port) + a.IP = ua.IP + if ipv4 := a.IP.To4(); len(ipv4) > 0 { + a.IP = ipv4 + } +} + +// UDPAddr creates a new net.UDPAddr. +func (a Address) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{ + IP: a.IP, + Port: int(a.Port), + } +} + +func (a Address) String() string { + if a.Port == 0 { + return a.IP.String() + } + return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10)) +} + +// Equal reports whether n is equal to o, which is equal to +// n.HasIPAndPort(o.IP, o.Port) +func (a Address) Equal(o Address) bool { + return a.Port == o.Port && a.IP.Equal(o.IP) +} + +// HasIPAndPort reports whether the current node has the ip and the port. +func (a Address) HasIPAndPort(ip net.IP, port uint16) bool { + return port == a.Port && a.IP.Equal(ip) +} + +// WriteBinary is the same as MarshalBinary, but writes the result into w +// instead of returning. +func (a Address) WriteBinary(w io.Writer) (m int, err error) { + if m, err = w.Write(a.IP); err == nil { + if err = binary.Write(w, binary.BigEndian, a.Port); err == nil { + m += 2 + } + } + return +} + +// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. +func (a *Address) UnmarshalBinary(b []byte) (err error) { + _len := len(b) - 2 + switch _len { + case net.IPv4len, net.IPv6len: + default: + return ErrInvalidAddr + } + + a.IP = make(net.IP, _len) + copy(a.IP, b[:_len]) + a.Port = binary.BigEndian.Uint16(b[_len:]) + return +} + +// MarshalBinary implements the interface binary.BinaryMarshaler. +func (a Address) MarshalBinary() (data []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(20) + if _, err = a.WriteBinary(buf); err == nil { + data = buf.Bytes() + } + return +} + +func (a *Address) decode(vs []interface{}) (err error) { + defer func() { + if e := recover(); e != nil { + err = e.(error) + } + }() + + host := vs[0].(string) + if a.IP = net.ParseIP(host); len(a.IP) == 0 { + return ErrInvalidAddr + } else if ip := a.IP.To4(); len(ip) > 0 { + a.IP = ip + } + + a.Port = uint16(vs[1].(int64)) + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (a *Address) UnmarshalBencode(b []byte) (err error) { + var iface interface{} + if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil { + return + } + + switch v := iface.(type) { + case string: + err = a.FromString(v) + case []interface{}: + err = a.decode(v) + default: + err = fmt.Errorf("unsupported type: %T", iface) + } + + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (a Address) MarshalBencode() (b []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(32) + err = bencode.NewEncoder(buf).Encode([]interface{}{a.IP.String(), a.Port}) + if err == nil { + b = buf.Bytes() + } + return +} + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// HostAddress is the same as the Address, but the host part may be +// either a domain or a ip. +type HostAddress struct { + Host string + Port uint16 +} + +// NewHostAddress returns a new host addrress. +func NewHostAddress(host string, port uint16) HostAddress { + return HostAddress{Host: host, Port: port} +} + +// NewHostAddressFromString returns a new host address by the string. +func NewHostAddressFromString(s string) (addr HostAddress, err error) { + err = addr.FromString(s) + return +} + +// FromString parses and sets the host from the string addr. +func (a *HostAddress) FromString(addr string) (err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("invalid address '%s': %s", addr, err) + } else if host == "" { + return fmt.Errorf("invalid address '%s': missing host", addr) + } + + if port != "" { + v, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return fmt.Errorf("invalid address '%s': %s", addr, err) + } + a.Port = uint16(v) + } + + a.Host = host + return +} + +func (a HostAddress) String() string { + if a.Port == 0 { + return a.Host + } + return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10)) +} + +// Addresses parses the host address to a list of Addresses. +func (a HostAddress) Addresses() (addrs []Address, err error) { + if ip := net.ParseIP(a.Host); len(ip) > 0 { + return []Address{NewAddress(ip, a.Port)}, nil + } + + ips, err := net.LookupIP(a.Host) + if err != nil { + err = fmt.Errorf("fail to lookup the domain '%s': %s", a.Host, err) + } else { + addrs = make([]Address, len(ips)) + for i, ip := range ips { + addrs[i] = NewAddress(ip, a.Port) + } + } + + return +} + +// Equal reports whether a is equal to o. +func (a HostAddress) Equal(o HostAddress) bool { + return a.Port == o.Port && a.Host == o.Host +} + +func (a *HostAddress) decode(vs []interface{}) (err error) { + defer func() { + if e := recover(); e != nil { + err = e.(error) + } + }() + + a.Host = vs[0].(string) + a.Port = uint16(vs[1].(int64)) + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (a *HostAddress) UnmarshalBencode(b []byte) (err error) { + var iface interface{} + if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil { + return + } + + switch v := iface.(type) { + case string: + err = a.FromString(v) + case []interface{}: + err = a.decode(v) + default: + err = fmt.Errorf("unsupported type: %T", iface) + } + + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (a HostAddress) MarshalBencode() (b []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(32) + err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port}) + if err == nil { + b = buf.Bytes() + } + return +} diff --git a/metainfo/address_test.go b/metainfo/address_test.go new file mode 100644 index 0000000..929794e --- /dev/null +++ b/metainfo/address_test.go @@ -0,0 +1,48 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "testing" +) + +func TestAddress(t *testing.T) { + var addr1 Address + if err := addr1.FromString("1.2.3.4:1234"); err != nil { + t.Error(err) + return + } + + data, err := addr1.MarshalBencode() + if err != nil { + t.Error(err) + return + } else if s := string(data); s != `l7:1.2.3.4i1234ee` { + t.Errorf(`expected 'l7:1.2.3.4i1234ee', but got '%s'`, s) + } + + var addr2 Address + if err = addr2.UnmarshalBencode(data); err != nil { + t.Error(err) + } else if addr2.String() != `1.2.3.4:1234` { + t.Errorf("expected '1.2.3.4:1234', but got '%s'", addr2) + } + + if data, err = addr2.MarshalBinary(); err != nil { + t.Error(err) + } else if s := string(data); s != "\x01\x02\x03\x04\x04\xd2" { + t.Errorf(`expected '\x01\x02\x03\x04\x04\xd2', but got '%#x'`, s) + } +} diff --git a/metainfo/doc.go b/metainfo/doc.go new file mode 100644 index 0000000..2fe213b --- /dev/null +++ b/metainfo/doc.go @@ -0,0 +1,17 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package metainfo is used to encode or decode the metainfo of the torrent file +// and the MagNet link. +package metainfo diff --git a/metainfo/file.go b/metainfo/file.go new file mode 100644 index 0000000..54080eb --- /dev/null +++ b/metainfo/file.go @@ -0,0 +1,65 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "path/filepath" +) + +// File represents a file in the multi-file case. +type File struct { + // Length is the length of the file in bytes. + Length int64 `json:"length" bencode:"length"` // BEP 3 + + // Paths is a list containing one or more string elements that together + // represent the path and filename. Each element in the list corresponds + // to either a directory name or (in the case of the final element) the + // filename. + // + // For example, a the file "dir1/dir2/file.ext" would consist of three + // string elements: "dir1", "dir2", and "file.ext". This is encoded as + // a bencoded list of strings such as l4:dir14:dir28:file.exte. + Paths []string `json:"path" bencode:"path"` // BEP 3 +} + +func (f File) String() string { + return filepath.Join(f.Paths...) +} + +// Path returns the path of the current. +func (f File) Path(info Info) string { + if info.IsDir() { + return f.String() + } + return info.Name +} + +// Offset returns the offset of the current file from the start. +func (f File) Offset(info Info) (ret int64) { + path := f.Path(info) + for _, file := range info.AllFiles() { + if path == file.Path(info) { + return + } + ret += file.Length + } + panic("not found") +} + +type files []File + +func (fs files) Len() int { return len(fs) } +func (fs files) Less(i, j int) bool { return fs[i].String() < fs[j].String() } +func (fs files) Swap(i, j int) { f := fs[i]; fs[i] = fs[j]; fs[j] = f } diff --git a/metainfo/info.go b/metainfo/info.go new file mode 100644 index 0000000..13bca7d --- /dev/null +++ b/metainfo/info.go @@ -0,0 +1,138 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" +) + +// Info is the file inforatino. +type Info struct { + // Name is the name of the file in the single file case. + // Or, it is the name of the directory in the muliple file case. + Name string `json:"name" bencode:"name"` // BEP 3 + + // PieceLength is the number of bytes in each piece, which is usually + // a power of 2. + PieceLength int64 `json:"piece length" bencode:"piece length"` // BEP 3 + + // Pieces is the concatenation of all 20-byte SHA1 hash values, + // one per piece (byte string, i.e. not urlencoded). + Pieces Hashes `json:"pieces" bencode:"pieces"` // BEP 3 + + // Length is the length of the file in bytes in the single file case. + // + // It's mutually exclusive with Files. + Length int64 `json:"length,omitempty" bencode:"length,omitempty"` // BEP 3 + + // Files is the list of all the files in the multi-file case. + // + // For the purposes of the other keys, the multi-file case is treated + // as only having a single file by concatenating the files in the order + // they appear in the files list. + // + // It's mutually exclusive with Length. + Files []File `json:"files,omitempty" bencode:"files,omitempty"` // BEP 3 +} + +// NewInfoFromFilePath returns a new Info from a file or directory. +func NewInfoFromFilePath(root string, pieceLength int64) (info Info, err error) { + info.Name = filepath.Base(root) + info.PieceLength = pieceLength + err = filepath.Walk(root, func(path string, fi os.FileInfo, err error) error { + if err != nil { + return err + } + + if path == root && !fi.IsDir() { // The root is a file. + info.Length = fi.Size() + return nil + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return fmt.Errorf("error getting relative path: %s", err) + } + + info.Files = append(info.Files, File{ + Paths: strings.Split(relPath, string(filepath.Separator)), + Length: fi.Size(), + }) + + return nil + }) + + if err == nil { + sort.Sort(files(info.Files)) + info.Pieces, err = GeneratePiecesFromFiles(info.AllFiles(), info.PieceLength, + func(file File) (io.ReadCloser, error) { + if _len := len(file.Paths); _len > 0 { + paths := make([]string, 0, _len+1) + paths = append(paths, root) + paths = append(paths, file.Paths...) + return os.Open(filepath.Join(paths...)) + } + return os.Open(root) + }) + + if err != nil { + err = fmt.Errorf("error generating pieces: %s", err) + } + } + + return +} + +// IsDir reports whether the name is a directory, that's, the file is not +// a single file. +func (info Info) IsDir() bool { return len(info.Files) != 0 } + +// CountPieces returns the number of the pieces. +func (info Info) CountPieces() int { return len(info.Pieces) } + +// TotalLength returns the total length of the torrent file. +func (info Info) TotalLength() (ret int64) { + if info.IsDir() { + for _, fi := range info.Files { + ret += fi.Length + } + } else { + ret = info.Length + } + return +} + +// Piece returns the Piece by the index starting with 0. +func (info Info) Piece(index int) Piece { + if n := len(info.Pieces); index >= n { + panic(fmt.Errorf("Info.Piece: index '%d' exceeds maximum '%d'", index, n)) + } + return Piece{info: info, index: index} +} + +// AllFiles returns all the files. +// +// Notice: for the single file, the Path is nil. +func (info Info) AllFiles() []File { + if info.IsDir() { + return info.Files + } + return []File{{Length: info.Length}} +} diff --git a/metainfo/infohash.go b/metainfo/infohash.go new file mode 100644 index 0000000..4d1ecff --- /dev/null +++ b/metainfo/infohash.go @@ -0,0 +1,207 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base32" + "encoding/hex" + "fmt" + + "github.com/xgfone/bt/bencode" +) + +var zeroHash Hash + +// HashSize is the size of the InfoHash. +const HashSize = 20 + +// Hash is the 20-byte SHA1 hash used for info and pieces. +type Hash [HashSize]byte + +// NewRandomHash returns a random hash. +func NewRandomHash() (h Hash) { + rand.Read(h[:]) + return +} + +// NewHash converts the 20-bytes to Hash. +func NewHash(b []byte) (h Hash) { + copy(h[:], b[:HashSize]) + return +} + +// NewHashFromString returns a new Hash from a string. +func NewHashFromString(s string) (h Hash) { + err := h.FromString(s) + if err != nil { + panic(err) + } + return +} + +// NewHashFromHexString returns a new Hash from a hex string. +func NewHashFromHexString(s string) (h Hash) { + err := h.FromHexString(s) + if err != nil { + panic(err) + } + return +} + +// NewHashFromBytes returns a new Hash from a byte slice. +func NewHashFromBytes(b []byte) (ret Hash) { + hasher := sha1.New() + hasher.Write(b) + copy(ret[:], hasher.Sum(nil)) + return +} + +// Bytes returns the byte slice type. +func (h Hash) Bytes() []byte { + return h[:] +} + +// String is equal to HexString. +func (h Hash) String() string { + return h.HexString() +} + +// BytesString returns the bytes string, that's, string(h[:]). +func (h Hash) BytesString() string { + return string(h[:]) +} + +// HexString returns the hex string format. +func (h Hash) HexString() string { + return hex.EncodeToString(h[:]) +} + +// IsZero reports whether the whole hash is zero. +func (h Hash) IsZero() bool { + return h == zeroHash +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (h Hash) MarshalBencode() (b []byte, err error) { + return bencode.EncodeBytes(h[:]) +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (h *Hash) UnmarshalBencode(b []byte) (err error) { + var s string + if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&s); err == nil { + err = h.FromString(s) + } + return +} + +// FromString resets the info hash from the string. +func (h *Hash) FromString(s string) (err error) { + switch len(s) { + case HashSize: + copy(h[:], s) + case 2 * HashSize: + err = h.FromHexString(s) + case 32: + var bs []byte + if bs, err = base32.StdEncoding.DecodeString(s); err == nil { + copy(h[:], bs) + } + default: + err = fmt.Errorf("hash string has bad length: %d", len(s)) + } + + return nil +} + +// FromHexString resets the info hash from the hex string. +func (h *Hash) FromHexString(s string) (err error) { + if len(s) != 2*HashSize { + err = fmt.Errorf("hash hex string has bad length: %d", len(s)) + return + } + + n, err := hex.Decode(h[:], []byte(s)) + if err != nil { + return + } + + if n != HashSize { + panic(n) + } + return +} + +// Xor returns the hash of h XOR o. +func (h Hash) Xor(o Hash) (ret Hash) { + for i := range o { + ret[i] = h[i] ^ o[i] + } + return +} + +// Compare returns 0 if h == o, -1 if h < o, or +1 if h > o. +func (h Hash) Compare(o Hash) int { return bytes.Compare(h[:], o[:]) } + +/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Hashes is a set of Hashes. +type Hashes []Hash + +// Contains reports whether hs contains h. +func (hs Hashes) Contains(h Hash) bool { + for _, _h := range hs { + if h == _h { + return true + } + } + return false +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (hs Hashes) MarshalBencode() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(HashSize * len(hs)) + for _, h := range hs { + buf.Write(h[:]) + } + return bencode.EncodeBytes(buf.Bytes()) +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (hs *Hashes) UnmarshalBencode(b []byte) (err error) { + var bs []byte + if err = bencode.DecodeBytes(b, &bs); err != nil { + return + } + + _len := len(bs) + if _len%HashSize != 0 { + return fmt.Errorf("Hashes: invalid bytes length '%d'", _len) + } + + hashes := make(Hashes, 0, _len/HashSize) + for i := 0; i < _len; i += HashSize { + var h Hash + copy(h[:], bs[i:i+HashSize]) + hashes = append(hashes, h) + } + + *hs = hashes + return +} diff --git a/metainfo/magnet.go b/metainfo/magnet.go new file mode 100644 index 0000000..a3e0128 --- /dev/null +++ b/metainfo/magnet.go @@ -0,0 +1,138 @@ +// Mozilla Public License Version 2.0 +// Modify from github.com/anacrolix/torrent/metainfo. + +package metainfo + +import ( + "encoding/base32" + "encoding/hex" + "errors" + "fmt" + "net/url" + "strings" +) + +// Magnet link components. +type Magnet struct { + InfoHash Hash // From "xt" + Trackers []string // From "tr" + DisplayName string // From "dn" if not empty + Params url.Values // All other values, such as "as", "xs", etc +} + +const xtPrefix = "urn:btih:" + +// Peers returns the list of the addresses of the peers. +// +// See BEP 9 +func (m Magnet) Peers() (peers []HostAddress, err error) { + vs := m.Params["x.pe"] + peers = make([]HostAddress, 0, len(vs)) + for _, v := range vs { + if v != "" { + var addr HostAddress + if err = addr.FromString(v); err != nil { + return + } + + peers = append(peers, addr) + } + } + return +} + +func (m Magnet) String() string { + vs := make(url.Values, len(m.Params)+len(m.Trackers)+2) + for k, v := range m.Params { + vs[k] = append([]string(nil), v...) + } + + for _, tr := range m.Trackers { + vs.Add("tr", tr) + } + if m.DisplayName != "" { + vs.Add("dn", m.DisplayName) + } + + // Transmission and Deluge both expect "urn:btih:" to be unescaped. + // Deluge wants it to be at the start of the magnet link. + // The InfoHash field is expected to be BitTorrent in this implementation. + u := url.URL{ + Scheme: "magnet", + RawQuery: "xt=" + xtPrefix + m.InfoHash.HexString(), + } + if len(vs) != 0 { + u.RawQuery += "&" + vs.Encode() + } + return u.String() +} + +// ParseMagnetURI parses Magnet-formatted URIs into a Magnet instance. +func ParseMagnetURI(uri string) (m Magnet, err error) { + u, err := url.Parse(uri) + if err != nil { + err = fmt.Errorf("error parsing uri: %w", err) + return + } else if u.Scheme != "magnet" { + err = fmt.Errorf("unexpected scheme %q", u.Scheme) + return + } + + q := u.Query() + xt := q.Get("xt") + if m.InfoHash, err = parseInfohash(q.Get("xt")); err != nil { + err = fmt.Errorf("error parsing infohash %q: %w", xt, err) + return + } + dropFirst(q, "xt") + + m.DisplayName = q.Get("dn") + dropFirst(q, "dn") + + m.Trackers = q["tr"] + delete(q, "tr") + + if len(q) == 0 { + q = nil + } + + m.Params = q + return +} + +func parseInfohash(xt string) (ih Hash, err error) { + if !strings.HasPrefix(xt, xtPrefix) { + err = errors.New("bad xt parameter prefix") + return + } + + var n int + encoded := xt[len(xtPrefix):] + switch len(encoded) { + case 40: + n, err = hex.Decode(ih[:], []byte(encoded)) + case 32: + n, err = base32.StdEncoding.Decode(ih[:], []byte(encoded)) + default: + err = fmt.Errorf("unhandled xt parameter encoding (encoded length %d)", len(encoded)) + return + } + + if err != nil { + err = fmt.Errorf("error decoding xt: %w", err) + } else if n != 20 { + panic(fmt.Errorf("invalid length '%d' of the decoded bytes", n)) + } + + return +} + +func dropFirst(vs url.Values, key string) { + sl := vs[key] + switch len(sl) { + case 0, 1: + vs.Del(key) + default: + vs[key] = sl[1:] + } +} diff --git a/metainfo/metainfo.go b/metainfo/metainfo.go new file mode 100644 index 0000000..b8535fb --- /dev/null +++ b/metainfo/metainfo.go @@ -0,0 +1,170 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "errors" + "io" + "os" + "strings" + + "github.com/xgfone/bt/bencode" + "github.com/xgfone/bt/utils" +) + +// Bytes is the []byte type. +type Bytes = bencode.RawMessage + +// AnnounceList is a list of the announces. +type AnnounceList [][]string + +// Unique returns the list of the unique announces. +func (al AnnounceList) Unique() (announces []string) { + announces = make([]string, 0, len(al)) + for _, tier := range al { + for _, v := range tier { + if v != "" && utils.InStringSlice(announces, v) { + announces = append(announces, v) + } + } + } + + return nil +} + +// URLList represents a list of the url. +// +// BEP 19 +type URLList []string + +// FullURL returns the index-th full url. +// +// For the single-file case, name is the "name" of "info". +// For the multi-file case, name is the path "name/path/file" +// from "info" and "files". +// +// See http://bittorrent.org/beps/bep_0019.html +func (us URLList) FullURL(index int, name string) (url string) { + if url = us[index]; strings.HasSuffix(url, "/") { + url += name + } + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (us URLList) MarshalBencode() (b []byte, err error) { + return bencode.EncodeBytes([]string(us)) +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (us *URLList) UnmarshalBencode(b []byte) (err error) { + var v interface{} + if err = bencode.DecodeBytes(b, &v); err == nil { + switch vs := v.(type) { + case string: + *us = URLList{vs} + case []interface{}: + urls := make(URLList, len(vs)) + for i, u := range vs { + s, ok := u.(string) + if !ok { + return errors.New("the element of 'url-list' is not string") + } + urls[i] = s + } + *us = urls + default: + return errors.New("invalid 'url-lsit'") + } + } + return +} + +// MetaInfo represents the .torrent file. +type MetaInfo struct { + InfoBytes Bytes `bencode:"info"` // BEP 3 + Announce string `bencode:"announce,omitempty"` // BEP 3 + AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12 + Nodes []HostAddress `bencode:"nodes,omitempty"` // BEP 5 + URLList URLList `bencode:"url-list,omitempty"` // BEP 19 + + // Where's this specified? + // Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification. + // All of them are optional. + + // CreationDate is the creation time of the torrent, in standard UNIX epoch + // format (seconds since 1-Jan-1970 00:00:00 UTC). + CreationDate int64 `bencode:"creation date,omitempty"` + // Comment is the free-form textual comments of the author. + Comment string `bencode:"comment,omitempty"` + // CreatedBy is name and version of the program used to create the .torrent. + CreatedBy string `bencode:"created by,omitempty"` + // Encoding is the string encoding format used to generate the pieces part + // of the info dictionary in the .torrent metafile. + Encoding string `bencode:"encoding,omitempty"` +} + +// Load loads a MetaInfo from an io.Reader. +func Load(r io.Reader) (mi MetaInfo, err error) { + err = bencode.NewDecoder(r).Decode(&mi) + return +} + +// LoadFromFile loads a MetaInfo from a file. +func LoadFromFile(filename string) (mi MetaInfo, err error) { + f, err := os.Open(filename) + if err == nil { + defer f.Close() + mi, err = Load(f) + } + return +} + +// Announces returns all the announces. +func (mi MetaInfo) Announces() AnnounceList { + if len(mi.AnnounceList) > 0 { + return mi.AnnounceList + } else if mi.Announce != "" { + return [][]string{{mi.Announce}} + } + return nil +} + +// Magnet creates a Magnet from a MetaInfo. +func (mi *MetaInfo) Magnet(displayName string, infoHash Hash) (m Magnet) { + for _, t := range mi.Announces().Unique() { + m.Trackers = append(m.Trackers, t) + } + + m.DisplayName = displayName + m.InfoHash = infoHash + return +} + +// Write encodes the metainfo to w. +func (mi MetaInfo) Write(w io.Writer) error { + return bencode.NewEncoder(w).Encode(mi) +} + +// InfoHash returns the hash of the info. +func (mi MetaInfo) InfoHash() Hash { + return NewHashFromBytes(mi.InfoBytes) +} + +// Info parses the InfoBytes to the Info. +func (mi MetaInfo) Info() (info Info, err error) { + err = bencode.DecodeBytes(mi.InfoBytes, &info) + return +} diff --git a/metainfo/piece.go b/metainfo/piece.go new file mode 100644 index 0000000..898ced1 --- /dev/null +++ b/metainfo/piece.go @@ -0,0 +1,99 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "crypto/sha1" + "errors" + "fmt" + "io" + + "github.com/xgfone/bt/utils" +) + +// Piece represents a torrent file piece. +type Piece struct { + info Info + index int +} + +// Index returns the index of the current piece. +func (p Piece) Index() int { return p.index } + +// Offset returns the offset that the current piece is in all the files. +func (p Piece) Offset() int64 { return int64(p.index) * p.info.PieceLength } + +// Hash returns the hash representation of the piece. +func (p Piece) Hash() (h Hash) { return p.info.Pieces[p.index] } + +// Length returns the length of the current piece. +func (p Piece) Length() int64 { + if p.index == p.info.CountPieces()-1 { + return p.info.TotalLength() - int64(p.index)*p.info.PieceLength + } + return p.info.PieceLength +} + +// GeneratePieces generates the pieces from the reader. +func GeneratePieces(r io.Reader, pieceLength int64) (hs Hashes, err error) { + buf := make([]byte, pieceLength) + for { + h := sha1.New() + written, err := utils.CopyNBuffer(h, r, pieceLength, buf) + if written > 0 { + hs = append(hs, NewHash(h.Sum(nil))) + } + + if err == io.EOF { + return hs, nil + } + + if err != nil { + return nil, err + } + } +} + +func writeFiles(w io.Writer, files []File, open func(File) (io.ReadCloser, error)) error { + buf := make([]byte, 8192) + for _, file := range files { + r, err := open(file) + if err != nil { + return fmt.Errorf("error opening %s: %s", file, err) + } + + n, err := utils.CopyNBuffer(w, r, file.Length, buf) + r.Close() + + if n != file.Length { + return fmt.Errorf("error copying %s: %s", file, err) + } + } + return nil +} + +// GeneratePiecesFromFiles generates the pieces from the files. +func GeneratePiecesFromFiles(files []File, pieceLength int64, + open func(File) (io.ReadCloser, error)) (Hashes, error) { + if pieceLength <= 0 { + return nil, errors.New("piece length must be a positive integer") + } + + pr, pw := io.Pipe() + defer pr.Close() + + go func() { pw.CloseWithError(writeFiles(pw, files, open)) }() + return GeneratePieces(pr, pieceLength) +} diff --git a/peerprotocol/doc.go b/peerprotocol/doc.go new file mode 100644 index 0000000..2d90e98 --- /dev/null +++ b/peerprotocol/doc.go @@ -0,0 +1,19 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package peerprotocol implements the core BT peer protocol. You can use it +// to implement a peer client, and upload/download the file to/from other peers. +// +// See TorrentDownloader to download the torrent file from other peer. +package peerprotocol diff --git a/peerprotocol/extension.go b/peerprotocol/extension.go new file mode 100644 index 0000000..7ab53ff --- /dev/null +++ b/peerprotocol/extension.go @@ -0,0 +1,161 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "bytes" + "errors" + "net" + + "github.com/xgfone/bt/bencode" +) + +var errInvalidIP = errors.New("invalid ipv4 or ipv6") + +// Predefine some extended message identifiers. +const ( + ExtendedIDHandshake = 0 // BEP 10 +) + +// Predefine some extended message names. +const ( + ExtendedMessageNameMetadata = "ut_metadata" // BEP 9 + ExtendedMessageNamePex = "ut_pex" // BEP 11 +) + +// Predefine some "ut_metadata" extended message types. +const ( + UtMetadataExtendedMsgTypeRequest = 0 // BEP 9 + UtMetadataExtendedMsgTypeData = 1 // BEP 9 + UtMetadataExtendedMsgTypeReject = 2 // BEP 9 +) + +// CompactIP is used to handle the compact ipv4 or ipv6. +type CompactIP net.IP + +func (ci CompactIP) String() string { + return net.IP(ci).String() +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (ci CompactIP) MarshalBencode() ([]byte, error) { + ip := net.IP(ci) + if ipv4 := ip.To4(); len(ipv4) != 0 { + ip = ipv4 + } + return bencode.EncodeBytes(ip[:]) +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (ci *CompactIP) UnmarshalBencode(b []byte) (err error) { + var ip net.IP + if err = bencode.DecodeBytes(b, &ip); err != nil { + return + } + + switch len(ip) { + case net.IPv4len, net.IPv6len: + default: + return errInvalidIP + } + + if ipv4 := ip.To4(); len(ipv4) != 0 { + ip = ipv4 + } + + *ci = CompactIP(ip) + return +} + +// ExtendedHandshakeMsg represent the extended handshake message. +// +// BEP 10 +type ExtendedHandshakeMsg struct { + // M is the type of map[ExtendedMessageName]ExtendedMessageID. + M map[string]uint8 `bencode:"m"` // BEP 10 + V string `bencode:"v,omitempty"` // BEP 10 + Reqq int `bencode:"reqq,omitempty"` // BEP 10 + + // Port is the local client port, which is redundant and no need + // for the receiving side of the connection to send this. + Port uint16 `bencode:"p,omitempty"` // BEP 10 + IPv6 net.IP `bencode:"ipv6,omitempty"` // BEP 10 + IPv4 CompactIP `bencode:"ipv4,omitempty"` // BEP 10 + YourIP CompactIP `bencode:"yourip,omitempty"` // BEP 10 + + MetadataSize int `bencode:"metadata_size,omitempty"` // BEP 9 +} + +// Decode decodes the extended handshake message from b. +func (ehm *ExtendedHandshakeMsg) Decode(b []byte) (err error) { + return bencode.DecodeBytes(b, ehm) +} + +// Encode encodes the extended handshake message to b. +func (ehm ExtendedHandshakeMsg) Encode() (b []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + if err = bencode.NewEncoder(buf).Encode(ehm); err != nil { + b = buf.Bytes() + } + return +} + +// UtMetadataExtendedMsg represents the "ut_metadata" extended message. +type UtMetadataExtendedMsg struct { + MsgType uint8 `bencode:"msg_type"` // BEP 9 + Piece int `bencode:"piece"` // BEP 9 + + // They are only used by "data" type + TotalSize int `bencode:"total_size,omitempty"` // BEP 9 + Data []byte `bencode:"-"` +} + +// EncodeToPayload encodes UtMetadataExtendedMsg to extended payload +// and write the result into buf. +func (um UtMetadataExtendedMsg) EncodeToPayload(buf *bytes.Buffer) (err error) { + if um.MsgType != UtMetadataExtendedMsgTypeData { + um.TotalSize = 0 + um.Data = nil + } + + buf.Grow(len(um.Data) + 50) + if err = bencode.NewEncoder(buf).Encode(um); err == nil { + _, err = buf.Write(um.Data) + } + return +} + +// EncodeToBytes is equal to +// +// buf := new(bytes.Buffer) +// err = um.EncodeToPayload(buf) +// return buf.Bytes(), err +// +func (um UtMetadataExtendedMsg) EncodeToBytes() (b []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + if err = um.EncodeToPayload(buf); err == nil { + b = buf.Bytes() + } + return +} + +// DecodeFromPayload decodes the extended payload to itself. +func (um *UtMetadataExtendedMsg) DecodeFromPayload(b []byte) (err error) { + dec := bencode.NewDecoder(bytes.NewReader(b)) + if err = dec.Decode(&um); err == nil { + um.Data = b[dec.BytesParsed():] + } + return +} diff --git a/peerprotocol/extension_test.go b/peerprotocol/extension_test.go new file mode 100644 index 0000000..c964993 --- /dev/null +++ b/peerprotocol/extension_test.go @@ -0,0 +1,54 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "bytes" + "testing" +) + +func TestCompactIP(t *testing.T) { + ipv4 := CompactIP([]byte{1, 2, 3, 4}) + b, err := ipv4.MarshalBencode() + if err != nil { + t.Fatal(err) + } + + var ip CompactIP + if err = ip.UnmarshalBencode(b); err != nil { + t.Error(err) + } else if ip.String() != "1.2.3.4" { + t.Error(ip) + } +} + +func TestUtMetadataExtendedMsg(t *testing.T) { + buf := new(bytes.Buffer) + data := []byte{0x31, 0x32, 0x33, 0x34, 0x35} + m1 := UtMetadataExtendedMsg{MsgType: 1, Piece: 2, TotalSize: 1024, Data: data} + if err := m1.EncodeToPayload(buf); err != nil { + t.Fatal(err) + } + + msg := Message{Type: Extended, ExtendedPayload: buf.Bytes()} + m2, err := msg.UtMetadataExtendedMsg() + if err != nil { + t.Fatal(err) + } else if m2.MsgType != 1 || m2.Piece != 2 || m2.TotalSize != 1024 { + t.Error(m2) + } else if !bytes.Equal(m2.Data, data) { + t.Fail() + } +} diff --git a/peerprotocol/fastset.go b/peerprotocol/fastset.go new file mode 100644 index 0000000..05f7b28 --- /dev/null +++ b/peerprotocol/fastset.go @@ -0,0 +1,69 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "crypto/sha1" + "encoding/binary" + "net" + + "github.com/xgfone/bt/metainfo" +) + +// GenerateAllowedFastSet generates some allowed fast set of the torrent file. +// +// Argument: +// set: generated piece set, the length of which is the number to be generated. +// sz: the number of pieces in torrent. +// ip: the of the remote peer of the connection. +// infohash: infohash of torrent. +// +// BEP 6 +func GenerateAllowedFastSet(set []uint32, sz uint32, ip net.IP, infohash metainfo.Hash) { + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + + iplen := len(ip) + x := make([]byte, 20+iplen) + for i, j := 0, iplen-1; i < j; i++ { // (1) compatible with IPv4/IPv6 + x[i] = ip[i] & 0xff // (1) + } + // x[iplen-1] = 0 // It is equal to 0 primitively. + copy(x[iplen:], infohash[:]) // (2) + + for cur, k := 0, len(set); cur < k; { + sum := sha1.Sum(x) // (3) + x = sum[:] // (3) + for i := 0; i < 5 && cur < k; i++ { // (4) + j := i * 4 // (5) + y := binary.BigEndian.Uint32(x[j : j+4]) // (6) + index := y % sz // (7) + if !uint32Contains(set, index) { // (8) + set[cur] = index // (9) + cur++ + } + } + } +} + +func uint32Contains(ss []uint32, s uint32) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} diff --git a/peerprotocol/fastset_test.go b/peerprotocol/fastset_test.go new file mode 100644 index 0000000..e472fd6 --- /dev/null +++ b/peerprotocol/fastset_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "net" + "testing" + + "github.com/xgfone/bt/metainfo" +) + +func TestGenerateAllowedFastSet(t *testing.T) { + hexs := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + infohash := metainfo.NewHashFromHexString(hexs) + + sets := make([]uint32, 7) + GenerateAllowedFastSet(sets, 1313, net.ParseIP("80.4.4.200"), infohash) + for i, v := range sets { + switch i { + case 0: + if v != 1059 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 1: + if v != 431 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 2: + if v != 808 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 3: + if v != 1217 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 4: + if v != 287 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 5: + if v != 376 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 6: + if v != 1188 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + default: + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + } + + sets = make([]uint32, 9) + GenerateAllowedFastSet(sets, 1313, net.ParseIP("80.4.4.200"), infohash) + for i, v := range sets { + switch i { + case 0: + if v != 1059 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 1: + if v != 431 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 2: + if v != 808 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 3: + if v != 1217 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 4: + if v != 287 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 5: + if v != 376 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 6: + if v != 1188 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 7: + if v != 353 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + case 8: + if v != 508 { + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + default: + t.Errorf("unknown '%d' at the index '%d'", v, i) + } + } +} diff --git a/peerprotocol/handshake.go b/peerprotocol/handshake.go new file mode 100644 index 0000000..63f4e2c --- /dev/null +++ b/peerprotocol/handshake.go @@ -0,0 +1,140 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "encoding/hex" + "fmt" + "io" + + "github.com/xgfone/bt/metainfo" +) + +var errInvalidProtocolHeader = fmt.Errorf("unexpected peer protocol header string") + +// Predefine some known extension bits. +const ( + ExtensionBitDHT = 0 // BEP 5 + ExtensionBitFast = 2 // BEP 6 + ExtensionBitExtended = 20 // BEP 10 + + ExtensionBitMax = 64 +) + +// ExtensionBits is the reserved bytes to be used by all extensions. +// +// BEP 10: The bit is counted starting at 0 from right to left. +type ExtensionBits [8]byte + +// String returns the hex string format. +func (eb ExtensionBits) String() string { + return hex.EncodeToString(eb[:]) +} + +// Set sets the bit to 1, that's, to set it to be on. +func (eb ExtensionBits) Set(bit uint) { + eb[7-bit/8] |= 1 << (bit % 8) +} + +// Unset sets the bit to 0, that's, to set it to be off. +func (eb ExtensionBits) Unset(bit uint) { + eb[7-bit/8] &^= 1 << (bit % 8) +} + +// IsSet reports whether the bit is on. +func (eb ExtensionBits) IsSet(bit uint) (yes bool) { + return eb[7-bit/8]&(1<<(bit%8)) != 0 +} + +// IsSupportDHT reports whether ExtensionBitDHT is set. +func (eb ExtensionBits) IsSupportDHT() (yes bool) { + return eb.IsSet(ExtensionBitDHT) +} + +// IsSupportFast reports whether ExtensionBitFast is set. +func (eb ExtensionBits) IsSupportFast() (yes bool) { + return eb.IsSet(ExtensionBitFast) +} + +// IsSupportExtended reports whether ExtensionBitExtended is set. +func (eb ExtensionBits) IsSupportExtended() (yes bool) { + return eb.IsSet(ExtensionBitExtended) +} + +// HandshakeMsg is the message used by the handshake +type HandshakeMsg struct { + ExtensionBits + + PeerID metainfo.Hash + InfoHash metainfo.Hash +} + +// NewHandshakeMsg returns a new HandshakeMsg. +func NewHandshakeMsg(peerID, infoHash metainfo.Hash, es ...ExtensionBits) HandshakeMsg { + var e ExtensionBits + if len(es) > 0 { + e = es[0] + } + return HandshakeMsg{ExtensionBits: e, PeerID: peerID, InfoHash: infoHash} +} + +// Handshake finishes the handshake with the peer. +// +// InfoHash may be ZERO, and it will read it from the peer then send it back +// to the peer. +// +// BEP 3 +func Handshake(sock io.ReadWriter, msg HandshakeMsg) (ret HandshakeMsg, err error) { + var read bool + if msg.InfoHash.IsZero() { + if err = getPeerHandshakeMsg(sock, &ret); err != nil { + return + } + read = true + msg.InfoHash = ret.InfoHash + } + + if _, err = io.WriteString(sock, ProtocolHeader); err != nil { + return + } + if _, err = sock.Write(msg.ExtensionBits[:]); err != nil { + return + } + if _, err = sock.Write(msg.InfoHash[:]); err != nil { + return + } + if _, err = sock.Write(msg.PeerID[:]); err != nil { + return + } + + if !read { + err = getPeerHandshakeMsg(sock, &ret) + } + return +} + +func getPeerHandshakeMsg(sock io.Reader, ret *HandshakeMsg) (err error) { + var b [68]byte + if _, err = io.ReadFull(sock, b[:]); err != nil { + return + } else if string(b[:20]) != ProtocolHeader { + return errInvalidProtocolHeader + } + + copy(ret.ExtensionBits[:], b[20:28]) + copy(ret.InfoHash[:], b[28:48]) + copy(ret.PeerID[:], b[48:68]) + return +} diff --git a/peerprotocol/message.go b/peerprotocol/message.go new file mode 100644 index 0000000..c573b95 --- /dev/null +++ b/peerprotocol/message.go @@ -0,0 +1,277 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "io/ioutil" +) + +var errMessageTooLong = fmt.Errorf("the peer message is too long") + +// Message is the message used by the peer protocol, which contains +// all the fields specified by the standard message types. +type Message struct { + Keepalive bool + Type MessageType + + // Index is used by these message types: + // + // BEP 3: Cancel, Request, Have, Piece + // BEP 6: Reject, Suggest, AllowedFast + // + Index uint32 + + // Begin is used by these message types: + // + // BEP 3: Request, Cancel, Piece + // BEP 6: Reject + // + Begin uint32 + + // Length is used by these message types: + // + // BEP 3: Request, Cancel + // BEP 6: Reject + // + Length uint32 + + // Piece is used by these message types: + // + // BEP 3: Piece + Piece []byte + + // Bitfield is used by these message types: + // + // BEP 3: Bitfield + Bitfield []bool + + // ExtendedID and ExtendedPayload are used by these message types: + // + // BEP 10: Extended + // + ExtendedID uint8 + ExtendedPayload []byte + + // Port is used by these message types: + // + // BEP 5: Port + // + Port uint16 +} + +// DecodeToMessage is equal to msg.Decode(r, maxLength). +func DecodeToMessage(r io.Reader, maxLength uint32) (msg Message, err error) { + err = msg.Decode(r, maxLength) + return +} + +// UtMetadataExtendedMsg decodes the extended payload as UtMetadataExtendedMsg. +// +// Notice: the message type must be Extended. +func (m Message) UtMetadataExtendedMsg() (um UtMetadataExtendedMsg, err error) { + if m.Type != Extended { + panic("the message type is Extended") + } + err = um.DecodeFromPayload(m.ExtendedPayload) + return +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler, +// which is equal to m.Decode(bytes.NewBuffer(data), 0). +func (m *Message) UnmarshalBinary(data []byte) (err error) { + return m.Decode(bytes.NewBuffer(data), 0) +} + +func readByte(r io.Reader) (b byte, err error) { + var bs [1]byte + if _, err = r.Read(bs[:]); err == nil { + b = bs[0] + } + return +} + +// Decode reads the data from r and decodes it to Message. +// +// if maxLength is equal to 0, it is unlimited. Or, it will read maxLength bytes +// at most. +func (m *Message) Decode(r io.Reader, maxLength uint32) (err error) { + var length uint32 + if err = binary.Read(r, binary.BigEndian, &length); err != nil { + if err != io.EOF { + err = fmt.Errorf("error reading peer message message length: %s", err) + } + return + } + + if length == 0 { + m.Keepalive = true + return + } else if maxLength > 0 && length > maxLength { + return errMessageTooLong + } + + m.Keepalive = false + lr := &io.LimitedReader{R: r, N: int64(length)} + + // Check that all of r was utilized. + defer func() { + if err == nil && lr.N != 0 { + err = fmt.Errorf("%d bytes unused in message type %d", lr.N, m.Type) + } + }() + + _type, err := readByte(lr) + if err != nil { + return + } + + switch m.Type = MessageType(_type); m.Type { + case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: + case Have, AllowedFast, Suggest: + err = binary.Read(lr, binary.BigEndian, &m.Index) + case Request, Cancel, Reject: + if err = binary.Read(lr, binary.BigEndian, &m.Index); err != nil { + return + } + if err = binary.Read(lr, binary.BigEndian, &m.Begin); err != nil { + return + } + if err = binary.Read(lr, binary.BigEndian, &m.Length); err != nil { + return + } + case Bitfield: + _len := length - 1 + bs := make([]byte, _len) + if _, err = io.ReadFull(lr, bs); err == nil { + m.Bitfield = make([]bool, 0, _len*8) + for _, b := range bs { + for i := byte(7); i >= 0; i-- { + m.Bitfield = append(m.Bitfield, (b>>i)&1 == 1) + } + } + } + case Piece: + if err = binary.Read(lr, binary.BigEndian, &m.Index); err != nil { + return + } + if err = binary.Read(lr, binary.BigEndian, &m.Begin); err != nil { + return + } + + // TODO: Should we use a []byte pool? + m.Piece = make([]byte, lr.N) + if _, err = io.ReadFull(lr, m.Piece); err != nil { + return fmt.Errorf("reading piece data error: %s", err) + } + case Extended: + if m.ExtendedID, err = readByte(lr); err == nil { + m.ExtendedPayload, err = ioutil.ReadAll(lr) + } + case Port: + err = binary.Read(lr, binary.BigEndian, &m.Port) + default: + err = fmt.Errorf("unknown message type %v", m.Type) + } + + return +} + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (m Message) MarshalBinary() (data []byte, err error) { + // TODO: Should we use a buffer pool? + buf := bytes.NewBuffer(make([]byte, 0, 4)) + if err = m.Encode(buf); err == nil { + data = buf.Bytes() + } + return +} + +// Encode encodes the message to buf. +func (m Message) Encode(buf *bytes.Buffer) (err error) { + // The 4-bytes is the placeholder of the length. + buf.Reset() + buf.Write([]byte{0, 0, 0, 0}) + + // Write the non-keepalive message. + if !m.Keepalive { + if err = buf.WriteByte(byte(m.Type)); err != nil { + return + } else if err = m.marshalBinaryType(buf); err != nil { + return + } + + // Calculate and reset the length of the message body. + data := buf.Bytes() + if payloadLen := len(data) - 4; payloadLen > 0 { + binary.BigEndian.PutUint32(data[:4], uint32(payloadLen)) + } + } + + return +} + +func (m Message) marshalBinaryType(buf *bytes.Buffer) (err error) { + switch m.Type { + case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: + case Have: + err = binary.Write(buf, binary.BigEndian, m.Index) + case Request, Cancel, Reject: + if err = binary.Write(buf, binary.BigEndian, m.Index); err != nil { + return + } + if err = binary.Write(buf, binary.BigEndian, m.Begin); err != nil { + return + } + if err = binary.Write(buf, binary.BigEndian, m.Length); err != nil { + return + } + case Bitfield: + _len := (len(m.Bitfield) + 7) / 8 + bs := make([]byte, _len) + for i, has := range m.Bitfield { + if has { + bs[i/8] |= (1 << byte(7-i%8)) + } + } + buf.Write(bs) + case Piece: + if err = binary.Write(buf, binary.BigEndian, m.Index); err != nil { + return + } + if err = binary.Write(buf, binary.BigEndian, m.Begin); err != nil { + return + } + + if n, err := buf.Write(m.Piece); err != nil { + return err + } else if _len := len(m.Piece); n != _len { + return fmt.Errorf("expect writing %d bytes, but wrote %d", _len, n) + } + case Extended: + if err = buf.WriteByte(byte(m.ExtendedID)); err != nil { + _, err = buf.Write(m.ExtendedPayload) + } + case Port: + err = binary.Write(buf, binary.BigEndian, m.Port) + default: + err = fmt.Errorf("unknown message type: %v", m.Type) + } + + return +} diff --git a/peerprotocol/peerconn.go b/peerprotocol/peerconn.go new file mode 100644 index 0000000..d5a72af --- /dev/null +++ b/peerprotocol/peerconn.go @@ -0,0 +1,522 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "bytes" + "fmt" + "io" + "net" + "time" + + "github.com/xgfone/bt/bencode" + "github.com/xgfone/bt/metainfo" +) + +// Predefine some errors about extension support. +var ( + ErrNotSupportDHT = fmt.Errorf("not support DHT extension") + ErrNotSupportFast = fmt.Errorf("not support Fast extension") + ErrNotSupportExtended = fmt.Errorf("not support Extended extension") +) + +// Bep3Handler is used to handle the BEP 3 type message if Handler has also +// implemented the interface. +type Bep3Handler interface { + Choke(pc *PeerConn) error + Unchoke(pc *PeerConn) error + Interested(pc *PeerConn) error + NotInterested(pc *PeerConn) error + Have(pc *PeerConn, index uint32) error + Bitfield(pc *PeerConn, bits []bool) error + Request(pc *PeerConn, index uint32, begin uint32, length uint32) error + Piece(pc *PeerConn, index uint32, begin uint32, piece []byte) error + Cancel(pc *PeerConn, index uint32, begin uint32, length uint32) error +} + +// Bep5Handler is used to handle the BEP 5 type message if Handler has also +// implemented the interface. +// +// Notice: the server must enable the DHT extension bit. +type Bep5Handler interface { + Port(pc *PeerConn, port uint16) error +} + +// Bep6Handler is used to handle the BEP 6 type message if Handler has also +// implemented the interface. +// +// Notice: the server must enable the Fast extension bit. +type Bep6Handler interface { + HaveAll(pc *PeerConn) error + HaveNone(pc *PeerConn) error + Suggest(pc *PeerConn, index uint32) error + AllowedFast(pc *PeerConn, index uint32) error + Reject(pc *PeerConn, index uint32, begin uint32, length uint32) error +} + +// Bep10Handler is used to handle the BEP 10 extended peer message +// if Handler has also implemented the interface. +// +// Notice: the server must enable the Extended extension bit. +type Bep10Handler interface { + OnHandShake(conn *PeerConn, exthmsg ExtendedHandshakeMsg) error + OnPayload(conn *PeerConn, extid uint8, payload []byte) error +} + +// PeerConn is used to manage the connection to the peer. +type PeerConn struct { + net.Conn + ExtensionBits + ID metainfo.Hash + + // Timeout is used to control the timeout of reading/writing the message. + // + // The default is 0, which represents no timeout. + Timeout time.Duration + + // MaxLength is used to limit the maximum number of the message body. + // + // The default is 0, which represents no limit. + MaxLength uint32 + + // These two states is controlled by the local client peer. + // + // Choked is used to indicate whether or not the local client has choked + // the remote peer, that's, if Choked is true, the local client will + // discard all the pending requests from the remote peer and not answer + // any requests until the local client is unchoked. + // + // Interested is used to indicate whether or not the local client is + // interested in something the remote peer has to offer, that's, + // if Interested is true, the local client will begin requesting blocks + // when the remote client unchokes them. + // + Choked bool // The default should be true. + Interested bool + + // These two states is controlled by the remote peer. + // + // PeerChoked is used to indicate whether or not the remote peer has choked + // the local client, that's, if PeerChoked is true, the remote peer will + // discard all the pending requests from the local client and not answer + // any requests until the remote peer is unchoked. + // + // PeerInterested is used to indicate whether or not the remote peer is + // interested in something the local client has to offer, that's, + // if PeerInterested is true, the remote peer will begin requesting blocks + // when the local client unchokes them. + // + PeerChoked bool // The default should be true. + PeerInterested bool + + // Data is used to store the context data associated with the connection. + Data interface{} +} + +// NewPeerConn returns a new PeerConn. +func NewPeerConn(id metainfo.Hash, conn net.Conn) *PeerConn { + return &PeerConn{ID: id, Conn: conn, Choked: true, PeerChoked: true} +} + +// NewPeerConnByDial returns a new PeerConn by dialing to addr with the "tcp" network. +func NewPeerConnByDial(id metainfo.Hash, addr string) (pc *PeerConn, err error) { + conn, err := net.Dial("tcp", addr) + if err == nil { + pc = NewPeerConn(id, conn) + } + return +} + +func (pc *PeerConn) setReadTimeout() { + if pc.Timeout > 0 { + pc.Conn.SetReadDeadline(time.Now().Add(pc.Timeout)) + } +} + +func (pc *PeerConn) setWriteTimeout() { + if pc.Timeout > 0 { + pc.Conn.SetWriteDeadline(time.Now().Add(pc.Timeout)) + } +} + +// SetChoked sets the Choked state of the local client peer. +// +// Notice: if the current state is not Choked, it will send a Choked message +// to the remote peer. +func (pc *PeerConn) SetChoked() (err error) { + if !pc.Choked { + if err = pc.SendChoke(); err == nil { + pc.Choked = true + } + } + return +} + +// SetUnchoked sets the Unchoked state of the local client peer. +// +// Notice: if the current state is not Unchoked, it will send a Unchoked message +// to the remote peer. +func (pc *PeerConn) SetUnchoked() (err error) { + if pc.Choked { + if err = pc.SendUnchoke(); err == nil { + pc.Choked = false + } + } + return +} + +// SetInterested sets the Interested state of the local client peer. +// +// Notice: if the current state is not Interested, it will send a Interested +// message to the remote peer. +func (pc *PeerConn) SetInterested() (err error) { + if !pc.Interested { + if err = pc.SendInterested(); err == nil { + pc.Interested = true + } + } + return +} + +// SetNotInterested sets the NotInterested state of the local client peer. +// +// Notice: if the current state is not NotInterested, it will send +// a NotInterested message to the remote peer. +func (pc *PeerConn) SetNotInterested() (err error) { + if pc.Interested { + if err = pc.SendNotInterested(); err == nil { + pc.Interested = false + } + } + return +} + +// Handshake does a handshake with the peer. +// +// BEP 3 +func (pc *PeerConn) Handshake(infoHash metainfo.Hash) (HandshakeMsg, error) { + m := HandshakeMsg{ExtensionBits: pc.ExtensionBits, PeerID: pc.ID, InfoHash: infoHash} + pc.setReadTimeout() + return Handshake(pc.Conn, m) +} + +// ReadMsg reads the message. +// +// BEP 3 +func (pc *PeerConn) ReadMsg() (m Message, err error) { + pc.setReadTimeout() + err = m.Decode(pc.Conn, pc.MaxLength) + return +} + +// WriteMsg writes the message to the peer. +// +// BEP 3 +func (pc *PeerConn) WriteMsg(m Message) (err error) { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + if err = m.Encode(buf); err == nil { + pc.setWriteTimeout() + + var n int + if n, err = pc.Conn.Write(buf.Bytes()); err == nil && n < buf.Len() { + err = io.ErrShortWrite + } + } + return +} + +// SendKeepalive sends a Keepalive message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendKeepalive() error { + return pc.WriteMsg(Message{Keepalive: true}) +} + +// SendChoke sends a Choke message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendChoke() error { + return pc.WriteMsg(Message{Type: Choke}) +} + +// SendUnchoke sends a Unchoke message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendUnchoke() error { + return pc.WriteMsg(Message{Type: Unchoke}) +} + +// SendInterested sends a Interested message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendInterested() error { + return pc.WriteMsg(Message{Type: Interested}) +} + +// SendNotInterested sends a NotInterested message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendNotInterested() error { + return pc.WriteMsg(Message{Type: NotInterested}) +} + +// SendBitfield sends a Bitfield message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendBitfield(bits []bool) error { + return pc.WriteMsg(Message{Type: Bitfield, Bitfield: bits}) +} + +// SendHave sends a Have message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendHave(index uint32) error { + return pc.WriteMsg(Message{Type: Have, Index: index}) +} + +// SendRequest sends a Request message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendRequest(index, begin, length uint32) error { + return pc.WriteMsg(Message{Type: Request, Index: index, Begin: begin, Length: length}) +} + +// SendCancel sends a Cancel message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendCancel(index, begin, length uint32) error { + return pc.WriteMsg(Message{Type: Cancel, Index: index, Begin: begin, Length: length}) +} + +// SendPiece sends a Piece message to the peer. +// +// BEP 3 +func (pc *PeerConn) SendPiece(index, begin uint32, piece []byte) error { + return pc.WriteMsg(Message{Type: Piece, Index: index, Begin: begin, Piece: piece}) +} + +// SendPort sends a Port message to the peer. +// +// BEP 5 +func (pc *PeerConn) SendPort(port uint16) error { + return pc.WriteMsg(Message{Type: Port, Port: port}) +} + +// SendHaveAll sends a HaveAll message to the peer. +// +// BEP 6 +func (pc *PeerConn) SendHaveAll() error { + return pc.WriteMsg(Message{Type: HaveAll}) +} + +// SendHaveNone sends a HaveNone message to the peer. +// +// BEP 6 +func (pc *PeerConn) SendHaveNone() error { + return pc.WriteMsg(Message{Type: HaveNone}) +} + +// SendSuggest sends a Suggest message to the peer. +// +// BEP 6 +func (pc *PeerConn) SendSuggest(index uint32) error { + return pc.WriteMsg(Message{Type: Suggest, Index: index}) +} + +// SendReject sends a Reject message to the peer. +// +// BEP 6 +func (pc *PeerConn) SendReject(index, begin, length uint32) error { + return pc.WriteMsg(Message{Type: Reject, Index: index, Begin: begin, Length: length}) +} + +// SendAllowedFast sends a AllowedFast message to the peer. +// +// BEP 6 +func (pc *PeerConn) SendAllowedFast(index uint32) error { + return pc.WriteMsg(Message{Type: AllowedFast, Index: index}) +} + +// SendExtHandshakeMsg sends the Extended Handshake message to the peer. +// +// BEP 10 +func (pc *PeerConn) SendExtHandshakeMsg(m ExtendedHandshakeMsg) (err error) { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + if err = bencode.NewEncoder(buf).Encode(m); err == nil { + err = pc.SendExtMsg(ExtendedIDHandshake, buf.Bytes()) + } + return +} + +// SendExtMsg sends the Extended message with the extended id and the payload. +// +// BEP 10 +func (pc *PeerConn) SendExtMsg(extID uint8, payload []byte) error { + return pc.WriteMsg(Message{Type: Extended, ExtendedID: extID, ExtendedPayload: payload}) +} + +// HandleMessage calls the method of the handler to handle the message. +// +// If handler has also implemented the interfaces Bep3Handler, Bep5Handler, +// Bep6Handler or Bep10Handler, their methods will be called instead of +// Handler.OnMessage for the corresponding type message. +func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) { + if msg.Keepalive { + return + } + + switch msg.Type { + // BEP 3 - The BitTorrent Protocol Specification + case Choke: + pc.PeerChoked = true + if h, ok := handler.(Bep3Handler); ok { + err = h.Choke(pc) + } else { + err = handler.OnMessage(pc, msg) + } + case Unchoke: + pc.PeerChoked = false + if h, ok := handler.(Bep3Handler); ok { + err = h.Unchoke(pc) + } else { + err = handler.OnMessage(pc, msg) + } + case Interested: + pc.PeerInterested = true + if h, ok := handler.(Bep3Handler); ok { + err = h.Interested(pc) + } else { + err = handler.OnMessage(pc, msg) + } + case NotInterested: + pc.PeerInterested = false + if h, ok := handler.(Bep3Handler); ok { + err = h.NotInterested(pc) + } else { + err = handler.OnMessage(pc, msg) + } + case Have: + if h, ok := handler.(Bep3Handler); ok { + err = h.Have(pc, msg.Index) + } else { + err = handler.OnMessage(pc, msg) + } + case Bitfield: + if h, ok := handler.(Bep3Handler); ok { + err = h.Bitfield(pc, msg.Bitfield) + } else { + err = handler.OnMessage(pc, msg) + } + case Request: + if h, ok := handler.(Bep3Handler); ok { + err = h.Request(pc, msg.Index, msg.Begin, msg.Length) + } else { + err = handler.OnMessage(pc, msg) + } + case Piece: + if h, ok := handler.(Bep3Handler); ok { + err = h.Piece(pc, msg.Index, msg.Begin, msg.Piece) + } else { + err = handler.OnMessage(pc, msg) + } + case Cancel: + if h, ok := handler.(Bep3Handler); ok { + err = h.Cancel(pc, msg.Index, msg.Begin, msg.Length) + } else { + err = handler.OnMessage(pc, msg) + } + + // BEP 5 - DHT Protocol + case Port: + if !pc.IsSupportDHT() { + return ErrNotSupportDHT + } else if h, ok := handler.(Bep5Handler); ok { + err = h.Port(pc, msg.Port) + } else { + err = handler.OnMessage(pc, msg) + } + + // BEP 6 - Fast Extension + case Suggest: + if !pc.IsSupportFast() { + return ErrNotSupportFast + } else if h, ok := handler.(Bep6Handler); ok { + err = h.Suggest(pc, msg.Index) + } else { + err = handler.OnMessage(pc, msg) + } + case HaveAll: + if !pc.IsSupportFast() { + return ErrNotSupportFast + } else if h, ok := handler.(Bep6Handler); ok { + err = h.HaveAll(pc) + } else { + err = handler.OnMessage(pc, msg) + } + case HaveNone: + if !pc.IsSupportFast() { + return ErrNotSupportFast + } else if h, ok := handler.(Bep6Handler); ok { + err = h.HaveNone(pc) + } else { + err = handler.OnMessage(pc, msg) + } + case Reject: + if !pc.IsSupportFast() { + return ErrNotSupportFast + } else if h, ok := handler.(Bep6Handler); ok { + err = h.Reject(pc, msg.Index, msg.Begin, msg.Length) + } else { + err = handler.OnMessage(pc, msg) + } + case AllowedFast: + if !pc.IsSupportFast() { + return ErrNotSupportFast + } else if h, ok := handler.(Bep6Handler); ok { + err = h.AllowedFast(pc, msg.Index) + } else { + err = handler.OnMessage(pc, msg) + } + + // BEP 10 - Extension Protocol + case Extended: + if !pc.IsSupportExtended() { + return ErrNotSupportExtended + } else if h, ok := handler.(Bep10Handler); ok { + err = pc.handleExtMsg(h, msg) + } else { + err = handler.OnMessage(pc, msg) + } + + // Other + default: + err = handler.OnMessage(pc, msg) + } + + return +} + +func (pc *PeerConn) handleExtMsg(h Bep10Handler, m Message) (err error) { + if m.ExtendedID == ExtendedIDHandshake { + var ehmsg ExtendedHandshakeMsg + if err = bencode.DecodeBytes(m.ExtendedPayload, &ehmsg); err == nil { + err = h.OnHandShake(pc, ehmsg) + } + } else { + err = h.OnPayload(pc, m.ExtendedID, m.ExtendedPayload) + } + + return +} diff --git a/peerprotocol/protocol.go b/peerprotocol/protocol.go new file mode 100644 index 0000000..49e1d0f --- /dev/null +++ b/peerprotocol/protocol.go @@ -0,0 +1,97 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "fmt" +) + +// ProtocolHeader is the BT protocal prefix. +// +// BEP 3 +const ProtocolHeader = "\x13BitTorrent protocol" + +// Predefine some message types. +const ( + // BEP 3 + Choke MessageType = 0 + Unchoke MessageType = 1 + Interested MessageType = 2 + NotInterested MessageType = 3 + Have MessageType = 4 + Bitfield MessageType = 5 + Request MessageType = 6 + Piece MessageType = 7 + Cancel MessageType = 8 + + // BEP 5 + Port MessageType = 9 + + // BEP 6 - Fast extension + Suggest MessageType = 0x0d // 13 + HaveAll MessageType = 0x0e // 14 + HaveNone MessageType = 0x0f // 15 + Reject MessageType = 0x10 // 16 + AllowedFast MessageType = 0x11 // 17 + + // BEP 10 + Extended MessageType = 20 +) + +// MessageType is used to represent the message type. +type MessageType byte + +func (mt MessageType) String() string { + switch mt { + case Choke: + return "Choke" + case Unchoke: + return "Unchoke" + case Interested: + return "Interested" + case NotInterested: + return "NotInterested" + case Have: + return "Have" + case Bitfield: + return "Bitfield" + case Request: + return "Request" + case Piece: + return "Piece" + case Cancel: + return "Cancel" + case Port: + return "Port" + case Suggest: + return "Suggest" + case HaveAll: + return "HaveAll" + case HaveNone: + return "HaveNone" + case Reject: + return "Reject" + case AllowedFast: + return "AllowedFast" + case Extended: + return "Extended" + } + return fmt.Sprintf("MessageType(%d)", mt) +} + +// FastExtension reports whether the message type is fast extension. +func (mt MessageType) FastExtension() bool { + return mt >= Suggest && mt <= AllowedFast +} diff --git a/peerprotocol/server.go b/peerprotocol/server.go new file mode 100644 index 0000000..910b178 --- /dev/null +++ b/peerprotocol/server.go @@ -0,0 +1,169 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peerprotocol + +import ( + "fmt" + "io" + "log" + "net" + "strings" + "time" + + "github.com/xgfone/bt/metainfo" +) + +// Handler is used to handle the incoming peer connection. +type Handler interface { + // OnHandShake is used to check whether the handshake extension is acceptable. + OnHandShake(conn *PeerConn, hmsg HandshakeMsg) error + + // OnMessage is used to handle the incoming peer message. + // + // If requires, it should write the response to the peer. + OnMessage(conn *PeerConn, msg Message) error +} + +// Config is used to configure the server. +type Config struct { + // ExtBits is used to handshake with the client. + ExtBits ExtensionBits + + // MaxLength is used to limit the maximum number of the message body. + // + // The default is 0, which represents no limit. + MaxLength uint32 + + // Timeout is used to control the timeout of the read/write the message. + // + // The default is 0, which represents no timeout. + Timeout time.Duration + + // ErrorLog is used to log the error. + ErrorLog func(format string, args ...interface{}) // Default: log.Printf + + // HandleMessage is used to handle the incoming message. So you can + // customize it to add the request queue. + // + // The default handler is to forward to pc.HandleMessage(msg, handler). + HandleMessage func(pc *PeerConn, msg Message, handler Handler) error +} + +func (c *Config) set(conf ...Config) { + if len(conf) > 0 { + *c = conf[0] + } + + if c.ErrorLog == nil { + c.ErrorLog = log.Printf + } + if c.HandleMessage == nil { + c.HandleMessage = func(pc *PeerConn, m Message, h Handler) error { + return pc.HandleMessage(m, h) + } + } +} + +// Server is used to implement the peer protocol server. +type Server struct { + ln net.Listener + id metainfo.Hash + h Handler + c Config +} + +// NewServerByListen returns a new Server by listening on the address. +func NewServerByListen(network, address string, id metainfo.Hash, h Handler, + c ...Config) (*Server, error) { + ln, err := net.Listen(network, address) + if err != nil { + return nil, err + } + return NewServer(ln, id, h, c...), nil +} + +// NewServer returns a new Server. +func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c ...Config) *Server { + if id.IsZero() { + panic("the peer node id must not be empty") + } + + var conf Config + conf.set(c...) + return &Server{ln: ln, id: id, h: h, c: conf} +} + +// Run starts the peer protocol server. +func (s *Server) Run() { + for { + conn, err := s.ln.Accept() + if err != nil { + s.c.ErrorLog("fail to accept new connection: %s", err) + } + go s.handleConn(conn) + } +} + +func (s *Server) handleConn(conn net.Conn) { + pc := &PeerConn{ + ID: s.id, + Conn: conn, + ExtensionBits: s.c.ExtBits, + Timeout: s.c.Timeout, + MaxLength: s.c.MaxLength, + Choked: true, + PeerChoked: true, + } + + if err := s.handlePeerMessage(pc); err != nil { + s.c.ErrorLog(err.Error()) + } +} + +func (s *Server) handlePeerMessage(pc *PeerConn) (err error) { + defer pc.Close() + m, err := pc.Handshake(metainfo.Hash{}) + if err != nil { + return fmt.Errorf("fail to handshake with '%s': %s", pc.RemoteAddr().String(), err) + } else if err = s.h.OnHandShake(pc, m); err != nil { + return fmt.Errorf("handshake error with '%s': %s", pc.RemoteAddr().String(), err) + } + + return s.loopRun(pc, s.h) +} + +// LoopRun loops running Read-Handle message. +func (s *Server) loopRun(pc *PeerConn, handler Handler) error { + for { + msg, err := pc.ReadMsg() + switch err { + case nil: + case io.EOF: + return nil + default: + s := err.Error() + if strings.Contains(s, "closed") { + return nil + } + return fmt.Errorf("fail to decode the message from '%s': %s", + pc.RemoteAddr().String(), s) + } + + if err = s.c.HandleMessage(pc, msg, handler); err != nil { + return fmt.Errorf("fail to handle peer message from '%s': %s", + pc.RemoteAddr().String(), err) + } + } +} diff --git a/tracker/httptracker/http.go b/tracker/httptracker/http.go new file mode 100644 index 0000000..03e6e43 --- /dev/null +++ b/tracker/httptracker/http.go @@ -0,0 +1,312 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httptracker implements the tracker protocol based on HTTP/HTTPS. +// +// You can use the package to implement a HTTP tracker server to track the +// information that other peers upload or download the file, or to create +// a HTTP tracker client to communicate with the HTTP tracker server. +package httptracker + +import ( + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/xgfone/bt/bencode" + "github.com/xgfone/bt/metainfo" +) + +// AnnounceRequest is the tracker announce requests. +// +// BEP 3 +type AnnounceRequest struct { + // InfoHash is the sha1 hash of the bencoded form of the info value from the metainfo file. + InfoHash metainfo.Hash `bencode:"info_hash"` // BEP 3 + + // PeerID is the id of the downloader. + // + // Each downloader generates its own id at random at the start of a new download. + PeerID metainfo.Hash `bencode:"peer_id"` // BEP 3 + + // Uploaded is the total amount uploaded so far, encoded in base ten ascii. + Uploaded int64 `bencode:"uploaded"` // BEP 3 + + // Downloaded is the total amount downloaded so far, encoded in base ten ascii. + Downloaded int64 `bencode:"downloaded"` // BEP 3 + + // Left is the number of bytes this peer still has to download, + // encoded in base ten ascii. + // + // Note that this can't be computed from downloaded and the file length + // since it might be a resume, and there's a chance that some of the + // downloaded data failed an integrity check and had to be re-downloaded. + // + // If less than 0, math.MaxInt64 will be used for HTTP trackers instead. + Left int64 `bencode:"left"` // BEP 3 + + // Port is the port that this peer is listening on. + // + // Common behavior is for a downloader to try to listen on port 6881, + // and if that port is taken try 6882, then 6883, etc. and give up after 6889. + Port uint16 `bencode:"port"` // BEP 3 + + // IP is the ip or DNS name which this peer is at, which generally used + // for the origin if it's on the same machine as the tracker. + // + // Optional. + IP string `bencode:"ip,omitempty"` // BEP 3 + + // If not present, this is one of the announcements done at regular intervals. + // An announcement using started is sent when a download first begins, + // and one using completed is sent when the download is complete. + // No completed is sent if the file was complete when started. + // Downloaders send an announcement using stopped when they cease downloading. + // + // Optional + Event uint32 `bencode:"event,omitempty"` // BEP 3 + + // Compact indicates whether it hopes the tracker to return the compact + // peer lists. + // + // Optional + Compact bool `bencode:"compact,omitempty"` // BEP 23 + + // NumWant is the number of peers that the client would like to receive + // from the tracker. This value is permitted to be zero. If omitted, + // typically defaults to 50 peers. + // + // See https://wiki.theory.org/index.php/BitTorrentSpecification + // + // Optional. + NumWant int32 `bencode:"numwant,omitempty"` + + Key int32 `bencode:"key,omitempty"` +} + +// ToQuery converts the Request to URL Query. +func (r AnnounceRequest) ToQuery() (vs url.Values) { + vs = make(url.Values, 9) + vs.Set("info_hash", r.InfoHash.BytesString()) + vs.Set("peer_id", r.PeerID.BytesString()) + vs.Set("uploaded", strconv.FormatInt(r.Uploaded, 10)) + vs.Set("downloaded", strconv.FormatInt(r.Downloaded, 10)) + vs.Set("left", strconv.FormatInt(r.Left, 10)) + + if r.IP != "" { + vs.Set("ip", r.IP) + } + if r.Event > 0 { + vs.Set("event", strconv.FormatInt(int64(r.Event), 10)) + } + if r.Port > 0 { + vs.Set("port", strconv.FormatUint(uint64(r.Port), 10)) + } + if r.NumWant > 0 { + vs.Set("numwant", strconv.FormatUint(uint64(r.NumWant), 10)) + } + if r.Key != 0 { + vs.Set("key", strconv.FormatInt(int64(r.Key), 10)) + } + + // BEP 23 + if r.Compact { + vs.Set("compact", "1") + } else { + vs.Set("compact", "0") + } + + return +} + +// FromQuery converts URL Query to itself. +func (r *AnnounceRequest) FromQuery(vs url.Values) (err error) { + if err = r.InfoHash.FromString(vs.Get("info_hash")); err != nil { + return + } + + if err = r.PeerID.FromString(vs.Get("peer_id")); err != nil { + return + } + + v, err := strconv.ParseInt(vs.Get("uploaded"), 10, 64) + if err != nil { + return + } + r.Uploaded = v + + v, err = strconv.ParseInt(vs.Get("downloaded"), 10, 64) + if err != nil { + return + } + r.Downloaded = v + + v, err = strconv.ParseInt(vs.Get("left"), 10, 64) + if err != nil { + return + } + r.Left = v + + if s := vs.Get("event"); s != "" { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + r.Event = uint32(v) + } + + if s := vs.Get("port"); s != "" { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + r.Port = uint16(v) + } + + if s := vs.Get("numwant"); s != "" { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + r.NumWant = int32(v) + } + + if s := vs.Get("key"); s != "" { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + r.Key = int32(v) + } + + r.IP = vs.Get("ip") + switch vs.Get("compact") { + case "1": + r.Compact = true + case "0": + r.Compact = false + } + + return +} + +// AnnounceResponse is a announce response. +type AnnounceResponse struct { + FailureReason string `bencode:"failure reason,omitempty"` + + // Interval is the seconds the downloader should wait before next rerequest. + Interval uint32 `bencode:"interval,omitempty"` // BEP 3 + + // Peers is the list of the peers. + Peers Peers `bencode:"peers,omitempty"` // BEP 3, BEP 23 + + // Peers6 is only used for ipv6 in the compact case. + Peers6 Peers6 `bencode:"peers6,omitempty"` // BEP 7 + + // Where's this specified? + // Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification. + + // Complete is the number of peers with the entire file. + Complete uint32 `bencode:"complete,omitempty"` + // Incomplete is the number of non-seeder peers. + Incomplete uint32 `bencode:"incomplete,omitempty"` + // TrackerID is that the client should send back on its next announcements. + // If absent and a previous announce sent a tracker id, + // do not discard the old value; keep using it. + TrackerID string `bencode:"tracker id,omitempty"` +} + +// ScrapeResponseResult is the result of the scraped file. +type ScrapeResponseResult struct { + // Complete is the number of active peers that have completed downloading. + Complete uint32 `bencode:"complete"` // BEP 48 + + // Incomplete is the number of active peers that have not completed downloading. + Incomplete uint32 `bencode:"incomplete"` // BEP 48 + + // The number of peers that have ever completed downloading. + Downloaded uint32 `bencode:"downloaded"` // BEP 48 +} + +// ScrapeResponse represents a Scrape response. +// +// BEP 48 +type ScrapeResponse struct { + FailureReason string `bencode:"failure_reason,omitempty"` + + Files map[metainfo.Hash]ScrapeResponseResult `bencode:"files,omitempty"` +} + +// DecodeFrom reads the []byte data from r and decodes them to sr by bencode. +// +// r may be the body of the request from the http client. +func (sr *ScrapeResponse) DecodeFrom(r io.Reader) (err error) { + return bencode.NewDecoder(r).Decode(sr) +} + +// EncodeTo encodes the response to []byte by bencode and write the result into w. +// +// w may be http.ResponseWriter. +func (sr ScrapeResponse) EncodeTo(w io.Writer) (err error) { + return bencode.NewEncoder(w).Encode(sr) +} + +// TrackerClient represents a tracker client based on HTTP/HTTPS. +type TrackerClient struct { + AnnounceURL string + ScrapeURL string +} + +// NewTrackerClient returns a new HTTPTrackerClient. +// +// scrapeURL may be empty, which will replace the "announce" in announceURL +// with "scrape" to generate the scrapeURL. +func NewTrackerClient(announceURL, scrapeURL string) *TrackerClient { + if scrapeURL == "" { + scrapeURL = strings.Replace(announceURL, "announce", "scrape", -1) + } + return &TrackerClient{AnnounceURL: announceURL, ScrapeURL: scrapeURL} +} + +func (t *TrackerClient) send(u string, vs url.Values, r interface{}) (err error) { + sym := "?" + if strings.IndexByte(u, '?') > 0 { + sym = "&" + } + + resp, err := http.Get(u + sym + vs.Encode()) + if err != nil { + return + } + defer resp.Body.Close() + return bencode.NewDecoder(resp.Body).Decode(r) +} + +// Announce sends a Announce request to the tracker. +func (t *TrackerClient) Announce(req AnnounceRequest) (resp AnnounceResponse, err error) { + err = t.send(t.AnnounceURL, req.ToQuery(), &resp) + return +} + +// Scrape sends a Scrape request to the tracker. +func (t *TrackerClient) Scrape(infohashes []metainfo.Hash) (resp ScrapeResponse, err error) { + hs := make([]string, len(infohashes)) + for i, h := range infohashes { + hs[i] = h.BytesString() + } + err = t.send(t.ScrapeURL, url.Values{"info_hash": hs}, &resp) + return +} diff --git a/tracker/httptracker/http_peer.go b/tracker/httptracker/http_peer.go new file mode 100644 index 0000000..2da9cd6 --- /dev/null +++ b/tracker/httptracker/http_peer.go @@ -0,0 +1,192 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httptracker + +import ( + "bytes" + "encoding/binary" + "errors" + "net" + + "github.com/xgfone/bt/bencode" + "github.com/xgfone/bt/metainfo" +) + +var errInvalidPeer = errors.New("invalid peer information format") + +// Peer is a tracker peer. +type Peer struct { + // ID is the peer's self-selected ID. + ID string `bencode:"peer id"` // BEP 3 + + // IP is the IP address or dns name. + IP string `bencode:"ip"` // BEP 3 + Port uint16 `bencode:"port"` // BEP 3 +} + +// Addresses returns the list of the addresses that the peer listens on. +func (p Peer) Addresses() (addrs []metainfo.Address, err error) { + if ip := net.ParseIP(p.IP); len(ip) != 0 { + return []metainfo.Address{{IP: ip, Port: p.Port}}, nil + } + + ips, err := net.LookupIP(p.IP) + if _len := len(ips); err == nil && len(ips) != 0 { + addrs = make([]metainfo.Address, _len) + for i, ip := range ips { + addrs[i] = metainfo.Address{IP: ip, Port: p.Port} + } + } + + return +} + +// Peers is a set of the peers. +type Peers []Peer + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (ps *Peers) UnmarshalBencode(b []byte) (err error) { + var v interface{} + if err = bencode.DecodeBytes(b, &v); err != nil { + return + } + + switch vs := v.(type) { + case string: // BEP 23 + _len := len(vs) + if _len%6 != 0 { + return metainfo.ErrInvalidAddr + } + + peers := make(Peers, 0, _len/6) + for i := 0; i < _len; i += 6 { + var addr metainfo.Address + if err = addr.UnmarshalBinary([]byte(vs[i : i+6])); err != nil { + return + } + peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port}) + } + + *ps = peers + case []interface{}: // BEP 3 + peers := make(Peers, len(vs)) + for i, p := range vs { + m, ok := p.(map[string]interface{}) + if !ok { + return errInvalidPeer + } + + pid, ok := m["peer id"].(string) + if !ok { + return errInvalidPeer + } + + ip, ok := m["ip"].(string) + if !ok { + return errInvalidPeer + } + + port, ok := m["port"].(int64) + if !ok { + return errInvalidPeer + } + + peers[i] = Peer{ID: pid, IP: ip, Port: uint16(port)} + } + *ps = peers + default: + return errInvalidPeer + } + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (ps Peers) MarshalBencode() (b []byte, err error) { + for _, p := range ps { + if p.ID == "" { + return ps.marshalCompactBencode() // BEP 23 + } + } + + // BEP 3 + buf := bytes.NewBuffer(make([]byte, 0, 64*len(ps))) + buf.WriteByte('l') + for _, p := range ps { + if err = bencode.NewEncoder(buf).Encode(p); err != nil { + return + } + } + buf.WriteByte('e') + b = buf.Bytes() + return +} + +func (ps Peers) marshalCompactBencode() (b []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 6*len(ps))) + for _, peer := range ps { + ip := net.ParseIP(peer.IP).To4() + if len(ip) == 0 { + return nil, errInvalidPeer + } + buf.Write(ip[:]) + binary.Write(buf, binary.BigEndian, peer.Port) + } + return bencode.EncodeBytes(buf.Bytes()) +} + +// Peers6 is a set of the peers for IPv6 in the compact case. +// +// BEP 7 +type Peers6 []Peer + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (ps *Peers6) UnmarshalBencode(b []byte) (err error) { + var s string + if err = bencode.DecodeBytes(b, &s); err != nil { + return + } + + _len := len(s) + if _len%18 != 0 { + return metainfo.ErrInvalidAddr + } + + peers := make(Peers6, 0, _len/18) + for i := 0; i < _len; i += 18 { + var addr metainfo.Address + if err = addr.UnmarshalBinary([]byte(s[i : i+18])); err != nil { + return + } + peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port}) + } + + *ps = peers + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (ps Peers6) MarshalBencode() (b []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 18*len(ps))) + for _, peer := range ps { + ip := net.ParseIP(peer.IP).To16() + if len(ip) == 0 { + return nil, errInvalidPeer + } + + buf.Write(ip[:]) + binary.Write(buf, binary.BigEndian, peer.Port) + } + return bencode.EncodeBytes(buf.Bytes()) +} diff --git a/tracker/httptracker/http_peer_test.go b/tracker/httptracker/http_peer_test.go new file mode 100644 index 0000000..fe91bd7 --- /dev/null +++ b/tracker/httptracker/http_peer_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httptracker + +import ( + "reflect" + "testing" +) + +func TestPeers(t *testing.T) { + peers := Peers{ + {ID: "123", IP: "1.1.1.1", Port: 80}, + {ID: "456", IP: "2.2.2.2", Port: 81}, + } + + b, err := peers.MarshalBencode() + if err != nil { + t.Fatal(err) + } + + var ps Peers + if err = ps.UnmarshalBencode(b); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ps, peers) { + t.Errorf("%v != %v", ps, peers) + } + + /// For BEP 23 + peers = Peers{ + {IP: "1.1.1.1", Port: 80}, + {IP: "2.2.2.2", Port: 81}, + } + + b, err = peers.MarshalBencode() + if err != nil { + t.Fatal(err) + } + + if err = ps.UnmarshalBencode(b); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ps, peers) { + t.Errorf("%v != %v", ps, peers) + } +} + +func TestPeers6(t *testing.T) { + peers := Peers6{ + {IP: "fe80::5054:ff:fef0:1ab", Port: 80}, + {IP: "fe80::5054:ff:fe29:205d", Port: 81}, + } + + b, err := peers.MarshalBencode() + if err != nil { + t.Fatal(err) + } + + var ps Peers6 + if err = ps.UnmarshalBencode(b); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ps, peers) { + t.Errorf("%v != %v", ps, peers) + } +} diff --git a/tracker/httptracker/http_test.go b/tracker/httptracker/http_test.go new file mode 100644 index 0000000..fbb0feb --- /dev/null +++ b/tracker/httptracker/http_test.go @@ -0,0 +1,72 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httptracker + +import ( + "reflect" + "testing" + + "github.com/xgfone/bt/metainfo" +) + +func TestHTTPAnnounceRequest(t *testing.T) { + infohash := metainfo.NewRandomHash() + peerid := metainfo.NewRandomHash() + v1 := AnnounceRequest{ + InfoHash: infohash, + PeerID: peerid, + Uploaded: 789, + Downloaded: 456, + Left: 123, + Port: 80, + Event: 123, + Compact: true, + } + vs := v1.ToQuery() + + var v2 AnnounceRequest + if err := v2.FromQuery(vs); err != nil { + t.Fatal(err) + } + + if v2.InfoHash != infohash { + t.Error(v2.InfoHash) + } + if v2.PeerID != peerid { + t.Error(v2.PeerID) + } + if v2.Uploaded != 789 { + t.Error(v2.Uploaded) + } + if v2.Downloaded != 456 { + t.Error(v2.Downloaded) + } + if v2.Left != 123 { + t.Error(v2.Left) + } + if v2.Port != 80 { + t.Error(v2.Port) + } + if v2.Event != 123 { + t.Error(v2.Event) + } + if !v2.Compact { + t.Error(v2.Compact) + } + + if !reflect.DeepEqual(v1, v2) { + t.Errorf("%v != %v", v1, v2) + } +} diff --git a/tracker/tracker.go b/tracker/tracker.go new file mode 100644 index 0000000..8eda329 --- /dev/null +++ b/tracker/tracker.go @@ -0,0 +1,266 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tracker supplies some common type interfaces of the BT tracker +// protocol. +package tracker + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "net" + "net/url" + + "github.com/xgfone/bt/metainfo" + "github.com/xgfone/bt/tracker/httptracker" + "github.com/xgfone/bt/tracker/udptracker" +) + +// Predefine some announce events. +// +// BEP 3 +const ( + None uint32 = iota + Completed // The local peer just completed the torrent. + Started // The local peer has just resumed this torrent. + Stopped // The local peer is leaving the swarm. +) + +// AnnounceRequest is the common Announce request. +// +// BEP 3, 15 +type AnnounceRequest struct { + InfoHash metainfo.Hash + PeerID metainfo.Hash + + Uploaded int64 + Downloaded int64 + Left int64 + Event uint32 + + IP net.IP + Key int32 + NumWant int32 // -1 for default + Port uint16 +} + +// ToHTTPAnnounceRequest creates a new httptracker.AnnounceRequest from itself. +func (ar AnnounceRequest) ToHTTPAnnounceRequest() httptracker.AnnounceRequest { + var ip string + if len(ar.IP) != 0 { + ip = ar.IP.String() + } + + return httptracker.AnnounceRequest{ + InfoHash: ar.InfoHash, + PeerID: ar.PeerID, + Uploaded: ar.Uploaded, + Downloaded: ar.Downloaded, + Left: ar.Left, + Port: ar.Port, + IP: ip, + Event: ar.Event, + NumWant: ar.NumWant, + Key: ar.Key, + } +} + +// ToUDPAnnounceRequest creates a new udptracker.AnnounceRequest from itself. +func (ar AnnounceRequest) ToUDPAnnounceRequest() udptracker.AnnounceRequest { + return udptracker.AnnounceRequest{ + InfoHash: ar.InfoHash, + PeerID: ar.PeerID, + Downloaded: ar.Downloaded, + Left: ar.Left, + Uploaded: ar.Uploaded, + Event: ar.Event, + IP: ar.IP, + Key: ar.Key, + NumWant: ar.NumWant, + Port: ar.Port, + } +} + +// AnnounceResponse is a common Announce response. +// +// BEP 3, 15 +type AnnounceResponse struct { + Interval uint32 + Leechers uint32 + Seeders uint32 + Addresses []metainfo.Address +} + +// FromHTTPAnnounceResponse sets itself from r. +func (ar *AnnounceResponse) FromHTTPAnnounceResponse(r httptracker.AnnounceResponse) { + ar.Interval = r.Interval + ar.Leechers = r.Incomplete + ar.Seeders = r.Complete + ar.Addresses = make([]metainfo.Address, 0, len(r.Peers)+len(r.Peers6)) + for _, peer := range r.Peers { + addrs, _ := peer.Addresses() + ar.Addresses = append(ar.Addresses, addrs...) + } + for _, peer := range r.Peers6 { + addrs, _ := peer.Addresses() + ar.Addresses = append(ar.Addresses, addrs...) + } +} + +// FromUDPAnnounceResponse sets itself from r. +func (ar *AnnounceResponse) FromUDPAnnounceResponse(r udptracker.AnnounceResponse) { + ar.Interval = r.Interval + ar.Leechers = r.Leechers + ar.Seeders = r.Seeders + ar.Addresses = r.Addresses +} + +// ScrapeResponseResult is a commont Scrape response result. +type ScrapeResponseResult struct { + // Seeders is the number of active peers that have completed downloading. + Seeders uint32 `bencode:"complete"` // BEP 15, 48 + + // Leechers is the number of active peers that have not completed downloading. + Leechers uint32 `bencode:"incomplete"` // BEP 15, 48 + + // Completed is the total number of peers that have ever completed downloading. + Completed uint32 `bencode:"downloaded"` // BEP 15, 48 +} + +// EncodeTo encodes the response to buf. +func (r ScrapeResponseResult) EncodeTo(buf *bytes.Buffer) { + binary.Write(buf, binary.BigEndian, r.Seeders) + binary.Write(buf, binary.BigEndian, r.Completed) + binary.Write(buf, binary.BigEndian, r.Leechers) +} + +// DecodeFrom decodes the response from b. +func (r *ScrapeResponseResult) DecodeFrom(b []byte) { + r.Seeders = binary.BigEndian.Uint32(b[:4]) + r.Completed = binary.BigEndian.Uint32(b[4:8]) + r.Leechers = binary.BigEndian.Uint32(b[8:12]) +} + +// ScrapeResponse is a commont Scrape response. +type ScrapeResponse map[metainfo.Hash]ScrapeResponseResult + +// FromHTTPScrapeResponse sets itself from r. +func (sr ScrapeResponse) FromHTTPScrapeResponse(r httptracker.ScrapeResponse) { + for k, v := range r.Files { + sr[k] = ScrapeResponseResult{ + Seeders: v.Complete, + Leechers: v.Incomplete, + Completed: v.Downloaded, + } + } +} + +// FromUDPScrapeResponse sets itself from hs and r. +func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash, + r []udptracker.ScrapeResponse) { + klen := len(hs) + if _len := len(r); _len < klen { + klen = _len + } + + for i := 0; i < klen; i++ { + sr[hs[i]] = ScrapeResponseResult{ + Seeders: r[i].Seeders, + Leechers: r[i].Leechers, + Completed: r[i].Completed, + } + } +} + +// Client is the interface of BT tracker client. +type Client interface { + Announce(AnnounceRequest) (AnnounceResponse, error) + Scrape([]metainfo.Hash) (ScrapeResponse, error) +} + +// NewClient returns a new Client. +func NewClient(connURL string) (c Client, err error) { + u, err := url.Parse(connURL) + if err == nil { + switch u.Scheme { + case "http", "https": + c = &tclient{http: httptracker.NewTrackerClient(connURL, "")} + case "udp", "udp4", "udp6": + var utc *udptracker.TrackerClient + utc, err = udptracker.NewTrackerClientByDial(u.Scheme, u.Host) + if err == nil { + var e []udptracker.Extension + if p := u.RequestURI(); p != "" { + e = []udptracker.Extension{udptracker.NewURLData([]byte(p))} + } + c = &tclient{exts: e, udp: utc} + } + default: + err = fmt.Errorf("unknown url scheme '%s'", u.Scheme) + } + } + return +} + +type tclient struct { + http *httptracker.TrackerClient // BEP 3 + udp *udptracker.TrackerClient // BEP 15 + exts []udptracker.Extension // BEP 41 +} + +func (c *tclient) Announce(req AnnounceRequest) (resp AnnounceResponse, err error) { + if c.http != nil { + var r httptracker.AnnounceResponse + if r, err = c.http.Announce(req.ToHTTPAnnounceRequest()); err != nil { + return + } else if r.FailureReason != "" { + err = errors.New(r.FailureReason) + return + } + resp.FromHTTPAnnounceResponse(r) + return + } + + r := req.ToUDPAnnounceRequest() + r.Exts = c.exts + rs, err := c.udp.Announce(r) + if err == nil { + resp.FromUDPAnnounceResponse(rs) + } + return +} + +func (c *tclient) Scrape(hs []metainfo.Hash) (resp ScrapeResponse, err error) { + if c.http != nil { + var r httptracker.ScrapeResponse + if r, err = c.http.Scrape(hs); err != nil { + return + } else if r.FailureReason != "" { + err = errors.New(r.FailureReason) + return + } + resp = make(ScrapeResponse, len(r.Files)) + resp.FromHTTPScrapeResponse(r) + return + } + + r, err := c.udp.Scrape(hs) + if err == nil { + resp = make(ScrapeResponse, len(r)) + resp.FromUDPScrapeResponse(hs, r) + } + return +} diff --git a/tracker/tracker_test.go b/tracker/tracker_test.go new file mode 100644 index 0000000..81455b8 --- /dev/null +++ b/tracker/tracker_test.go @@ -0,0 +1,121 @@ +package tracker + +import ( + "errors" + "fmt" + "log" + "net" + "time" + + "github.com/xgfone/bt/metainfo" + "github.com/xgfone/bt/tracker/udptracker" +) + +type testHandler struct{} + +func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return } +func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) ( + r udptracker.AnnounceResponse, err error) { + if req.Port != 80 { + err = errors.New("port is not 80") + return + } + + if len(req.Exts) > 0 { + for i, ext := range req.Exts { + switch ext.Type { + case udptracker.URLData: + fmt.Printf("Extensions[%d]: URLData(%s)\n", i, string(ext.Data)) + default: + fmt.Printf("Extensions[%d]: %s\n", i, ext.Type.String()) + } + } + } + + r = udptracker.AnnounceResponse{ + Interval: 1, + Leechers: 2, + Seeders: 3, + Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}}, + } + return +} +func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ( + rs []udptracker.ScrapeResponse, err error) { + rs = make([]udptracker.ScrapeResponse, len(infohashes)) + for i := range infohashes { + rs[i] = udptracker.ScrapeResponse{ + Seeders: uint32(i)*10 + 1, + Leechers: uint32(i)*10 + 2, + Completed: uint32(i)*10 + 3, + } + } + return +} + +func ExampleClient() { + // Start the UDP tracker server + sconn, err := net.ListenPacket("udp4", "127.0.0.1:8001") + if err != nil { + log.Fatal(err) + } + server := udptracker.NewTrackerServer(sconn, testHandler{}) + defer server.Close() + go server.Run() + + // Wait for the server to be started + time.Sleep(time.Second) + + // Create a client and dial to the UDP tracker server. + client, err := NewClient("udp://127.0.0.1:8001/path?a=1&b=2") + if err != nil { + log.Fatal(err) + } + + // Send the ANNOUNCE request to the UDP tracker server, + // and get the ANNOUNCE response. + req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80} + resp, err := client.Announce(req) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Interval: %d\n", resp.Interval) + fmt.Printf("Leechers: %d\n", resp.Leechers) + fmt.Printf("Seeders: %d\n", resp.Seeders) + for i, addr := range resp.Addresses { + fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String()) + fmt.Printf("Address[%d].Port: %d\n", i, addr.Port) + } + + // Send the SCRAPE request to the UDP tracker server, + // and get the SCRAPE respsone. + h1 := metainfo.Hash{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + h2 := metainfo.Hash{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2} + rs, err := client.Scrape([]metainfo.Hash{h1, h2}) + if err != nil { + log.Fatal(err) + } else if len(rs) != 2 { + log.Fatalf("%+v", rs) + } + + for i, r := range rs { + fmt.Printf("%s.Seeders: %d\n", i.HexString(), r.Seeders) + fmt.Printf("%s.Leechers: %d\n", i.HexString(), r.Leechers) + fmt.Printf("%s.Completed: %d\n", i.HexString(), r.Completed) + } + + // Output: + // Extensions[0]: URLData(/path?a=1&b=2) + // Interval: 1 + // Leechers: 2 + // Seeders: 3 + // Address[0].IP: 127.0.0.1 + // Address[0].Port: 8001 + // 0101010101010101010101010101010101010101.Seeders: 1 + // 0101010101010101010101010101010101010101.Leechers: 2 + // 0101010101010101010101010101010101010101.Completed: 3 + // 0202020202020202020202020202020202020202.Seeders: 11 + // 0202020202020202020202020202020202020202.Leechers: 12 + // 0202020202020202020202020202020202020202.Completed: 13 +} diff --git a/tracker/udptracker/udp.go b/tracker/udptracker/udp.go new file mode 100644 index 0000000..90fef57 --- /dev/null +++ b/tracker/udptracker/udp.go @@ -0,0 +1,271 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package udptracker implements the tracker protocol based on UDP. +// +// You can use the package to implement a UDP tracker server to track the +// information that other peers upload or download the file, or to create +// a UDP tracker client to communicate with the UDP tracker server. +package udptracker + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + + "github.com/xgfone/bt/metainfo" +) + +// ProtocolID is magic constant for the udp tracker connection. +// +// BEP 15 +const ProtocolID = uint64(0x41727101980) + +// Predefine some actions. +// +// BEP 15 +const ( + ActionConnect = uint32(0) + ActionAnnounce = uint32(1) + ActionScrape = uint32(2) + ActionError = uint32(3) +) + +// AnnounceRequest represents the announce request used by UDP tracker. +// +// BEP 15 +type AnnounceRequest struct { + InfoHash metainfo.Hash + PeerID metainfo.Hash + + Downloaded int64 + Left int64 + Uploaded int64 + Event uint32 + + IP net.IP + Key int32 + NumWant int32 // -1 for default + Port uint16 + + Exts []Extension // BEP 41 +} + +// DecodeFrom decodes the request from b. +func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) { + r.InfoHash = metainfo.NewHash(b[0:20]) + r.PeerID = metainfo.NewHash(b[20:40]) + r.Downloaded = int64(binary.BigEndian.Uint64(b[40:48])) + r.Left = int64(binary.BigEndian.Uint64(b[48:56])) + r.Uploaded = int64(binary.BigEndian.Uint64(b[56:64])) + r.Event = binary.BigEndian.Uint32(b[64:68]) + + if ipv4 { + r.IP = make(net.IP, net.IPv4len) + copy(r.IP, b[68:72]) + b = b[72:] + } else { + r.IP = make(net.IP, net.IPv6len) + copy(r.IP, b[68:84]) + b = b[84:] + } + + r.Key = int32(binary.BigEndian.Uint32(b[0:4])) + r.NumWant = int32(binary.BigEndian.Uint32(b[4:8])) + r.Port = binary.BigEndian.Uint16(b[8:10]) + + b = b[10:] + for len(b) > 0 { + var ext Extension + parsed := ext.DecodeFrom(b) + r.Exts = append(r.Exts, ext) + b = b[parsed:] + } +} + +// EncodeTo encodes the request to buf. +func (r AnnounceRequest) EncodeTo(buf *bytes.Buffer) { + buf.Grow(82) + buf.Write(r.InfoHash[:]) + buf.Write(r.PeerID[:]) + + binary.Write(buf, binary.BigEndian, r.Downloaded) + binary.Write(buf, binary.BigEndian, r.Left) + binary.Write(buf, binary.BigEndian, r.Uploaded) + binary.Write(buf, binary.BigEndian, r.Event) + + if ip := r.IP.To4(); ip != nil { + buf.Write(ip[:]) + } else { + buf.Write(r.IP[:]) + } + + binary.Write(buf, binary.BigEndian, r.Key) + binary.Write(buf, binary.BigEndian, r.NumWant) + binary.Write(buf, binary.BigEndian, r.Port) + + for _, ext := range r.Exts { + ext.EncodeTo(buf) + } +} + +// AnnounceResponse represents the announce response used by UDP tracker. +// +// BEP 15 +type AnnounceResponse struct { + Interval uint32 + Leechers uint32 + Seeders uint32 + Addresses []metainfo.Address +} + +// EncodeTo encodes the response to buf. +func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer) { + buf.Grow(12 + len(r.Addresses)*18) + binary.Write(buf, binary.BigEndian, r.Interval) + binary.Write(buf, binary.BigEndian, r.Leechers) + binary.Write(buf, binary.BigEndian, r.Seeders) + for _, addr := range r.Addresses { + if ip := addr.IP.To4(); ip != nil { + buf.Write(ip[:]) + } else { + buf.Write(addr.IP[:]) + } + binary.Write(buf, binary.BigEndian, addr.Port) + } +} + +// DecodeFrom decodes the response from b. +func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) { + r.Interval = binary.BigEndian.Uint32(b[:4]) + r.Leechers = binary.BigEndian.Uint32(b[4:8]) + r.Seeders = binary.BigEndian.Uint32(b[8:12]) + + b = b[12:] + iplen := net.IPv6len + if ipv4 { + iplen = net.IPv4len + } + + _len := len(b) + step := iplen + 2 + r.Addresses = make([]metainfo.Address, 0, _len/step) + for i := step; i <= _len; i += step { + ip := make(net.IP, iplen) + copy(ip, b[i-step:i-2]) + port := binary.BigEndian.Uint16(b[i-2 : i]) + r.Addresses = append(r.Addresses, metainfo.Address{IP: ip, Port: port}) + } +} + +// ScrapeResponse represents the UDP SCRAPE response. +// +// BEP 15 +type ScrapeResponse struct { + Seeders uint32 + Leechers uint32 + Completed uint32 +} + +// EncodeTo encodes the response to buf. +func (r ScrapeResponse) EncodeTo(buf *bytes.Buffer) { + binary.Write(buf, binary.BigEndian, r.Seeders) + binary.Write(buf, binary.BigEndian, r.Completed) + binary.Write(buf, binary.BigEndian, r.Leechers) +} + +// DecodeFrom decodes the response from b. +func (r *ScrapeResponse) DecodeFrom(b []byte) { + r.Seeders = binary.BigEndian.Uint32(b[:4]) + r.Completed = binary.BigEndian.Uint32(b[4:8]) + r.Leechers = binary.BigEndian.Uint32(b[8:12]) +} + +// Predefine some UDP extension types. +// +// BEP 41 +const ( + EndOfOptions ExtensionType = iota + Nop + URLData +) + +// NewEndOfOptions returns a new EndOfOptions UDP extension. +func NewEndOfOptions() Extension { return Extension{Type: EndOfOptions} } + +// NewNop returns a new Nop UDP extension. +func NewNop() Extension { return Extension{Type: Nop} } + +// NewURLData returns a new URLData UDP extension. +func NewURLData(data []byte) Extension { + return Extension{Type: URLData, Length: uint8(len(data)), Data: data} +} + +// ExtensionType represents the type of UDP extension. +type ExtensionType uint8 + +func (et ExtensionType) String() string { + switch et { + case EndOfOptions: + return "EndOfOptions" + case Nop: + return "Nop" + case URLData: + return "URLData" + default: + return fmt.Sprintf("ExtensionType(%d)", et) + } +} + +// Extension represent the extension used by the UDP ANNOUNCE request. +// +// BEP 41 +type Extension struct { + Type ExtensionType + Length uint8 + Data []byte +} + +// EncodeTo encodes the response to buf. +func (e Extension) EncodeTo(buf *bytes.Buffer) { + if _len := uint8(len(e.Data)); e.Length == 0 && _len != 0 { + e.Length = _len + } else if _len != e.Length { + panic("the length of data is inconsistent") + } + + buf.WriteByte(byte(e.Type)) + buf.WriteByte(e.Length) + buf.Write(e.Data) +} + +// DecodeFrom decodes the response from b. +func (e *Extension) DecodeFrom(b []byte) (parsed int) { + switch len(b) { + case 0: + case 1: + e.Type = ExtensionType(b[0]) + parsed = 1 + default: + e.Type = ExtensionType(b[0]) + e.Length = b[1] + parsed = 2 + if e.Length > 0 { + parsed += int(e.Length) + e.Data = b[2:parsed] + } + } + return +} diff --git a/tracker/udptracker/udp_client.go b/tracker/udptracker/udp_client.go new file mode 100644 index 0000000..4e24803 --- /dev/null +++ b/tracker/udptracker/udp_client.go @@ -0,0 +1,289 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udptracker + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "net" + "strings" + "sync/atomic" + "time" + + "github.com/xgfone/bt/metainfo" +) + +// NewTrackerClientByDial returns a new TrackerClient by dialing. +func NewTrackerClientByDial(network, address string, c ...TrackerClientConfig) ( + *TrackerClient, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + + return NewTrackerClient(conn.(*net.UDPConn), c...), nil +} + +// NewTrackerClient returns a new TrackerClient. +func NewTrackerClient(conn *net.UDPConn, c ...TrackerClientConfig) *TrackerClient { + var conf TrackerClientConfig + conf.set(c...) + ipv4 := strings.Contains(conn.LocalAddr().String(), ".") + return &TrackerClient{conn: conn, conf: conf, ipv4: ipv4} +} + +// TrackerClientConfig is used to configure the TrackerClient. +type TrackerClientConfig struct { + // ReadTimeout is used to receive the response. + ReadTimeout time.Duration // Default: 5s + MaxBufSize int // Default: 2048 +} + +func (c *TrackerClientConfig) set(conf ...TrackerClientConfig) { + if len(conf) > 0 { + *c = conf[0] + } + + if c.MaxBufSize <= 0 { + c.MaxBufSize = 2048 + } + if c.ReadTimeout <= 0 { + c.ReadTimeout = time.Second * 5 + } +} + +// TrackerClient is a tracker client based on UDP. +// +// Notice: the request is synchronized, that's, the last request is not returned, +// the next request must not be sent. +// +// BEP 15 +type TrackerClient struct { + ipv4 bool + conf TrackerClientConfig + conn *net.UDPConn + last time.Time + cid uint64 + tid uint32 +} + +// Close closes the UDP tracker client. +func (utc *TrackerClient) Close() { utc.conn.Close() } + +func (utc *TrackerClient) readResp(b []byte) (int, error) { + utc.conn.SetReadDeadline(time.Now().Add(utc.conf.ReadTimeout)) + return utc.conn.Read(b) +} + +func (utc *TrackerClient) getTranID() uint32 { + return atomic.AddUint32(&utc.tid, 1) +} + +func (utc *TrackerClient) parseError(b []byte) (tid uint32, reason string) { + tid = binary.BigEndian.Uint32(b[:4]) + reason = string(b[4:]) + return +} + +func (utc *TrackerClient) send(b []byte) (err error) { + n, err := utc.conn.Write(b) + if err == nil && n < len(b) { + err = io.ErrShortWrite + } + return +} + +func (utc *TrackerClient) connect() (err error) { + tid := utc.getTranID() + buf := bytes.NewBuffer(make([]byte, 0, 16)) + binary.Write(buf, binary.BigEndian, ProtocolID) + binary.Write(buf, binary.BigEndian, ActionConnect) + binary.Write(buf, binary.BigEndian, tid) + if err = utc.send(buf.Bytes()); err != nil { + return + } + + data := make([]byte, 32) + n, err := utc.readResp(data) + if err != nil { + return + } else if n < 8 { + return io.ErrShortBuffer + } + + data = data[:n] + switch binary.BigEndian.Uint32(data[:4]) { + case ActionConnect: + case ActionError: + _, reason := utc.parseError(data[4:]) + return errors.New(reason) + default: + return errors.New("tracker response not connect action") + } + + if n < 16 { + return io.ErrShortBuffer + } + + if binary.BigEndian.Uint32(data[4:8]) != tid { + return errors.New("invalid transaction id") + } + + utc.cid = binary.BigEndian.Uint64(data[8:16]) + utc.last = time.Now() + return +} + +func (utc *TrackerClient) getConnectionID() (cid uint64, err error) { + cid = utc.cid + if time.Now().Sub(utc.last) > time.Minute { + if err = utc.connect(); err == nil { + cid = utc.cid + } + } + return +} + +func (utc *TrackerClient) announce(req AnnounceRequest) (r AnnounceResponse, err error) { + cid, err := utc.getConnectionID() + if err != nil { + return + } + + tid := utc.getTranID() + buf := bytes.NewBuffer(make([]byte, 0, 110)) + binary.Write(buf, binary.BigEndian, cid) + binary.Write(buf, binary.BigEndian, ActionAnnounce) + binary.Write(buf, binary.BigEndian, tid) + req.EncodeTo(buf) + b := buf.Bytes() + if err = utc.send(b); err != nil { + return + } + + data := make([]byte, utc.conf.MaxBufSize) + n, err := utc.readResp(data) + if err != nil { + return + } else if n < 8 { + err = io.ErrShortBuffer + return + } + + data = data[:n] + switch binary.BigEndian.Uint32(data[:4]) { + case ActionAnnounce: + case ActionError: + _, reason := utc.parseError(data[4:]) + err = errors.New(reason) + return + default: + err = errors.New("tracker response not connect action") + return + } + + if n < 16 { + err = io.ErrShortBuffer + return + } + + if binary.BigEndian.Uint32(data[4:8]) != tid { + err = errors.New("invalid transaction id") + return + } + + r.DecodeFrom(data[8:], utc.ipv4) + return +} + +// Announce sends a Announce request to the tracker. +// +// Notice: +// 1. if it does not connect to the UDP tracker server, it will connect to it, +// then send the ANNOUNCE request. +// 2. If returning an error, you should retry it. +// See http://www.bittorrent.org/beps/bep_0015.html#time-outs +func (utc *TrackerClient) Announce(r AnnounceRequest) (AnnounceResponse, error) { + return utc.announce(r) +} + +func (utc *TrackerClient) scrape(infohashes []metainfo.Hash) (rs []ScrapeResponse, err error) { + cid, err := utc.getConnectionID() + if err != nil { + return + } + + tid := utc.getTranID() + buf := bytes.NewBuffer(make([]byte, 0, 16+len(infohashes)*20)) + binary.Write(buf, binary.BigEndian, cid) + binary.Write(buf, binary.BigEndian, ActionScrape) + binary.Write(buf, binary.BigEndian, tid) + for _, h := range infohashes { + buf.Write(h[:]) + } + if err = utc.send(buf.Bytes()); err != nil { + return + } + + data := make([]byte, utc.conf.MaxBufSize) + n, err := utc.readResp(data) + if err != nil { + return + } else if n < 8 { + err = io.ErrShortBuffer + return + } + + data = data[:n] + switch binary.BigEndian.Uint32(data[:4]) { + case ActionScrape: + case ActionError: + _, reason := utc.parseError(data[4:]) + err = errors.New(reason) + return + default: + err = errors.New("tracker response not connect action") + return + } + + if binary.BigEndian.Uint32(data[4:8]) != tid { + err = errors.New("invalid transaction id") + return + } + + data = data[8:] + _len := len(data) + rs = make([]ScrapeResponse, 0, _len/12) + for i := 12; i <= _len; i += 12 { + var r ScrapeResponse + r.DecodeFrom(data[i-12 : i]) + rs = append(rs, r) + } + + return +} + +// Scrape sends a Scrape request to the tracker. +// +// Notice: +// 1. if it does not connect to the UDP tracker server, it will connect to it, +// then send the ANNOUNCE request. +// 2. If returning an error, you should retry it. +// See http://www.bittorrent.org/beps/bep_0015.html#time-outs +func (utc *TrackerClient) Scrape(hs []metainfo.Hash) ([]ScrapeResponse, error) { + return utc.scrape(hs) +} diff --git a/tracker/udptracker/udp_server.go b/tracker/udptracker/udp_server.go new file mode 100644 index 0000000..e87e8d5 --- /dev/null +++ b/tracker/udptracker/udp_server.go @@ -0,0 +1,281 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udptracker + +import ( + "bytes" + "encoding/binary" + "log" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/xgfone/bt/metainfo" +) + +// TrackerServerHandler is used to handle the request from the client. +type TrackerServerHandler interface { + // OnConnect is used to check whether to make the connection or not. + OnConnect(raddr *net.UDPAddr) (err error) + OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (AnnounceResponse, error) + OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ([]ScrapeResponse, error) +} + +func encodeResponseHeader(buf *bytes.Buffer, action, tid uint32) { + binary.Write(buf, binary.BigEndian, action) + binary.Write(buf, binary.BigEndian, tid) +} + +type wrappedPeerAddr struct { + Addr *net.UDPAddr + Time time.Time +} + +// TrackerServerConfig is used to configure the TrackerServer. +type TrackerServerConfig struct { + MaxBufSize int // Default: 2048 + ErrorLog func(format string, args ...interface{}) // Default: log.Printf +} + +func (c *TrackerServerConfig) setDefault() { + if c.MaxBufSize <= 0 { + c.MaxBufSize = 2048 + } + if c.ErrorLog == nil { + c.ErrorLog = log.Printf + } +} + +// TrackerServer is a tracker server based on UDP. +type TrackerServer struct { + conn net.PacketConn + conf TrackerServerConfig + handler TrackerServerHandler + bufpool sync.Pool + + cid uint64 + exit chan struct{} + lock sync.RWMutex + conns map[uint64]wrappedPeerAddr +} + +// NewTrackerServer returns a new TrackerServer. +func NewTrackerServer(c net.PacketConn, h TrackerServerHandler, + config ...TrackerServerConfig) *TrackerServer { + var conf TrackerServerConfig + if len(config) > 0 { + conf = config[0] + } + conf.setDefault() + + s := &TrackerServer{ + conf: conf, + conn: c, + handler: h, + exit: make(chan struct{}), + conns: make(map[uint64]wrappedPeerAddr, 128), + } + s.bufpool.New = func() interface{} { return make([]byte, conf.MaxBufSize) } + + return s +} + +// Close closes the tracker server. +func (uts *TrackerServer) Close() { + select { + case <-uts.exit: + default: + close(uts.exit) + uts.conn.Close() + } +} + +func (uts *TrackerServer) cleanConnectionID(interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + for { + select { + case <-uts.exit: + return + case now := <-tick.C: + uts.lock.RLock() + for cid, wa := range uts.conns { + if now.Sub(wa.Time) > interval { + delete(uts.conns, cid) + } + } + uts.lock.RUnlock() + } + } +} + +// Run starts the tracker server. +func (uts *TrackerServer) Run() { + go uts.cleanConnectionID(time.Minute * 2) + for { + buf := uts.bufpool.Get().([]byte) + n, raddr, err := uts.conn.ReadFrom(buf) + if err != nil { + if !strings.Contains(err.Error(), "closed") { + uts.conf.ErrorLog("failed to read udp tracker request: %s", err) + } + return + } else if n < 16 { + continue + } + go uts.handleRequest(raddr.(*net.UDPAddr), buf, n) + } +} + +func (uts *TrackerServer) handleRequest(raddr *net.UDPAddr, buf []byte, n int) { + defer uts.bufpool.Put(buf) + uts.handlePacket(raddr, buf[:n]) +} + +func (uts *TrackerServer) send(raddr *net.UDPAddr, b []byte) { + n, err := uts.conn.WriteTo(b, raddr) + if err != nil { + uts.conf.ErrorLog("fail to send the udp tracker response to '%s': %s", + raddr.String(), err) + } else if n < len(b) { + uts.conf.ErrorLog("too short udp tracker response sent to '%s'", raddr.String()) + } +} + +func (uts *TrackerServer) getConnectionID() uint64 { + return atomic.AddUint64(&uts.cid, 1) +} + +func (uts *TrackerServer) addConnection(cid uint64, raddr *net.UDPAddr) { + now := time.Now() + uts.lock.Lock() + uts.conns[cid] = wrappedPeerAddr{Addr: raddr, Time: now} + uts.lock.Unlock() +} + +func (uts *TrackerServer) checkConnection(cid uint64, raddr *net.UDPAddr) (ok bool) { + uts.lock.RLock() + if w, _ok := uts.conns[cid]; _ok && w.Addr.Port == raddr.Port && + bytes.Equal(w.Addr.IP, raddr.IP) { + ok = true + } + uts.lock.RUnlock() + return +} + +func (uts *TrackerServer) sendError(raddr *net.UDPAddr, tid uint32, reason string) { + buf := bytes.NewBuffer(make([]byte, 0, 8+len(reason))) + encodeResponseHeader(buf, ActionError, tid) + buf.WriteString(reason) + uts.send(raddr, buf.Bytes()) +} + +func (uts *TrackerServer) sendConnResp(raddr *net.UDPAddr, tid uint32, cid uint64) { + buf := bytes.NewBuffer(make([]byte, 0, 16)) + encodeResponseHeader(buf, ActionConnect, tid) + binary.Write(buf, binary.BigEndian, cid) + uts.send(raddr, buf.Bytes()) +} + +func (uts *TrackerServer) sendAnnounceResp(raddr *net.UDPAddr, tid uint32, + resp AnnounceResponse) { + buf := bytes.NewBuffer(make([]byte, 0, 8+12+len(resp.Addresses)*18)) + encodeResponseHeader(buf, ActionAnnounce, tid) + resp.EncodeTo(buf) + uts.send(raddr, buf.Bytes()) +} + +func (uts *TrackerServer) sendScrapResp(raddr *net.UDPAddr, tid uint32, + rs []ScrapeResponse) { + buf := bytes.NewBuffer(make([]byte, 0, 8+len(rs)*12)) + encodeResponseHeader(buf, ActionScrape, tid) + for _, r := range rs { + r.EncodeTo(buf) + } + uts.send(raddr, buf.Bytes()) +} + +func (uts *TrackerServer) handlePacket(raddr *net.UDPAddr, b []byte) { + cid := binary.BigEndian.Uint64(b[:8]) + action := binary.BigEndian.Uint32(b[8:12]) + tid := binary.BigEndian.Uint32(b[12:16]) + b = b[16:] + + // Handle the connection request. + if cid == ProtocolID && action == ActionConnect { + if err := uts.handler.OnConnect(raddr); err != nil { + uts.sendError(raddr, tid, err.Error()) + return + } + + cid := uts.getConnectionID() + uts.addConnection(cid, raddr) + uts.sendConnResp(raddr, tid, cid) + return + } + + // Check whether the request is connected. + if !uts.checkConnection(cid, raddr) { + uts.sendError(raddr, tid, "connection is expired") + return + } + + switch action { + case ActionAnnounce: + var req AnnounceRequest + if raddr.IP.To4() != nil { // For ipv4 + if len(b) < 82 { + uts.sendError(raddr, tid, "invalid announce request") + return + } + req.DecodeFrom(b, true) + } else { // For ipv6 + if len(b) < 94 { + uts.sendError(raddr, tid, "invalid announce request") + return + } + req.DecodeFrom(b, false) + } + + resp, err := uts.handler.OnAnnounce(raddr, req) + if err != nil { + uts.sendError(raddr, tid, err.Error()) + } else { + uts.sendAnnounceResp(raddr, tid, resp) + } + case ActionScrape: + _len := len(b) + infohashes := make([]metainfo.Hash, 0, _len/20) + for i, _len := 20, len(b); i <= _len; i += 20 { + infohashes = append(infohashes, metainfo.NewHash(b[i-20:i])) + } + + if len(infohashes) == 0 { + uts.sendError(raddr, tid, "no infohash") + return + } + + resps, err := uts.handler.OnScrap(raddr, infohashes) + if err != nil { + uts.sendError(raddr, tid, err.Error()) + } else { + uts.sendScrapResp(raddr, tid, resps) + } + default: + uts.sendError(raddr, tid, "unkwnown action") + } +} diff --git a/tracker/udptracker/udp_test.go b/tracker/udptracker/udp_test.go new file mode 100644 index 0000000..e109417 --- /dev/null +++ b/tracker/udptracker/udp_test.go @@ -0,0 +1,135 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udptracker + +import ( + "errors" + "fmt" + "log" + "net" + "time" + + "github.com/xgfone/bt/metainfo" +) + +type testHandler struct{} + +func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return } +func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) ( + r AnnounceResponse, err error) { + if req.Port != 80 { + err = errors.New("port is not 80") + return + } + + if len(req.Exts) > 0 { + for i, ext := range req.Exts { + switch ext.Type { + case URLData: + fmt.Printf("Extensions[%d]: URLData(%s)\n", i, string(ext.Data)) + default: + fmt.Printf("Extensions[%d]: %s\n", i, ext.Type.String()) + } + } + } + + r = AnnounceResponse{ + Interval: 1, + Leechers: 2, + Seeders: 3, + Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}}, + } + return +} +func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ( + rs []ScrapeResponse, err error) { + rs = make([]ScrapeResponse, len(infohashes)) + for i := range infohashes { + rs[i] = ScrapeResponse{ + Seeders: uint32(i)*10 + 1, + Leechers: uint32(i)*10 + 2, + Completed: uint32(i)*10 + 3, + } + } + return +} + +func ExampleTrackerClient() { + // Start the UDP tracker server + sconn, err := net.ListenPacket("udp4", "127.0.0.1:8001") + if err != nil { + log.Fatal(err) + } + server := NewTrackerServer(sconn, testHandler{}) + defer server.Close() + go server.Run() + + // Wait for the server to be started + time.Sleep(time.Second) + + // Create a client and dial to the UDP tracker server. + client, err := NewTrackerClientByDial("udp4", "127.0.0.1:8001") + if err != nil { + log.Fatal(err) + } + + // Send the ANNOUNCE request to the UDP tracker server, + // and get the ANNOUNCE response. + exts := []Extension{NewURLData([]byte("data")), NewNop()} + req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80, Exts: exts} + resp, err := client.Announce(req) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Interval: %d\n", resp.Interval) + fmt.Printf("Leechers: %d\n", resp.Leechers) + fmt.Printf("Seeders: %d\n", resp.Seeders) + for i, addr := range resp.Addresses { + fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String()) + fmt.Printf("Address[%d].Port: %d\n", i, addr.Port) + } + + // Send the SCRAPE request to the UDP tracker server, + // and get the SCRAPE respsone. + hs := []metainfo.Hash{metainfo.NewRandomHash(), metainfo.NewRandomHash()} + rs, err := client.Scrape(hs) + if err != nil { + log.Fatal(err) + } else if len(rs) != 2 { + log.Fatalf("%+v", rs) + } + + for i, r := range rs { + fmt.Printf("%d.Seeders: %d\n", i, r.Seeders) + fmt.Printf("%d.Leechers: %d\n", i, r.Leechers) + fmt.Printf("%d.Completed: %d\n", i, r.Completed) + } + + // Output: + // Extensions[0]: URLData(data) + // Extensions[1]: Nop + // Interval: 1 + // Leechers: 2 + // Seeders: 3 + // Address[0].IP: 127.0.0.1 + // Address[0].Port: 8001 + // 0.Seeders: 1 + // 0.Leechers: 2 + // 0.Completed: 3 + // 1.Seeders: 11 + // 1.Leechers: 12 + // 1.Completed: 13 +} diff --git a/utils/bool.go b/utils/bool.go new file mode 100644 index 0000000..2bbcdd2 --- /dev/null +++ b/utils/bool.go @@ -0,0 +1,50 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync/atomic" +) + +// Bool is used to implement a atomic bool. +type Bool struct { + v uint32 +} + +// NewBool returns a new Bool with the initialized value. +func NewBool(t bool) (b Bool) { + if t { + b.v = 1 + } + return +} + +// Get returns the bool value. +func (b *Bool) Get() bool { return atomic.LoadUint32(&b.v) == 1 } + +// SetTrue sets the bool value to true. +func (b *Bool) SetTrue() { atomic.StoreUint32(&b.v, 1) } + +// SetFalse sets the bool value to false. +func (b *Bool) SetFalse() { atomic.StoreUint32(&b.v, 0) } + +// Set sets the bool value to t. +func (b *Bool) Set(t bool) { + if t { + b.SetTrue() + } else { + b.SetFalse() + } +} diff --git a/utils/doc.go b/utils/doc.go new file mode 100644 index 0000000..6103847 --- /dev/null +++ b/utils/doc.go @@ -0,0 +1,16 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package utils supplies some convenient functions. +package utils diff --git a/utils/io.go b/utils/io.go new file mode 100644 index 0000000..f91b063 --- /dev/null +++ b/utils/io.go @@ -0,0 +1,30 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "io" + +// CopyNBuffer is the same as io.CopyN, but uses the given buf as the buffer. +func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) { + written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf) + if written == n { + return n, nil + } else if written < n && err == nil { + // src stopped early; must have been EOF. + err = io.EOF + } + + return +} diff --git a/utils/slice.go b/utils/slice.go new file mode 100644 index 0000000..53ce3a6 --- /dev/null +++ b/utils/slice.go @@ -0,0 +1,26 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +// InStringSlice reports whether s is in ss. +func InStringSlice(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + + return false +} diff --git a/utils/slice_test.go b/utils/slice_test.go new file mode 100644 index 0000000..d67327f --- /dev/null +++ b/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInStringSlice(t *testing.T) { + if !InStringSlice([]string{"a", "b"}, "a") { + t.Fail() + } + + if InStringSlice([]string{"a", "b"}, "z") { + t.Fail() + } +} diff --git a/utils/string.go b/utils/string.go new file mode 100644 index 0000000..19865fa --- /dev/null +++ b/utils/string.go @@ -0,0 +1,24 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "crypto/rand" + +// RandomString generates a size-length string randomly. +func RandomString(size int) string { + bs := make([]byte, size) + rand.Read(bs) + return string(bs) +}