Merge branch 'main' into 'master'

Main

See merge request idk/bt!2
This commit is contained in:
idk
2023-05-14 20:07:13 +00:00
59 changed files with 1997 additions and 1691 deletions

View File

@ -4,7 +4,7 @@ env:
GO111MODULE: on GO111MODULE: on
jobs: jobs:
build: build:
runs-on: ubuntu-18.04 runs-on: ubuntu-22.04
name: Go ${{ matrix.go }} name: Go ${{ matrix.go }}
strategy: strategy:
matrix: matrix:
@ -17,10 +17,12 @@ jobs:
- '1.16' - '1.16'
- '1.17' - '1.17'
- '1.18' - '1.18'
- '1.19'
- '1.20'
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- run: go test -race ./... - run: go test -race -cover ./...

11
.gitignore vendored
View File

@ -33,11 +33,16 @@ test_*
# log # log
*.log *.log
# Mac vendor/
.DS_Store
# VS Code # VS Code
.vscode/ .vscode/
debug
debug_test
# Unix hidden files # Mac
.DS_Store
# hidden files
.* .*
_*

228
README.md
View File

@ -1,24 +1,23 @@
# BT - Another Implementation For Golang [![Build Status](https://github.com/xgfone/bt/actions/workflows/go.yml/badge.svg)](https://github.com/xgfone/bt/actions/workflows/go.yml) [![GoDoc](https://pkg.go.dev/badge/github.com/xgfone/bt)](https://pkg.go.dev/github.com/xgfone/bt) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://raw.githubusercontent.com/xgfone/bt/master/LICENSE) # BT - Another Implementation For Golang [![Build Status](https://github.com/xgfone/go-bt/actions/workflows/go.yml/badge.svg)](https://github.com/xgfone/go-bt/actions/workflows/go.yml) [![GoDoc](https://pkg.go.dev/badge/github.com/xgfone/go-bt)](https://pkg.go.dev/github.com/xgfone/go-bt) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://raw.githubusercontent.com/xgfone/go-bt/master/LICENSE)
A pure golang implementation of [BitTorrent](http://bittorrent.org/beps/bep_0000.html) library, which is inspired by [dht](https://github.com/shiyanhui/dht) and [torrent](https://github.com/anacrolix/torrent). A pure golang implementation of [BitTorrent](http://bittorrent.org/beps/bep_0000.html) library, which is inspired by [dht](https://github.com/shiyanhui/dht) and [torrent](https://github.com/anacrolix/torrent).
## Install ## Install
```shell
$ go get -u github.com/xgfone/bt
```
```shell
$ go get -u github.com/xgfone/go-bt
```
## Features ## Features
- Support `Go1.9+`. - Support `Go1.11+`.
- Support IPv4/IPv6. - Support IPv4/IPv6.
- Multi-BEPs implementation. - Multi-BEPs implementation.
- Pure Go implementation without `CGO`. - Pure Go implementation without `CGO`.
- Only library without any denpendencies. For the command tools, see [bttools](https://github.com/xgfone/bttools). - Only library without any denpendencies. For the command tools, see [bttools](https://github.com/xgfone/bttools).
## The Implemented Specifications ## The Implemented Specifications
- [x] [**BEP 03:** The BitTorrent Protocol Specification](http://bittorrent.org/beps/bep_0003.html) - [x] [**BEP 03:** The BitTorrent Protocol Specification](http://bittorrent.org/beps/bep_0003.html)
- [x] [**BEP 05:** DHT Protocol](http://bittorrent.org/beps/bep_0005.html) - [x] [**BEP 05:** DHT Protocol](http://bittorrent.org/beps/bep_0005.html)
- [x] [**BEP 06:** Fast Extension](http://bittorrent.org/beps/bep_0006.html) - [x] [**BEP 06:** Fast Extension](http://bittorrent.org/beps/bep_0006.html)
@ -37,219 +36,6 @@ $ go get -u github.com/xgfone/bt
- [ ] [**BEP 44:** Storing arbitrary data in the DHT](http://bittorrent.org/beps/bep_0044.html) - [ ] [**BEP 44:** Storing arbitrary data in the DHT](http://bittorrent.org/beps/bep_0044.html)
- [x] [**BEP 48:** Tracker Protocol Extension: Scrape](http://bittorrent.org/beps/bep_0048.html) - [x] [**BEP 48:** Tracker Protocol Extension: Scrape](http://bittorrent.org/beps/bep_0048.html)
## Example ## Example
See [godoc](https://pkg.go.dev/github.com/xgfone/bt) or [bttools](https://github.com/xgfone/bttools).
### Example 1: Download the file from the remote peer See [godoc](https://pkg.go.dev/github.com/xgfone/go-bt) or [bttools](https://github.com/xgfone/bttools).
```go
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/url"
"os"
"time"
"github.com/xgfone/bt/downloader"
"github.com/xgfone/bt/metainfo"
pp "github.com/xgfone/bt/peerprotocol"
"github.com/xgfone/bt/tracker"
)
var peeraddr string
func init() {
flag.StringVar(&peeraddr, "peeraddr", "", "The address of the peer storing the file.")
}
func getPeersFromTrackers(id, infohash metainfo.Hash, trackers []string) (peers []string) {
c, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp := tracker.GetPeers(c, id, infohash, trackers)
for r := range resp {
for _, addr := range r.Resp.Addresses {
addrs := addr.String()
nonexist := true
for _, peer := range peers {
if peer == addrs {
nonexist = false
break
}
}
if nonexist {
peers = append(peers, addrs)
}
}
}
return
}
func main() {
flag.Parse()
torrentfile := os.Args[1]
mi, err := metainfo.LoadFromFile(torrentfile)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
id := metainfo.NewRandomHash()
infohash := mi.InfoHash()
info, err := mi.Info()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
var peers []string
if peeraddr != "" {
peers = []string{peeraddr}
} else {
// Get the peers from the trackers in the torrent file.
trackers := mi.Announces().Unique()
if len(trackers) == 0 {
fmt.Println("no trackers")
return
}
peers = getPeersFromTrackers(id, infohash, trackers)
if len(peers) == 0 {
fmt.Println("no peers")
return
}
}
// We save the downloaded file to the current directory.
w := metainfo.NewWriter("", info, 0)
defer w.Close()
// We don't request the blocks from the remote peers concurrently,
// and it is only an example. But you can do it concurrently.
dm := newDownloadManager(w, info)
for peerslen := len(peers); peerslen > 0 && !dm.IsFinished(); {
peerslen--
peer := peers[peerslen]
peers = peers[:peerslen]
downloadFileFromPeer(peer, id, infohash, dm)
}
}
func downloadFileFromPeer(peer string, id, infohash metainfo.Hash, dm *downloadManager) {
pc, err := pp.NewPeerConnByDial(peer, id, infohash, time.Second*3)
if err != nil {
log.Printf("fail to dial '%s'", peer)
return
}
defer pc.Close()
dm.doing = false
pc.Timeout = time.Second * 10
if err = pc.Handshake(); err != nil {
log.Printf("fail to handshake with '%s': %s", peer, err)
return
}
info := dm.writer.Info()
bdh := downloader.NewBlockDownloadHandler(info, dm.OnBlock, dm.RequestBlock)
if err = bdh.OnHandShake(pc); err != nil {
log.Printf("handshake error with '%s': %s", peer, err)
return
}
var msg pp.Message
for !dm.IsFinished() {
switch msg, err = pc.ReadMsg(); err {
case nil:
switch err = pc.HandleMessage(msg, bdh); err {
case nil, pp.ErrChoked:
default:
log.Printf("fail to handle the msg from '%s': %s", peer, err)
return
}
case io.EOF:
log.Printf("got EOF from '%s'", peer)
return
default:
log.Printf("fail to read the msg from '%s': %s", peer, err)
return
}
}
}
func newDownloadManager(w metainfo.Writer, info metainfo.Info) *downloadManager {
length := info.Piece(0).Length()
return &downloadManager{writer: w, plength: length}
}
type downloadManager struct {
writer metainfo.Writer
pindex uint32
poffset uint32
plength int64
doing bool
}
func (dm *downloadManager) IsFinished() bool {
if dm.pindex >= uint32(dm.writer.Info().CountPieces()) {
return true
}
return false
}
func (dm *downloadManager) OnBlock(index, offset uint32, b []byte) (err error) {
if dm.pindex != index {
return fmt.Errorf("inconsistent piece: old=%d, new=%d", dm.pindex, index)
} else if dm.poffset != offset {
return fmt.Errorf("inconsistent offset for piece '%d': old=%d, new=%d",
index, dm.poffset, offset)
}
dm.doing = false
n, err := dm.writer.WriteBlock(index, offset, b)
if err == nil {
dm.poffset = offset + uint32(n)
dm.plength -= int64(n)
}
return
}
func (dm *downloadManager) RequestBlock(pc *pp.PeerConn) (err error) {
if dm.doing {
return
}
if dm.plength <= 0 {
dm.pindex++
if dm.IsFinished() {
return
}
dm.poffset = 0
dm.plength = dm.writer.Info().Piece(int(dm.pindex)).Length()
}
index := dm.pindex
begin := dm.poffset
length := uint32(downloader.BlockSize)
if length > uint32(dm.plength) {
length = uint32(dm.plength)
}
log.Printf("Request Block from '%s': index=%d, offset=%d, length=%d",
pc.RemoteAddr().String(), index, begin, length)
if err = pc.SendRequest(index, begin, length); err == nil {
dm.doing = true
}
return
}
```

View File

@ -17,6 +17,8 @@ package dht
import ( import (
"sync" "sync"
"time" "time"
"github.com/xgfone/go-bt/krpc"
) )
// Blacklist is used to manage the ip blacklist. // Blacklist is used to manage the ip blacklist.
@ -24,25 +26,24 @@ import (
// Notice: The implementation should clear the address existed for long time. // Notice: The implementation should clear the address existed for long time.
type Blacklist interface { type Blacklist interface {
// In reports whether the address, ip and port, is in the blacklist. // In reports whether the address, ip and port, is in the blacklist.
In(ip string, port int) bool In(krpc.Addr) bool
// If port is equal to 0, it should ignore port and only use ip when matching. // If port is equal to 0, it should ignore port and only use ip when matching.
Add(ip string, port int) Add(krpc.Addr)
// If port is equal to 0, it should delete the address by only the ip. // If port is equal to 0, it should delete the address by only the ip.
Del(ip string, port int) Del(krpc.Addr)
// Close is used to notice the implementation to release the underlying // Close is used to notice the implementation to release the underlying resource.
// resource.
Close() Close()
} }
type noopBlacklist struct{} type noopBlacklist struct{}
func (nbl noopBlacklist) In(ip string, port int) bool { return false } func (nbl noopBlacklist) In(krpc.Addr) bool { return false }
func (nbl noopBlacklist) Add(ip string, port int) {} func (nbl noopBlacklist) Add(krpc.Addr) {}
func (nbl noopBlacklist) Del(ip string, port int) {} func (nbl noopBlacklist) Del(krpc.Addr) {}
func (nbl noopBlacklist) Close() {} func (nbl noopBlacklist) Close() {}
// NewNoopBlacklist returns a no-op Blacklist. // NewNoopBlacklist returns a no-op Blacklist.
func NewNoopBlacklist() Blacklist { return noopBlacklist{} } func NewNoopBlacklist() Blacklist { return noopBlacklist{} }
@ -57,14 +58,14 @@ type logBlacklist struct {
logf func(string, ...interface{}) logf func(string, ...interface{})
} }
func (dbl logBlacklist) Add(ip string, port int) { func (l logBlacklist) Add(addr krpc.Addr) {
dbl.logf("add the blacklist: ip=%s, port=%d", ip, port) l.logf("add the addr '%s' into the blacklist", addr.String())
dbl.Blacklist.Add(ip, port) l.Blacklist.Add(addr)
} }
func (dbl logBlacklist) Del(ip string, port int) { func (l logBlacklist) Del(addr krpc.Addr) {
dbl.logf("delete the blacklist: ip=%s, port=%d", ip, port) l.logf("delete the addr '%s' from the blacklist", addr.String())
dbl.Blacklist.Del(ip, port) l.Blacklist.Del(addr)
} }
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> /// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@ -85,7 +86,7 @@ func NewMemoryBlacklist(maxnum int, duration time.Duration) Blacklist {
type wrappedPort struct { type wrappedPort struct {
Time time.Time Time time.Time
Enable bool Enable bool
Ports map[int]struct{} Ports map[uint16]struct{}
} }
type blacklist struct { type blacklist struct {
@ -123,11 +124,11 @@ func (bl *blacklist) Close() {
} }
// In reports whether the address, ip and port, is in the blacklist. // In reports whether the address, ip and port, is in the blacklist.
func (bl *blacklist) In(ip string, port int) (yes bool) { func (bl *blacklist) In(addr krpc.Addr) (yes bool) {
bl.lock.RLock() bl.lock.RLock()
if wp, ok := bl.ips[ip]; ok { if wp, ok := bl.ips[addr.IP.String()]; ok {
if wp.Enable { if wp.Enable {
_, yes = wp.Ports[port] _, yes = wp.Ports[addr.Port]
} else { } else {
yes = true yes = true
} }
@ -136,7 +137,8 @@ func (bl *blacklist) In(ip string, port int) (yes bool) {
return return
} }
func (bl *blacklist) Add(ip string, port int) { func (bl *blacklist) Add(addr krpc.Addr) {
ip := addr.IP.String()
bl.lock.Lock() bl.lock.Lock()
wp, ok := bl.ips[ip] wp, ok := bl.ips[ip]
if !ok { if !ok {
@ -149,30 +151,31 @@ func (bl *blacklist) Add(ip string, port int) {
bl.ips[ip] = wp bl.ips[ip] = wp
} }
if port < 1 { if addr.Port < 1 {
wp.Enable = false wp.Enable = false
wp.Ports = nil wp.Ports = nil
} else if wp.Ports == nil { } else if wp.Ports == nil {
wp.Ports = map[int]struct{}{port: {}} wp.Ports = map[uint16]struct{}{addr.Port: {}}
} else { } else {
wp.Ports[port] = struct{}{} wp.Ports[addr.Port] = struct{}{}
} }
wp.Time = time.Now() wp.Time = time.Now()
bl.lock.Unlock() bl.lock.Unlock()
} }
func (bl *blacklist) Del(ip string, port int) { func (bl *blacklist) Del(addr krpc.Addr) {
ip := addr.IP.String()
bl.lock.Lock() bl.lock.Lock()
if wp, ok := bl.ips[ip]; ok { if wp, ok := bl.ips[ip]; ok {
if port < 1 { if addr.Port < 1 {
delete(bl.ips, ip) delete(bl.ips, ip)
} else if wp.Enable { } else if wp.Enable {
switch len(wp.Ports) { switch len(wp.Ports) {
case 0, 1: case 0, 1:
delete(bl.ips, ip) delete(bl.ips, ip)
default: default:
delete(wp.Ports, port) delete(wp.Ports, addr.Port)
wp.Time = time.Now() wp.Time = time.Now()
} }
} }

View File

@ -15,8 +15,11 @@
package dht package dht
import ( import (
"net"
"testing" "testing"
"time" "time"
"github.com/xgfone/go-bt/krpc"
) )
func (bl *blacklist) portsLen() (n int) { func (bl *blacklist) portsLen() (n int) {
@ -42,12 +45,12 @@ func TestMemoryBlacklist(t *testing.T) {
bl := NewMemoryBlacklist(3, time.Second).(*blacklist) bl := NewMemoryBlacklist(3, time.Second).(*blacklist)
defer bl.Close() defer bl.Close()
bl.Add("1.1.1.1", 123) bl.Add(krpc.NewAddr(net.ParseIP("1.1.1.1"), 123))
bl.Add("1.1.1.1", 456) bl.Add(krpc.NewAddr(net.ParseIP("1.1.1.1"), 456))
bl.Add("1.1.1.1", 789) bl.Add(krpc.NewAddr(net.ParseIP("1.1.1.1"), 789))
bl.Add("2.2.2.2", 111) bl.Add(krpc.NewAddr(net.ParseIP("2.2.2.2"), 111))
bl.Add("3.3.3.3", 0) bl.Add(krpc.NewAddr(net.ParseIP("3.3.3.3"), 0))
bl.Add("4.4.4.4", 222) bl.Add(krpc.NewAddr(net.ParseIP("4.4.4.4"), 222))
ips := bl.getIPs() ips := bl.getIPs()
if len(ips) != 3 { if len(ips) != 3 {
@ -66,15 +69,15 @@ func TestMemoryBlacklist(t *testing.T) {
t.Errorf("expect port num 4, but got %d", n) t.Errorf("expect port num 4, but got %d", n)
} }
if bl.In("1.1.1.1", 111) || !bl.In("1.1.1.1", 123) { if bl.In(krpc.NewAddr(net.ParseIP("1.1.1.1"), 111)) || !bl.In(krpc.NewAddr(net.ParseIP("1.1.1.1"), 123)) {
t.Fail() t.Fail()
} }
if !bl.In("3.3.3.3", 111) || bl.In("4.4.4.4", 222) { if !bl.In(krpc.NewAddr(net.ParseIP("3.3.3.3"), 111)) || bl.In(krpc.NewAddr(net.ParseIP("4.4.4.4"), 222)) {
t.Fail() t.Fail()
} }
bl.Del("3.3.3.3", 0) bl.Del(krpc.NewAddr(net.ParseIP("3.3.3.3"), 0))
if bl.In("3.3.3.3", 111) { if bl.In(krpc.NewAddr(net.ParseIP("3.3.3.3"), 111)) {
t.Fail() t.Fail()
} }
} }

View File

@ -25,9 +25,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/bt/krpc" "github.com/xgfone/go-bt/krpc"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
const ( const (
@ -53,7 +53,7 @@ type Result struct {
// Addr is the address of the peer where the request is sent to. // Addr is the address of the peer where the request is sent to.
// //
// Notice: it may be nil for "get_peers" request. // Notice: it may be nil for "get_peers" request.
Addr *net.UDPAddr Addr net.Addr
// For Error // For Error
Code int // 0 represents the success. Code int // 0 represents the success.
@ -61,7 +61,7 @@ type Result struct {
Timeout bool // Timeout indicates whether the response is timeout. Timeout bool // Timeout indicates whether the response is timeout.
// The list of the address of the peers returned by GetPeers. // The list of the address of the peers returned by GetPeers.
Peers []metainfo.Address Peers []string
} }
// Config is used to configure the DHT server. // Config is used to configure the DHT server.
@ -134,18 +134,24 @@ type Config struct {
// The default is log.Printf. // The default is log.Printf.
ErrorLog func(format string, args ...interface{}) ErrorLog func(format string, args ...interface{})
// They is used to convert the address between net.Addr and krpc.Addr.
//
// For the default, it asserts net.Addr to *net.UDPAddr.
GetNetAddr func(krpc.Addr) net.Addr
GetKrpcAddr func(net.Addr) krpc.Addr
// OnSearch is called when someone searches the torrent infohash, // OnSearch is called when someone searches the torrent infohash,
// that's, the "get_peers" query. // that's, the "get_peers" query.
// //
// The default callback does noting. // The default callback does noting.
OnSearch func(infohash string, ip net.IP, port uint16) OnSearch func(infohash string, addr krpc.Addr)
// OnTorrent is called when someone has the torrent infohash // OnTorrent is called when someone has the torrent infohash
// or someone has just downloaded the torrent infohash, // or someone has just downloaded the torrent infohash,
// that's, the "get_peers" response or "announce_peer" query. // that's, the "get_peers" response or "announce_peer" query.
// //
// The default callback does noting. // The default callback does noting.
OnTorrent func(infohash string, ip net.IP, port uint16) OnTorrent func(infohash string, addr string)
// HandleInMessage is used to intercept the incoming DHT message. // HandleInMessage is used to intercept the incoming DHT message.
// For example, you can debug the message as the log. // For example, you can debug the message as the log.
@ -153,23 +159,36 @@ type Config struct {
// Return true if going on handling by the default. Or return false. // Return true if going on handling by the default. Or return false.
// //
// The default is nil. // The default is nil.
HandleInMessage func(*net.UDPAddr, *krpc.Message) bool HandleInMessage func(net.Addr, *krpc.Message) (handled bool)
// HandleOutMessage is used to intercept the outgoing DHT message. // HandleOutMessage is used to intercept the outgoing DHT message.
// For example, you can debug the message as the log. // For example, you can debug the message as the log.
// //
// Return (false, nil) if going on handling by the default. // Return (true, nil) if not going on handling by the default.
// //
// The default is nil. // The default is nil.
HandleOutMessage func(*net.UDPAddr, *krpc.Message) (wrote bool, err error) HandleOutMessage func(net.Addr, *krpc.Message) (wrote bool, err error)
} }
func (c Config) in(*net.UDPAddr, *krpc.Message) bool { return true } func (c Config) in(net.Addr, *krpc.Message) bool { return false }
func (c Config) out(*net.UDPAddr, *krpc.Message) (bool, error) { return false, nil } func (c Config) out(net.Addr, *krpc.Message) (bool, error) { return false, nil }
func (c *Config) set(conf ...Config) { func (c Config) onSearch(string, krpc.Addr) {}
if len(conf) > 0 { func (c Config) onTorrent(string, string) {}
*c = conf[0]
func (c Config) k2nAddr(a krpc.Addr) net.Addr {
if a.Orig != nil {
return a.Orig
}
return a.UDPAddr()
}
func (c Config) n2kAddr(a net.Addr) krpc.Addr {
return krpc.NewAddrFromUDPAddr(a.(*net.UDPAddr))
}
func (c *Config) set(conf *Config) {
if conf != nil {
*c = *conf
} }
if c.K <= 0 { if c.K <= 0 {
@ -197,10 +216,16 @@ func (c *Config) set(conf ...Config) {
c.RespTimeout = time.Second * 10 c.RespTimeout = time.Second * 10
} }
if c.OnSearch == nil { if c.OnSearch == nil {
c.OnSearch = func(string, net.IP, uint16) {} c.OnSearch = c.onSearch
} }
if c.OnTorrent == nil { if c.OnTorrent == nil {
c.OnTorrent = func(string, net.IP, uint16) {} c.OnTorrent = c.onTorrent
}
if c.GetNetAddr == nil {
c.GetNetAddr = c.k2nAddr
}
if c.GetKrpcAddr == nil {
c.GetKrpcAddr = c.n2kAddr
} }
if c.HandleInMessage == nil { if c.HandleInMessage == nil {
c.HandleInMessage = c.in c.HandleInMessage = c.in
@ -215,6 +240,7 @@ type Server struct {
conf Config conf Config
exit chan struct{} exit chan struct{}
conn net.PacketConn conn net.PacketConn
lock sync.Mutex
once sync.Once once sync.Once
ipv4 bool ipv4 bool
@ -230,9 +256,9 @@ type Server struct {
} }
// NewServer returns a new DHT server. // NewServer returns a new DHT server.
func NewServer(conn net.PacketConn, config ...Config) *Server { func NewServer(conn net.PacketConn, config *Config) *Server {
var conf Config var conf Config
conf.set(config...) conf.set(config)
if len(conf.IPProtocols) == 0 { if len(conf.IPProtocols) == 0 {
host, _, err := net.SplitHostPort(conn.LocalAddr().String()) host, _, err := net.SplitHostPort(conn.LocalAddr().String())
@ -278,7 +304,6 @@ func NewServer(conn net.PacketConn, config ...Config) *Server {
if s.peerManager == nil { if s.peerManager == nil {
s.peerManager = s.tokenPeerManager s.peerManager = s.tokenPeerManager
} }
return s return s
} }
@ -292,14 +317,14 @@ func (s *Server) Bootstrap(addrs []string) {
if (s.ipv4 && s.routingTable4.Len() == 0) || if (s.ipv4 && s.routingTable4.Len() == 0) ||
(s.ipv6 && s.routingTable6.Len() == 0) { (s.ipv6 && s.routingTable6.Len() == 0) {
for _, addr := range addrs { for _, addr := range addrs {
as, err := metainfo.NewAddressesFromString(addr) ipports, err := krpc.ParseAddrs(addr)
if err != nil { if err != nil {
s.conf.ErrorLog(err.Error()) s.conf.ErrorLog(err.Error())
continue continue
} }
for _, a := range as { for _, ipport := range ipports {
if isIPv6(a.IP) { if isIPv6(ipport.IP) {
if !s.ipv6 { if !s.ipv6 {
continue continue
} }
@ -307,8 +332,8 @@ func (s *Server) Bootstrap(addrs []string) {
continue continue
} }
if err = s.FindNode(a.UDPAddr(), s.conf.ID); err != nil { if err = s.FindNode(ipport, s.conf.ID); err != nil {
s.conf.ErrorLog(`fail to bootstrap '%s': %s`, a.String(), err) s.conf.ErrorLog(`fail to bootstrap '%s': %s`, ipport.String(), err)
} }
} }
} }
@ -324,12 +349,12 @@ func (s *Server) Node6Num() int { return s.routingTable6.Len() }
// AddNode adds the node into the routing table. // AddNode adds the node into the routing table.
// //
// The returned value: // The returned value:
// NodeAdded: The node is added successfully.
// NodeNotAdded: The node is not added and is discarded.
// NodeExistAndUpdated: The node has existed, and its status has been updated.
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
// The current node will be discarded.
// //
// NodeAdded: The node is added successfully.
// NodeNotAdded: The node is not added and is discarded.
// NodeExistAndUpdated: The node has existed, and its status has been updated.
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
// The current node will be discarded.
func (s *Server) AddNode(node krpc.Node) int { func (s *Server) AddNode(node krpc.Node) int {
// For IPv6 // For IPv6
if isIPv6(node.Addr.IP) { if isIPv6(node.Addr.IP) {
@ -347,13 +372,13 @@ func (s *Server) AddNode(node krpc.Node) int {
return NodeNotAdded return NodeNotAdded
} }
func (s *Server) addNode(a *net.UDPAddr, id metainfo.Hash, ro bool) (r int) { func (s *Server) addNode(kaddr krpc.Addr, id metainfo.Hash, ro bool) (r int) {
if ro { // BEP 43 if ro { // BEP 43
return NodeNotAdded return NodeNotAdded
} }
if r = s.AddNode(krpc.NewNodeByUDPAddr(id, a)); r == NodeExistAndChanged { if r = s.AddNode(krpc.NewNode(id, kaddr)); r == NodeExistAndChanged {
s.conf.Blacklist.Add(a.IP.String(), a.Port) s.conf.Blacklist.Add(kaddr)
} }
return return
@ -401,11 +426,15 @@ func (s *Server) Run() {
return return
} }
s.handlePacket(raddr.(*net.UDPAddr), buf[:n]) s.handlePacket(raddr, buf[:n])
} }
} }
func (s *Server) isDisabled(raddr *net.UDPAddr) bool { func (s *Server) isDisabled(raddr krpc.Addr) bool {
if len(raddr.IP) == 0 {
return false
}
if isIPv6(raddr.IP) { if isIPv6(raddr.IP) {
if !s.ipv6 { if !s.ipv6 {
return true return true
@ -417,13 +446,16 @@ func (s *Server) isDisabled(raddr *net.UDPAddr) bool {
} }
// HandlePacket handles the incoming DHT message. // HandlePacket handles the incoming DHT message.
func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) { func (s *Server) handlePacket(raddr net.Addr, data []byte) {
if s.isDisabled(raddr) { kaddr := s.conf.GetKrpcAddr(raddr)
kaddr.Orig = raddr
if s.isDisabled(kaddr) {
return return
} }
// Check whether the raddr is in the ip blacklist. If yes, discard it. // Check whether the raddr is in the ip blacklist. If yes, discard it.
if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) { if s.conf.Blacklist.In(kaddr) {
return return
} }
@ -436,43 +468,46 @@ func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) {
return return
} }
// TODO: Should we use a task pool?? // (xgf): Should we use a task pool??
go s.handleMessage(raddr, msg) go s.handleMessage(kaddr, msg)
} }
func (s *Server) handleMessage(raddr *net.UDPAddr, m krpc.Message) { func (s *Server) handleMessage(kaddr krpc.Addr, m krpc.Message) {
if !s.conf.HandleInMessage(raddr, &m) { if s.conf.HandleInMessage(kaddr.Orig, &m) {
return return
} }
switch m.Y { switch m.Y {
case "q": case "q":
if !m.A.ID.IsZero() { if !m.A.ID.IsZero() {
r := s.addNode(raddr, m.A.ID, m.RO) r := s.addNode(kaddr, m.A.ID, m.RO)
if r != NodeExistAndChanged && !s.conf.ReadOnly { // BEP 43 if r != NodeExistAndChanged && !s.conf.ReadOnly { // BEP 43
s.handleQuery(raddr, m) s.handleQuery(kaddr, m)
} }
} }
case "r": case "r":
if !m.R.ID.IsZero() { if !m.R.ID.IsZero() {
if s.addNode(raddr, m.R.ID, m.RO) == NodeExistAndChanged { if s.addNode(kaddr, m.R.ID, m.RO) == NodeExistAndChanged {
return return
} }
if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil { if t := s.transactionManager.PopTransaction(m.T, kaddr); t != nil {
t.OnResponse(t, raddr, m) t.OnResponse(t, kaddr, m)
} }
} }
case "e": case "e":
if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil { if t := s.transactionManager.PopTransaction(m.T, kaddr); t != nil {
t.OnError(t, m.E.Code, m.E.Reason) t.OnError(t, m.E.Code, m.E.Reason)
} }
default: default:
s.conf.ErrorLog("unknown dht message type '%s'", m.Y) s.conf.ErrorLog("unknown dht message type '%s'", m.Y)
} }
} }
func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) { func (s *Server) handleQuery(raddr krpc.Addr, m krpc.Message) {
switch m.Q { switch m.Q {
case queryMethodPing: case queryMethodPing:
s.reply(raddr, m.T, krpc.ResponseResult{}) s.reply(raddr, m.T, krpc.ResponseResult{})
@ -540,35 +575,39 @@ func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) {
r.Token = s.tokenManager.Token(raddr) r.Token = s.tokenManager.Token(raddr)
s.reply(raddr, m.T, r) s.reply(raddr, m.T, r)
s.conf.OnSearch(m.A.InfoHash.HexString(), raddr.IP, uint16(raddr.Port)) s.conf.OnSearch(m.A.InfoHash.HexString(), raddr)
case queryMethodAnnouncePeer: case queryMethodAnnouncePeer:
if s.tokenManager.Check(raddr, m.A.Token) { if s.tokenManager.Check(raddr, m.A.Token) {
return return
} }
s.reply(raddr, m.T, krpc.ResponseResult{}) s.reply(raddr, m.T, krpc.ResponseResult{})
s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.IP, m.A.GetPort(raddr.Port)) s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.String())
default: default:
s.sendError(raddr, m.T, "unknown query method", krpc.ErrorCodeMethodUnknown) s.sendError(raddr, m.T, "unknown query method", krpc.ErrorCodeMethodUnknown)
} }
} }
func (s *Server) send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) { func (s *Server) send(kaddr krpc.Addr, m krpc.Message) (wrote bool, err error) {
// // TODO: Should we check the ip blacklist?? // // (xgf): Should we check the ip blacklist??
// if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) { // if s.conf.Blacklist.In(kaddr) {
// return // return
// } // }
m.RO = s.conf.ReadOnly // BEP 43 m.RO = s.conf.ReadOnly // BEP 43
if wrote, err = s.conf.HandleOutMessage(raddr, &m); !wrote && err == nil { kaddr.Orig = s.conf.GetNetAddr(kaddr)
wrote, err = s._send(raddr, m)
s.lock.Lock()
defer s.lock.Unlock()
if wrote, err = s.conf.HandleOutMessage(kaddr.Orig, &m); !wrote && err == nil {
wrote, err = s._send(kaddr, m)
} }
return return
} }
func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) { func (s *Server) _send(kaddr krpc.Addr, m krpc.Message) (wrote bool, err error) {
if m.T == "" || m.Y == "" { if m.T == "" || m.Y == "" {
panic(`DHT message "t" or "y" must not be empty`) panic(`DHT message "t" or "y" must not be empty`)
} }
@ -579,10 +618,10 @@ func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err erro
panic(err) panic(err)
} }
n, err := s.conn.WriteTo(buf.Bytes(), raddr) n, err := s.conn.WriteTo(buf.Bytes(), kaddr.Orig)
if err != nil { if err != nil {
err = fmt.Errorf("error writing %d bytes to %s: %s", buf.Len(), raddr, err) err = fmt.Errorf("error writing %d bytes to %v: %s", buf.Len(), kaddr.Orig, err)
s.conf.Blacklist.Add(raddr.IP.String(), 0) s.conf.Blacklist.Add(kaddr)
return return
} }
@ -594,13 +633,13 @@ func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err erro
return return
} }
func (s *Server) sendError(raddr *net.UDPAddr, tid, reason string, code int) { func (s *Server) sendError(raddr krpc.Addr, tid, reason string, code int) {
if _, err := s.send(raddr, krpc.NewErrorMsg(tid, code, reason)); err != nil { if _, err := s.send(raddr, krpc.NewErrorMsg(tid, code, reason)); err != nil {
s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error()) s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error())
} }
} }
func (s *Server) reply(raddr *net.UDPAddr, tid string, r krpc.ResponseResult) { func (s *Server) reply(raddr krpc.Addr, tid string, r krpc.ResponseResult) {
r.ID = s.conf.ID r.ID = s.conf.ID
if _, err := s.send(raddr, krpc.NewResponseMsg(tid, r)); err != nil { if _, err := s.send(raddr, krpc.NewResponseMsg(tid, r)); err != nil {
s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error()) s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error())
@ -631,7 +670,7 @@ func (s *Server) onError(t *transaction, code int, reason string) {
} }
func (s *Server) onTimeout(t *transaction) { func (s *Server) onTimeout(t *transaction) {
// TODO: Should we use a task pool?? // (xgf): Should we use a task pool??
t.Done(Result{Timeout: true}) t.Done(Result{Timeout: true})
var qid string var qid string
@ -645,11 +684,19 @@ func (s *Server) onTimeout(t *transaction) {
t.ID, s.conf.ID, t.Query, qid, s.conn.LocalAddr(), t.Addr.String()) t.ID, s.conf.ID, t.Query, qid, s.conn.LocalAddr(), t.Addr.String())
} }
func (s *Server) onPingResp(t *transaction, a *net.UDPAddr, m krpc.Message) { // Ping sends a PING query to addr, and the callback function cb will be called
// when the response or error is returned, or it's timeout.
func (s *Server) Ping(addr krpc.Addr, cb ...func(Result)) (err error) {
t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...)
t.OnResponse = s.onPingResp
return s.request(t)
}
func (s *Server) onPingResp(t *transaction, a krpc.Addr, m krpc.Message) {
t.Done(Result{}) t.Done(Result{})
} }
func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) { func (s *Server) onGetPeersResp(t *transaction, a krpc.Addr, m krpc.Message) {
// Store the response node with the token. // Store the response node with the token.
if m.R.Token != "" { if m.R.Token != "" {
s.tokenPeerManager.Set(m.R.ID, a, m.R.Token) s.tokenPeerManager.Set(m.R.ID, a, m.R.Token)
@ -659,7 +706,7 @@ func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message)
if len(m.R.Values) > 0 { if len(m.R.Values) > 0 {
t.Done(Result{Peers: m.R.Values}) t.Done(Result{Peers: m.R.Values})
for _, addr := range m.R.Values { for _, addr := range m.R.Values {
s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr.IP, addr.Port) s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr)
} }
return return
} }
@ -710,10 +757,9 @@ func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message)
} }
} }
func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int, func (s *Server) getPeers(info metainfo.Hash, addr krpc.Addr, depth int, ids metainfo.Hashes, cb ...func(Result)) {
ids metainfo.Hashes, cb ...func(Result)) {
arg := krpc.QueryArg{InfoHash: info, Wants: s.want} arg := krpc.QueryArg{InfoHash: info, Wants: s.want}
t := newTransaction(s, addr.UDPAddr(), queryMethodGetPeers, arg, cb...) t := newTransaction(s, addr, queryMethodGetPeers, arg, cb...)
t.OnResponse = s.onGetPeersResp t.OnResponse = s.onGetPeersResp
t.Depth = depth t.Depth = depth
t.Visited = ids t.Visited = ids
@ -722,14 +768,6 @@ func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int,
} }
} }
// Ping sends a PING query to addr, and the callback function cb will be called
// when the response or error is returned, or it's timeout.
func (s *Server) Ping(addr *net.UDPAddr, cb ...func(Result)) (err error) {
t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...)
t.OnResponse = s.onPingResp
return s.request(t)
}
// GetPeers searches the peer storing the torrent by the infohash of the torrent, // GetPeers searches the peer storing the torrent by the infohash of the torrent,
// which will search it recursively until some peers are returned or it reaches // which will search it recursively until some peers are returned or it reaches
// the maximun depth, that's, ServerConfig.SearchDepth. // the maximun depth, that's, ServerConfig.SearchDepth.
@ -783,16 +821,15 @@ func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort b
sentNodes := make([]krpc.Node, 0, len(nodes)) sentNodes := make([]krpc.Node, 0, len(nodes))
for _, node := range nodes { for _, node := range nodes {
addr := node.Addr.UDPAddr() token := s.tokenPeerManager.Get(infohash, node.Addr)
token := s.tokenPeerManager.Get(infohash, addr)
if token == "" { if token == "" {
continue continue
} }
arg := krpc.QueryArg{ImpliedPort: impliedPort, InfoHash: infohash, Port: port, Token: token} arg := krpc.QueryArg{ImpliedPort: impliedPort, InfoHash: infohash, Port: port, Token: token}
t := newTransaction(s, addr, queryMethodAnnouncePeer, arg) t := newTransaction(s, node.Addr, queryMethodAnnouncePeer, arg)
if err := s.request(t); err != nil { if err := s.request(t); err != nil {
s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err) s.conf.ErrorLog("fail to send query message to '%s': %s", node.Addr.String(), err)
} else { } else {
sentNodes = append(sentNodes, node) sentNodes = append(sentNodes, node)
} }
@ -804,7 +841,7 @@ func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort b
// FindNode sends the "find_node" query to the addr to find the target node. // FindNode sends the "find_node" query to the addr to find the target node.
// //
// Notice: In general, it's used to bootstrap the routing table. // Notice: In general, it's used to bootstrap the routing table.
func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error { func (s *Server) FindNode(addr krpc.Addr, target metainfo.Hash) error {
if target.IsZero() { if target.IsZero() {
panic("the target is ZERO") panic("the target is ZERO")
} }
@ -812,8 +849,7 @@ func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error {
return s.findNode(target, addr, s.conf.SearchDepth, nil) return s.findNode(target, addr, s.conf.SearchDepth, nil)
} }
func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int, func (s *Server) findNode(target metainfo.Hash, addr krpc.Addr, depth int, ids metainfo.Hashes) error {
ids metainfo.Hashes) error {
arg := krpc.QueryArg{Target: target, Wants: s.want} arg := krpc.QueryArg{Target: target, Wants: s.want}
t := newTransaction(s, addr, queryMethodFindNode, arg) t := newTransaction(s, addr, queryMethodFindNode, arg)
t.OnResponse = s.onFindNodeResp t.OnResponse = s.onFindNodeResp
@ -821,7 +857,7 @@ func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int,
return s.request(t) return s.request(t)
} }
func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message) { func (s *Server) onFindNodeResp(t *transaction, _ krpc.Addr, m krpc.Message) {
t.Done(Result{}) t.Done(Result{})
// Search the target node recursively. // Search the target node recursively.
@ -863,7 +899,7 @@ func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message)
} }
for _, node := range nodes { for _, node := range nodes {
err := s.findNode(t.Arg.Target, node.Addr.UDPAddr(), t.Depth, ids) err := s.findNode(t.Arg.Target, node.Addr, t.Depth, ids)
if err != nil { if err != nil {
s.conf.ErrorLog(`fail to send "find_node" query to '%s': %s`, s.conf.ErrorLog(`fail to send "find_node" query to '%s': %s`,
node.Addr.String(), err) node.Addr.String(), err)

View File

@ -17,39 +17,33 @@ package dht
import ( import (
"fmt" "fmt"
"net" "net"
"strconv"
"sync" "sync"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/internal/helper"
"github.com/xgfone/go-bt/krpc"
"github.com/xgfone/go-bt/metainfo"
) )
type testPeerManager struct { type testPeerManager struct {
lock sync.RWMutex lock sync.RWMutex
peers map[metainfo.Hash][]metainfo.Address peers map[metainfo.Hash][]string
} }
func newTestPeerManager() *testPeerManager { func newTestPeerManager() *testPeerManager {
return &testPeerManager{peers: make(map[metainfo.Hash][]metainfo.Address)} return &testPeerManager{peers: make(map[metainfo.Hash][]string)}
} }
func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr metainfo.Address) { func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr string) {
pm.lock.Lock() pm.lock.Lock()
var exist bool defer pm.lock.Unlock()
for _, orig := range pm.peers[infohash] {
if orig.Equal(addr) { if !helper.ContainsString(pm.peers[infohash], addr) {
exist = true
break
}
}
if !exist {
pm.peers[infohash] = append(pm.peers[infohash], addr) pm.peers[infohash] = append(pm.peers[infohash], addr)
} }
pm.lock.Unlock()
} }
func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) (addrs []string) {
ipv6 bool) (addrs []metainfo.Address) {
// We only supports IPv4, so ignore the ipv6 argument. // We only supports IPv4, so ignore the ipv6 argument.
pm.lock.RLock() pm.lock.RLock()
_addrs := pm.peers[infohash] _addrs := pm.peers[infohash]
@ -63,13 +57,11 @@ func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int,
return return
} }
func onSearch(infohash string, ip net.IP, port uint16) { func onSearch(infohash string, addr krpc.Addr) {
addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10)) fmt.Printf("%s is searching %s\n", addr.String(), infohash)
fmt.Printf("%s is searching %s\n", addr, infohash)
} }
func onTorrent(infohash string, ip net.IP, port uint16) { func onTorrent(infohash string, addr string) {
addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10))
fmt.Printf("%s has downloaded %s\n", addr, infohash) fmt.Printf("%s has downloaded %s\n", addr, infohash)
} }
@ -77,7 +69,7 @@ func newDHTServer(id metainfo.Hash, addr string, pm PeerManager) (s *Server, err
conn, err := net.ListenPacket("udp", addr) conn, err := net.ListenPacket("udp", addr)
if err == nil { if err == nil {
c := Config{ID: id, PeerManager: pm, OnSearch: onSearch, OnTorrent: onTorrent} c := Config{ID: id, PeerManager: pm, OnSearch: onSearch, OnTorrent: onTorrent}
s = NewServer(conn, c) s = NewServer(conn, &c)
} }
return return
} }
@ -140,7 +132,7 @@ func ExampleServer() {
fmt.Printf("%s: no peers for %s\n", r.Addr.String(), infohash) fmt.Printf("%s: no peers for %s\n", r.Addr.String(), infohash)
} else { } else {
for _, peer := range r.Peers { for _, peer := range r.Peers {
fmt.Printf("%s: %s\n", infohash, peer.String()) fmt.Printf("%s: %s\n", infohash, peer)
} }
} }
}) })
@ -149,7 +141,7 @@ func ExampleServer() {
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
// Add the peer to let the DHT server1 has the peer. // Add the peer to let the DHT server1 has the peer.
pm.AddPeer(infohash, metainfo.NewAddress(net.ParseIP("127.0.0.1"), 9001)) pm.AddPeer(infohash, "127.0.0.1:9001")
// Search the torrent infohash again, but from DHT server2, // Search the torrent infohash again, but from DHT server2,
// which will search the DHT server1 recursively. // which will search the DHT server1 recursively.
@ -158,7 +150,7 @@ func ExampleServer() {
fmt.Printf("%s: no peers for %s\n", r.Addr.String(), infohash) fmt.Printf("%s: no peers for %s\n", r.Addr.String(), infohash)
} else { } else {
for _, peer := range r.Peers { for _, peer := range r.Peers {
fmt.Printf("%s: %s\n", infohash, peer.String()) fmt.Printf("%s: %s\n", infohash, peer)
} }
} }
}) })

View File

@ -15,23 +15,24 @@
package dht package dht
import ( import (
"net"
"sync" "sync"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/krpc"
"github.com/xgfone/go-bt/metainfo"
) )
// PeerManager is used to manage the peers. // PeerManager is used to manage the peers.
type PeerManager interface { type PeerManager interface {
// If ipv6 is true, only return ipv6 addresses. Or return ipv4 addresses. // If ipv6 is true, only return ipv6 addresses. Or return ipv4 addresses.
GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []metainfo.Address GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []string
} }
var _ PeerManager = new(tokenPeerManager)
type peer struct { type peer struct {
ID metainfo.Hash ID metainfo.Hash
IP net.IP Addr krpc.Addr
Port uint16
Token string Token string
Time time.Time Time time.Time
} }
@ -75,7 +76,7 @@ func (tpm *tokenPeerManager) Start(interval time.Duration) {
} }
} }
func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token string) { func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr krpc.Addr, token string) {
addrkey := addr.String() addrkey := addr.String()
tpm.lock.Lock() tpm.lock.Lock()
peers, ok := tpm.peers[id] peers, ok := tpm.peers[id]
@ -83,17 +84,11 @@ func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token stri
peers = make(map[string]peer, 4) peers = make(map[string]peer, 4)
tpm.peers[id] = peers tpm.peers[id] = peers
} }
peers[addrkey] = peer{ peers[addrkey] = peer{ID: id, Addr: addr, Token: token, Time: time.Now()}
ID: id,
IP: addr.IP,
Port: uint16(addr.Port),
Token: token,
Time: time.Now(),
}
tpm.lock.Unlock() tpm.lock.Unlock()
} }
func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr *net.UDPAddr) (token string) { func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr krpc.Addr) (token string) {
addrkey := addr.String() addrkey := addr.String()
tpm.lock.RLock() tpm.lock.RLock()
if peers, ok := tpm.peers[id]; ok { if peers, ok := tpm.peers[id]; ok {
@ -113,9 +108,8 @@ func (tpm *tokenPeerManager) Stop() {
} }
} }
func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) (addrs []string) {
ipv6 bool) (addrs []metainfo.Address) { addrs = make([]string, 0, maxnum)
addrs = make([]metainfo.Address, 0, maxnum)
tpm.lock.RLock() tpm.lock.RLock()
if peers, ok := tpm.peers[infohash]; ok { if peers, ok := tpm.peers[infohash]; ok {
for _, peer := range peers { for _, peer := range peers {
@ -124,13 +118,13 @@ func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int,
} }
if ipv6 { // For IPv6 if ipv6 { // For IPv6
if isIPv6(peer.IP) { if isIPv6(peer.Addr.IP) {
maxnum-- maxnum--
addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port)) addrs = append(addrs, peer.Addr.String())
} }
} else if !isIPv6(peer.IP) { // For IPv4 } else if !isIPv6(peer.Addr.IP) { // For IPv4
maxnum-- maxnum--
addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port)) addrs = append(addrs, peer.Addr.String())
} }
} }
} }

View File

@ -19,8 +19,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/xgfone/bt/krpc" "github.com/xgfone/go-bt/krpc"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
const bktlen = 160 const bktlen = 160
@ -155,12 +155,12 @@ func (rt *routingTable) Stop() {
// AddNode adds the node into the routing table. // AddNode adds the node into the routing table.
// //
// The returned value: // The returned value:
// NodeAdded: The node is added successfully.
// NodeNotAdded: The node is not added and is discarded.
// NodeExistAndUpdated: The node has existed, and its status has been updated.
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
// The current node will be discarded.
// //
// NodeAdded: The node is added successfully.
// NodeNotAdded: The node is not added and is discarded.
// NodeExistAndUpdated: The node has existed, and its status has been updated.
// NodeExistAndChanged: The node has existed, but the address is inconsistent.
// The current node will be discarded.
func (rt *routingTable) AddNode(n krpc.Node) (r int) { func (rt *routingTable) AddNode(n krpc.Node) (r int) {
if n.ID == rt.root { // Don't add itself. if n.ID == rt.root { // Don't add itself.
return NodeNotAdded return NodeNotAdded
@ -264,7 +264,7 @@ func (b *bucket) AddNode(n krpc.Node, now time.Time) (status int) {
return return
} }
// // TODO: Should we replace the old one?? // // (xgf): Should we replace the old one??
// b.UpdateLastChangedTime(now) // b.UpdateLastChangedTime(now)
// copy(b.Nodes[i:], b.Nodes[i+1:]) // copy(b.Nodes[i:], b.Nodes[i+1:])
// b.Nodes[len(b.Nodes)-1] = newWrappedNode(b, n, now) // b.Nodes[len(b.Nodes)-1] = newWrappedNode(b, n, now)
@ -314,7 +314,7 @@ func (b *bucket) CheckAllNodes(now time.Time) {
case nodeStatusGood: case nodeStatusGood:
case nodeStatusDubious: case nodeStatusDubious:
// Try to send the PING query to the dubious node to check whether it is alive. // Try to send the PING query to the dubious node to check whether it is alive.
if err := b.table.s.Ping(node.Node.Addr.UDPAddr()); err != nil { if err := b.table.s.Ping(node.Node.Addr); err != nil {
b.table.s.conf.ErrorLog("fail to ping '%s': %s", node.Node.String(), err) b.table.s.conf.ErrorLog("fail to ping '%s': %s", node.Node.String(), err)
} }
case nodeStatusBad: case nodeStatusBad:

View File

@ -17,8 +17,8 @@ package dht
import ( import (
"time" "time"
"github.com/xgfone/bt/krpc" "github.com/xgfone/go-bt/krpc"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// RoutingTableNode represents the node with last changed time in the routing table. // RoutingTableNode represents the node with last changed time in the routing table.

View File

@ -15,16 +15,16 @@
package dht package dht
import ( import (
"net"
"sync" "sync"
"time" "time"
"github.com/xgfone/bt/utils" "github.com/xgfone/go-bt/internal/helper"
"github.com/xgfone/go-bt/krpc"
) )
// TokenManager is used to manage and validate the token. // TokenManager is used to manage and validate the token.
// //
// TODO: Should we allocate the different token for each node?? // (xgf): Should we allocate the different token for each node??
type tokenManager struct { type tokenManager struct {
lock sync.RWMutex lock sync.RWMutex
last string last string
@ -34,12 +34,12 @@ type tokenManager struct {
} }
func newTokenManager() *tokenManager { func newTokenManager() *tokenManager {
token := utils.RandomString(8) token := helper.RandomString(8)
return &tokenManager{last: token, new: token, exit: make(chan struct{})} return &tokenManager{last: token, new: token, exit: make(chan struct{})}
} }
func (tm *tokenManager) updateToken() { func (tm *tokenManager) updateToken() {
token := utils.RandomString(8) token := helper.RandomString(8)
tm.lock.Lock() tm.lock.Lock()
tm.last, tm.new = tm.new, token tm.last, tm.new = tm.new, token
tm.lock.Unlock() tm.lock.Unlock()
@ -83,7 +83,7 @@ func (tm *tokenManager) Stop() {
} }
// Token allocates a token for a node addr and returns the token. // Token allocates a token for a node addr and returns the token.
func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) { func (tm *tokenManager) Token(addr krpc.Addr) (token string) {
addrs := addr.String() addrs := addr.String()
tm.lock.RLock() tm.lock.RLock()
token = tm.new token = tm.new
@ -94,7 +94,7 @@ func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) {
// Check checks whether the token associated with the node addr is valid, // Check checks whether the token associated with the node addr is valid,
// that's, it's not expired. // that's, it's not expired.
func (tm *tokenManager) Check(addr *net.UDPAddr, token string) (ok bool) { func (tm *tokenManager) Check(addr krpc.Addr, token string) (ok bool) {
tm.lock.RLock() tm.lock.RLock()
last, new := tm.last, tm.new last, new := tm.last, tm.new
tm.lock.RUnlock() tm.lock.RUnlock()

View File

@ -15,21 +15,20 @@
package dht package dht
import ( import (
"net"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/xgfone/bt/krpc" "github.com/xgfone/go-bt/krpc"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
type transaction struct { type transaction struct {
ID string ID string
Query string Query string
Arg krpc.QueryArg Arg krpc.QueryArg
Addr *net.UDPAddr Addr krpc.Addr
Time time.Time Time time.Time
Depth int Depth int
@ -37,7 +36,7 @@ type transaction struct {
Callback func(Result) Callback func(Result)
OnError func(t *transaction, code int, reason string) OnError func(t *transaction, code int, reason string)
OnTimeout func(t *transaction) OnTimeout func(t *transaction)
OnResponse func(t *transaction, radd *net.UDPAddr, msg krpc.Message) OnResponse func(t *transaction, radd krpc.Addr, msg krpc.Message)
} }
func (t *transaction) Done(r Result) { func (t *transaction) Done(r Result) {
@ -47,8 +46,8 @@ func (t *transaction) Done(r Result) {
} }
} }
func noopResponse(*transaction, *net.UDPAddr, krpc.Message) {} func noopResponse(*transaction, krpc.Addr, krpc.Message) {}
func newTransaction(s *Server, a *net.UDPAddr, q string, qa krpc.QueryArg, func newTransaction(s *Server, a krpc.Addr, q string, qa krpc.QueryArg,
callback ...func(Result)) *transaction { callback ...func(Result)) *transaction {
var cb func(Result) var cb func(Result)
if len(callback) > 0 { if len(callback) > 0 {
@ -141,7 +140,7 @@ func (tm *transactionManager) DeleteTransaction(t *transaction) {
// and the peer address. // and the peer address.
// //
// Return nil if there is no the transaction. // Return nil if there is no the transaction.
func (tm *transactionManager) PopTransaction(tid string, addr *net.UDPAddr) (t *transaction) { func (tm *transactionManager) PopTransaction(tid string, addr krpc.Addr) (t *transaction) {
key := transactionkey{id: tid, addr: addr.String()} key := transactionkey{id: tid, addr: addr.String()}
tm.lock.Lock() tm.lock.Lock()
if t = tm.trans[key]; t != nil { if t = tm.trans[key]; t != nil {

View File

@ -14,10 +14,7 @@
package downloader package downloader
import ( import pp "github.com/xgfone/go-bt/peerprotocol"
"github.com/xgfone/bt/metainfo"
pp "github.com/xgfone/bt/peerprotocol"
)
// BlockDownloadHandler is used to downloads the files in the torrent file. // BlockDownloadHandler is used to downloads the files in the torrent file.
type BlockDownloadHandler struct { type BlockDownloadHandler struct {
@ -25,51 +22,45 @@ type BlockDownloadHandler struct {
pp.NoopBep3Handler pp.NoopBep3Handler
pp.NoopBep6Handler pp.NoopBep6Handler
Info metainfo.Info // Required OnBlock func(index, offset uint32, b []byte) error
OnBlock func(index, offset uint32, b []byte) error // Required ReqBlock func(c *pp.PeerConn) error
RequestBlock func(c *pp.PeerConn) error // Required PieceNum int
} }
// NewBlockDownloadHandler returns a new BlockDownloadHandler. // NewBlockDownloadHandler returns a new BlockDownloadHandler.
func NewBlockDownloadHandler(info metainfo.Info, func NewBlockDownloadHandler(pieceNum int, reqBlock func(c *pp.PeerConn) error,
onBlock func(pieceIndex, pieceOffset uint32, b []byte) error, onBlock func(pieceIndex, pieceOffset uint32, b []byte) error) BlockDownloadHandler {
requestBlock func(c *pp.PeerConn) error) BlockDownloadHandler {
return BlockDownloadHandler{ return BlockDownloadHandler{
Info: info, OnBlock: onBlock,
OnBlock: onBlock, ReqBlock: reqBlock,
RequestBlock: requestBlock, PieceNum: pieceNum,
} }
} }
// OnHandShake implements the interface Handler#OnHandShake.
//
// Notice: it uses the field Data to store the inner data, you mustn't override
// it.
func (fd BlockDownloadHandler) OnHandShake(c *pp.PeerConn) (err error) {
if err = c.SetUnchoked(); err == nil {
err = c.SetInterested()
}
return
}
/// --------------------------------------------------------------------------- /// ---------------------------------------------------------------------------
/// BEP 3 /// BEP 3
func (fd BlockDownloadHandler) request(pc *pp.PeerConn) (err error) { func (fd BlockDownloadHandler) request(pc *pp.PeerConn) (err error) {
if fd.ReqBlock == nil {
return nil
}
if pc.PeerChoked { if pc.PeerChoked {
err = pp.ErrChoked err = pp.ErrChoked
} else { } else {
err = fd.RequestBlock(pc) err = fd.ReqBlock(pc)
} }
return return
} }
// Piece implements the interface Bep3Handler#Piece. // Piece implements the interface Bep3Handler#Piece.
func (fd BlockDownloadHandler) Piece(c *pp.PeerConn, i, b uint32, p []byte) (err error) { func (fd BlockDownloadHandler) Piece(c *pp.PeerConn, i, b uint32, p []byte) error {
if err = fd.OnBlock(i, b, p); err == nil { if fd.OnBlock != nil {
err = fd.request(c) if err := fd.OnBlock(i, b, p); err != nil {
return err
}
} }
return return fd.request(c)
} }
// Unchoke implements the interface Bep3Handler#Unchoke. // Unchoke implements the interface Bep3Handler#Unchoke.
@ -88,7 +79,9 @@ func (fd BlockDownloadHandler) Have(pc *pp.PeerConn, index uint32) (err error) {
// HaveAll implements the interface Bep6Handler#HaveAll. // HaveAll implements the interface Bep6Handler#HaveAll.
func (fd BlockDownloadHandler) HaveAll(pc *pp.PeerConn) (err error) { func (fd BlockDownloadHandler) HaveAll(pc *pp.PeerConn) (err error) {
pc.BitField = pp.NewBitField(fd.Info.CountPieces(), true) if fd.PieceNum > 0 {
pc.BitField = pp.NewBitField(fd.PieceNum, true)
}
return return
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,5 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package utils supplies some convenient functions. // Package downloader is used to download the torrent or the real file
package utils // from the peer node by the peer wire protocol.
package downloader

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -26,8 +26,8 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
pp "github.com/xgfone/bt/peerprotocol" pp "github.com/xgfone/go-bt/peerprotocol"
) )
// BlockSize is the size of a block of the piece. // BlockSize is the size of a block of the piece.
@ -35,9 +35,9 @@ const BlockSize = 16384 // 16KiB.
// Request is used to send a download request. // Request is used to send a download request.
type request struct { type request struct {
Host string Host string
Port uint16 Port uint16
PeerID metainfo.Hash // PeerID metainfo.Hash
InfoHash metainfo.Hash InfoHash metainfo.Hash
} }
@ -57,6 +57,11 @@ type TorrentDownloaderConfig struct {
// The default is a random id. // The default is a random id.
ID metainfo.Hash ID metainfo.Hash
// The size of a block of the piece.
//
// Default: 16384
BlockSize uint64
// WorkerNum is the number of the worker downloading the torrent concurrently. // WorkerNum is the number of the worker downloading the torrent concurrently.
// //
// The default is 128. // The default is 128.
@ -71,11 +76,14 @@ type TorrentDownloaderConfig struct {
ErrorLog func(format string, args ...interface{}) ErrorLog func(format string, args ...interface{})
} }
func (c *TorrentDownloaderConfig) set(conf ...TorrentDownloaderConfig) { func (c *TorrentDownloaderConfig) set(conf *TorrentDownloaderConfig) {
if len(conf) > 0 { if conf != nil {
*c = conf[0] *c = *conf
} }
if c.BlockSize <= 0 {
c.BlockSize = BlockSize
}
if c.WorkerNum <= 0 { if c.WorkerNum <= 0 {
c.WorkerNum = 128 c.WorkerNum = 128
} }
@ -102,16 +110,15 @@ type TorrentDownloader struct {
// NewTorrentDownloader returns a new TorrentDownloader. // NewTorrentDownloader returns a new TorrentDownloader.
// //
// If id is ZERO, it is reset to a random id. workerNum is 128 by default. // If id is ZERO, it is reset to a random id. workerNum is 128 by default.
func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader { func NewTorrentDownloader(c *TorrentDownloaderConfig) *TorrentDownloader {
var conf TorrentDownloaderConfig var conf TorrentDownloaderConfig
conf.set(c...) conf.set(c)
d := &TorrentDownloader{ d := &TorrentDownloader{
conf: conf, conf: conf,
exit: make(chan struct{}), exit: make(chan struct{}),
requests: make(chan request, conf.WorkerNum), requests: make(chan request, conf.WorkerNum),
responses: make(chan TorrentResponse, 1024), responses: make(chan TorrentResponse, 1024),
ehmsg: pp.ExtendedHandshakeMsg{ ehmsg: pp.ExtendedHandshakeMsg{
M: map[string]uint8{pp.ExtendedMessageNameMetadata: 1}, M: map[string]uint8{pp.ExtendedMessageNameMetadata: 1},
}, },
@ -130,6 +137,9 @@ func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader {
// Notice: the remote peer must support the "ut_metadata" extenstion. // Notice: the remote peer must support the "ut_metadata" extenstion.
// Or downloading fails. // Or downloading fails.
func (d *TorrentDownloader) Request(host string, port uint16, infohash metainfo.Hash) { func (d *TorrentDownloader) Request(host string, port uint16, infohash metainfo.Hash) {
if infohash.IsZero() {
panic("infohash is ZERO")
}
d.requests <- request{Host: host, Port: port, InfoHash: infohash} d.requests <- request{Host: host, Port: port, InfoHash: infohash}
} }
@ -168,7 +178,7 @@ func (d *TorrentDownloader) worker() {
case <-d.exit: case <-d.exit:
return return
case r := <-d.requests: case r := <-d.requests:
if err := d.download(r.Host, r.Port, r.PeerID, r.InfoHash); err != nil { if err := d.download(r.Host, r.Port, r.InfoHash); err != nil {
d.conf.ErrorLog("fail to download the torrent '%s': %s", d.conf.ErrorLog("fail to download the torrent '%s': %s",
r.InfoHash.HexString(), err) r.InfoHash.HexString(), err)
} }
@ -176,8 +186,7 @@ func (d *TorrentDownloader) worker() {
} }
} }
func (d *TorrentDownloader) download(host string, port uint16, func (d *TorrentDownloader) download(host string, port uint16, infohash metainfo.Hash) (err error) {
peerID, infohash metainfo.Hash) (err error) {
addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)) addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10))
conn, err := pp.NewPeerConnByDial(addr, d.conf.ID, infohash, d.conf.DialTimeout) conn, err := pp.NewPeerConnByDial(addr, d.conf.ID, infohash, d.conf.DialTimeout)
if err != nil { if err != nil {
@ -190,8 +199,6 @@ func (d *TorrentDownloader) download(host string, port uint16,
return return
} else if !conn.PeerExtBits.IsSupportExtended() { } else if !conn.PeerExtBits.IsSupportExtended() {
return fmt.Errorf("the remote peer '%s' does not support Extended", addr) return fmt.Errorf("the remote peer '%s' does not support Extended", addr)
} else if !peerID.IsZero() && peerID != conn.PeerID {
return fmt.Errorf("inconsistent peer id '%s'", conn.PeerID.HexString())
} }
if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil { if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil {
@ -200,7 +207,7 @@ func (d *TorrentDownloader) download(host string, port uint16,
var pieces [][]byte var pieces [][]byte
var piecesNum int var piecesNum int
var metadataSize int var metadataSize uint64
var utmetadataID uint8 var utmetadataID uint8
var msg pp.Message var msg pp.Message
@ -247,13 +254,14 @@ func (d *TorrentDownloader) download(host string, port uint16,
} }
metadataSize = ehmsg.MetadataSize metadataSize = ehmsg.MetadataSize
piecesNum = metadataSize / BlockSize piecesNum = int(metadataSize / d.conf.BlockSize)
if metadataSize%BlockSize != 0 { if metadataSize%d.conf.BlockSize != 0 {
piecesNum++ piecesNum++
} }
pieces = make([][]byte, piecesNum) pieces = make([][]byte, piecesNum)
go d.requestPieces(conn, utmetadataID, piecesNum) go d.requestPieces(conn, utmetadataID, piecesNum)
case 1: case 1:
if pieces == nil { if pieces == nil {
return return
@ -269,8 +277,8 @@ func (d *TorrentDownloader) download(host string, port uint16,
} }
pieceLen := len(utmsg.Data) pieceLen := len(utmsg.Data)
if (utmsg.Piece != piecesNum-1 && pieceLen != BlockSize) || if (utmsg.Piece != piecesNum-1 && pieceLen != int(d.conf.BlockSize)) ||
(utmsg.Piece == piecesNum-1 && pieceLen != metadataSize%BlockSize) { (utmsg.Piece == piecesNum-1 && pieceLen != int(metadataSize%d.conf.BlockSize)) {
return return
} }
pieces[utmsg.Piece] = utmsg.Data pieces[utmsg.Piece] = utmsg.Data

124
downloader/torrent_test.go Normal file
View File

@ -0,0 +1,124 @@
// Copyright 2023 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package downloader
import (
"context"
"errors"
"fmt"
"time"
"github.com/xgfone/go-bt/metainfo"
"github.com/xgfone/go-bt/peerprotocol"
pp "github.com/xgfone/go-bt/peerprotocol"
)
type bep10Handler struct {
peerprotocol.NoopHandler // For implementing peerprotocol.Handler.
infodata string
}
func (h bep10Handler) OnExtHandShake(c *peerprotocol.PeerConn) error {
if _, ok := c.ExtendedHandshakeMsg.M[peerprotocol.ExtendedMessageNameMetadata]; !ok {
return errors.New("missing the extension 'ut_metadata'")
}
return c.SendExtHandshakeMsg(peerprotocol.ExtendedHandshakeMsg{
M: map[string]uint8{pp.ExtendedMessageNameMetadata: 2},
MetadataSize: uint64(len(h.infodata)),
})
}
func (h bep10Handler) OnPayload(c *peerprotocol.PeerConn, extid uint8, extdata []byte) error {
if extid != 2 {
return fmt.Errorf("unknown extension id %d", extid)
}
var reqmsg peerprotocol.UtMetadataExtendedMsg
if err := reqmsg.DecodeFromPayload(extdata); err != nil {
return err
}
if reqmsg.MsgType != peerprotocol.UtMetadataExtendedMsgTypeRequest {
return errors.New("unsupported ut_metadata extension type")
}
startIndex := reqmsg.Piece * BlockSize
endIndex := startIndex + BlockSize
if totalSize := len(h.infodata); totalSize < endIndex {
endIndex = totalSize
}
respmsg := peerprotocol.UtMetadataExtendedMsg{
MsgType: peerprotocol.UtMetadataExtendedMsgTypeData,
Piece: reqmsg.Piece,
Data: []byte(h.infodata[startIndex:endIndex]),
}
data, err := respmsg.EncodeToBytes()
if err != nil {
return err
}
peerextid := c.ExtendedHandshakeMsg.M[peerprotocol.ExtendedMessageNameMetadata]
return c.SendExtMsg(peerextid, data)
}
func ExampleTorrentDownloader() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
handler := bep10Handler{infodata: "1234567890"}
infohash := metainfo.NewHashFromString(handler.infodata)
// Start the torrent server.
var serverConfig peerprotocol.Config
serverConfig.ExtBits.Set(peerprotocol.ExtensionBitExtended)
server, err := peerprotocol.NewServerByListen("tcp", "127.0.0.1:9010", metainfo.NewRandomHash(), handler, &serverConfig)
if err != nil {
fmt.Println(err)
return
}
defer server.Close()
go server.Run() // Start the torrent server.
time.Sleep(time.Millisecond * 100) // Wait that the torrent server finishes to start.
// Start the torrent downloader.
downloaderConfig := &TorrentDownloaderConfig{WorkerNum: 3, DialTimeout: time.Second}
downloader := NewTorrentDownloader(downloaderConfig)
go func() {
for {
select {
case <-ctx.Done():
return
case result := <-downloader.Response():
fmt.Println(string(result.InfoBytes))
}
}
}()
// Start to download the torrent.
downloader.Request("127.0.0.1", 9010, infohash)
// Wait to finish the test.
time.Sleep(time.Second)
cancel()
time.Sleep(time.Millisecond * 50)
// Output:
// 1234567890
}

2
go.mod
View File

@ -1,3 +1,3 @@
module github.com/xgfone/bt module github.com/xgfone/go-bt
go 1.11 go 1.11

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,18 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package utils package helper
import "io" import "io"
// CopyNBuffer is the same as io.CopyN, but uses the given buf as the buffer. // CopyNBuffer is the same as io.CopyN, but uses the given buf as the buffer.
//
// If buf is nil or empty, it will make a new one with 2048.
func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) { func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) {
if len(buf) == 0 {
buf = make([]byte, 2048)
}
written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf) written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf)
if written == n { if written == n {
return n, nil return n, nil

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package utils package helper
// InStringSlice reports whether s is in ss. // ContainsString reports whether s is in ss.
func InStringSlice(ss []string, s string) bool { func ContainsString(ss []string, s string) bool {
for _, v := range ss { for _, v := range ss {
if v == s { if v == s {
return true return true
} }
} }
return false return false
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,18 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package utils package helper
import ( import "testing"
"testing"
)
func TestInStringSlice(t *testing.T) { func TestContainsString(t *testing.T) {
if !InStringSlice([]string{"a", "b"}, "a") { if !ContainsString([]string{"a", "b"}, "a") {
t.Fail() t.Fail()
} }
if InStringSlice([]string{"a", "b"}, "z") { if ContainsString([]string{"a", "b"}, "z") {
t.Fail() t.Fail()
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,13 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package utils package helper
import "crypto/rand" import (
crand "crypto/rand"
"math/rand"
)
// RandomString generates a size-length string randomly. // RandomString generates a size-length string randomly.
func RandomString(size int) string { func RandomString(size int) string {
bs := make([]byte, size) bs := make([]byte, size)
rand.Read(bs) if n, _ := crand.Read(bs); n < size {
for ; n < size; n++ {
bs[n] = byte(rand.Intn(256))
}
}
return string(bs) return string(bs)
} }

312
krpc/addr.go Normal file
View File

@ -0,0 +1,312 @@
// Copyright 2023 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package krpc
import (
"bytes"
"encoding"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"github.com/xgfone/go-bt/bencode"
)
// Addr represents an address based on ip and port,
// which implements "Compact IP-address/port info".
//
// See http://bittorrent.org/beps/bep_0005.html.
type Addr struct {
IP net.IP // For IPv4, its length must be 4.
Port uint16
// The original network address, which is only used by the DHT server.
Orig net.Addr
}
// ParseAddrs parses the address from the string s with the format "IP:PORT".
func ParseAddrs(s string) (addrs []Addr, err error) {
_ip, _port, err := net.SplitHostPort(s)
if err != nil {
return
}
port, err := strconv.ParseUint(_port, 10, 16)
if err != nil {
return
}
ip := net.ParseIP(_ip)
if ip != nil {
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
return []Addr{NewAddr(ip, uint16(port))}, nil
}
ips, err := net.LookupIP(_ip)
if err != nil {
return nil, err
}
addrs = make([]Addr, len(ips))
for i, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
addrs[i] = NewAddr(ip, uint16(port))
}
return
}
// NewAddr returns a new Addr with ip and port.
func NewAddr(ip net.IP, port uint16) Addr {
return Addr{IP: ip, Port: port}
}
// NewAddrFromUDPAddr converts *net.UDPAddr to a new Addr.
func NewAddrFromUDPAddr(ua *net.UDPAddr) Addr {
return Addr{IP: ua.IP, Port: uint16(ua.Port), Orig: ua}
}
// Valid reports whether the addr is valid.
func (a Addr) Valid() bool {
return len(a.IP) > 0 && a.Port > 0
}
// Equal reports whether a is equal to o.
func (a Addr) Equal(o Addr) bool {
return a.Port == o.Port && a.IP.Equal(o.IP)
}
// UDPAddr converts itself to *net.UDPAddr.
func (a Addr) UDPAddr() *net.UDPAddr {
return &net.UDPAddr{IP: a.IP, Port: int(a.Port)}
}
var _ net.Addr = Addr{}
// Network implements the interface net.Addr#Network.
func (a Addr) Network() string {
return "krpc"
}
func (a Addr) String() string {
if a.Port == 0 {
return a.IP.String()
}
return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10))
}
// WriteBinary is the same as MarshalBinary, but writes the result into w
// instead of returning.
func (a Addr) WriteBinary(w io.Writer) (n int, err error) {
if n, err = w.Write(a.IP); err == nil {
if err = binary.Write(w, binary.BigEndian, a.Port); err == nil {
n += 2
}
}
return
}
var (
_ encoding.BinaryMarshaler = new(Addr)
_ encoding.BinaryUnmarshaler = new(Addr)
)
// MarshalBinary implements the interface encoding.BinaryMarshaler,
func (a Addr) MarshalBinary() (data []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(18)
if _, err = a.WriteBinary(buf); err == nil {
data = buf.Bytes()
}
return
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (a *Addr) UnmarshalBinary(data []byte) error {
_len := len(data) - 2
switch _len {
case net.IPv4len, net.IPv6len:
default:
return errors.New("invalid compact ip-address/port info")
}
a.IP = make(net.IP, _len)
copy(a.IP, data[:_len])
a.Port = binary.BigEndian.Uint16(data[_len:])
return nil
}
var (
_ bencode.Marshaler = new(Addr)
_ bencode.Unmarshaler = new(Addr)
)
// MarshalBencode implements the interface bencode.Marshaler.
func (a Addr) MarshalBencode() (b []byte, err error) {
if b, err = a.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *Addr) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = a.UnmarshalBinary(data)
}
return
}
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
var (
_ bencode.Marshaler = new(CompactIPv4Addrs)
_ bencode.Unmarshaler = new(CompactIPv4Addrs)
_ encoding.BinaryMarshaler = new(CompactIPv4Addrs)
_ encoding.BinaryUnmarshaler = new(CompactIPv4Addrs)
)
// CompactIPv4Addrs is a set of IPv4 Addrs.
type CompactIPv4Addrs []Addr
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (cas CompactIPv4Addrs) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(6 * len(cas))
for _, addr := range cas {
if addr.IP = addr.IP.To4(); len(addr.IP) == 0 {
continue
}
if n, err := addr.WriteBinary(buf); err != nil {
return nil, err
} else if n != 6 {
panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (cas *CompactIPv4Addrs) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%6 != 0 {
return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len)
}
addrs := make(CompactIPv4Addrs, 0, _len/6)
for i := 0; i < _len; i += 6 {
var addr Addr
if err = addr.UnmarshalBinary(b[i : i+6]); err != nil {
return
}
addrs = append(addrs, addr)
}
*cas = addrs
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cas CompactIPv4Addrs) MarshalBencode() (b []byte, err error) {
if b, err = cas.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cas *CompactIPv4Addrs) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = cas.UnmarshalBinary(data)
}
return
}
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
var (
_ bencode.Marshaler = new(CompactIPv6Addrs)
_ bencode.Unmarshaler = new(CompactIPv6Addrs)
_ encoding.BinaryMarshaler = new(CompactIPv6Addrs)
_ encoding.BinaryUnmarshaler = new(CompactIPv6Addrs)
)
// CompactIPv6Addrs is a set of IPv6 Addrs.
type CompactIPv6Addrs []Addr
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (cas CompactIPv6Addrs) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(18 * len(cas))
for _, addr := range cas {
if addr.IP = addr.IP.To4(); len(addr.IP) == 0 {
continue
}
if n, err := addr.WriteBinary(buf); err != nil {
return nil, err
} else if n != 18 {
panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (cas *CompactIPv6Addrs) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%18 != 0 {
return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len)
}
addrs := make(CompactIPv6Addrs, 0, _len/18)
for i := 0; i < _len; i += 18 {
var addr Addr
if err = addr.UnmarshalBinary(b[i : i+18]); err != nil {
return
}
addrs = append(addrs, addr)
}
*cas = addrs
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cas CompactIPv6Addrs) MarshalBencode() (b []byte, err error) {
if b, err = cas.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cas *CompactIPv6Addrs) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = cas.UnmarshalBinary(data)
}
return
}

51
krpc/addr_test.go Normal file
View File

@ -0,0 +1,51 @@
// Copyright 2023 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package krpc
import (
"bytes"
"net"
"testing"
"github.com/xgfone/go-bt/bencode"
)
func TestAddr(t *testing.T) {
addrs := []Addr{
{IP: net.ParseIP("172.16.1.1").To4(), Port: 123},
{IP: net.ParseIP("192.168.1.1").To4(), Port: 456},
}
expect := "l6:\xac\x10\x01\x01\x00\x7b6:\xc0\xa8\x01\x01\x01\xc8e"
buf := new(bytes.Buffer)
if err := bencode.NewEncoder(buf).Encode(addrs); err != nil {
t.Error(err)
} else if result := buf.String(); result != expect {
t.Errorf("expect %s, but got %x\n", expect, result)
}
var raddrs []Addr
if err := bencode.DecodeString(expect, &raddrs); err != nil {
t.Error(err)
} else if len(raddrs) != len(addrs) {
t.Errorf("expect addrs length %d, but got %d\n", len(addrs), len(raddrs))
} else {
for i, addr := range addrs {
if !addr.Equal(raddrs[i]) {
t.Errorf("%d: expect addr %v, but got %v\n", i, addr, raddrs[i])
}
}
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -18,8 +18,8 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// Predefine some error code. // Predefine some error code.
@ -272,14 +272,14 @@ type ResponseResult struct {
// of the requested ipv4 target. // of the requested ipv4 target.
// //
// find_node // find_node
Nodes CompactIPv4Node `bencode:"nodes,omitempty"` // BEP 5 Nodes CompactIPv4Nodes `bencode:"nodes,omitempty"` // BEP 5
// Nodes6 is a string containing the compact node information for the list // Nodes6 is a string containing the compact node information for the list
// of the ipv6 target node, or the K(8) closest good nodes in routing table // of the ipv6 target node, or the K(8) closest good nodes in routing table
// of the requested ipv6 target. // of the requested ipv6 target.
// //
// find_node // find_node
Nodes6 CompactIPv6Node `bencode:"nodes6,omitempty"` // BEP 32 Nodes6 CompactIPv6Nodes `bencode:"nodes6,omitempty"` // BEP 32
// Token is used for future "announce_peer". // Token is used for future "announce_peer".
// //
@ -288,157 +288,14 @@ type ResponseResult struct {
// Values is a list of the torrent peers. // Values is a list of the torrent peers.
// //
// For each element, in general, it is a compact IP-address/port info and
// may be decoded to Addr, for example,
//
// addrs := make([]Addr, len(values))
// for i, v := range values {
// addrs[i].UnmarshalBinary([]byte(v))
// }
//
// get_peers // get_peers
Values CompactAddresses `bencode:"values,omitempty"` // BEP 5 Values []string `bencode:"values,omitempty"` // BEP 5
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// CompactAddresses represents a group of the compact addresses.
type CompactAddresses []metainfo.Address
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (cas CompactAddresses) MarshalBinary() ([]byte, error) {
ss := make([]string, len(cas))
for i, addr := range cas {
data, err := addr.MarshalBinary()
if err != nil {
return nil, err
}
ss[i] = string(data)
}
return bencode.EncodeBytes(ss)
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (cas *CompactAddresses) UnmarshalBinary(b []byte) (err error) {
var ss []string
if err = bencode.DecodeBytes(b, &ss); err == nil {
addrs := make(CompactAddresses, len(ss))
for i, s := range ss {
var addr metainfo.Address
if err = addr.UnmarshalBinary([]byte(s)); err != nil {
return
}
addrs[i] = addr
}
}
return
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// CompactIPv4Node is a set of IPv4 Nodes.
type CompactIPv4Node []Node
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (cn CompactIPv4Node) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(26 * len(cn))
for _, ni := range cn {
if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 {
continue
}
if n, err := ni.WriteBinary(buf); err != nil {
return nil, err
} else if n != 26 {
panic(fmt.Errorf("CompactIPv4NodeInfo: the invalid NodeInfo length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (cn *CompactIPv4Node) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%26 != 0 {
return fmt.Errorf("CompactIPv4NodeInfo: invalid bytes length '%d'", _len)
}
nis := make([]Node, 0, _len/26)
for i := 0; i < _len; i += 26 {
var ni Node
if err = ni.UnmarshalBinary(b[i : i+26]); err != nil {
return
}
nis = append(nis, ni)
}
*cn = nis
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cn CompactIPv4Node) MarshalBencode() (b []byte, err error) {
if b, err = cn.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cn *CompactIPv4Node) UnmarshalBencode(b []byte) (err error) {
var s string
if err = bencode.DecodeBytes(b, &s); err == nil {
err = cn.UnmarshalBinary([]byte(s))
}
return
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// CompactIPv6Node is a set of IPv6 Nodes.
type CompactIPv6Node []Node
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (cn CompactIPv6Node) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(38 * len(cn))
for _, ni := range cn {
ni.Addr.IP = ni.Addr.IP.To16()
if n, err := ni.WriteBinary(buf); err != nil {
return nil, err
} else if n != 38 {
panic(fmt.Errorf("CompactIPv6NodeInfo: the invalid NodeInfo length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (cn *CompactIPv6Node) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%38 != 0 {
return fmt.Errorf("CompactIPv6NodeInfo: invalid bytes length '%d'", _len)
}
nis := make([]Node, 0, _len/38)
for i := 0; i < _len; i += 38 {
var ni Node
if err = ni.UnmarshalBinary(b[i : i+38]); err != nil {
return
}
nis = append(nis, ni)
}
*cn = nis
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cn CompactIPv6Node) MarshalBencode() (b []byte, err error) {
if b, err = cn.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cn *CompactIPv6Node) UnmarshalBencode(b []byte) (err error) {
var s string
if err = bencode.DecodeBytes(b, &s); err == nil {
err = cn.UnmarshalBinary([]byte(s))
}
return
} }

View File

@ -17,7 +17,7 @@ package krpc
import ( import (
"testing" "testing"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
) )
func TestMessage(t *testing.T) { func TestMessage(t *testing.T) {

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -16,29 +16,23 @@ package krpc
import ( import (
"bytes" "bytes"
"encoding"
"fmt" "fmt"
"io" "io"
"net"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/go-bt/metainfo"
) )
// Node represents a node information. // Node represents a node information.
type Node struct { type Node struct {
ID metainfo.Hash ID metainfo.Hash
Addr metainfo.Address Addr Addr
} }
// NewNode returns a new Node. // NewNode returns a new Node.
func NewNode(id metainfo.Hash, ip net.IP, port int) Node { func NewNode(id metainfo.Hash, addr Addr) Node {
return Node{ID: id, Addr: metainfo.NewAddress(ip, uint16(port))} return Node{ID: id, Addr: addr}
}
// NewNodeByUDPAddr returns a new Node with the id and the UDP address.
func NewNodeByUDPAddr(id metainfo.Hash, addr *net.UDPAddr) (n Node) {
n.ID = id
n.Addr.FromUDPAddr(addr)
return
} }
func (n Node) String() string { func (n Node) String() string {
@ -63,17 +57,28 @@ func (n Node) WriteBinary(w io.Writer) (m int, err error) {
return return
} }
// MarshalBinary implements the interface binary.BinaryMarshaler. var (
_ encoding.BinaryMarshaler = new(Node)
_ encoding.BinaryUnmarshaler = new(Node)
)
// MarshalBinary implements the interface encoding.BinaryMarshaler,
// which implements "Compact node info".
//
// See http://bittorrent.org/beps/bep_0005.html.
func (n Node) MarshalBinary() (data []byte, err error) { func (n Node) MarshalBinary() (data []byte, err error) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
buf.Grow(48) buf.Grow(40)
if _, err = n.WriteBinary(buf); err == nil { if _, err = n.WriteBinary(buf); err == nil {
data = buf.Bytes() data = buf.Bytes()
} }
return return
} }
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. // UnmarshalBinary implements the interface encoding.BinaryUnmarshaler,
// which implements "Compact node info".
//
// See http://bittorrent.org/beps/bep_0005.html.
func (n *Node) UnmarshalBinary(b []byte) error { func (n *Node) UnmarshalBinary(b []byte) error {
if len(b) < 26 { if len(b) < 26 {
return io.ErrShortBuffer return io.ErrShortBuffer
@ -82,3 +87,133 @@ func (n *Node) UnmarshalBinary(b []byte) error {
copy(n.ID[:], b[:20]) copy(n.ID[:], b[:20])
return n.Addr.UnmarshalBinary(b[20:]) return n.Addr.UnmarshalBinary(b[20:])
} }
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
var (
_ bencode.Marshaler = new(CompactIPv4Nodes)
_ bencode.Unmarshaler = new(CompactIPv4Nodes)
_ encoding.BinaryMarshaler = new(CompactIPv4Nodes)
_ encoding.BinaryUnmarshaler = new(CompactIPv4Nodes)
)
// CompactIPv4Nodes is a set of IPv4 Nodes.
type CompactIPv4Nodes []Node
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (cns CompactIPv4Nodes) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(26 * len(cns))
for _, ni := range cns {
if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 {
continue
}
if n, err := ni.WriteBinary(buf); err != nil {
return nil, err
} else if n != 26 {
panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (cns *CompactIPv4Nodes) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%26 != 0 {
return fmt.Errorf("CompactIPv4Nodes: invalid node info length '%d'", _len)
}
nis := make([]Node, 0, _len/26)
for i := 0; i < _len; i += 26 {
var ni Node
if err = ni.UnmarshalBinary(b[i : i+26]); err != nil {
return
}
nis = append(nis, ni)
}
*cns = nis
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cns CompactIPv4Nodes) MarshalBencode() (b []byte, err error) {
if b, err = cns.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cns *CompactIPv4Nodes) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = cns.UnmarshalBinary(data)
}
return
}
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
var (
_ bencode.Marshaler = new(CompactIPv6Nodes)
_ bencode.Unmarshaler = new(CompactIPv6Nodes)
_ encoding.BinaryMarshaler = new(CompactIPv6Nodes)
_ encoding.BinaryUnmarshaler = new(CompactIPv6Nodes)
)
// CompactIPv6Nodes is a set of IPv6 Nodes.
type CompactIPv6Nodes []Node
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (cns CompactIPv6Nodes) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(38 * len(cns))
for _, ni := range cns {
ni.Addr.IP = ni.Addr.IP.To16()
if n, err := ni.WriteBinary(buf); err != nil {
return nil, err
} else if n != 38 {
panic(fmt.Errorf("CompactIPv6Nodes: the invalid node info length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (cns *CompactIPv6Nodes) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%38 != 0 {
return fmt.Errorf("CompactIPv6Nodes: invalid node info length '%d'", _len)
}
nis := make([]Node, 0, _len/38)
for i := 0; i < _len; i += 38 {
var ni Node
if err = ni.UnmarshalBinary(b[i : i+38]); err != nil {
return
}
nis = append(nis, ni)
}
*cns = nis
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cns CompactIPv6Nodes) MarshalBencode() (b []byte, err error) {
if b, err = cns.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cns *CompactIPv6Nodes) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = cns.UnmarshalBinary(data)
}
return
}

270
metainfo/addr_compact.go Normal file
View File

@ -0,0 +1,270 @@
// Copyright 2023 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"bytes"
"encoding"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"github.com/xgfone/go-bt/bencode"
)
// CompactAddr represents an address based on ip and port,
// which implements "Compact IP-address/port info".
//
// See http://bittorrent.org/beps/bep_0005.html.
type CompactAddr struct {
IP net.IP // For IPv4, its length must be 4.
Port uint16
}
// NewCompactAddr returns a new compact Addr with ip and port.
func NewCompactAddr(ip net.IP, port uint16) CompactAddr {
return CompactAddr{IP: ip, Port: port}
}
// NewCompactAddrFromUDPAddr converts *net.UDPAddr to a new CompactAddr.
func NewCompactAddrFromUDPAddr(ua *net.UDPAddr) CompactAddr {
return CompactAddr{IP: ua.IP, Port: uint16(ua.Port)}
}
// Valid reports whether the addr is valid.
func (a CompactAddr) Valid() bool {
return len(a.IP) > 0 && a.Port > 0
}
// Equal reports whether a is equal to o.
func (a CompactAddr) Equal(o CompactAddr) bool {
return a.Port == o.Port && a.IP.Equal(o.IP)
}
// UDPAddr converts itself to *net.Addr.
func (a CompactAddr) UDPAddr() *net.UDPAddr {
return &net.UDPAddr{IP: a.IP, Port: int(a.Port)}
}
var _ net.Addr = CompactAddr{}
// Network implements the interface net.Addr#Network.
func (a CompactAddr) Network() string {
return "krpc"
}
func (a CompactAddr) String() string {
if a.Port == 0 {
return a.IP.String()
}
return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10))
}
// WriteBinary is the same as MarshalBinary, but writes the result into w
// instead of returning.
func (a CompactAddr) WriteBinary(w io.Writer) (n int, err error) {
if n, err = w.Write(a.IP); err == nil {
if err = binary.Write(w, binary.BigEndian, a.Port); err == nil {
n += 2
}
}
return
}
var (
_ encoding.BinaryMarshaler = new(CompactAddr)
_ encoding.BinaryUnmarshaler = new(CompactAddr)
)
// MarshalBinary implements the interface encoding.BinaryMarshaler,
func (a CompactAddr) MarshalBinary() (data []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(18)
if _, err = a.WriteBinary(buf); err == nil {
data = buf.Bytes()
}
return
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (a *CompactAddr) UnmarshalBinary(data []byte) error {
_len := len(data) - 2
switch _len {
case net.IPv4len, net.IPv6len:
default:
return errors.New("invalid compact ip-address/port info")
}
a.IP = make(net.IP, _len)
copy(a.IP, data[:_len])
a.Port = binary.BigEndian.Uint16(data[_len:])
return nil
}
var (
_ bencode.Marshaler = new(CompactAddr)
_ bencode.Unmarshaler = new(CompactAddr)
)
// MarshalBencode implements the interface bencode.Marshaler.
func (a CompactAddr) MarshalBencode() (b []byte, err error) {
if b, err = a.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *CompactAddr) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = a.UnmarshalBinary(data)
}
return
}
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
var (
_ bencode.Marshaler = new(CompactIPv4Addrs)
_ bencode.Unmarshaler = new(CompactIPv4Addrs)
_ encoding.BinaryMarshaler = new(CompactIPv4Addrs)
_ encoding.BinaryUnmarshaler = new(CompactIPv4Addrs)
)
// CompactIPv4Addrs is a set of IPv4 Addrs.
type CompactIPv4Addrs []CompactAddr
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (cas CompactIPv4Addrs) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(6 * len(cas))
for _, addr := range cas {
if addr.IP = addr.IP.To4(); len(addr.IP) == 0 {
continue
}
if n, err := addr.WriteBinary(buf); err != nil {
return nil, err
} else if n != 6 {
panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (cas *CompactIPv4Addrs) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%6 != 0 {
return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len)
}
addrs := make(CompactIPv4Addrs, 0, _len/6)
for i := 0; i < _len; i += 6 {
var addr CompactAddr
if err = addr.UnmarshalBinary(b[i : i+6]); err != nil {
return
}
addrs = append(addrs, addr)
}
*cas = addrs
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cas CompactIPv4Addrs) MarshalBencode() (b []byte, err error) {
if b, err = cas.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cas *CompactIPv4Addrs) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = cas.UnmarshalBinary(data)
}
return
}
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
var (
_ bencode.Marshaler = new(CompactIPv6Addrs)
_ bencode.Unmarshaler = new(CompactIPv6Addrs)
_ encoding.BinaryMarshaler = new(CompactIPv6Addrs)
_ encoding.BinaryUnmarshaler = new(CompactIPv6Addrs)
)
// CompactIPv6Addrs is a set of IPv6 Addrs.
type CompactIPv6Addrs []CompactAddr
// MarshalBinary implements the interface encoding.BinaryMarshaler.
func (cas CompactIPv6Addrs) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.Grow(18 * len(cas))
for _, addr := range cas {
addr.IP = addr.IP.To16()
if n, err := addr.WriteBinary(buf); err != nil {
return nil, err
} else if n != 18 {
panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n))
}
}
return buf.Bytes(), nil
}
// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler.
func (cas *CompactIPv6Addrs) UnmarshalBinary(b []byte) (err error) {
_len := len(b)
if _len%18 != 0 {
return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len)
}
addrs := make(CompactIPv6Addrs, 0, _len/18)
for i := 0; i < _len; i += 18 {
var addr CompactAddr
if err = addr.UnmarshalBinary(b[i : i+18]); err != nil {
return
}
addrs = append(addrs, addr)
}
*cas = addrs
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (cas CompactIPv6Addrs) MarshalBencode() (b []byte, err error) {
if b, err = cas.MarshalBinary(); err == nil {
b, err = bencode.EncodeBytes(b)
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (cas *CompactIPv6Addrs) UnmarshalBencode(b []byte) (err error) {
var data []byte
if err = bencode.DecodeBytes(b, &data); err == nil {
err = cas.UnmarshalBinary(data)
}
return
}

View File

@ -0,0 +1,51 @@
// Copyright 2023 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"bytes"
"net"
"testing"
"github.com/xgfone/go-bt/bencode"
)
func TestCompactAddr(t *testing.T) {
addrs := []CompactAddr{
{IP: net.ParseIP("172.16.1.1").To4(), Port: 123},
{IP: net.ParseIP("192.168.1.1").To4(), Port: 456},
}
expect := "l6:\xac\x10\x01\x01\x00\x7b6:\xc0\xa8\x01\x01\x01\xc8e"
buf := new(bytes.Buffer)
if err := bencode.NewEncoder(buf).Encode(addrs); err != nil {
t.Error(err)
} else if result := buf.String(); result != expect {
t.Errorf("expect %s, but got %x\n", expect, result)
}
var raddrs []CompactAddr
if err := bencode.DecodeString(expect, &raddrs); err != nil {
t.Error(err)
} else if len(raddrs) != len(addrs) {
t.Errorf("expect addrs length %d, but got %d\n", len(addrs), len(raddrs))
} else {
for i, addr := range addrs {
if !addr.Equal(raddrs[i]) {
t.Errorf("%d: expect addr %v, but got %v\n", i, addr, raddrs[i])
}
}
}
}

115
metainfo/addr_host.go Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"bytes"
"fmt"
"net"
"strconv"
"github.com/xgfone/go-bt/bencode"
)
// HostAddr represents an address based on host and port.
type HostAddr struct {
Host string
Port uint16
}
// NewHostAddr returns a new host Addr.
func NewHostAddr(host string, port uint16) HostAddr {
return HostAddr{Host: host, Port: port}
}
// ParseHostAddr parses a string s to Addr.
func ParseHostAddr(s string) (HostAddr, error) {
host, port, err := net.SplitHostPort(s)
if err != nil {
return HostAddr{}, err
}
_port, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return HostAddr{}, err
}
return NewHostAddr(host, uint16(_port)), nil
}
func (a HostAddr) String() string {
if a.Port == 0 {
return a.Host
}
return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10))
}
// Equal reports whether a is equal to o.
func (a HostAddr) Equal(o HostAddr) bool {
return a.Port == o.Port && a.Host == o.Host
}
var (
_ bencode.Marshaler = new(HostAddr)
_ bencode.Unmarshaler = new(HostAddr)
)
// MarshalBencode implements the interface bencode.Marshaler.
func (a HostAddr) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(64)
err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port})
if err == nil {
b = buf.Bytes()
}
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *HostAddr) UnmarshalBencode(b []byte) (err error) {
var iface interface{}
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
return
}
switch v := iface.(type) {
case string:
*a, err = ParseHostAddr(v)
case []interface{}:
err = a.decode(v)
default:
err = fmt.Errorf("unsupported type: %T", iface)
}
return
}
func (a *HostAddr) decode(vs []interface{}) (err error) {
defer func() {
switch e := recover().(type) {
case nil:
case error:
err = e
default:
err = fmt.Errorf("%v", e)
}
}()
a.Host = vs[0].(string)
a.Port = uint16(vs[1].(int64))
return
}

View File

@ -16,33 +16,33 @@ package metainfo
import ( import (
"testing" "testing"
"github.com/xgfone/go-bt/bencode"
) )
func TestAddress(t *testing.T) { func TestAddress(t *testing.T) {
var addr1 Address addrs := []HostAddr{
if err := addr1.FromString("1.2.3.4:1234"); err != nil { {Host: "1.2.3.4", Port: 123},
{Host: "www.example.com", Port: 456},
}
expect := `ll7:1.2.3.4i123eel15:www.example.comi456eee`
if result, err := bencode.EncodeString(addrs); err != nil {
t.Error(err) t.Error(err)
return } else if result != expect {
t.Errorf("expect %s, but got %s\n", expect, result)
} }
data, err := addr1.MarshalBencode() var raddrs []HostAddr
if err != nil { if err := bencode.DecodeString(expect, &raddrs); err != nil {
t.Error(err) t.Error(err)
return } else if len(raddrs) != len(addrs) {
} else if s := string(data); s != `l7:1.2.3.4i1234ee` { t.Errorf("expect addrs length %d, but got %d\n", len(addrs), len(raddrs))
t.Errorf(`expected 'l7:1.2.3.4i1234ee', but got '%s'`, s) } else {
} for i, addr := range addrs {
if !addr.Equal(raddrs[i]) {
var addr2 Address t.Errorf("%d: expect %v, but got %v\n", i, addr, raddrs[i])
if err = addr2.UnmarshalBencode(data); err != nil { }
t.Error(err) }
} else if addr2.String() != `1.2.3.4:1234` {
t.Errorf("expected '1.2.3.4:1234', but got '%s'", addr2)
}
if data, err = addr2.MarshalBinary(); err != nil {
t.Error(err)
} else if s := string(data); s != "\x01\x02\x03\x04\x04\xd2" {
t.Errorf(`expected '\x01\x02\x03\x04\x04\xd2', but got '%#x'`, s)
} }
} }

View File

@ -1,354 +0,0 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metainfo
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"github.com/xgfone/bt/bencode"
)
// ErrInvalidAddr is returned when the compact address is invalid.
var ErrInvalidAddr = fmt.Errorf("invalid compact information of ip and port")
// Address represents a client/server listening on a UDP port implementing
// the DHT protocol.
type Address struct {
IP net.IP // For IPv4, its length must be 4.
Port uint16
}
// NewAddress returns a new Address.
func NewAddress(ip net.IP, port uint16) Address {
if ipv4 := ip.To4(); len(ipv4) > 0 {
ip = ipv4
}
return Address{IP: ip, Port: port}
}
// NewAddressFromString returns a new Address by the address string.
func NewAddressFromString(s string) (addr Address, err error) {
err = addr.FromString(s)
return
}
// NewAddressesFromString returns a list of Addresses by the address string.
func NewAddressesFromString(s string) (addrs []Address, err error) {
shost, sport, err := net.SplitHostPort(s)
if err != nil {
return nil, fmt.Errorf("invalid address '%s': %s", s, err)
}
var port uint16
if sport != "" {
v, err := strconv.ParseUint(sport, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid address '%s': %s", s, err)
}
port = uint16(v)
}
ips, err := net.LookupIP(shost)
if err != nil {
return nil, fmt.Errorf("fail to lookup the domain '%s': %s", shost, err)
}
addrs = make([]Address, len(ips))
for i, ip := range ips {
if ipv4 := ip.To4(); len(ipv4) != 0 {
addrs[i] = Address{IP: ipv4, Port: port}
} else {
addrs[i] = Address{IP: ip, Port: port}
}
}
return
}
// FromString parses and sets the ip from the string addr.
func (a *Address) FromString(addr string) (err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
}
if port != "" {
v, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
}
a.Port = uint16(v)
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("fail to lookup the domain '%s': %s", host, err)
} else if len(ips) == 0 {
return fmt.Errorf("the domain '%s' has no ips", host)
}
a.IP = ips[0]
if ip := a.IP.To4(); len(ip) > 0 {
a.IP = ip
}
return
}
// FromUDPAddr sets the ip from net.UDPAddr.
func (a *Address) FromUDPAddr(ua *net.UDPAddr) {
a.Port = uint16(ua.Port)
a.IP = ua.IP
if ipv4 := a.IP.To4(); len(ipv4) != 0 {
a.IP = ipv4
}
}
// UDPAddr creates a new net.UDPAddr.
func (a Address) UDPAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: a.IP,
Port: int(a.Port),
}
}
func (a Address) String() string {
if a.Port == 0 {
return a.IP.String()
}
return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10))
}
// Equal reports whether n is equal to o, which is equal to
// n.HasIPAndPort(o.IP, o.Port)
func (a Address) Equal(o Address) bool {
return a.Port == o.Port && a.IP.Equal(o.IP)
}
// HasIPAndPort reports whether the current node has the ip and the port.
func (a Address) HasIPAndPort(ip net.IP, port uint16) bool {
return port == a.Port && a.IP.Equal(ip)
}
// WriteBinary is the same as MarshalBinary, but writes the result into w
// instead of returning.
func (a Address) WriteBinary(w io.Writer) (m int, err error) {
if m, err = w.Write(a.IP); err == nil {
if err = binary.Write(w, binary.BigEndian, a.Port); err == nil {
m += 2
}
}
return
}
// UnmarshalBinary implements the interface binary.BinaryUnmarshaler.
func (a *Address) UnmarshalBinary(b []byte) (err error) {
_len := len(b) - 2
switch _len {
case net.IPv4len, net.IPv6len:
default:
return ErrInvalidAddr
}
a.IP = make(net.IP, _len)
copy(a.IP, b[:_len])
a.Port = binary.BigEndian.Uint16(b[_len:])
return
}
// MarshalBinary implements the interface binary.BinaryMarshaler.
func (a Address) MarshalBinary() (data []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(20)
if _, err = a.WriteBinary(buf); err == nil {
data = buf.Bytes()
}
return
}
func (a *Address) decode(vs []interface{}) (err error) {
defer func() {
switch e := recover().(type) {
case nil:
case error:
err = e
default:
err = fmt.Errorf("%v", e)
}
}()
host := vs[0].(string)
if a.IP = net.ParseIP(host); len(a.IP) == 0 {
return ErrInvalidAddr
} else if ip := a.IP.To4(); len(ip) != 0 {
a.IP = ip
}
a.Port = uint16(vs[1].(int64))
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *Address) UnmarshalBencode(b []byte) (err error) {
var iface interface{}
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
return
}
switch v := iface.(type) {
case string:
err = a.FromString(v)
case []interface{}:
err = a.decode(v)
default:
err = fmt.Errorf("unsupported type: %T", iface)
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (a Address) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(32)
err = bencode.NewEncoder(buf).Encode([]interface{}{a.IP.String(), a.Port})
if err == nil {
b = buf.Bytes()
}
return
}
/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// HostAddress is the same as the Address, but the host part may be
// either a domain or a ip.
type HostAddress struct {
Host string
Port uint16
}
// NewHostAddress returns a new host addrress.
func NewHostAddress(host string, port uint16) HostAddress {
return HostAddress{Host: host, Port: port}
}
// NewHostAddressFromString returns a new host address by the string.
func NewHostAddressFromString(s string) (addr HostAddress, err error) {
err = addr.FromString(s)
return
}
// FromString parses and sets the host from the string addr.
func (a *HostAddress) FromString(addr string) (err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
} else if host == "" {
return fmt.Errorf("invalid address '%s': missing host", addr)
}
if port != "" {
v, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return fmt.Errorf("invalid address '%s': %s", addr, err)
}
a.Port = uint16(v)
}
a.Host = host
return
}
func (a HostAddress) String() string {
if a.Port == 0 {
return a.Host
}
return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10))
}
// Addresses parses the host address to a list of Addresses.
func (a HostAddress) Addresses() (addrs []Address, err error) {
if ip := net.ParseIP(a.Host); len(ip) != 0 {
return []Address{NewAddress(ip, a.Port)}, nil
}
ips, err := net.LookupIP(a.Host)
if err != nil {
err = fmt.Errorf("fail to lookup the domain '%s': %s", a.Host, err)
} else {
addrs = make([]Address, len(ips))
for i, ip := range ips {
addrs[i] = NewAddress(ip, a.Port)
}
}
return
}
// Equal reports whether a is equal to o.
func (a HostAddress) Equal(o HostAddress) bool {
return a.Port == o.Port && a.Host == o.Host
}
func (a *HostAddress) decode(vs []interface{}) (err error) {
defer func() {
switch e := recover().(type) {
case nil:
case error:
err = e
default:
err = fmt.Errorf("%v", e)
}
}()
a.Host = vs[0].(string)
a.Port = uint16(vs[1].(int64))
return
}
// UnmarshalBencode implements the interface bencode.Unmarshaler.
func (a *HostAddress) UnmarshalBencode(b []byte) (err error) {
var iface interface{}
if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil {
return
}
switch v := iface.(type) {
case string:
err = a.FromString(v)
case []interface{}:
err = a.decode(v)
default:
err = fmt.Errorf("unsupported type: %T", iface)
}
return
}
// MarshalBencode implements the interface bencode.Marshaler.
func (a HostAddress) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(nil)
buf.Grow(32)
err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port})
if err == nil {
b = buf.Bytes()
}
return
}

View File

@ -14,9 +14,7 @@
package metainfo package metainfo
import ( import "path/filepath"
"path/filepath"
)
// File represents a file in the multi-file case. // File represents a file in the multi-file case.
type File struct { type File struct {

View File

@ -89,10 +89,10 @@ func TestNewInfoFromFilePath(t *testing.T) {
t.Errorf("invalid info %+v\n", info) t.Errorf("invalid info %+v\n", info)
} }
info, err = NewInfoFromFilePath("../../bt", PieceSize256KB) info, err = NewInfoFromFilePath("../../go-bt", PieceSize256KB)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if info.Name != "bt" || info.Files == nil || info.Length > 0 { } else if info.Name != "go-bt" || info.Files == nil || info.Length > 0 {
t.Errorf("invalid info %+v\n", info) t.Errorf("invalid info %+v\n", info)
} }
} }

View File

@ -24,7 +24,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
) )
var zeroHash Hash var zeroHash Hash
@ -145,10 +145,12 @@ func (h *Hash) FromString(s string) (err error) {
copy(h[:], bs) copy(h[:], bs)
} }
default: default:
err = fmt.Errorf("hash string has bad length: %d", len(s)) hasher := sha1.New()
hasher.Write([]byte(s))
copy(h[:], hasher.Sum(nil))
} }
return nil return
} }
// FromHexString resets the info hash from the hex string. // FromHexString resets the info hash from the hex string.

View File

@ -25,16 +25,15 @@ const xtPrefix = "urn:btih:"
// Peers returns the list of the addresses of the peers. // Peers returns the list of the addresses of the peers.
// //
// See BEP 9 // See BEP 9
func (m Magnet) Peers() (peers []HostAddress, err error) { func (m Magnet) Peers() (peers []HostAddr, err error) {
vs := m.Params["x.pe"] vs := m.Params["x.pe"]
peers = make([]HostAddress, 0, len(vs)) peers = make([]HostAddr, 0, len(vs))
for _, v := range vs { for _, v := range vs {
if v != "" { if v != "" {
var addr HostAddress addr, err := ParseHostAddr(v)
if err = addr.FromString(v); err != nil { if err != nil {
return return nil, err
} }
peers = append(peers, addr) peers = append(peers, addr)
} }
} }

View File

@ -20,8 +20,8 @@ import (
"os" "os"
"strings" "strings"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/bt/utils" "github.com/xgfone/go-bt/internal/helper"
) )
// Bytes is the []byte type. // Bytes is the []byte type.
@ -35,7 +35,7 @@ func (al AnnounceList) Unique() (announces []string) {
announces = make([]string, 0, len(al)) announces = make([]string, 0, len(al))
for _, tier := range al { for _, tier := range al {
for _, v := range tier { for _, v := range tier {
if v != "" && !utils.InStringSlice(announces, v) { if v != "" && !helper.ContainsString(announces, v) {
announces = append(announces, v) announces = append(announces, v)
} }
} }
@ -93,11 +93,11 @@ func (us *URLList) UnmarshalBencode(b []byte) (err error) {
// MetaInfo represents the .torrent file. // MetaInfo represents the .torrent file.
type MetaInfo struct { type MetaInfo struct {
InfoBytes Bytes `bencode:"info"` // BEP 3 InfoBytes Bytes `bencode:"info"` // BEP 3
Announce string `bencode:"announce,omitempty"` // BEP 3 Announce string `bencode:"announce,omitempty"` // BEP 3, Single Tracker
AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12 AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12, Multi-Tracker
Nodes []HostAddress `bencode:"nodes,omitempty"` // BEP 5 Nodes []HostAddr `bencode:"nodes,omitempty"` // BEP 5, DHT
URLList URLList `bencode:"url-list,omitempty"` // BEP 19 URLList URLList `bencode:"url-list,omitempty"` // BEP 19, WebSeed
// Where's this specified? // Where's this specified?
// Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification. // Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification.

View File

@ -21,9 +21,12 @@ import (
"io" "io"
"sort" "sort"
"github.com/xgfone/bt/utils" "github.com/xgfone/go-bt/internal/helper"
) )
// BlockSize is the default size of a piece block.
const BlockSize = 16 * 1024 // 2^14 = 16KB
// Predefine some sizes of the pieces. // Predefine some sizes of the pieces.
const ( const (
PieceSize256KB = 1024 * 256 PieceSize256KB = 1024 * 256
@ -69,7 +72,7 @@ func GeneratePieces(r io.Reader, pieceLength int64) (hs Hashes, err error) {
buf := make([]byte, pieceLength) buf := make([]byte, pieceLength)
for { for {
h := sha1.New() h := sha1.New()
written, err := utils.CopyNBuffer(h, r, pieceLength, buf) written, err := helper.CopyNBuffer(h, r, pieceLength, buf)
if written > 0 { if written > 0 {
hs = append(hs, NewHash(h.Sum(nil))) hs = append(hs, NewHash(h.Sum(nil)))
} }
@ -92,7 +95,7 @@ func writeFiles(w io.Writer, files []File, open func(File) (io.ReadCloser, error
return fmt.Errorf("error opening %s: %s", file, err) return fmt.Errorf("error opening %s: %s", file, err)
} }
n, err := utils.CopyNBuffer(w, r, file.Length, buf) n, err := helper.CopyNBuffer(w, r, file.Length, buf)
r.Close() r.Close()
if n != file.Length { if n != file.Length {

View File

@ -86,7 +86,7 @@ func (w *writer) Close() error {
return nil return nil
} }
// WriteBlock writes a data block. // WriteBlock writes a block data.
func (w *writer) WriteBlock(pieceIndex, pieceOffset uint32, p []byte) (int, error) { func (w *writer) WriteBlock(pieceIndex, pieceOffset uint32, p []byte) (int, error) {
return w.WriteAt(p, w.info.PieceOffset(pieceIndex, pieceOffset)) return w.WriteAt(p, w.info.PieceOffset(pieceIndex, pieceOffset))
} }

View File

@ -19,7 +19,7 @@ import (
"errors" "errors"
"net" "net"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
) )
var errInvalidIP = errors.New("invalid ipv4 or ipv6") var errInvalidIP = errors.New("invalid ipv4 or ipv6")
@ -86,7 +86,7 @@ type ExtendedHandshakeMsg struct {
// M is the type of map[ExtendedMessageName]ExtendedMessageID. // M is the type of map[ExtendedMessageName]ExtendedMessageID.
M map[string]uint8 `bencode:"m"` // BEP 10 M map[string]uint8 `bencode:"m"` // BEP 10
V string `bencode:"v,omitempty"` // BEP 10 V string `bencode:"v,omitempty"` // BEP 10
Reqq int `bencode:"reqq,omitempty"` // BEP 10 Reqq uint64 `bencode:"reqq,omitempty"` // BEP 10. The default in in libtorrent is 250.
// Port is the local client port, which is redundant and no need // Port is the local client port, which is redundant and no need
// for the receiving side of the connection to send this. // for the receiving side of the connection to send this.
@ -95,7 +95,7 @@ type ExtendedHandshakeMsg struct {
IPv4 CompactIP `bencode:"ipv4,omitempty"` // BEP 10 IPv4 CompactIP `bencode:"ipv4,omitempty"` // BEP 10
YourIP CompactIP `bencode:"yourip,omitempty"` // BEP 10 YourIP CompactIP `bencode:"yourip,omitempty"` // BEP 10
MetadataSize int `bencode:"metadata_size,omitempty"` // BEP 9 MetadataSize uint64 `bencode:"metadata_size,omitempty"` // BEP 9
} }
// Decode decodes the extended handshake message from b. // Decode decodes the extended handshake message from b.
@ -118,7 +118,7 @@ type UtMetadataExtendedMsg struct {
Piece int `bencode:"piece"` // BEP 9 Piece int `bencode:"piece"` // BEP 9
// They are only used by "data" type // They are only used by "data" type
TotalSize int `bencode:"total_size,omitempty"` // BEP 9 TotalSize uint64 `bencode:"total_size,omitempty"` // BEP 9
Data []byte `bencode:"-"` Data []byte `bencode:"-"`
} }
@ -139,10 +139,9 @@ func (um UtMetadataExtendedMsg) EncodeToPayload(buf *bytes.Buffer) (err error) {
// EncodeToBytes is equal to // EncodeToBytes is equal to
// //
// buf := new(bytes.Buffer) // buf := new(bytes.Buffer)
// err = um.EncodeToPayload(buf) // err = um.EncodeToPayload(buf)
// return buf.Bytes(), err // return buf.Bytes(), err
//
func (um UtMetadataExtendedMsg) EncodeToBytes() (b []byte, err error) { func (um UtMetadataExtendedMsg) EncodeToBytes() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 128)) buf := bytes.NewBuffer(make([]byte, 0, 128))
if err = um.EncodeToPayload(buf); err == nil { if err = um.EncodeToPayload(buf); err == nil {

View File

@ -19,16 +19,17 @@ import (
"encoding/binary" "encoding/binary"
"net" "net"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// GenerateAllowedFastSet generates some allowed fast set of the torrent file. // GenerateAllowedFastSet generates some allowed fast set of the torrent file.
// //
// Argument: // Argument:
// set: generated piece set, the length of which is the number to be generated. //
// sz: the number of pieces in torrent. // set: generated piece set, the length of which is the number to be generated.
// ip: the of the remote peer of the connection. // sz: the number of pieces in torrent.
// infohash: infohash of torrent. // ip: the of the remote peer of the connection.
// infohash: infohash of torrent.
// //
// BEP 6 // BEP 6
func GenerateAllowedFastSet(set []uint32, sz uint32, ip net.IP, infohash metainfo.Hash) { func GenerateAllowedFastSet(set []uint32, sz uint32, ip net.IP, infohash metainfo.Hash) {

View File

@ -18,7 +18,7 @@ import (
"net" "net"
"testing" "testing"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
func TestGenerateAllowedFastSet(t *testing.T) { func TestGenerateAllowedFastSet(t *testing.T) {

View File

@ -19,7 +19,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
var errInvalidProtocolHeader = fmt.Errorf("unexpected peer protocol header string") var errInvalidProtocolHeader = fmt.Errorf("unexpected peer protocol header string")
@ -44,12 +44,12 @@ func (eb ExtensionBits) String() string {
} }
// Set sets the bit to 1, that's, to set it to be on. // Set sets the bit to 1, that's, to set it to be on.
func (eb ExtensionBits) Set(bit uint) { func (eb *ExtensionBits) Set(bit uint) {
eb[7-bit/8] |= 1 << (bit % 8) eb[7-bit/8] |= 1 << (bit % 8)
} }
// Unset sets the bit to 0, that's, to set it to be off. // Unset sets the bit to 0, that's, to set it to be off.
func (eb ExtensionBits) Unset(bit uint) { func (eb *ExtensionBits) Unset(bit uint) {
eb[7-bit/8] &^= 1 << (bit % 8) eb[7-bit/8] &^= 1 << (bit % 8)
} }

View File

@ -125,18 +125,27 @@ func (bf BitField) Unsets() (pieces Pieces) {
return return
} }
// CanSet reports whether the index can be set.
func (bf BitField) CanSet(index uint32) bool {
return (int(index) / 8) < len(bf)
}
// Set sets the bit of the piece to 1 by its index. // Set sets the bit of the piece to 1 by its index.
func (bf BitField) Set(index uint32) { func (bf BitField) Set(index uint32) (ok bool) {
if i := int(index) / 8; i < len(bf) { if i := int(index) / 8; i < len(bf) {
bf[i] |= (1 << byte(7-index%8)) bf[i] |= (1 << byte(7-index%8))
ok = true
} }
return
} }
// Unset sets the bit of the piece to 0 by its index. // Unset sets the bit of the piece to 0 by its index.
func (bf BitField) Unset(index uint32) { func (bf BitField) Unset(index uint32) (ok bool) {
if i := int(index) / 8; i < len(bf) { if i := int(index) / 8; i < len(bf) {
bf[i] &^= (1 << byte(7-index%8)) bf[i] &^= (1 << byte(7-index%8))
ok = true
} }
return
} }
// IsSet reports whether the bit of the piece is set to 1. // IsSet reports whether the bit of the piece is set to 1.
@ -335,7 +344,9 @@ func (m Message) Encode(buf *bytes.Buffer) (err error) {
if !m.Keepalive { if !m.Keepalive {
if err = buf.WriteByte(byte(m.Type)); err != nil { if err = buf.WriteByte(byte(m.Type)); err != nil {
return return
} else if err = m.marshalBinaryType(buf); err != nil { }
if err = m.marshalBinaryType(buf); err != nil {
return return
} }
@ -376,7 +387,7 @@ func (m Message) marshalBinaryType(buf *bytes.Buffer) (err error) {
} }
_, err = buf.Write(m.Piece) _, err = buf.Write(m.Piece)
case MTypeExtended: case MTypeExtended:
if err = buf.WriteByte(byte(m.ExtendedID)); err != nil { if err = buf.WriteByte(byte(m.ExtendedID)); err == nil {
_, err = buf.Write(m.ExtendedPayload) _, err = buf.Write(m.ExtendedPayload)
} }
case MTypePort: case MTypePort:

View File

@ -14,9 +14,7 @@
package peerprotocol package peerprotocol
import ( import "testing"
"testing"
)
func TestBitField(t *testing.T) { func TestBitField(t *testing.T) {
bf := NewBitFieldFromBools([]bool{ bf := NewBitFieldFromBools([]bool{

View File

@ -21,18 +21,18 @@ import (
"net" "net"
"time" "time"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// Predefine some errors about extension support. // Predefine some errors about extension support.
var ( var (
ErrChoked = fmt.Errorf("choked") ErrChoked = fmt.Errorf("choked")
ErrNotFirstMsg = fmt.Errorf("not the first message")
ErrNotSupportDHT = fmt.Errorf("not support DHT extension") ErrNotSupportDHT = fmt.Errorf("not support DHT extension")
ErrNotSupportFast = fmt.Errorf("not support Fast extension") ErrNotSupportFast = fmt.Errorf("not support Fast extension")
ErrNotSupportExtended = fmt.Errorf("not support Extended extension") ErrNotSupportExtended = fmt.Errorf("not support Extended extension")
ErrSecondExtHandshake = fmt.Errorf("second extended handshake") ErrSecondExtHandshake = fmt.Errorf("second extended handshake")
ErrNoExtMessageID = fmt.Errorf("no extended message id")
ErrNoExtHandshake = fmt.Errorf("no extended handshake") ErrNoExtHandshake = fmt.Errorf("no extended handshake")
) )
@ -79,6 +79,38 @@ type Bep10Handler interface {
OnPayload(conn *PeerConn, extid uint8, payload []byte) error OnPayload(conn *PeerConn, extid uint8, payload []byte) error
} }
// ConnStage represents the stage of connection to the peer.
type ConnStage int
// IsConnected reports whether the connection stage is connected.
func (s ConnStage) IsConnected() bool { return s == ConnStageConnected }
// IsHandshook reports whether the connection stage is handshook.
func (s ConnStage) IsHandshook() bool { return s == ConnStageHandshook }
// IsMessage reports whether the connection stage is message.
func (s ConnStage) IsMessage() bool { return s == ConnStageMessage }
func (s ConnStage) String() string {
switch s {
case ConnStageConnected:
return "connected"
case ConnStageHandshook:
return "handshook"
case ConnStageMessage:
return "message"
default:
return fmt.Sprintf("ConnStage(%d)", s)
}
}
// Pre-define some connection stages.
const (
ConnStageConnected ConnStage = iota
ConnStageHandshook
ConnStageMessage
)
// PeerConn is used to manage the connection to the peer. // PeerConn is used to manage the connection to the peer.
type PeerConn struct { type PeerConn struct {
net.Conn net.Conn
@ -90,6 +122,7 @@ type PeerConn struct {
PeerID metainfo.Hash // The ID of the remote peer. PeerID metainfo.Hash // The ID of the remote peer.
PeerExtBits ExtensionBits // The extension bits of the remote peer. PeerExtBits ExtensionBits // The extension bits of the remote peer.
PeerStage ConnStage // The connection stage of peer.
// These two states is controlled by the local client peer. // These two states is controlled by the local client peer.
// //
@ -122,9 +155,13 @@ type PeerConn struct {
PeerInterested bool PeerInterested bool
// Timeout is used to control the timeout of reading/writing the message. // Timeout is used to control the timeout of reading/writing the message.
// If WriteTimeout or ReadTimeout is ZERO, try to use Timeout instead.
//
// //
// The default is 0, which represents no timeout. // The default is 0, which represents no timeout.
Timeout time.Duration WriteTimeout time.Duration
ReadTimeout time.Duration
Timeout time.Duration
// MaxLength is used to limit the maximum number of the message body. // MaxLength is used to limit the maximum number of the message body.
// //
@ -147,9 +184,6 @@ type PeerConn struct {
// //
// Optional. // Optional.
OnWriteMsg func(pc *PeerConn, m Message) error OnWriteMsg func(pc *PeerConn, m Message) error
notFirstMsg bool
extHandshake bool
} }
// NewPeerConn returns a new PeerConn. // NewPeerConn returns a new PeerConn.
@ -158,6 +192,7 @@ type PeerConn struct {
// for the peer server, but not for the peer client. // for the peer server, but not for the peer client.
func NewPeerConn(conn net.Conn, id, infohash metainfo.Hash) *PeerConn { func NewPeerConn(conn net.Conn, id, infohash metainfo.Hash) *PeerConn {
return &PeerConn{ return &PeerConn{
PeerStage: ConnStageConnected,
Conn: conn, Conn: conn,
ID: id, ID: id,
InfoHash: infohash, InfoHash: infohash,
@ -168,8 +203,8 @@ func NewPeerConn(conn net.Conn, id, infohash metainfo.Hash) *PeerConn {
} }
// NewPeerConnByDial returns a new PeerConn by dialing to addr with the "tcp" network. // NewPeerConnByDial returns a new PeerConn by dialing to addr with the "tcp" network.
func NewPeerConnByDial(addr string, id, infohash metainfo.Hash, timeout time.Duration) (pc *PeerConn, err error) { func NewPeerConnByDial(addr string, id, infohash metainfo.Hash, dialTimeout time.Duration) (pc *PeerConn, err error) {
conn, err := net.DialTimeout("tcp", addr, timeout) conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err == nil { if err == nil {
pc = NewPeerConn(conn, id, infohash) pc = NewPeerConn(conn, id, infohash)
} }
@ -177,17 +212,53 @@ func NewPeerConnByDial(addr string, id, infohash metainfo.Hash, timeout time.Dur
} }
func (pc *PeerConn) setReadTimeout() { func (pc *PeerConn) setReadTimeout() {
if pc.Timeout > 0 { switch {
pc.Conn.SetReadDeadline(time.Now().Add(pc.Timeout)) case pc.ReadTimeout > 0:
pc.SetReadTimeout(pc.ReadTimeout)
case pc.Timeout > 0:
pc.SetReadTimeout(pc.Timeout)
} }
} }
func (pc *PeerConn) setWriteTimeout() { func (pc *PeerConn) setWriteTimeout() {
if pc.Timeout > 0 { switch {
pc.Conn.SetWriteDeadline(time.Now().Add(pc.Timeout)) case pc.WriteTimeout > 0:
pc.SetWriteTimeout(pc.WriteTimeout)
case pc.Timeout > 0:
pc.SetWriteTimeout(pc.Timeout)
} }
} }
// SetTimeout is a convenient method to set timeout of the read&write operation.
//
// If timeout is ZERO or negative, clear the read&write timeout.
func (pc *PeerConn) SetTimeout(timeout time.Duration) error {
if timeout > 0 {
pc.Conn.SetDeadline(time.Now().Add(timeout))
}
return pc.Conn.SetDeadline(time.Time{})
}
// SetReadTimeout is a convenient method to set timeout of the read operation.
//
// If timeout is ZERO or negative, clear the read timeout.
func (pc *PeerConn) SetReadTimeout(timeout time.Duration) error {
if timeout > 0 {
pc.Conn.SetReadDeadline(time.Now().Add(timeout))
}
return pc.Conn.SetReadDeadline(time.Time{})
}
// SetWriteTimeout is a convenient method to set timeout of the write operation.
//
// If timeout is ZERO or negative, clear the write timeout.
func (pc *PeerConn) SetWriteTimeout(timeout time.Duration) error {
if timeout > 0 {
pc.Conn.SetWriteDeadline(time.Now().Add(timeout))
}
return pc.Conn.SetWriteDeadline(time.Time{})
}
// SetChoked sets the Choked state of the local client peer. // SetChoked sets the Choked state of the local client peer.
// //
// Notice: if the current state is not Choked, it will send a Choked message // Notice: if the current state is not Choked, it will send a Choked message
@ -244,21 +315,29 @@ func (pc *PeerConn) SetNotInterested() (err error) {
// //
// BEP 3 // BEP 3
func (pc *PeerConn) Handshake() error { func (pc *PeerConn) Handshake() error {
if !pc.PeerStage.IsConnected() {
return fmt.Errorf("the peer connection stage '%s' is not connected", pc.PeerStage)
}
m := HandshakeMsg{ExtensionBits: pc.ExtBits, PeerID: pc.ID, InfoHash: pc.InfoHash} m := HandshakeMsg{ExtensionBits: pc.ExtBits, PeerID: pc.ID, InfoHash: pc.InfoHash}
pc.setReadTimeout() pc.setReadTimeout()
rhm, err := Handshake(pc.Conn, m) rhm, err := Handshake(pc.Conn, m)
if err == nil { if err != nil {
pc.PeerID = rhm.PeerID return err
pc.PeerExtBits = rhm.ExtensionBits
if pc.InfoHash.IsZero() {
pc.InfoHash = rhm.InfoHash
} else if pc.InfoHash != rhm.InfoHash {
return fmt.Errorf("inconsistent infohash: local(%s)=%s, remote(%s)=%s",
pc.Conn.LocalAddr().String(), pc.InfoHash.String(),
pc.Conn.RemoteAddr().String(), rhm.InfoHash.String())
}
} }
return err
pc.PeerID = rhm.PeerID
pc.PeerExtBits = rhm.ExtensionBits
if pc.InfoHash.IsZero() {
pc.InfoHash = rhm.InfoHash
} else if pc.InfoHash != rhm.InfoHash {
return fmt.Errorf("inconsistent infohash: local(%s)=%s, remote(%s)=%s",
pc.Conn.LocalAddr().String(), pc.InfoHash.String(),
pc.Conn.RemoteAddr().String(), rhm.InfoHash.String())
}
pc.PeerStage = ConnStageHandshook
return nil
} }
// ReadMsg reads the message. // ReadMsg reads the message.
@ -267,6 +346,22 @@ func (pc *PeerConn) Handshake() error {
func (pc *PeerConn) ReadMsg() (m Message, err error) { func (pc *PeerConn) ReadMsg() (m Message, err error) {
pc.setReadTimeout() pc.setReadTimeout()
err = m.Decode(pc.Conn, pc.MaxLength) err = m.Decode(pc.Conn, pc.MaxLength)
if err != nil {
return
}
switch m.Type {
case MTypeBitField, MTypeHaveAll, MTypeHaveNone:
if !pc.PeerStage.IsHandshook() {
err = fmt.Errorf("%s is not first message after handshake", m.Type.String())
return
}
}
if pc.PeerStage.IsHandshook() {
pc.PeerStage = ConnStageMessage
}
return return
} }
@ -444,14 +539,19 @@ func (pc *PeerConn) SendExtMsg(extID uint8, payload []byte) error {
}) })
} }
// PeerHasPiece reports whether the peer has the piece.
func (pc *PeerConn) PeerHasPiece(index uint32) bool {
return pc.BitField.IsSet(index)
}
// HandleMessage calls the method of the handler to handle the message. // HandleMessage calls the method of the handler to handle the message.
// //
// If handler has also implemented the interfaces Bep3Handler, Bep5Handler, // If handler has also implemented the interfaces Bep3Handler, Bep5Handler,
// Bep6Handler or Bep10Handler, their methods will be called instead of // Bep6Handler or Bep10Handler, their methods will be called instead of
// Handler.OnMessage for the corresponding type message. // Handler.OnMessage for the corresponding type message.
func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) { func (pc *PeerConn) HandleMessage(msg Message, handler Handler) error {
if msg.Keepalive { if msg.Keepalive {
return return nil
} }
switch msg.Type { switch msg.Type {
@ -459,167 +559,143 @@ func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
case MTypeChoke: case MTypeChoke:
pc.PeerChoked = true pc.PeerChoked = true
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Choke(pc) return h.Choke(pc)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeUnchoke: case MTypeUnchoke:
pc.PeerChoked = false pc.PeerChoked = false
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Unchoke(pc) return h.Unchoke(pc)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeInterested: case MTypeInterested:
pc.PeerInterested = true pc.PeerInterested = true
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Interested(pc) return h.Interested(pc)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeNotInterested: case MTypeNotInterested:
pc.PeerInterested = false pc.PeerInterested = false
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.NotInterested(pc) return h.NotInterested(pc)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeHave: case MTypeHave:
pc.BitField.Set(msg.Index)
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Have(pc, msg.Index) return h.Have(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeBitField: case MTypeBitField:
if pc.notFirstMsg { pc.BitField = msg.BitField
err = ErrNotFirstMsg if h, ok := handler.(Bep3Handler); ok {
} else { return h.BitField(pc, msg.BitField)
pc.BitField = msg.BitField
if h, ok := handler.(Bep3Handler); ok {
err = h.BitField(pc, msg.BitField)
} else {
err = handler.OnMessage(pc, msg)
}
} }
case MTypeRequest: case MTypeRequest:
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Request(pc, msg.Index, msg.Begin, msg.Length) return h.Request(pc, msg.Index, msg.Begin, msg.Length)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypePiece: case MTypePiece:
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Piece(pc, msg.Index, msg.Begin, msg.Piece) return h.Piece(pc, msg.Index, msg.Begin, msg.Piece)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeCancel: case MTypeCancel:
if h, ok := handler.(Bep3Handler); ok { if h, ok := handler.(Bep3Handler); ok {
err = h.Cancel(pc, msg.Index, msg.Begin, msg.Length) return h.Cancel(pc, msg.Index, msg.Begin, msg.Length)
} else {
err = handler.OnMessage(pc, msg)
} }
// BEP 5 - DHT Protocol // BEP 5 - DHT Protocol
case MTypePort: case MTypePort:
if !pc.ExtBits.IsSupportDHT() { if !pc.ExtBits.IsSupportDHT() {
err = ErrNotSupportDHT return ErrNotSupportDHT
} else if h, ok := handler.(Bep5Handler); ok { } else if h, ok := handler.(Bep5Handler); ok {
err = h.Port(pc, msg.Port) return h.Port(pc, msg.Port)
} else {
err = handler.OnMessage(pc, msg)
} }
// BEP 6 - Fast Extension // BEP 6 - Fast Extension
case MTypeSuggest: case MTypeSuggest:
if !pc.ExtBits.IsSupportFast() { if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast return ErrNotSupportFast
} else {
pc.Suggests = pc.Suggests.Append(msg.Index)
if h, ok := handler.(Bep6Handler); ok {
err = h.Suggest(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
}
} }
pc.Suggests = pc.Suggests.Append(msg.Index)
if h, ok := handler.(Bep6Handler); ok {
return h.Suggest(pc, msg.Index)
}
case MTypeHaveAll: case MTypeHaveAll:
if pc.notFirstMsg { if !pc.ExtBits.IsSupportFast() {
err = ErrNotFirstMsg return ErrNotSupportFast
} else if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.HaveAll(pc)
} else {
err = handler.OnMessage(pc, msg)
} }
if h, ok := handler.(Bep6Handler); ok {
return h.HaveAll(pc)
}
case MTypeHaveNone: case MTypeHaveNone:
if pc.notFirstMsg { if !pc.ExtBits.IsSupportFast() {
err = ErrNotFirstMsg return ErrNotSupportFast
} else if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.HaveNone(pc)
} else {
err = handler.OnMessage(pc, msg)
} }
if h, ok := handler.(Bep6Handler); ok {
return h.HaveNone(pc)
}
case MTypeReject: case MTypeReject:
if !pc.ExtBits.IsSupportFast() { if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast return ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok { } else if h, ok := handler.(Bep6Handler); ok {
err = h.Reject(pc, msg.Index, msg.Begin, msg.Length) return h.Reject(pc, msg.Index, msg.Begin, msg.Length)
} else {
err = handler.OnMessage(pc, msg)
} }
case MTypeAllowedFast: case MTypeAllowedFast:
if !pc.ExtBits.IsSupportFast() { if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast return ErrNotSupportFast
} else { }
pc.Fasts = pc.Fasts.Append(msg.Index)
if h, ok := handler.(Bep6Handler); ok { pc.Fasts = pc.Fasts.Append(msg.Index)
err = h.AllowedFast(pc, msg.Index) if h, ok := handler.(Bep6Handler); ok {
} else { return h.AllowedFast(pc, msg.Index)
err = handler.OnMessage(pc, msg)
}
} }
// BEP 10 - Extension Protocol // BEP 10 - Extension Protocol
case MTypeExtended: case MTypeExtended:
if !pc.ExtBits.IsSupportExtended() { if !pc.ExtBits.IsSupportExtended() {
err = ErrNotSupportExtended return ErrNotSupportExtended
} else if h, ok := handler.(Bep10Handler); ok { } else if h, ok := handler.(Bep10Handler); ok {
err = pc.handleExtMsg(h, msg) return pc.handleExtMsg(h, msg)
} else {
err = handler.OnMessage(pc, msg)
} }
// Other
default: default:
err = handler.OnMessage(pc, msg) // (xgf): Do something??
} }
if !pc.notFirstMsg { return handler.OnMessage(pc, msg)
pc.notFirstMsg = true
}
return
} }
func (pc *PeerConn) handleExtMsg(h Bep10Handler, m Message) (err error) { func (pc *PeerConn) handleExtMsg(h Bep10Handler, m Message) (err error) {
// For Extended Message Handshake
if m.ExtendedID == ExtendedIDHandshake { if m.ExtendedID == ExtendedIDHandshake {
if pc.extHandshake { if len(pc.ExtendedHandshakeMsg.M) > 0 { // The extended handshake has done.
return ErrSecondExtHandshake return ErrSecondExtHandshake
} }
pc.extHandshake = true
err = bencode.DecodeBytes(m.ExtendedPayload, &pc.ExtendedHandshakeMsg) err = bencode.DecodeBytes(m.ExtendedPayload, &pc.ExtendedHandshakeMsg)
if err == nil { if err != nil {
err = h.OnExtHandShake(pc) return
} }
} else if pc.extHandshake {
err = h.OnPayload(pc, m.ExtendedID, m.ExtendedPayload) if len(pc.ExtendedHandshakeMsg.M) == 0 { // The extended message ids must exist.
} else { return ErrNoExtMessageID
err = ErrNoExtHandshake }
return h.OnExtHandShake(pc)
} }
return // For Extended Message
if len(pc.ExtendedHandshakeMsg.M) == 0 {
return ErrNoExtHandshake
}
return h.OnPayload(pc, m.ExtendedID, m.ExtendedPayload)
} }

View File

@ -14,9 +14,7 @@
package peerprotocol package peerprotocol
import ( import "fmt"
"fmt"
)
// ProtocolHeader is the BT protocal prefix. // ProtocolHeader is the BT protocal prefix.
// //

View File

@ -22,7 +22,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// Handler is used to handle the incoming peer connection. // Handler is used to handle the incoming peer connection.
@ -65,9 +65,9 @@ type Config struct {
HandleMessage func(pc *PeerConn, msg Message, handler Handler) error HandleMessage func(pc *PeerConn, msg Message, handler Handler) error
} }
func (c *Config) set(conf ...Config) { func (c *Config) set(conf *Config) {
if len(conf) > 0 { if conf != nil {
*c = conf[0] *c = *conf
} }
if c.MaxLength == 0 { if c.MaxLength == 0 {
@ -93,33 +93,35 @@ type Server struct {
} }
// NewServerByListen returns a new Server by listening on the address. // NewServerByListen returns a new Server by listening on the address.
func NewServerByListen(network, address string, id metainfo.Hash, h Handler, func NewServerByListen(network, address string, id metainfo.Hash, h Handler, c *Config) (*Server, error) {
c ...Config) (*Server, error) {
ln, err := net.Listen(network, address) ln, err := net.Listen(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewServer(ln, id, h, c...), nil return NewServer(ln, id, h, c), nil
} }
// NewServer returns a new Server. // NewServer returns a new Server.
func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c ...Config) *Server { func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c *Config) *Server {
if id.IsZero() { if id.IsZero() {
panic("the peer node id must not be empty") panic("the peer node id must not be empty")
} }
var conf Config var conf Config
conf.set(c...) conf.set(c)
return &Server{Listener: ln, ID: id, Handler: h, Config: conf} return &Server{Listener: ln, ID: id, Handler: h, Config: conf}
} }
// Run starts the peer protocol server. // Run starts the peer protocol server.
func (s *Server) Run() { func (s *Server) Run() {
s.Config.set() s.Config.set(nil)
for { for {
conn, err := s.Listener.Accept() conn, err := s.Listener.Accept()
if err != nil { if err != nil {
s.Config.ErrorLog("fail to accept new connection: %s", err) if !strings.Contains(err.Error(), "closed") {
s.Config.ErrorLog("fail to accept new connection: %s", err)
}
return
} }
go s.handleConn(conn) go s.handleConn(conn)
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -16,81 +16,20 @@ package tracker
import ( import (
"context" "context"
"net/url"
"sync"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// GetPeersResult represents the result of getting the peers from the tracker. // GetPeers gets the peers from the tracker.
type GetPeersResult struct { func GetPeers(ctx context.Context, tracker string, nodeID, infoHash metainfo.Hash, totalLength int64) (AnnounceResponse, error) {
Error error // nil stands for success. Or, for failure. client, err := NewClient(tracker, nodeID, nil)
Tracker string if err != nil {
Resp AnnounceResponse return AnnounceResponse{}, err
} }
// GetPeers gets the peers from the trackers. return client.Announce(ctx, AnnounceRequest{
// Left: totalLength,
// Notice: the returned chan will be closed when all the requests end. InfoHash: infoHash,
func GetPeers(ctx context.Context, id, infohash metainfo.Hash, trackers []string) []GetPeersResult { Port: 6881,
if len(trackers) == 0 { })
return nil
}
for i, t := range trackers {
if u, err := url.Parse(t); err == nil && u.Path == "" {
u.Path = "/announce"
trackers[i] = u.String()
}
}
_len := len(trackers)
wlen := _len
if wlen > 10 {
wlen = 10
}
reqs := make(chan string, wlen)
go func() {
for i := 0; i < _len; i++ {
reqs <- trackers[i]
}
}()
wg := new(sync.WaitGroup)
wg.Add(_len)
var lock sync.Mutex
results := make([]GetPeersResult, 0, _len)
for i := 0; i < wlen; i++ {
go func() {
for tracker := range reqs {
resp, err := getPeers(ctx, wg, tracker, id, infohash)
lock.Lock()
results = append(results, GetPeersResult{
Tracker: tracker,
Error: err,
Resp: resp,
})
lock.Unlock()
}
}()
}
wg.Wait()
close(reqs)
return results
}
func getPeers(ctx context.Context, wg *sync.WaitGroup, tracker string,
nodeID, infoHash metainfo.Hash) (resp AnnounceResponse, err error) {
defer wg.Done()
client, err := NewClient(tracker, ClientConfig{ID: nodeID})
if err == nil {
resp, err = client.Announce(ctx, AnnounceRequest{InfoHash: infoHash})
}
return
} }

View File

@ -28,8 +28,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// AnnounceRequest is the tracker announce requests. // AnnounceRequest is the tracker announce requests.
@ -101,7 +101,7 @@ type AnnounceRequest struct {
// ToQuery converts the Request to URL Query. // ToQuery converts the Request to URL Query.
func (r AnnounceRequest) ToQuery() (vs url.Values) { func (r AnnounceRequest) ToQuery() (vs url.Values) {
vs = make(url.Values, 9) vs = make(url.Values, 10)
vs.Set("info_hash", r.InfoHash.BytesString()) vs.Set("info_hash", r.InfoHash.BytesString())
vs.Set("peer_id", r.PeerID.BytesString()) vs.Set("peer_id", r.PeerID.BytesString())
vs.Set("uploaded", strconv.FormatInt(r.Uploaded, 10)) vs.Set("uploaded", strconv.FormatInt(r.Uploaded, 10))
@ -278,11 +278,15 @@ type Client struct {
// //
// scrapeURL may be empty, which will replace the "announce" in announceURL // scrapeURL may be empty, which will replace the "announce" in announceURL
// with "scrape" to generate the scrapeURL. // with "scrape" to generate the scrapeURL.
func NewClient(announceURL, scrapeURL string) *Client { func NewClient(id metainfo.Hash, announceURL, scrapeURL string) *Client {
if scrapeURL == "" { if scrapeURL == "" {
scrapeURL = strings.Replace(announceURL, "announce", "scrape", -1) scrapeURL = strings.Replace(announceURL, "announce", "scrape", -1)
} }
id := metainfo.NewRandomHash()
if id.IsZero() {
id = metainfo.NewRandomHash()
}
return &Client{AnnounceURL: announceURL, ScrapeURL: scrapeURL, ID: id} return &Client{AnnounceURL: announceURL, ScrapeURL: scrapeURL, ID: id}
} }
@ -310,7 +314,7 @@ func (t *Client) send(c context.Context, u string, vs url.Values, r interface{})
resp, err = t.Client.Do(req) resp, err = t.Client.Do(req)
} }
if resp.Body != nil { if resp != nil {
defer resp.Body.Close() defer resp.Body.Close()
} }
@ -323,14 +327,11 @@ func (t *Client) send(c context.Context, u string, vs url.Values, r interface{})
// Announce sends a Announce request to the tracker. // Announce sends a Announce request to the tracker.
func (t *Client) Announce(c context.Context, req AnnounceRequest) (resp AnnounceResponse, err error) { func (t *Client) Announce(c context.Context, req AnnounceRequest) (resp AnnounceResponse, err error) {
if req.PeerID.IsZero() { if req.InfoHash.IsZero() {
if t.ID.IsZero() { panic("infohash is ZERO")
req.PeerID = metainfo.NewRandomHash()
} else {
req.PeerID = t.ID
}
} }
req.PeerID = t.ID
err = t.send(c, t.AnnounceURL, req.ToQuery(), &resp) err = t.send(c, t.AnnounceURL, req.ToQuery(), &resp)
return return
} }

View File

@ -15,43 +15,26 @@
package httptracker package httptracker
import ( import (
"bytes"
"encoding/binary"
"errors" "errors"
"net" "net"
"github.com/xgfone/bt/bencode" "github.com/xgfone/go-bt/bencode"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
var errInvalidPeer = errors.New("invalid peer information format") var errInvalidPeer = errors.New("invalid bt peer information format")
// Peer is a tracker peer. // Peer is a tracker peer.
type Peer struct { type Peer struct {
// ID is the peer's self-selected ID. ID string `bencode:"peer id"` // BEP 3, the peer's self-selected ID.
ID string `bencode:"peer id"` // BEP 3 IP string `bencode:"ip"` // BEP 3, an IP address or dns name.
Port uint16 `bencode:"port"` // BEP 3
// IP is the IP address or dns name.
IP string `bencode:"ip"` // BEP 3
Port uint16 `bencode:"port"` // BEP 3
} }
// Addresses returns the list of the addresses that the peer listens on. var (
func (p Peer) Addresses() (addrs []metainfo.Address, err error) { _ bencode.Marshaler = new(Peers)
if ip := net.ParseIP(p.IP); len(ip) != 0 { _ bencode.Unmarshaler = new(Peers)
return []metainfo.Address{{IP: ip, Port: p.Port}}, nil )
}
ips, err := net.LookupIP(p.IP)
if _len := len(ips); err == nil && len(ips) != 0 {
addrs = make([]metainfo.Address, _len)
for i, ip := range ips {
addrs[i] = metainfo.Address{IP: ip, Port: p.Port}
}
}
return
}
// Peers is a set of the peers. // Peers is a set of the peers.
type Peers []Peer type Peers []Peer
@ -65,21 +48,17 @@ func (ps *Peers) UnmarshalBencode(b []byte) (err error) {
switch vs := v.(type) { switch vs := v.(type) {
case string: // BEP 23 case string: // BEP 23
_len := len(vs) var addrs metainfo.CompactIPv4Addrs
if _len%6 != 0 { if err = addrs.UnmarshalBinary([]byte(vs)); err != nil {
return metainfo.ErrInvalidAddr return err
} }
peers := make(Peers, 0, _len/6) peers := make(Peers, len(addrs))
for i := 0; i < _len; i += 6 { for i, addr := range addrs {
var addr metainfo.Address peers[i] = Peer{IP: addr.IP.String(), Port: addr.Port}
if err = addr.UnmarshalBinary([]byte(vs[i : i+6])); err != nil {
return
}
peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port})
} }
*ps = peers *ps = peers
case []interface{}: // BEP 3 case []interface{}: // BEP 3
peers := make(Peers, len(vs)) peers := make(Peers, len(vs))
for i, p := range vs { for i, p := range vs {
@ -106,6 +85,7 @@ func (ps *Peers) UnmarshalBencode(b []byte) (err error) {
peers[i] = Peer{ID: pid, IP: ip, Port: uint16(port)} peers[i] = Peer{ID: pid, IP: ip, Port: uint16(port)}
} }
*ps = peers *ps = peers
default: default:
return errInvalidPeer return errInvalidPeer
} }
@ -114,38 +94,32 @@ func (ps *Peers) UnmarshalBencode(b []byte) (err error) {
// MarshalBencode implements the interface bencode.Marshaler. // MarshalBencode implements the interface bencode.Marshaler.
func (ps Peers) MarshalBencode() (b []byte, err error) { func (ps Peers) MarshalBencode() (b []byte, err error) {
for _, p := range ps { // BEP 23
if p.ID == "" { if b, err = ps.marshalCompactBencode(); err == nil {
return ps.marshalCompactBencode() // BEP 23 return
}
} }
// BEP 3 // BEP 3
buf := bytes.NewBuffer(make([]byte, 0, 64*len(ps))) return bencode.EncodeBytes([]Peer(ps))
buf.WriteByte('l')
for _, p := range ps {
if err = bencode.NewEncoder(buf).Encode(p); err != nil {
return
}
}
buf.WriteByte('e')
b = buf.Bytes()
return
} }
func (ps Peers) marshalCompactBencode() (b []byte, err error) { func (ps Peers) marshalCompactBencode() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 6*len(ps))) addrs := make(metainfo.CompactIPv4Addrs, len(ps))
for _, peer := range ps { for i, p := range ps {
ip := net.ParseIP(peer.IP).To4() ip := net.ParseIP(p.IP).To4()
if len(ip) == 0 { if ip == nil {
return nil, errInvalidPeer return nil, errInvalidPeer
} }
buf.Write(ip[:]) addrs[i] = metainfo.CompactAddr{IP: ip, Port: p.Port}
binary.Write(buf, binary.BigEndian, peer.Port)
} }
return bencode.EncodeBytes(buf.Bytes()) return addrs.MarshalBencode()
} }
var (
_ bencode.Marshaler = new(Peers6)
_ bencode.Unmarshaler = new(Peers6)
)
// Peers6 is a set of the peers for IPv6 in the compact case. // Peers6 is a set of the peers for IPv6 in the compact case.
// //
// BEP 7 // BEP 7
@ -158,35 +132,29 @@ func (ps *Peers6) UnmarshalBencode(b []byte) (err error) {
return return
} }
_len := len(s) var addrs metainfo.CompactIPv6Addrs
if _len%18 != 0 { if err = addrs.UnmarshalBinary([]byte(s)); err != nil {
return metainfo.ErrInvalidAddr return err
} }
peers := make(Peers6, 0, _len/18) peers := make(Peers6, len(addrs))
for i := 0; i < _len; i += 18 { for i, addr := range addrs {
var addr metainfo.Address peers[i] = Peer{IP: addr.IP.String(), Port: addr.Port}
if err = addr.UnmarshalBinary([]byte(s[i : i+18])); err != nil {
return
}
peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port})
} }
*ps = peers *ps = peers
return return
} }
// MarshalBencode implements the interface bencode.Marshaler. // MarshalBencode implements the interface bencode.Marshaler.
func (ps Peers6) MarshalBencode() (b []byte, err error) { func (ps Peers6) MarshalBencode() (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 18*len(ps))) addrs := make(metainfo.CompactIPv6Addrs, len(ps))
for _, peer := range ps { for i, p := range ps {
ip := net.ParseIP(peer.IP).To16() ip := net.ParseIP(p.IP)
if len(ip) == 0 { if ip == nil {
return nil, errInvalidPeer return nil, errInvalidPeer
} }
addrs[i] = metainfo.CompactAddr{IP: ip, Port: p.Port}
buf.Write(ip[:])
binary.Write(buf, binary.BigEndian, peer.Port)
} }
return bencode.EncodeBytes(buf.Bytes()) return addrs.MarshalBencode()
} }

View File

@ -21,8 +21,8 @@ import (
func TestPeers(t *testing.T) { func TestPeers(t *testing.T) {
peers := Peers{ peers := Peers{
{ID: "123", IP: "1.1.1.1", Port: 80}, {IP: "1.1.1.1", Port: 80},
{ID: "456", IP: "2.2.2.2", Port: 81}, {IP: "2.2.2.2", Port: 81},
} }
b, err := peers.MarshalBencode() b, err := peers.MarshalBencode()
@ -71,5 +71,6 @@ func TestPeers6(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} else if !reflect.DeepEqual(ps, peers) { } else if !reflect.DeepEqual(ps, peers) {
t.Errorf("%v != %v", ps, peers) t.Errorf("%v != %v", ps, peers)
t.Error(string(b))
} }
} }

View File

@ -18,7 +18,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
func TestHTTPAnnounceRequest(t *testing.T) { func TestHTTPAnnounceRequest(t *testing.T) {

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -26,9 +26,9 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
"github.com/xgfone/bt/tracker/httptracker" "github.com/xgfone/go-bt/tracker/httptracker"
"github.com/xgfone/bt/tracker/udptracker" "github.com/xgfone/go-bt/tracker/udptracker"
) )
// Predefine some announce events. // Predefine some announce events.
@ -55,7 +55,7 @@ type AnnounceRequest struct {
IP net.IP // Optional IP net.IP // Optional
Key int32 // Optional Key int32 // Optional
NumWant int32 // Optional, BEP 15: -1 for default. But we use 0 as default. NumWant int32 // Optional
Port uint16 // Optional Port uint16 // Optional
} }
@ -77,6 +77,7 @@ func (ar AnnounceRequest) ToHTTPAnnounceRequest() httptracker.AnnounceRequest {
Event: ar.Event, Event: ar.Event,
NumWant: ar.NumWant, NumWant: ar.NumWant,
Key: ar.Key, Key: ar.Key,
Compact: true,
} }
} }
@ -89,7 +90,6 @@ func (ar AnnounceRequest) ToUDPAnnounceRequest() udptracker.AnnounceRequest {
Left: ar.Left, Left: ar.Left,
Uploaded: ar.Uploaded, Uploaded: ar.Uploaded,
Event: ar.Event, Event: ar.Event,
IP: ar.IP,
Key: ar.Key, Key: ar.Key,
NumWant: ar.NumWant, NumWant: ar.NumWant,
Port: ar.Port, Port: ar.Port,
@ -100,10 +100,10 @@ func (ar AnnounceRequest) ToUDPAnnounceRequest() udptracker.AnnounceRequest {
// //
// BEP 3, 15 // BEP 3, 15
type AnnounceResponse struct { type AnnounceResponse struct {
Interval uint32 Interval uint32 // Reflush Interval
Leechers uint32 Leechers uint32 // Incomplete
Seeders uint32 Seeders uint32 // Complete
Addresses []metainfo.Address Addresses []metainfo.HostAddr
} }
// FromHTTPAnnounceResponse sets itself from r. // FromHTTPAnnounceResponse sets itself from r.
@ -111,14 +111,12 @@ func (ar *AnnounceResponse) FromHTTPAnnounceResponse(r httptracker.AnnounceRespo
ar.Interval = r.Interval ar.Interval = r.Interval
ar.Leechers = r.Incomplete ar.Leechers = r.Incomplete
ar.Seeders = r.Complete ar.Seeders = r.Complete
ar.Addresses = make([]metainfo.Address, 0, len(r.Peers)+len(r.Peers6)) ar.Addresses = make([]metainfo.HostAddr, 0, len(r.Peers)+len(r.Peers6))
for _, peer := range r.Peers { for _, p := range r.Peers {
addrs, _ := peer.Addresses() ar.Addresses = append(ar.Addresses, metainfo.NewHostAddr(p.IP, p.Port))
ar.Addresses = append(ar.Addresses, addrs...)
} }
for _, peer := range r.Peers6 { for _, p := range r.Peers6 {
addrs, _ := peer.Addresses() ar.Addresses = append(ar.Addresses, metainfo.NewHostAddr(p.IP, p.Port))
ar.Addresses = append(ar.Addresses, addrs...)
} }
} }
@ -127,7 +125,11 @@ func (ar *AnnounceResponse) FromUDPAnnounceResponse(r udptracker.AnnounceRespons
ar.Interval = r.Interval ar.Interval = r.Interval
ar.Leechers = r.Leechers ar.Leechers = r.Leechers
ar.Seeders = r.Seeders ar.Seeders = r.Seeders
ar.Addresses = r.Addresses
ar.Addresses = make([]metainfo.HostAddr, len(r.Addresses))
for i, a := range r.Addresses {
ar.Addresses[i] = metainfo.NewHostAddr(a.IP.String(), a.Port)
}
} }
// ScrapeResponseResult is a commont Scrape response result. // ScrapeResponseResult is a commont Scrape response result.
@ -171,8 +173,7 @@ func (sr ScrapeResponse) FromHTTPScrapeResponse(r httptracker.ScrapeResponse) {
} }
// FromUDPScrapeResponse sets itself from hs and r. // FromUDPScrapeResponse sets itself from hs and r.
func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash, func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash, r []udptracker.ScrapeResponse) {
r []udptracker.ScrapeResponse) {
klen := len(hs) klen := len(hs)
if _len := len(r); _len < klen { if _len := len(r); _len < klen {
klen = _len klen = _len
@ -195,48 +196,37 @@ type Client interface {
Close() error Close() error
} }
// ClientConfig is used to configure the defalut client implementation.
type ClientConfig struct {
// The ID of the local client peer.
ID metainfo.Hash
// The http client used only the tracker client is based on HTTP.
HTTPClient *http.Client
}
// NewClient returns a new Client. // NewClient returns a new Client.
func NewClient(connURL string, conf ...ClientConfig) (c Client, err error) { //
var config ClientConfig // If id is ZERO, use a random hash instead.
if len(conf) > 0 { // If client is nil, use http.DefaultClient instead for the http tracker.
config = conf[0] func NewClient(connURL string, id metainfo.Hash, client *http.Client) (c Client, err error) {
}
u, err := url.Parse(connURL) u, err := url.Parse(connURL)
if err == nil { if err != nil {
switch u.Scheme { return
case "http", "https":
tracker := httptracker.NewClient(connURL, "")
if !config.ID.IsZero() {
tracker.ID = config.ID
}
c = &tclient{url: connURL, http: tracker}
case "udp", "udp4", "udp6":
var utc *udptracker.Client
config := udptracker.ClientConfig{ID: config.ID}
utc, err = udptracker.NewClientByDial(u.Scheme, u.Host, config)
if err == nil {
var e []udptracker.Extension
if p := u.RequestURI(); p != "" {
e = []udptracker.Extension{udptracker.NewURLData([]byte(p))}
}
c = &tclient{url: connURL, exts: e, udp: utc}
}
default:
err = fmt.Errorf("unknown url scheme '%s'", u.Scheme)
}
} }
return
tclient := &tclient{url: connURL}
switch u.Scheme {
case "http", "https":
tclient.http = httptracker.NewClient(id, connURL, "")
tclient.http.Client = client
case "udp", "udp4", "udp6":
tclient.udp, err = udptracker.NewClientByDial(u.Scheme, u.Host, id)
if err != nil {
return
}
if p := u.RequestURI(); p != "" {
tclient.exts = []udptracker.Extension{udptracker.NewURLData([]byte(p))}
}
default:
err = fmt.Errorf("unknown url scheme '%s'", u.Scheme)
}
return tclient, nil
} }
type tclient struct { type tclient struct {

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -22,15 +22,14 @@ import (
"net" "net"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
"github.com/xgfone/bt/tracker/udptracker" "github.com/xgfone/go-bt/tracker/udptracker"
) )
type testHandler struct{} type testHandler struct{}
func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return } func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return }
func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) ( func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) (r udptracker.AnnounceResponse, err error) {
r udptracker.AnnounceResponse, err error) {
if req.Port != 80 { if req.Port != 80 {
err = errors.New("port is not 80") err = errors.New("port is not 80")
return return
@ -51,12 +50,11 @@ func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest
Interval: 1, Interval: 1,
Leechers: 2, Leechers: 2,
Seeders: 3, Seeders: 3,
Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8000}}, Addresses: []metainfo.CompactAddr{{IP: net.ParseIP("127.0.0.1"), Port: 8000}},
} }
return return
} }
func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ( func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (rs []udptracker.ScrapeResponse, err error) {
rs []udptracker.ScrapeResponse, err error) {
rs = make([]udptracker.ScrapeResponse, len(infohashes)) rs = make([]udptracker.ScrapeResponse, len(infohashes))
for i := range infohashes { for i := range infohashes {
rs[i] = udptracker.ScrapeResponse{ rs[i] = udptracker.ScrapeResponse{
@ -74,7 +72,7 @@ func ExampleClient() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
server := udptracker.NewServer(sconn, testHandler{}) server := udptracker.NewServer(sconn, testHandler{}, 0)
defer server.Close() defer server.Close()
go server.Run() go server.Run()
@ -82,14 +80,14 @@ func ExampleClient() {
time.Sleep(time.Second) time.Sleep(time.Second)
// Create a client and dial to the UDP tracker server. // Create a client and dial to the UDP tracker server.
client, err := NewClient("udp://127.0.0.1:8000/path?a=1&b=2") client, err := NewClient("udp://127.0.0.1:8000/path?a=1&b=2", metainfo.Hash{}, nil)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// Send the ANNOUNCE request to the UDP tracker server, // Send the ANNOUNCE request to the UDP tracker server,
// and get the ANNOUNCE response. // and get the ANNOUNCE response.
req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80} req := AnnounceRequest{InfoHash: metainfo.NewRandomHash(), IP: net.ParseIP("127.0.0.1"), Port: 80}
resp, err := client.Announce(context.Background(), req) resp, err := client.Announce(context.Background(), req)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -99,7 +97,7 @@ func ExampleClient() {
fmt.Printf("Leechers: %d\n", resp.Leechers) fmt.Printf("Leechers: %d\n", resp.Leechers)
fmt.Printf("Seeders: %d\n", resp.Seeders) fmt.Printf("Seeders: %d\n", resp.Seeders)
for i, addr := range resp.Addresses { for i, addr := range resp.Addresses {
fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String()) fmt.Printf("Address[%d].IP: %s\n", i, addr.Host)
fmt.Printf("Address[%d].Port: %d\n", i, addr.Port) fmt.Printf("Address[%d].Port: %d\n", i, addr.Port)
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -25,9 +25,11 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
const maxBufSize = 2048
// ProtocolID is magic constant for the udp tracker connection. // ProtocolID is magic constant for the udp tracker connection.
// //
// BEP 15 // BEP 15
@ -55,16 +57,15 @@ type AnnounceRequest struct {
Uploaded int64 Uploaded int64
Event uint32 Event uint32
IP net.IP
Key int32 Key int32
NumWant int32 // -1 for default NumWant int32 // -1 for default and use -1 instead if 0
Port uint16 Port uint16
Exts []Extension // BEP 41 Exts []Extension // BEP 41
} }
// DecodeFrom decodes the request from b. // DecodeFrom decodes the request from b.
func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) { func (r *AnnounceRequest) DecodeFrom(b []byte) {
r.InfoHash = metainfo.NewHash(b[0:20]) r.InfoHash = metainfo.NewHash(b[0:20])
r.PeerID = metainfo.NewHash(b[20:40]) r.PeerID = metainfo.NewHash(b[20:40])
r.Downloaded = int64(binary.BigEndian.Uint64(b[40:48])) r.Downloaded = int64(binary.BigEndian.Uint64(b[40:48]))
@ -72,21 +73,13 @@ func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) {
r.Uploaded = int64(binary.BigEndian.Uint64(b[56:64])) r.Uploaded = int64(binary.BigEndian.Uint64(b[56:64]))
r.Event = binary.BigEndian.Uint32(b[64:68]) r.Event = binary.BigEndian.Uint32(b[64:68])
if ipv4 { // ignore b[68:72] // 4 bytes
r.IP = make(net.IP, net.IPv4len)
copy(r.IP, b[68:72])
b = b[72:]
} else {
r.IP = make(net.IP, net.IPv6len)
copy(r.IP, b[68:84])
b = b[84:]
}
r.Key = int32(binary.BigEndian.Uint32(b[0:4])) r.Key = int32(binary.BigEndian.Uint32(b[72:76]))
r.NumWant = int32(binary.BigEndian.Uint32(b[4:8])) r.NumWant = int32(binary.BigEndian.Uint32(b[76:80]))
r.Port = binary.BigEndian.Uint16(b[8:10]) r.Port = binary.BigEndian.Uint16(b[80:82])
b = b[10:] b = b[82:]
for len(b) > 0 { for len(b) > 0 {
var ext Extension var ext Extension
parsed := ext.DecodeFrom(b) parsed := ext.DecodeFrom(b)
@ -97,26 +90,25 @@ func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) {
// EncodeTo encodes the request to buf. // EncodeTo encodes the request to buf.
func (r AnnounceRequest) EncodeTo(buf *bytes.Buffer) { func (r AnnounceRequest) EncodeTo(buf *bytes.Buffer) {
buf.Grow(82) if r.NumWant <= 0 {
buf.Write(r.InfoHash[:]) r.NumWant = -1
buf.Write(r.PeerID[:])
binary.Write(buf, binary.BigEndian, r.Downloaded)
binary.Write(buf, binary.BigEndian, r.Left)
binary.Write(buf, binary.BigEndian, r.Uploaded)
binary.Write(buf, binary.BigEndian, r.Event)
if ip := r.IP.To4(); ip != nil {
buf.Write(ip[:])
} else {
buf.Write(r.IP[:])
} }
binary.Write(buf, binary.BigEndian, r.Key) buf.Grow(82)
binary.Write(buf, binary.BigEndian, r.NumWant) buf.Write(r.InfoHash[:]) // 20: 16 - 36
binary.Write(buf, binary.BigEndian, r.Port) buf.Write(r.PeerID[:]) // 20: 36 - 56
for _, ext := range r.Exts { binary.Write(buf, binary.BigEndian, r.Downloaded) // 8: 56 - 64
binary.Write(buf, binary.BigEndian, r.Left) // 8: 64 - 72
binary.Write(buf, binary.BigEndian, r.Uploaded) // 8: 72 - 80
binary.Write(buf, binary.BigEndian, r.Event) // 4: 80 - 84
binary.Write(buf, binary.BigEndian, uint32(0)) // 4: 84 - 88
binary.Write(buf, binary.BigEndian, r.Key) // 4: 88 - 92
binary.Write(buf, binary.BigEndian, r.NumWant) // 4: 92 - 96
binary.Write(buf, binary.BigEndian, r.Port) // 2: 96 - 98
for _, ext := range r.Exts { // N: 98 -
ext.EncodeTo(buf) ext.EncodeTo(buf)
} }
} }
@ -128,32 +120,35 @@ type AnnounceResponse struct {
Interval uint32 Interval uint32
Leechers uint32 Leechers uint32
Seeders uint32 Seeders uint32
Addresses []metainfo.Address Addresses []metainfo.CompactAddr
} }
// EncodeTo encodes the response to buf. // EncodeTo encodes the response to buf.
func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer) { func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer, ipv4 bool) {
buf.Grow(12 + len(r.Addresses)*18) buf.Grow(12 + len(r.Addresses)*18)
binary.Write(buf, binary.BigEndian, r.Interval) binary.Write(buf, binary.BigEndian, r.Interval)
binary.Write(buf, binary.BigEndian, r.Leechers) binary.Write(buf, binary.BigEndian, r.Leechers)
binary.Write(buf, binary.BigEndian, r.Seeders) binary.Write(buf, binary.BigEndian, r.Seeders)
for _, addr := range r.Addresses { for i, addr := range r.Addresses {
if ip := addr.IP.To4(); ip != nil { if ipv4 {
buf.Write(ip[:]) addr.IP = addr.IP.To4()
} else { } else {
buf.Write(addr.IP[:]) addr.IP = addr.IP.To16()
} }
binary.Write(buf, binary.BigEndian, addr.Port) if len(addr.IP) == 0 {
panic(fmt.Errorf("invalid ip '%s'", r.Addresses[i].IP.String()))
}
addr.WriteBinary(buf)
} }
} }
// DecodeFrom decodes the response from b. // DecodeFrom decodes the response from b.
func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) { func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) {
r.Interval = binary.BigEndian.Uint32(b[:4]) r.Interval = binary.BigEndian.Uint32(b[:4]) // 4: 8 - 12
r.Leechers = binary.BigEndian.Uint32(b[4:8]) r.Leechers = binary.BigEndian.Uint32(b[4:8]) // 4: 12 - 16
r.Seeders = binary.BigEndian.Uint32(b[8:12]) r.Seeders = binary.BigEndian.Uint32(b[8:12]) // 4: 16 - 20
b = b[12:] b = b[12:] // N*(6|18): 20 -
iplen := net.IPv6len iplen := net.IPv6len
if ipv4 { if ipv4 {
iplen = net.IPv4len iplen = net.IPv4len
@ -161,12 +156,13 @@ func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) {
_len := len(b) _len := len(b)
step := iplen + 2 step := iplen + 2
r.Addresses = make([]metainfo.Address, 0, _len/step) r.Addresses = make([]metainfo.CompactAddr, 0, _len/step)
for i := step; i <= _len; i += step { for i := step; i <= _len; i += step {
ip := make(net.IP, iplen) var addr metainfo.CompactAddr
copy(ip, b[i-step:i-2]) if err := addr.UnmarshalBinary(b[i-step : i]); err != nil {
port := binary.BigEndian.Uint16(b[i-2 : i]) panic(err)
r.Addresses = append(r.Addresses, metainfo.Address{IP: ip, Port: port}) }
r.Addresses = append(r.Addresses, addr)
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -19,49 +19,38 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// NewClientByDial returns a new Client by dialing. // NewClientByDial returns a new Client by dialing.
func NewClientByDial(network, address string, c ...ClientConfig) (*Client, error) { func NewClientByDial(network, address string, id metainfo.Hash) (*Client, error) {
conn, err := net.Dial(network, address) conn, err := net.Dial(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewClient(conn.(*net.UDPConn), id), nil
return NewClient(conn.(*net.UDPConn), c...), nil
} }
// NewClient returns a new Client. // NewClient returns a new Client.
func NewClient(conn *net.UDPConn, c ...ClientConfig) *Client { func NewClient(conn *net.UDPConn, id metainfo.Hash) *Client {
var conf ClientConfig
conf.set(c...)
ipv4 := strings.Contains(conn.LocalAddr().String(), ".") ipv4 := strings.Contains(conn.LocalAddr().String(), ".")
return &Client{conn: conn, conf: conf, ipv4: ipv4} if id.IsZero() {
} id = metainfo.NewRandomHash()
// ClientConfig is used to configure the Client.
type ClientConfig struct {
ID metainfo.Hash
MaxBufSize int // Default: 2048
}
func (c *ClientConfig) set(conf ...ClientConfig) {
if len(conf) > 0 {
*c = conf[0]
} }
if c.ID.IsZero() { return &Client{
c.ID = metainfo.NewRandomHash() MaxBufSize: maxBufSize,
}
if c.MaxBufSize <= 0 { conn: conn,
c.MaxBufSize = 2048 ipv4: ipv4,
id: id,
} }
} }
@ -72,12 +61,14 @@ func (c *ClientConfig) set(conf ...ClientConfig) {
// //
// BEP 15 // BEP 15
type Client struct { type Client struct {
ipv4 bool MaxBufSize int // Default: 2048
conf ClientConfig
id metainfo.Hash
conn *net.UDPConn conn *net.UDPConn
last time.Time last time.Time
cid uint64 cid uint64
tid uint32 tid uint32
ipv4 bool
} }
// Close closes the UDP tracker client. // Close closes the UDP tracker client.
@ -140,21 +131,21 @@ func (utc *Client) connect(ctx context.Context) (err error) {
} }
data = data[:n] data = data[:n]
switch binary.BigEndian.Uint32(data[:4]) { switch action := binary.BigEndian.Uint32(data[:4]); action {
case ActionConnect: case ActionConnect:
case ActionError: case ActionError:
_, reason := utc.parseError(data[4:]) _, reason := utc.parseError(data[4:])
return errors.New(reason) return errors.New(reason)
default: default:
return errors.New("tracker response not connect action") return fmt.Errorf("tracker response is not connect action: %d", action)
} }
if n < 16 { if n < 16 {
return io.ErrShortBuffer return io.ErrShortBuffer
} }
if binary.BigEndian.Uint32(data[4:8]) != tid { if _tid := binary.BigEndian.Uint32(data[4:8]); _tid != tid {
return errors.New("invalid transaction id") return fmt.Errorf("expect transaction id %d, but got %d", tid, _tid)
} }
utc.cid = binary.BigEndian.Uint64(data[8:16]) utc.cid = binary.BigEndian.Uint64(data[8:16])
@ -172,8 +163,7 @@ func (utc *Client) getConnectionID(ctx context.Context) (cid uint64, err error)
return return
} }
func (utc *Client) announce(ctx context.Context, req AnnounceRequest) ( func (utc *Client) announce(ctx context.Context, req AnnounceRequest) (r AnnounceResponse, err error) {
r AnnounceResponse, err error) {
cid, err := utc.getConnectionID(ctx) cid, err := utc.getConnectionID(ctx)
if err != nil { if err != nil {
return return
@ -181,16 +171,21 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) (
tid := utc.getTranID() tid := utc.getTranID()
buf := bytes.NewBuffer(make([]byte, 0, 110)) buf := bytes.NewBuffer(make([]byte, 0, 110))
binary.Write(buf, binary.BigEndian, cid) binary.Write(buf, binary.BigEndian, cid) // 8: 0 - 8
binary.Write(buf, binary.BigEndian, ActionAnnounce) binary.Write(buf, binary.BigEndian, ActionAnnounce) // 4: 8 - 12
binary.Write(buf, binary.BigEndian, tid) binary.Write(buf, binary.BigEndian, tid) // 4: 12 - 16
req.EncodeTo(buf) req.EncodeTo(buf)
b := buf.Bytes() b := buf.Bytes()
if err = utc.send(b); err != nil { if err = utc.send(b); err != nil {
return return
} }
data := make([]byte, utc.conf.MaxBufSize) bufSize := utc.MaxBufSize
if bufSize <= 0 {
bufSize = maxBufSize
}
data := make([]byte, bufSize)
n, err := utc.readResp(ctx, data) n, err := utc.readResp(ctx, data)
if err != nil { if err != nil {
return return
@ -200,14 +195,14 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) (
} }
data = data[:n] data = data[:n]
switch binary.BigEndian.Uint32(data[:4]) { switch action := binary.BigEndian.Uint32(data[:4]); action {
case ActionAnnounce: case ActionAnnounce:
case ActionError: case ActionError:
_, reason := utc.parseError(data[4:]) _, reason := utc.parseError(data[4:])
err = errors.New(reason) err = errors.New(reason)
return return
default: default:
err = errors.New("tracker response not connect action") err = fmt.Errorf("tracker response is not connect action: %d", action)
return return
} }
@ -216,8 +211,8 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) (
return return
} }
if binary.BigEndian.Uint32(data[4:8]) != tid { if _tid := binary.BigEndian.Uint32(data[4:8]); _tid != tid {
err = errors.New("invalid transaction id") err = fmt.Errorf("expect transaction id %d, but got %d", tid, _tid)
return return
} }
@ -228,19 +223,20 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) (
// Announce sends a Announce request to the tracker. // Announce sends a Announce request to the tracker.
// //
// Notice: // Notice:
// 1. if it does not connect to the UDP tracker server, it will connect to it, // 1. if it does not connect to the UDP tracker server, it will connect to it,
// then send the ANNOUNCE request. // then send the ANNOUNCE request.
// 2. If returning an error, you should retry it. // 2. If returning an error, you should retry it.
// See http://www.bittorrent.org/beps/bep_0015.html#time-outs // See http://www.bittorrent.org/beps/bep_0015.html#time-outs
func (utc *Client) Announce(c context.Context, r AnnounceRequest) (AnnounceResponse, error) { func (utc *Client) Announce(c context.Context, r AnnounceRequest) (AnnounceResponse, error) {
if r.PeerID.IsZero() { if r.InfoHash.IsZero() {
r.PeerID = utc.conf.ID panic("infohash is ZERO")
} }
r.PeerID = utc.id
return utc.announce(c, r) return utc.announce(c, r)
} }
func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) ( func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) (rs []ScrapeResponse, err error) {
rs []ScrapeResponse, err error) {
cid, err := utc.getConnectionID(c) cid, err := utc.getConnectionID(c)
if err != nil { if err != nil {
return return
@ -248,17 +244,24 @@ func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) (
tid := utc.getTranID() tid := utc.getTranID()
buf := bytes.NewBuffer(make([]byte, 0, 16+len(ihs)*20)) buf := bytes.NewBuffer(make([]byte, 0, 16+len(ihs)*20))
binary.Write(buf, binary.BigEndian, cid) binary.Write(buf, binary.BigEndian, cid) // 8: 0 - 8
binary.Write(buf, binary.BigEndian, ActionScrape) binary.Write(buf, binary.BigEndian, ActionScrape) // 4: 8 - 12
binary.Write(buf, binary.BigEndian, tid) binary.Write(buf, binary.BigEndian, tid) // 4: 12 - 16
for _, h := range ihs {
for _, h := range ihs { // 20*N: 16 -
buf.Write(h[:]) buf.Write(h[:])
} }
if err = utc.send(buf.Bytes()); err != nil { if err = utc.send(buf.Bytes()); err != nil {
return return
} }
data := make([]byte, utc.conf.MaxBufSize) bufSize := utc.MaxBufSize
if bufSize <= 0 {
bufSize = maxBufSize
}
data := make([]byte, bufSize)
n, err := utc.readResp(c, data) n, err := utc.readResp(c, data)
if err != nil { if err != nil {
return return
@ -268,19 +271,19 @@ func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) (
} }
data = data[:n] data = data[:n]
switch binary.BigEndian.Uint32(data[:4]) { switch action := binary.BigEndian.Uint32(data[:4]); action {
case ActionScrape: case ActionScrape:
case ActionError: case ActionError:
_, reason := utc.parseError(data[4:]) _, reason := utc.parseError(data[4:])
err = errors.New(reason) err = errors.New(reason)
return return
default: default:
err = errors.New("tracker response not connect action") err = fmt.Errorf("tracker response is not connect action %d", action)
return return
} }
if binary.BigEndian.Uint32(data[4:8]) != tid { if _tid := binary.BigEndian.Uint32(data[4:8]); _tid != tid {
err = errors.New("invalid transaction id") err = fmt.Errorf("expect transaction id %d, but got %d", tid, _tid)
return return
} }
@ -299,10 +302,10 @@ func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) (
// Scrape sends a Scrape request to the tracker. // Scrape sends a Scrape request to the tracker.
// //
// Notice: // Notice:
// 1. if it does not connect to the UDP tracker server, it will connect to it, // 1. if it does not connect to the UDP tracker server, it will connect to it,
// then send the ANNOUNCE request. // then send the ANNOUNCE request.
// 2. If returning an error, you should retry it. // 2. If returning an error, you should retry it.
// See http://www.bittorrent.org/beps/bep_0015.html#time-outs // See http://www.bittorrent.org/beps/bep_0015.html#time-outs
func (utc *Client) Scrape(c context.Context, hs []metainfo.Hash) ([]ScrapeResponse, error) { func (utc *Client) Scrape(c context.Context, hs []metainfo.Hash) ([]ScrapeResponse, error) {
return utc.scrape(c, hs) return utc.scrape(c, hs)
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -24,7 +24,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
// ServerHandler is used to handle the request from the client. // ServerHandler is used to handle the request from the client.
@ -45,25 +45,14 @@ type wrappedPeerAddr struct {
Time time.Time Time time.Time
} }
// ServerConfig is used to configure the Server. type buffer struct{ Data []byte }
type ServerConfig struct {
MaxBufSize int // Default: 2048
ErrorLog func(format string, args ...interface{}) // Default: log.Printf
}
func (c *ServerConfig) setDefault() {
if c.MaxBufSize <= 0 {
c.MaxBufSize = 2048
}
if c.ErrorLog == nil {
c.ErrorLog = log.Printf
}
}
// Server is a tracker server based on UDP. // Server is a tracker server based on UDP.
type Server struct { type Server struct {
// Default: log.Printf
ErrorLog func(format string, args ...interface{})
conn net.PacketConn conn net.PacketConn
conf ServerConfig
handler ServerHandler handler ServerHandler
bufpool sync.Pool bufpool sync.Pool
@ -74,23 +63,20 @@ type Server struct {
} }
// NewServer returns a new Server. // NewServer returns a new Server.
func NewServer(c net.PacketConn, h ServerHandler, sc ...ServerConfig) *Server { func NewServer(c net.PacketConn, h ServerHandler, bufSize int) *Server {
var conf ServerConfig if bufSize <= 0 {
if len(sc) > 0 { bufSize = maxBufSize
conf = sc[0]
} }
conf.setDefault()
s := &Server{ return &Server{
conf: conf,
conn: c, conn: c,
handler: h, handler: h,
exit: make(chan struct{}), exit: make(chan struct{}),
conns: make(map[uint64]wrappedPeerAddr, 128), conns: make(map[uint64]wrappedPeerAddr, 128),
bufpool: sync.Pool{New: func() interface{} {
return &buffer{Data: make([]byte, bufSize)}
}},
} }
s.bufpool.New = func() interface{} { return make([]byte, conf.MaxBufSize) }
return s
} }
// Close closes the tracker server. // Close closes the tracker server.
@ -126,11 +112,11 @@ func (uts *Server) cleanConnectionID(interval time.Duration) {
func (uts *Server) Run() { func (uts *Server) Run() {
go uts.cleanConnectionID(time.Minute * 2) go uts.cleanConnectionID(time.Minute * 2)
for { for {
buf := uts.bufpool.Get().([]byte) buf := uts.bufpool.Get().(*buffer)
n, raddr, err := uts.conn.ReadFrom(buf) n, raddr, err := uts.conn.ReadFrom(buf.Data)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "closed") { if !strings.Contains(err.Error(), "closed") {
uts.conf.ErrorLog("failed to read udp tracker request: %s", err) uts.errorf("failed to read udp tracker request: %s", err)
} }
return return
} else if n < 16 { } else if n < 16 {
@ -140,18 +126,26 @@ func (uts *Server) Run() {
} }
} }
func (uts *Server) handleRequest(raddr *net.UDPAddr, buf []byte, n int) { func (uts *Server) errorf(format string, args ...interface{}) {
if uts.ErrorLog == nil {
log.Printf(format, args...)
} else {
uts.ErrorLog(format, args...)
}
}
func (uts *Server) handleRequest(raddr *net.UDPAddr, buf *buffer, n int) {
defer uts.bufpool.Put(buf) defer uts.bufpool.Put(buf)
uts.handlePacket(raddr, buf[:n]) uts.handlePacket(raddr, buf.Data[:n])
} }
func (uts *Server) send(raddr *net.UDPAddr, b []byte) { func (uts *Server) send(raddr *net.UDPAddr, b []byte) {
n, err := uts.conn.WriteTo(b, raddr) n, err := uts.conn.WriteTo(b, raddr)
if err != nil { if err != nil {
uts.conf.ErrorLog("fail to send the udp tracker response to '%s': %s", uts.errorf("fail to send the udp tracker response to '%s': %s",
raddr.String(), err) raddr.String(), err)
} else if n < len(b) { } else if n < len(b) {
uts.conf.ErrorLog("too short udp tracker response sent to '%s'", raddr.String()) uts.errorf("too short udp tracker response sent to '%s'", raddr.String())
} }
} }
@ -190,16 +184,14 @@ func (uts *Server) sendConnResp(raddr *net.UDPAddr, tid uint32, cid uint64) {
uts.send(raddr, buf.Bytes()) uts.send(raddr, buf.Bytes())
} }
func (uts *Server) sendAnnounceResp(raddr *net.UDPAddr, tid uint32, func (uts *Server) sendAnnounceResp(raddr *net.UDPAddr, tid uint32, resp AnnounceResponse) {
resp AnnounceResponse) {
buf := bytes.NewBuffer(make([]byte, 0, 8+12+len(resp.Addresses)*18)) buf := bytes.NewBuffer(make([]byte, 0, 8+12+len(resp.Addresses)*18))
encodeResponseHeader(buf, ActionAnnounce, tid) encodeResponseHeader(buf, ActionAnnounce, tid)
resp.EncodeTo(buf) resp.EncodeTo(buf, raddr.IP.To4() != nil)
uts.send(raddr, buf.Bytes()) uts.send(raddr, buf.Bytes())
} }
func (uts *Server) sendScrapResp(raddr *net.UDPAddr, tid uint32, func (uts *Server) sendScrapResp(raddr *net.UDPAddr, tid uint32, rs []ScrapeResponse) {
rs []ScrapeResponse) {
buf := bytes.NewBuffer(make([]byte, 0, 8+len(rs)*12)) buf := bytes.NewBuffer(make([]byte, 0, 8+len(rs)*12))
encodeResponseHeader(buf, ActionScrape, tid) encodeResponseHeader(buf, ActionScrape, tid)
for _, r := range rs { for _, r := range rs {
@ -209,9 +201,9 @@ func (uts *Server) sendScrapResp(raddr *net.UDPAddr, tid uint32,
} }
func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) { func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) {
cid := binary.BigEndian.Uint64(b[:8]) cid := binary.BigEndian.Uint64(b[:8]) // 8: 0 - 8
action := binary.BigEndian.Uint32(b[8:12]) action := binary.BigEndian.Uint32(b[8:12]) // 4: 8 - 12
tid := binary.BigEndian.Uint32(b[12:16]) tid := binary.BigEndian.Uint32(b[12:16]) // 4: 12 - 16
b = b[16:] b = b[16:]
// Handle the connection request. // Handle the connection request.
@ -241,13 +233,13 @@ func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) {
uts.sendError(raddr, tid, "invalid announce request") uts.sendError(raddr, tid, "invalid announce request")
return return
} }
req.DecodeFrom(b, true) req.DecodeFrom(b)
} else { // For ipv6 } else { // For ipv6
if len(b) < 94 { if len(b) < 94 {
uts.sendError(raddr, tid, "invalid announce request") uts.sendError(raddr, tid, "invalid announce request")
return return
} }
req.DecodeFrom(b, false) req.DecodeFrom(b)
} }
resp, err := uts.handler.OnAnnounce(raddr, req) resp, err := uts.handler.OnAnnounce(raddr, req)
@ -256,6 +248,7 @@ func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) {
} else { } else {
uts.sendAnnounceResp(raddr, tid, resp) uts.sendAnnounceResp(raddr, tid, resp)
} }
case ActionScrape: case ActionScrape:
_len := len(b) _len := len(b)
infohashes := make([]metainfo.Hash, 0, _len/20) infohashes := make([]metainfo.Hash, 0, _len/20)
@ -274,6 +267,7 @@ func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) {
} else { } else {
uts.sendScrapResp(raddr, tid, resps) uts.sendScrapResp(raddr, tid, resps)
} }
default: default:
uts.sendError(raddr, tid, "unkwnown action") uts.sendError(raddr, tid, "unkwnown action")
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 xgfone // Copyright 2020~2023 xgfone
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -22,14 +22,13 @@ import (
"net" "net"
"time" "time"
"github.com/xgfone/bt/metainfo" "github.com/xgfone/go-bt/metainfo"
) )
type testHandler struct{} type testHandler struct{}
func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return } func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return }
func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) ( func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (r AnnounceResponse, err error) {
r AnnounceResponse, err error) {
if req.Port != 80 { if req.Port != 80 {
err = errors.New("port is not 80") err = errors.New("port is not 80")
return return
@ -50,12 +49,11 @@ func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (
Interval: 1, Interval: 1,
Leechers: 2, Leechers: 2,
Seeders: 3, Seeders: 3,
Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}}, Addresses: []metainfo.CompactAddr{{IP: net.ParseIP("127.0.0.1"), Port: 8001}},
} }
return return
} }
func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ( func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (rs []ScrapeResponse, err error) {
rs []ScrapeResponse, err error) {
rs = make([]ScrapeResponse, len(infohashes)) rs = make([]ScrapeResponse, len(infohashes))
for i := range infohashes { for i := range infohashes {
rs[i] = ScrapeResponse{ rs[i] = ScrapeResponse{
@ -73,7 +71,7 @@ func ExampleClient() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
server := NewServer(sconn, testHandler{}) server := NewServer(sconn, testHandler{}, 0)
defer server.Close() defer server.Close()
go server.Run() go server.Run()
@ -81,7 +79,7 @@ func ExampleClient() {
time.Sleep(time.Second) time.Sleep(time.Second)
// Create a client and dial to the UDP tracker server. // Create a client and dial to the UDP tracker server.
client, err := NewClientByDial("udp4", "127.0.0.1:8001") client, err := NewClientByDial("udp4", "127.0.0.1:8001", metainfo.Hash{})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -89,7 +87,7 @@ func ExampleClient() {
// Send the ANNOUNCE request to the UDP tracker server, // Send the ANNOUNCE request to the UDP tracker server,
// and get the ANNOUNCE response. // and get the ANNOUNCE response.
exts := []Extension{NewURLData([]byte("data")), NewNop()} exts := []Extension{NewURLData([]byte("data")), NewNop()}
req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80, Exts: exts} req := AnnounceRequest{InfoHash: metainfo.NewRandomHash(), Port: 80, Exts: exts}
resp, err := client.Announce(context.Background(), req) resp, err := client.Announce(context.Background(), req)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -1,50 +0,0 @@
// Copyright 2020 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"sync/atomic"
)
// Bool is used to implement a atomic bool.
type Bool struct {
v uint32
}
// NewBool returns a new Bool with the initialized value.
func NewBool(t bool) (b Bool) {
if t {
b.v = 1
}
return
}
// Get returns the bool value.
func (b *Bool) Get() bool { return atomic.LoadUint32(&b.v) == 1 }
// SetTrue sets the bool value to true.
func (b *Bool) SetTrue() { atomic.StoreUint32(&b.v, 1) }
// SetFalse sets the bool value to false.
func (b *Bool) SetFalse() { atomic.StoreUint32(&b.v, 0) }
// Set sets the bool value to t.
func (b *Bool) Set(t bool) {
if t {
b.SetTrue()
} else {
b.SetFalse()
}
}