diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 558a168..c332dbd 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -4,7 +4,7 @@ env: GO111MODULE: on jobs: build: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 name: Go ${{ matrix.go }} strategy: matrix: @@ -17,10 +17,12 @@ jobs: - '1.16' - '1.17' - '1.18' + - '1.19' + - '1.20' steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - run: go test -race ./... diff --git a/README.md b/README.md index 67db40f..a44b31b 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,6 @@ import ( "fmt" "io" "log" - "net/url" "os" "time" @@ -73,7 +72,7 @@ func getPeersFromTrackers(id, infohash metainfo.Hash, trackers []string) (peers defer cancel() resp := tracker.GetPeers(c, id, infohash, trackers) - for r := range resp { + for _, r := range resp { for _, addr := range r.Resp.Addresses { addrs := addr.String() nonexist := true diff --git a/dht/blacklist.go b/dht/blacklist.go index d31056e..d8cefff 100644 --- a/dht/blacklist.go +++ b/dht/blacklist.go @@ -17,6 +17,8 @@ package dht import ( "sync" "time" + + "github.com/xgfone/bt/krpc" ) // Blacklist is used to manage the ip blacklist. @@ -24,25 +26,24 @@ import ( // Notice: The implementation should clear the address existed for long time. type Blacklist interface { // In reports whether the address, ip and port, is in the blacklist. - In(ip string, port int) bool + In(krpc.Addr) bool // If port is equal to 0, it should ignore port and only use ip when matching. - Add(ip string, port int) + Add(krpc.Addr) // If port is equal to 0, it should delete the address by only the ip. - Del(ip string, port int) + Del(krpc.Addr) - // Close is used to notice the implementation to release the underlying - // resource. + // Close is used to notice the implementation to release the underlying resource. Close() } type noopBlacklist struct{} -func (nbl noopBlacklist) In(ip string, port int) bool { return false } -func (nbl noopBlacklist) Add(ip string, port int) {} -func (nbl noopBlacklist) Del(ip string, port int) {} -func (nbl noopBlacklist) Close() {} +func (nbl noopBlacklist) In(krpc.Addr) bool { return false } +func (nbl noopBlacklist) Add(krpc.Addr) {} +func (nbl noopBlacklist) Del(krpc.Addr) {} +func (nbl noopBlacklist) Close() {} // NewNoopBlacklist returns a no-op Blacklist. func NewNoopBlacklist() Blacklist { return noopBlacklist{} } @@ -57,14 +58,14 @@ type logBlacklist struct { logf func(string, ...interface{}) } -func (dbl logBlacklist) Add(ip string, port int) { - dbl.logf("add the blacklist: ip=%s, port=%d", ip, port) - dbl.Blacklist.Add(ip, port) +func (l logBlacklist) Add(addr krpc.Addr) { + l.logf("add the addr '%s' into the blacklist", addr.String()) + l.Blacklist.Add(addr) } -func (dbl logBlacklist) Del(ip string, port int) { - dbl.logf("delete the blacklist: ip=%s, port=%d", ip, port) - dbl.Blacklist.Del(ip, port) +func (l logBlacklist) Del(addr krpc.Addr) { + l.logf("delete the addr '%s' from the blacklist", addr.String()) + l.Blacklist.Del(addr) } /// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> @@ -85,7 +86,7 @@ func NewMemoryBlacklist(maxnum int, duration time.Duration) Blacklist { type wrappedPort struct { Time time.Time Enable bool - Ports map[int]struct{} + Ports map[uint16]struct{} } type blacklist struct { @@ -123,11 +124,11 @@ func (bl *blacklist) Close() { } // In reports whether the address, ip and port, is in the blacklist. -func (bl *blacklist) In(ip string, port int) (yes bool) { +func (bl *blacklist) In(addr krpc.Addr) (yes bool) { bl.lock.RLock() - if wp, ok := bl.ips[ip]; ok { + if wp, ok := bl.ips[addr.IP.String()]; ok { if wp.Enable { - _, yes = wp.Ports[port] + _, yes = wp.Ports[addr.Port] } else { yes = true } @@ -136,7 +137,8 @@ func (bl *blacklist) In(ip string, port int) (yes bool) { return } -func (bl *blacklist) Add(ip string, port int) { +func (bl *blacklist) Add(addr krpc.Addr) { + ip := addr.IP.String() bl.lock.Lock() wp, ok := bl.ips[ip] if !ok { @@ -149,30 +151,31 @@ func (bl *blacklist) Add(ip string, port int) { bl.ips[ip] = wp } - if port < 1 { + if addr.Port < 1 { wp.Enable = false wp.Ports = nil } else if wp.Ports == nil { - wp.Ports = map[int]struct{}{port: struct{}{}} + wp.Ports = map[uint16]struct{}{addr.Port: {}} } else { - wp.Ports[port] = struct{}{} + wp.Ports[addr.Port] = struct{}{} } wp.Time = time.Now() bl.lock.Unlock() } -func (bl *blacklist) Del(ip string, port int) { +func (bl *blacklist) Del(addr krpc.Addr) { + ip := addr.IP.String() bl.lock.Lock() if wp, ok := bl.ips[ip]; ok { - if port < 1 { + if addr.Port < 1 { delete(bl.ips, ip) } else if wp.Enable { switch len(wp.Ports) { case 0, 1: delete(bl.ips, ip) default: - delete(wp.Ports, port) + delete(wp.Ports, addr.Port) wp.Time = time.Now() } } diff --git a/dht/blacklist_test.go b/dht/blacklist_test.go index ca1cb99..12b3091 100644 --- a/dht/blacklist_test.go +++ b/dht/blacklist_test.go @@ -15,8 +15,11 @@ package dht import ( + "net" "testing" "time" + + "github.com/xgfone/bt/krpc" ) func (bl *blacklist) portsLen() (n int) { @@ -42,12 +45,12 @@ func TestMemoryBlacklist(t *testing.T) { bl := NewMemoryBlacklist(3, time.Second).(*blacklist) defer bl.Close() - bl.Add("1.1.1.1", 123) - bl.Add("1.1.1.1", 456) - bl.Add("1.1.1.1", 789) - bl.Add("2.2.2.2", 111) - bl.Add("3.3.3.3", 0) - bl.Add("4.4.4.4", 222) + bl.Add(krpc.NewAddr(net.ParseIP("1.1.1.1"), 123)) + bl.Add(krpc.NewAddr(net.ParseIP("1.1.1.1"), 456)) + bl.Add(krpc.NewAddr(net.ParseIP("1.1.1.1"), 789)) + bl.Add(krpc.NewAddr(net.ParseIP("2.2.2.2"), 111)) + bl.Add(krpc.NewAddr(net.ParseIP("3.3.3.3"), 0)) + bl.Add(krpc.NewAddr(net.ParseIP("4.4.4.4"), 222)) ips := bl.getIPs() if len(ips) != 3 { @@ -66,15 +69,15 @@ func TestMemoryBlacklist(t *testing.T) { t.Errorf("expect port num 4, but got %d", n) } - if bl.In("1.1.1.1", 111) || !bl.In("1.1.1.1", 123) { + if bl.In(krpc.NewAddr(net.ParseIP("1.1.1.1"), 111)) || !bl.In(krpc.NewAddr(net.ParseIP("1.1.1.1"), 123)) { t.Fail() } - if !bl.In("3.3.3.3", 111) || bl.In("4.4.4.4", 222) { + if !bl.In(krpc.NewAddr(net.ParseIP("3.3.3.3"), 111)) || bl.In(krpc.NewAddr(net.ParseIP("4.4.4.4"), 222)) { t.Fail() } - bl.Del("3.3.3.3", 0) - if bl.In("3.3.3.3", 111) { + bl.Del(krpc.NewAddr(net.ParseIP("3.3.3.3"), 0)) + if bl.In(krpc.NewAddr(net.ParseIP("3.3.3.3"), 111)) { t.Fail() } } diff --git a/dht/dht_server.go b/dht/dht_server.go index eebd289..165d62f 100644 --- a/dht/dht_server.go +++ b/dht/dht_server.go @@ -53,7 +53,7 @@ type Result struct { // Addr is the address of the peer where the request is sent to. // // Notice: it may be nil for "get_peers" request. - Addr *net.UDPAddr + Addr net.Addr // For Error Code int // 0 represents the success. @@ -61,7 +61,7 @@ type Result struct { Timeout bool // Timeout indicates whether the response is timeout. // The list of the address of the peers returned by GetPeers. - Peers []metainfo.Address + Peers []string } // Config is used to configure the DHT server. @@ -134,18 +134,24 @@ type Config struct { // The default is log.Printf. ErrorLog func(format string, args ...interface{}) + // They is used to convert the address between net.Addr and krpc.Addr. + // + // For the default, it asserts net.Addr to *net.UDPAddr. + GetNetAddr func(krpc.Addr) net.Addr + GetKrpcAddr func(net.Addr) krpc.Addr + // OnSearch is called when someone searches the torrent infohash, // that's, the "get_peers" query. // // The default callback does noting. - OnSearch func(infohash string, ip net.IP, port uint16) + OnSearch func(infohash string, addr krpc.Addr) // OnTorrent is called when someone has the torrent infohash // or someone has just downloaded the torrent infohash, // that's, the "get_peers" response or "announce_peer" query. // // The default callback does noting. - OnTorrent func(infohash string, ip net.IP, port uint16) + OnTorrent func(infohash string, addr string) // HandleInMessage is used to intercept the incoming DHT message. // For example, you can debug the message as the log. @@ -153,23 +159,36 @@ type Config struct { // Return true if going on handling by the default. Or return false. // // The default is nil. - HandleInMessage func(*net.UDPAddr, *krpc.Message) bool + HandleInMessage func(net.Addr, *krpc.Message) (handled bool) // HandleOutMessage is used to intercept the outgoing DHT message. // For example, you can debug the message as the log. // - // Return (false, nil) if going on handling by the default. + // Return (true, nil) if not going on handling by the default. // // The default is nil. - HandleOutMessage func(*net.UDPAddr, *krpc.Message) (wrote bool, err error) + HandleOutMessage func(net.Addr, *krpc.Message) (wrote bool, err error) } -func (c Config) in(*net.UDPAddr, *krpc.Message) bool { return true } -func (c Config) out(*net.UDPAddr, *krpc.Message) (bool, error) { return false, nil } +func (c Config) in(net.Addr, *krpc.Message) bool { return false } +func (c Config) out(net.Addr, *krpc.Message) (bool, error) { return false, nil } -func (c *Config) set(conf ...Config) { - if len(conf) > 0 { - *c = conf[0] +func (c Config) onSearch(string, krpc.Addr) {} +func (c Config) onTorrent(string, string) {} + +func (c Config) k2nAddr(a krpc.Addr) net.Addr { + if a.Orig != nil { + return a.Orig + } + return a.UDPAddr() +} +func (c Config) n2kAddr(a net.Addr) krpc.Addr { + return krpc.NewAddrFromUDPAddr(a.(*net.UDPAddr)) +} + +func (c *Config) set(conf *Config) { + if conf != nil { + *c = *conf } if c.K <= 0 { @@ -197,10 +216,16 @@ func (c *Config) set(conf ...Config) { c.RespTimeout = time.Second * 10 } if c.OnSearch == nil { - c.OnSearch = func(string, net.IP, uint16) {} + c.OnSearch = c.onSearch } if c.OnTorrent == nil { - c.OnTorrent = func(string, net.IP, uint16) {} + c.OnTorrent = c.onTorrent + } + if c.GetNetAddr == nil { + c.GetNetAddr = c.k2nAddr + } + if c.GetKrpcAddr == nil { + c.GetKrpcAddr = c.n2kAddr } if c.HandleInMessage == nil { c.HandleInMessage = c.in @@ -215,6 +240,7 @@ type Server struct { conf Config exit chan struct{} conn net.PacketConn + lock sync.Mutex once sync.Once ipv4 bool @@ -230,9 +256,9 @@ type Server struct { } // NewServer returns a new DHT server. -func NewServer(conn net.PacketConn, config ...Config) *Server { +func NewServer(conn net.PacketConn, config *Config) *Server { var conf Config - conf.set(config...) + conf.set(config) if len(conf.IPProtocols) == 0 { host, _, err := net.SplitHostPort(conn.LocalAddr().String()) @@ -278,7 +304,6 @@ func NewServer(conn net.PacketConn, config ...Config) *Server { if s.peerManager == nil { s.peerManager = s.tokenPeerManager } - return s } @@ -292,14 +317,14 @@ func (s *Server) Bootstrap(addrs []string) { if (s.ipv4 && s.routingTable4.Len() == 0) || (s.ipv6 && s.routingTable6.Len() == 0) { for _, addr := range addrs { - as, err := metainfo.NewAddressesFromString(addr) + ipports, err := krpc.ParseAddrs(addr) if err != nil { s.conf.ErrorLog(err.Error()) continue } - for _, a := range as { - if isIPv6(a.IP) { + for _, ipport := range ipports { + if isIPv6(ipport.IP) { if !s.ipv6 { continue } @@ -307,8 +332,8 @@ func (s *Server) Bootstrap(addrs []string) { continue } - if err = s.FindNode(a.UDPAddr(), s.conf.ID); err != nil { - s.conf.ErrorLog(`fail to bootstrap '%s': %s`, a.String(), err) + if err = s.FindNode(ipport, s.conf.ID); err != nil { + s.conf.ErrorLog(`fail to bootstrap '%s': %s`, ipport.String(), err) } } } @@ -324,12 +349,12 @@ func (s *Server) Node6Num() int { return s.routingTable6.Len() } // AddNode adds the node into the routing table. // // The returned value: -// NodeAdded: The node is added successfully. -// NodeNotAdded: The node is not added and is discarded. -// NodeExistAndUpdated: The node has existed, and its status has been updated. -// NodeExistAndChanged: The node has existed, but the address is inconsistent. -// The current node will be discarded. // +// NodeAdded: The node is added successfully. +// NodeNotAdded: The node is not added and is discarded. +// NodeExistAndUpdated: The node has existed, and its status has been updated. +// NodeExistAndChanged: The node has existed, but the address is inconsistent. +// The current node will be discarded. func (s *Server) AddNode(node krpc.Node) int { // For IPv6 if isIPv6(node.Addr.IP) { @@ -347,13 +372,13 @@ func (s *Server) AddNode(node krpc.Node) int { return NodeNotAdded } -func (s *Server) addNode(a *net.UDPAddr, id metainfo.Hash, ro bool) (r int) { +func (s *Server) addNode(kaddr krpc.Addr, id metainfo.Hash, ro bool) (r int) { if ro { // BEP 43 return NodeNotAdded } - if r = s.AddNode(krpc.NewNodeByUDPAddr(id, a)); r == NodeExistAndChanged { - s.conf.Blacklist.Add(a.IP.String(), a.Port) + if r = s.AddNode(krpc.NewNode(id, kaddr)); r == NodeExistAndChanged { + s.conf.Blacklist.Add(kaddr) } return @@ -401,11 +426,15 @@ func (s *Server) Run() { return } - s.handlePacket(raddr.(*net.UDPAddr), buf[:n]) + s.handlePacket(raddr, buf[:n]) } } -func (s *Server) isDisabled(raddr *net.UDPAddr) bool { +func (s *Server) isDisabled(raddr krpc.Addr) bool { + if len(raddr.IP) == 0 { + return false + } + if isIPv6(raddr.IP) { if !s.ipv6 { return true @@ -417,13 +446,16 @@ func (s *Server) isDisabled(raddr *net.UDPAddr) bool { } // HandlePacket handles the incoming DHT message. -func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) { - if s.isDisabled(raddr) { +func (s *Server) handlePacket(raddr net.Addr, data []byte) { + kaddr := s.conf.GetKrpcAddr(raddr) + kaddr.Orig = raddr + + if s.isDisabled(kaddr) { return } // Check whether the raddr is in the ip blacklist. If yes, discard it. - if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) { + if s.conf.Blacklist.In(kaddr) { return } @@ -436,46 +468,50 @@ func (s *Server) handlePacket(raddr *net.UDPAddr, data []byte) { return } - // TODO: Should we use a task pool?? - go s.handleMessage(raddr, msg) + // (xgf): Should we use a task pool?? + go s.handleMessage(kaddr, msg) } -func (s *Server) handleMessage(raddr *net.UDPAddr, m krpc.Message) { - if !s.conf.HandleInMessage(raddr, &m) { +func (s *Server) handleMessage(kaddr krpc.Addr, m krpc.Message) { + if s.conf.HandleInMessage(kaddr.Orig, &m) { return } switch m.Y { case "q": if !m.A.ID.IsZero() { - r := s.addNode(raddr, m.A.ID, m.RO) + r := s.addNode(kaddr, m.A.ID, m.RO) if r != NodeExistAndChanged && !s.conf.ReadOnly { // BEP 43 - s.handleQuery(raddr, m) + s.handleQuery(kaddr, m) } } + case "r": if !m.R.ID.IsZero() { - if s.addNode(raddr, m.R.ID, m.RO) == NodeExistAndChanged { + if s.addNode(kaddr, m.R.ID, m.RO) == NodeExistAndChanged { return } - if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil { - t.OnResponse(t, raddr, m) + if t := s.transactionManager.PopTransaction(m.T, kaddr); t != nil { + t.OnResponse(t, kaddr, m) } } + case "e": - if t := s.transactionManager.PopTransaction(m.T, raddr); t != nil { + if t := s.transactionManager.PopTransaction(m.T, kaddr); t != nil { t.OnError(t, m.E.Code, m.E.Reason) } + default: s.conf.ErrorLog("unknown dht message type '%s'", m.Y) } } -func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) { +func (s *Server) handleQuery(raddr krpc.Addr, m krpc.Message) { switch m.Q { case queryMethodPing: s.reply(raddr, m.T, krpc.ResponseResult{}) + case queryMethodFindNode: // See BEP 32 var r krpc.ResponseResult n4 := m.A.ContainsWant(krpc.WantNodes) @@ -495,6 +531,7 @@ func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) { } } s.reply(raddr, m.T, r) + case queryMethodGetPeers: // See BEP 32 n4 := m.A.ContainsWant(krpc.WantNodes) n6 := m.A.ContainsWant(krpc.WantNodes6) @@ -538,33 +575,39 @@ func (s *Server) handleQuery(raddr *net.UDPAddr, m krpc.Message) { r.Token = s.tokenManager.Token(raddr) s.reply(raddr, m.T, r) - s.conf.OnSearch(m.A.InfoHash.HexString(), raddr.IP, uint16(raddr.Port)) + s.conf.OnSearch(m.A.InfoHash.HexString(), raddr) + case queryMethodAnnouncePeer: if s.tokenManager.Check(raddr, m.A.Token) { return } s.reply(raddr, m.T, krpc.ResponseResult{}) - s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.IP, m.A.GetPort(raddr.Port)) + s.conf.OnTorrent(m.A.InfoHash.HexString(), raddr.String()) + default: s.sendError(raddr, m.T, "unknown query method", krpc.ErrorCodeMethodUnknown) } } -func (s *Server) send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) { - // // TODO: Should we check the ip blacklist?? - // if s.conf.Blacklist.In(raddr.IP.String(), raddr.Port) { +func (s *Server) send(kaddr krpc.Addr, m krpc.Message) (wrote bool, err error) { + // // (xgf): Should we check the ip blacklist?? + // if s.conf.Blacklist.In(kaddr) { // return // } m.RO = s.conf.ReadOnly // BEP 43 - if wrote, err = s.conf.HandleOutMessage(raddr, &m); !wrote && err == nil { - wrote, err = s._send(raddr, m) + kaddr.Orig = s.conf.GetNetAddr(kaddr) + + s.lock.Lock() + defer s.lock.Unlock() + if wrote, err = s.conf.HandleOutMessage(kaddr.Orig, &m); !wrote && err == nil { + wrote, err = s._send(kaddr, m) } return } -func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err error) { +func (s *Server) _send(kaddr krpc.Addr, m krpc.Message) (wrote bool, err error) { if m.T == "" || m.Y == "" { panic(`DHT message "t" or "y" must not be empty`) } @@ -575,10 +618,10 @@ func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err erro panic(err) } - n, err := s.conn.WriteTo(buf.Bytes(), raddr) + n, err := s.conn.WriteTo(buf.Bytes(), kaddr.Orig) if err != nil { - err = fmt.Errorf("error writing %d bytes to %s: %s", buf.Len(), raddr, err) - s.conf.Blacklist.Add(raddr.IP.String(), 0) + err = fmt.Errorf("error writing %d bytes to %v: %s", buf.Len(), kaddr.Orig, err) + s.conf.Blacklist.Add(kaddr) return } @@ -590,13 +633,13 @@ func (s *Server) _send(raddr *net.UDPAddr, m krpc.Message) (wrote bool, err erro return } -func (s *Server) sendError(raddr *net.UDPAddr, tid, reason string, code int) { +func (s *Server) sendError(raddr krpc.Addr, tid, reason string, code int) { if _, err := s.send(raddr, krpc.NewErrorMsg(tid, code, reason)); err != nil { s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error()) } } -func (s *Server) reply(raddr *net.UDPAddr, tid string, r krpc.ResponseResult) { +func (s *Server) reply(raddr krpc.Addr, tid string, r krpc.ResponseResult) { r.ID = s.conf.ID if _, err := s.send(raddr, krpc.NewResponseMsg(tid, r)); err != nil { s.conf.ErrorLog("error replying to %s: %s", raddr.String(), err.Error()) @@ -627,7 +670,7 @@ func (s *Server) onError(t *transaction, code int, reason string) { } func (s *Server) onTimeout(t *transaction) { - // TODO: Should we use a task pool?? + // (xgf): Should we use a task pool?? t.Done(Result{Timeout: true}) var qid string @@ -641,11 +684,19 @@ func (s *Server) onTimeout(t *transaction) { t.ID, s.conf.ID, t.Query, qid, s.conn.LocalAddr(), t.Addr.String()) } -func (s *Server) onPingResp(t *transaction, a *net.UDPAddr, m krpc.Message) { +// Ping sends a PING query to addr, and the callback function cb will be called +// when the response or error is returned, or it's timeout. +func (s *Server) Ping(addr krpc.Addr, cb ...func(Result)) (err error) { + t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...) + t.OnResponse = s.onPingResp + return s.request(t) +} + +func (s *Server) onPingResp(t *transaction, a krpc.Addr, m krpc.Message) { t.Done(Result{}) } -func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) { +func (s *Server) onGetPeersResp(t *transaction, a krpc.Addr, m krpc.Message) { // Store the response node with the token. if m.R.Token != "" { s.tokenPeerManager.Set(m.R.ID, a, m.R.Token) @@ -655,7 +706,7 @@ func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) if len(m.R.Values) > 0 { t.Done(Result{Peers: m.R.Values}) for _, addr := range m.R.Values { - s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr.IP, addr.Port) + s.conf.OnTorrent(t.Arg.InfoHash.HexString(), addr) } return } @@ -706,10 +757,9 @@ func (s *Server) onGetPeersResp(t *transaction, a *net.UDPAddr, m krpc.Message) } } -func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int, - ids metainfo.Hashes, cb ...func(Result)) { +func (s *Server) getPeers(info metainfo.Hash, addr krpc.Addr, depth int, ids metainfo.Hashes, cb ...func(Result)) { arg := krpc.QueryArg{InfoHash: info, Wants: s.want} - t := newTransaction(s, addr.UDPAddr(), queryMethodGetPeers, arg, cb...) + t := newTransaction(s, addr, queryMethodGetPeers, arg, cb...) t.OnResponse = s.onGetPeersResp t.Depth = depth t.Visited = ids @@ -718,14 +768,6 @@ func (s *Server) getPeers(info metainfo.Hash, addr metainfo.Address, depth int, } } -// Ping sends a PING query to addr, and the callback function cb will be called -// when the response or error is returned, or it's timeout. -func (s *Server) Ping(addr *net.UDPAddr, cb ...func(Result)) (err error) { - t := newTransaction(s, addr, queryMethodPing, krpc.QueryArg{}, cb...) - t.OnResponse = s.onPingResp - return s.request(t) -} - // GetPeers searches the peer storing the torrent by the infohash of the torrent, // which will search it recursively until some peers are returned or it reaches // the maximun depth, that's, ServerConfig.SearchDepth. @@ -760,7 +802,6 @@ func (s *Server) GetPeers(infohash metainfo.Hash, cb ...func(Result)) { for _, node := range nodes { s.getPeers(infohash, node.Addr, s.conf.SearchDepth, ids, cb...) } - } // AnnouncePeer announces the torrent infohash to the K closest nodes, @@ -780,16 +821,15 @@ func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort b sentNodes := make([]krpc.Node, 0, len(nodes)) for _, node := range nodes { - addr := node.Addr.UDPAddr() - token := s.tokenPeerManager.Get(infohash, addr) + token := s.tokenPeerManager.Get(infohash, node.Addr) if token == "" { continue } arg := krpc.QueryArg{ImpliedPort: impliedPort, InfoHash: infohash, Port: port, Token: token} - t := newTransaction(s, addr, queryMethodAnnouncePeer, arg) + t := newTransaction(s, node.Addr, queryMethodAnnouncePeer, arg) if err := s.request(t); err != nil { - s.conf.ErrorLog("fail to send query message to '%s': %s", addr.String(), err) + s.conf.ErrorLog("fail to send query message to '%s': %s", node.Addr.String(), err) } else { sentNodes = append(sentNodes, node) } @@ -801,7 +841,7 @@ func (s *Server) AnnouncePeer(infohash metainfo.Hash, port uint16, impliedPort b // FindNode sends the "find_node" query to the addr to find the target node. // // Notice: In general, it's used to bootstrap the routing table. -func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error { +func (s *Server) FindNode(addr krpc.Addr, target metainfo.Hash) error { if target.IsZero() { panic("the target is ZERO") } @@ -809,8 +849,7 @@ func (s *Server) FindNode(addr *net.UDPAddr, target metainfo.Hash) error { return s.findNode(target, addr, s.conf.SearchDepth, nil) } -func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int, - ids metainfo.Hashes) error { +func (s *Server) findNode(target metainfo.Hash, addr krpc.Addr, depth int, ids metainfo.Hashes) error { arg := krpc.QueryArg{Target: target, Wants: s.want} t := newTransaction(s, addr, queryMethodFindNode, arg) t.OnResponse = s.onFindNodeResp @@ -818,7 +857,7 @@ func (s *Server) findNode(target metainfo.Hash, addr *net.UDPAddr, depth int, return s.request(t) } -func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message) { +func (s *Server) onFindNodeResp(t *transaction, _ krpc.Addr, m krpc.Message) { t.Done(Result{}) // Search the target node recursively. @@ -860,7 +899,7 @@ func (s *Server) onFindNodeResp(t *transaction, a *net.UDPAddr, m krpc.Message) } for _, node := range nodes { - err := s.findNode(t.Arg.Target, node.Addr.UDPAddr(), t.Depth, ids) + err := s.findNode(t.Arg.Target, node.Addr, t.Depth, ids) if err != nil { s.conf.ErrorLog(`fail to send "find_node" query to '%s': %s`, node.Addr.String(), err) diff --git a/dht/dht_server_test.go b/dht/dht_server_test.go index cd21de0..3375704 100644 --- a/dht/dht_server_test.go +++ b/dht/dht_server_test.go @@ -17,39 +17,33 @@ package dht import ( "fmt" "net" - "strconv" "sync" "time" + "github.com/xgfone/bt/internal/helper" + "github.com/xgfone/bt/krpc" "github.com/xgfone/bt/metainfo" ) type testPeerManager struct { lock sync.RWMutex - peers map[metainfo.Hash][]metainfo.Address + peers map[metainfo.Hash][]string } func newTestPeerManager() *testPeerManager { - return &testPeerManager{peers: make(map[metainfo.Hash][]metainfo.Address)} + return &testPeerManager{peers: make(map[metainfo.Hash][]string)} } -func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr metainfo.Address) { +func (pm *testPeerManager) AddPeer(infohash metainfo.Hash, addr string) { pm.lock.Lock() - var exist bool - for _, orig := range pm.peers[infohash] { - if orig.Equal(addr) { - exist = true - break - } - } - if !exist { + defer pm.lock.Unlock() + + if !helper.ContainsString(pm.peers[infohash], addr) { pm.peers[infohash] = append(pm.peers[infohash], addr) } - pm.lock.Unlock() } -func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, - ipv6 bool) (addrs []metainfo.Address) { +func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) (addrs []string) { // We only supports IPv4, so ignore the ipv6 argument. pm.lock.RLock() _addrs := pm.peers[infohash] @@ -63,13 +57,11 @@ func (pm *testPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, return } -func onSearch(infohash string, ip net.IP, port uint16) { - addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10)) - fmt.Printf("%s is searching %s\n", addr, infohash) +func onSearch(infohash string, addr krpc.Addr) { + fmt.Printf("%s is searching %s\n", addr.String(), infohash) } -func onTorrent(infohash string, ip net.IP, port uint16) { - addr := net.JoinHostPort(ip.String(), strconv.FormatUint(uint64(port), 10)) +func onTorrent(infohash string, addr string) { fmt.Printf("%s has downloaded %s\n", addr, infohash) } @@ -77,7 +69,7 @@ func newDHTServer(id metainfo.Hash, addr string, pm PeerManager) (s *Server, err conn, err := net.ListenPacket("udp", addr) if err == nil { c := Config{ID: id, PeerManager: pm, OnSearch: onSearch, OnTorrent: onTorrent} - s = NewServer(conn, c) + s = NewServer(conn, &c) } return } @@ -140,7 +132,7 @@ func ExampleServer() { fmt.Printf("no peers for %s\n", infohash) } else { for _, peer := range r.Peers { - fmt.Printf("%s: %s\n", infohash, peer.String()) + fmt.Printf("%s: %s\n", infohash, peer) } } }) @@ -149,7 +141,7 @@ func ExampleServer() { time.Sleep(time.Second * 2) // Add the peer to let the DHT server1 has the peer. - pm.AddPeer(infohash, metainfo.NewAddress(net.ParseIP("127.0.0.1"), 9001)) + pm.AddPeer(infohash, "127.0.0.1:9001") // Search the torrent infohash again, but from DHT server2, // which will search the DHT server1 recursively. @@ -158,7 +150,7 @@ func ExampleServer() { fmt.Printf("no peers for %s\n", infohash) } else { for _, peer := range r.Peers { - fmt.Printf("%s: %s\n", infohash, peer.String()) + fmt.Printf("%s: %s\n", infohash, peer) } } }) diff --git a/dht/peer_manager.go b/dht/peer_manager.go index 03d5006..ad38533 100644 --- a/dht/peer_manager.go +++ b/dht/peer_manager.go @@ -15,23 +15,24 @@ package dht import ( - "net" "sync" "time" + "github.com/xgfone/bt/krpc" "github.com/xgfone/bt/metainfo" ) // PeerManager is used to manage the peers. type PeerManager interface { // If ipv6 is true, only return ipv6 addresses. Or return ipv4 addresses. - GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []metainfo.Address + GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) []string } +var _ PeerManager = new(tokenPeerManager) + type peer struct { ID metainfo.Hash - IP net.IP - Port uint16 + Addr krpc.Addr Token string Time time.Time } @@ -75,7 +76,7 @@ func (tpm *tokenPeerManager) Start(interval time.Duration) { } } -func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token string) { +func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr krpc.Addr, token string) { addrkey := addr.String() tpm.lock.Lock() peers, ok := tpm.peers[id] @@ -83,17 +84,11 @@ func (tpm *tokenPeerManager) Set(id metainfo.Hash, addr *net.UDPAddr, token stri peers = make(map[string]peer, 4) tpm.peers[id] = peers } - peers[addrkey] = peer{ - ID: id, - IP: addr.IP, - Port: uint16(addr.Port), - Token: token, - Time: time.Now(), - } + peers[addrkey] = peer{ID: id, Addr: addr, Token: token, Time: time.Now()} tpm.lock.Unlock() } -func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr *net.UDPAddr) (token string) { +func (tpm *tokenPeerManager) Get(id metainfo.Hash, addr krpc.Addr) (token string) { addrkey := addr.String() tpm.lock.RLock() if peers, ok := tpm.peers[id]; ok { @@ -113,9 +108,8 @@ func (tpm *tokenPeerManager) Stop() { } } -func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, - ipv6 bool) (addrs []metainfo.Address) { - addrs = make([]metainfo.Address, 0, maxnum) +func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, ipv6 bool) (addrs []string) { + addrs = make([]string, 0, maxnum) tpm.lock.RLock() if peers, ok := tpm.peers[infohash]; ok { for _, peer := range peers { @@ -124,13 +118,13 @@ func (tpm *tokenPeerManager) GetPeers(infohash metainfo.Hash, maxnum int, } if ipv6 { // For IPv6 - if isIPv6(peer.IP) { + if isIPv6(peer.Addr.IP) { maxnum-- - addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port)) + addrs = append(addrs, peer.Addr.String()) } - } else if !isIPv6(peer.IP) { // For IPv4 + } else if !isIPv6(peer.Addr.IP) { // For IPv4 maxnum-- - addrs = append(addrs, metainfo.NewAddress(peer.IP, peer.Port)) + addrs = append(addrs, peer.Addr.String()) } } } diff --git a/dht/routing_table.go b/dht/routing_table.go index ab31ca4..46c3952 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -155,12 +155,12 @@ func (rt *routingTable) Stop() { // AddNode adds the node into the routing table. // // The returned value: -// NodeAdded: The node is added successfully. -// NodeNotAdded: The node is not added and is discarded. -// NodeExistAndUpdated: The node has existed, and its status has been updated. -// NodeExistAndChanged: The node has existed, but the address is inconsistent. -// The current node will be discarded. // +// NodeAdded: The node is added successfully. +// NodeNotAdded: The node is not added and is discarded. +// NodeExistAndUpdated: The node has existed, and its status has been updated. +// NodeExistAndChanged: The node has existed, but the address is inconsistent. +// The current node will be discarded. func (rt *routingTable) AddNode(n krpc.Node) (r int) { if n.ID == rt.root { // Don't add itself. return NodeNotAdded @@ -264,7 +264,7 @@ func (b *bucket) AddNode(n krpc.Node, now time.Time) (status int) { return } - // // TODO: Should we replace the old one?? + // // (xgf): Should we replace the old one?? // b.UpdateLastChangedTime(now) // copy(b.Nodes[i:], b.Nodes[i+1:]) // b.Nodes[len(b.Nodes)-1] = newWrappedNode(b, n, now) @@ -314,7 +314,7 @@ func (b *bucket) CheckAllNodes(now time.Time) { case nodeStatusGood: case nodeStatusDubious: // Try to send the PING query to the dubious node to check whether it is alive. - if err := b.table.s.Ping(node.Node.Addr.UDPAddr()); err != nil { + if err := b.table.s.Ping(node.Node.Addr); err != nil { b.table.s.conf.ErrorLog("fail to ping '%s': %s", node.Node.String(), err) } case nodeStatusBad: diff --git a/dht/token_manager.go b/dht/token_manager.go index 8e6b6a9..0f923c1 100644 --- a/dht/token_manager.go +++ b/dht/token_manager.go @@ -15,16 +15,16 @@ package dht import ( - "net" "sync" "time" - "github.com/xgfone/bt/utils" + "github.com/xgfone/bt/internal/helper" + "github.com/xgfone/bt/krpc" ) // TokenManager is used to manage and validate the token. // -// TODO: Should we allocate the different token for each node?? +// (xgf): Should we allocate the different token for each node?? type tokenManager struct { lock sync.RWMutex last string @@ -34,12 +34,12 @@ type tokenManager struct { } func newTokenManager() *tokenManager { - token := utils.RandomString(8) + token := helper.RandomString(8) return &tokenManager{last: token, new: token, exit: make(chan struct{})} } func (tm *tokenManager) updateToken() { - token := utils.RandomString(8) + token := helper.RandomString(8) tm.lock.Lock() tm.last, tm.new = tm.new, token tm.lock.Unlock() @@ -83,7 +83,7 @@ func (tm *tokenManager) Stop() { } // Token allocates a token for a node addr and returns the token. -func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) { +func (tm *tokenManager) Token(addr krpc.Addr) (token string) { addrs := addr.String() tm.lock.RLock() token = tm.new @@ -94,7 +94,7 @@ func (tm *tokenManager) Token(addr *net.UDPAddr) (token string) { // Check checks whether the token associated with the node addr is valid, // that's, it's not expired. -func (tm *tokenManager) Check(addr *net.UDPAddr, token string) (ok bool) { +func (tm *tokenManager) Check(addr krpc.Addr, token string) (ok bool) { tm.lock.RLock() last, new := tm.last, tm.new tm.lock.RUnlock() diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go index 5a05a44..253006a 100644 --- a/dht/transaction_manager.go +++ b/dht/transaction_manager.go @@ -15,7 +15,6 @@ package dht import ( - "net" "strconv" "sync" "sync/atomic" @@ -29,7 +28,7 @@ type transaction struct { ID string Query string Arg krpc.QueryArg - Addr *net.UDPAddr + Addr krpc.Addr Time time.Time Depth int @@ -37,7 +36,7 @@ type transaction struct { Callback func(Result) OnError func(t *transaction, code int, reason string) OnTimeout func(t *transaction) - OnResponse func(t *transaction, radd *net.UDPAddr, msg krpc.Message) + OnResponse func(t *transaction, radd krpc.Addr, msg krpc.Message) } func (t *transaction) Done(r Result) { @@ -47,8 +46,8 @@ func (t *transaction) Done(r Result) { } } -func noopResponse(*transaction, *net.UDPAddr, krpc.Message) {} -func newTransaction(s *Server, a *net.UDPAddr, q string, qa krpc.QueryArg, +func noopResponse(*transaction, krpc.Addr, krpc.Message) {} +func newTransaction(s *Server, a krpc.Addr, q string, qa krpc.QueryArg, callback ...func(Result)) *transaction { var cb func(Result) if len(callback) > 0 { @@ -141,7 +140,7 @@ func (tm *transactionManager) DeleteTransaction(t *transaction) { // and the peer address. // // Return nil if there is no the transaction. -func (tm *transactionManager) PopTransaction(tid string, addr *net.UDPAddr) (t *transaction) { +func (tm *transactionManager) PopTransaction(tid string, addr krpc.Addr) (t *transaction) { key := transactionkey{id: tid, addr: addr.String()} tm.lock.Lock() if t = tm.trans[key]; t != nil { diff --git a/utils/doc.go b/downloader/doc.go similarity index 77% rename from utils/doc.go rename to downloader/doc.go index 6103847..09ca683 100644 --- a/utils/doc.go +++ b/downloader/doc.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,5 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package utils supplies some convenient functions. -package utils +// Package downloader is used to download the torrent or the real file +// from the peer node by the peer wire protocol. +package downloader diff --git a/downloader/torrent.go b/downloader/torrent.go index 92c5b59..f93c797 100644 --- a/downloader/torrent.go +++ b/downloader/torrent.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,9 +35,9 @@ const BlockSize = 16384 // 16KiB. // Request is used to send a download request. type request struct { - Host string - Port uint16 - PeerID metainfo.Hash + Host string + Port uint16 + // PeerID metainfo.Hash InfoHash metainfo.Hash } @@ -57,6 +57,11 @@ type TorrentDownloaderConfig struct { // The default is a random id. ID metainfo.Hash + // The size of a block of the piece. + // + // Default: 16384 + BlockSize uint64 + // WorkerNum is the number of the worker downloading the torrent concurrently. // // The default is 128. @@ -71,11 +76,14 @@ type TorrentDownloaderConfig struct { ErrorLog func(format string, args ...interface{}) } -func (c *TorrentDownloaderConfig) set(conf ...TorrentDownloaderConfig) { - if len(conf) > 0 { - *c = conf[0] +func (c *TorrentDownloaderConfig) set(conf *TorrentDownloaderConfig) { + if conf != nil { + *c = *conf } + if c.BlockSize <= 0 { + c.BlockSize = BlockSize + } if c.WorkerNum <= 0 { c.WorkerNum = 128 } @@ -102,16 +110,15 @@ type TorrentDownloader struct { // NewTorrentDownloader returns a new TorrentDownloader. // // If id is ZERO, it is reset to a random id. workerNum is 128 by default. -func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader { +func NewTorrentDownloader(c *TorrentDownloaderConfig) *TorrentDownloader { var conf TorrentDownloaderConfig - conf.set(c...) + conf.set(c) d := &TorrentDownloader{ conf: conf, exit: make(chan struct{}), requests: make(chan request, conf.WorkerNum), responses: make(chan TorrentResponse, 1024), - ehmsg: pp.ExtendedHandshakeMsg{ M: map[string]uint8{pp.ExtendedMessageNameMetadata: 1}, }, @@ -130,6 +137,9 @@ func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader { // Notice: the remote peer must support the "ut_metadata" extenstion. // Or downloading fails. func (d *TorrentDownloader) Request(host string, port uint16, infohash metainfo.Hash) { + if infohash.IsZero() { + panic("infohash is ZERO") + } d.requests <- request{Host: host, Port: port, InfoHash: infohash} } @@ -168,7 +178,7 @@ func (d *TorrentDownloader) worker() { case <-d.exit: return case r := <-d.requests: - if err := d.download(r.Host, r.Port, r.PeerID, r.InfoHash); err != nil { + if err := d.download(r.Host, r.Port, r.InfoHash); err != nil { d.conf.ErrorLog("fail to download the torrent '%s': %s", r.InfoHash.HexString(), err) } @@ -176,8 +186,7 @@ func (d *TorrentDownloader) worker() { } } -func (d *TorrentDownloader) download(host string, port uint16, - peerID, infohash metainfo.Hash) (err error) { +func (d *TorrentDownloader) download(host string, port uint16, infohash metainfo.Hash) (err error) { addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)) conn, err := pp.NewPeerConnByDial(addr, d.conf.ID, infohash, d.conf.DialTimeout) if err != nil { @@ -190,8 +199,6 @@ func (d *TorrentDownloader) download(host string, port uint16, return } else if !conn.PeerExtBits.IsSupportExtended() { return fmt.Errorf("the remote peer '%s' does not support Extended", addr) - } else if !peerID.IsZero() && peerID != conn.PeerID { - return fmt.Errorf("inconsistent peer id '%s'", conn.PeerID.HexString()) } if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil { @@ -200,7 +207,7 @@ func (d *TorrentDownloader) download(host string, port uint16, var pieces [][]byte var piecesNum int - var metadataSize int + var metadataSize uint64 var utmetadataID uint8 var msg pp.Message @@ -247,13 +254,14 @@ func (d *TorrentDownloader) download(host string, port uint16, } metadataSize = ehmsg.MetadataSize - piecesNum = metadataSize / BlockSize - if metadataSize%BlockSize != 0 { + piecesNum = int(metadataSize / d.conf.BlockSize) + if metadataSize%d.conf.BlockSize != 0 { piecesNum++ } pieces = make([][]byte, piecesNum) go d.requestPieces(conn, utmetadataID, piecesNum) + case 1: if pieces == nil { return @@ -269,8 +277,8 @@ func (d *TorrentDownloader) download(host string, port uint16, } pieceLen := len(utmsg.Data) - if (utmsg.Piece != piecesNum-1 && pieceLen != BlockSize) || - (utmsg.Piece == piecesNum-1 && pieceLen != metadataSize%BlockSize) { + if (utmsg.Piece != piecesNum-1 && pieceLen != int(d.conf.BlockSize)) || + (utmsg.Piece == piecesNum-1 && pieceLen != int(metadataSize%d.conf.BlockSize)) { return } pieces[utmsg.Piece] = utmsg.Data diff --git a/downloader/torrent_test.go b/downloader/torrent_test.go new file mode 100644 index 0000000..143c4ff --- /dev/null +++ b/downloader/torrent_test.go @@ -0,0 +1,124 @@ +// Copyright 2023 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package downloader + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/xgfone/bt/metainfo" + "github.com/xgfone/bt/peerprotocol" + pp "github.com/xgfone/bt/peerprotocol" +) + +type bep10Handler struct { + peerprotocol.NoopHandler // For implementing peerprotocol.Handler. + + infodata string +} + +func (h bep10Handler) OnExtHandShake(c *peerprotocol.PeerConn) error { + if _, ok := c.ExtendedHandshakeMsg.M[peerprotocol.ExtendedMessageNameMetadata]; !ok { + return errors.New("missing the extension 'ut_metadata'") + } + + return c.SendExtHandshakeMsg(peerprotocol.ExtendedHandshakeMsg{ + M: map[string]uint8{pp.ExtendedMessageNameMetadata: 2}, + MetadataSize: uint64(len(h.infodata)), + }) +} + +func (h bep10Handler) OnPayload(c *peerprotocol.PeerConn, extid uint8, extdata []byte) error { + if extid != 2 { + return fmt.Errorf("unknown extension id %d", extid) + } + + var reqmsg peerprotocol.UtMetadataExtendedMsg + if err := reqmsg.DecodeFromPayload(extdata); err != nil { + return err + } + + if reqmsg.MsgType != peerprotocol.UtMetadataExtendedMsgTypeRequest { + return errors.New("unsupported ut_metadata extension type") + } + + startIndex := reqmsg.Piece * BlockSize + endIndex := startIndex + BlockSize + if totalSize := len(h.infodata); totalSize < endIndex { + endIndex = totalSize + } + + respmsg := peerprotocol.UtMetadataExtendedMsg{ + MsgType: peerprotocol.UtMetadataExtendedMsgTypeData, + Piece: reqmsg.Piece, + Data: []byte(h.infodata[startIndex:endIndex]), + } + + data, err := respmsg.EncodeToBytes() + if err != nil { + return err + } + + peerextid := c.ExtendedHandshakeMsg.M[peerprotocol.ExtendedMessageNameMetadata] + return c.SendExtMsg(peerextid, data) +} + +func ExampleTorrentDownloader() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + handler := bep10Handler{infodata: "1234567890"} + infohash := metainfo.NewHashFromString(handler.infodata) + + // Start the torrent server. + var serverConfig peerprotocol.Config + serverConfig.ExtBits.Set(peerprotocol.ExtensionBitExtended) + server, err := peerprotocol.NewServerByListen("tcp", "127.0.0.1:9010", metainfo.NewRandomHash(), handler, &serverConfig) + if err != nil { + fmt.Println(err) + return + } + defer server.Close() + + go server.Run() // Start the torrent server. + time.Sleep(time.Millisecond * 100) // Wait that the torrent server finishes to start. + + // Start the torrent downloader. + downloaderConfig := &TorrentDownloaderConfig{WorkerNum: 3, DialTimeout: time.Second} + downloader := NewTorrentDownloader(downloaderConfig) + go func() { + for { + select { + case <-ctx.Done(): + return + case result := <-downloader.Response(): + fmt.Println(string(result.InfoBytes)) + } + } + }() + + // Start to download the torrent. + downloader.Request("127.0.0.1", 9010, infohash) + + // Wait to finish the test. + time.Sleep(time.Second) + cancel() + time.Sleep(time.Millisecond * 50) + + // Output: + // 1234567890 +} diff --git a/utils/io.go b/internal/helper/io.go similarity index 86% rename from utils/io.go rename to internal/helper/io.go index 2098c1a..fa11c4c 100644 --- a/utils/io.go +++ b/internal/helper/io.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package helper import "io" // CopyNBuffer is the same as io.CopyN, but uses the given buf as the buffer. -// -// If buf is nil or empty, it will make a new one with 2048. func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) { - if len(buf) == 0 { - buf = make([]byte, 2048) - } - written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf) if written == n { return n, nil diff --git a/utils/slice.go b/internal/helper/slice.go similarity index 82% rename from utils/slice.go rename to internal/helper/slice.go index 53ce3a6..23c3206 100644 --- a/utils/slice.go +++ b/internal/helper/slice.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package helper -// InStringSlice reports whether s is in ss. -func InStringSlice(ss []string, s string) bool { +// ContainsString reports whether s is in ss. +func ContainsString(ss []string, s string) bool { for _, v := range ss { if v == s { return true } } - return false } diff --git a/utils/slice_test.go b/internal/helper/slice_test.go similarity index 75% rename from utils/slice_test.go rename to internal/helper/slice_test.go index d67327f..4d5db7b 100644 --- a/utils/slice_test.go +++ b/internal/helper/slice_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package helper -import ( - "testing" -) +import "testing" -func TestInStringSlice(t *testing.T) { - if !InStringSlice([]string{"a", "b"}, "a") { +func TestContainsString(t *testing.T) { + if !ContainsString([]string{"a", "b"}, "a") { t.Fail() } - if InStringSlice([]string{"a", "b"}, "z") { + if ContainsString([]string{"a", "b"}, "z") { t.Fail() } } diff --git a/utils/string.go b/internal/helper/string.go similarity index 78% rename from utils/string.go rename to internal/helper/string.go index 19865fa..ff82817 100644 --- a/utils/string.go +++ b/internal/helper/string.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,13 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package helper -import "crypto/rand" +import ( + crand "crypto/rand" + "math/rand" +) // RandomString generates a size-length string randomly. func RandomString(size int) string { bs := make([]byte, size) - rand.Read(bs) + if n, _ := crand.Read(bs); n < size { + for ; n < size; n++ { + bs[n] = byte(rand.Intn(256)) + } + } return string(bs) } diff --git a/krpc/addr.go b/krpc/addr.go new file mode 100644 index 0000000..fc48e36 --- /dev/null +++ b/krpc/addr.go @@ -0,0 +1,312 @@ +// Copyright 2023 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package krpc + +import ( + "bytes" + "encoding" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + + "github.com/xgfone/bt/bencode" +) + +// Addr represents an address based on ip and port, +// which implements "Compact IP-address/port info". +// +// See http://bittorrent.org/beps/bep_0005.html. +type Addr struct { + IP net.IP // For IPv4, its length must be 4. + Port uint16 + + // The original network address, which is only used by the DHT server. + Orig net.Addr +} + +// ParseAddrs parses the address from the string s with the format "IP:PORT". +func ParseAddrs(s string) (addrs []Addr, err error) { + _ip, _port, err := net.SplitHostPort(s) + if err != nil { + return + } + + port, err := strconv.ParseUint(_port, 10, 16) + if err != nil { + return + } + + ip := net.ParseIP(_ip) + if ip != nil { + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + return []Addr{NewAddr(ip, uint16(port))}, nil + } + + ips, err := net.LookupIP(_ip) + if err != nil { + return nil, err + } + + addrs = make([]Addr, len(ips)) + for i, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + addrs[i] = NewAddr(ip, uint16(port)) + } + + return +} + +// NewAddr returns a new Addr with ip and port. +func NewAddr(ip net.IP, port uint16) Addr { + return Addr{IP: ip, Port: port} +} + +// NewAddrFromUDPAddr converts *net.UDPAddr to a new Addr. +func NewAddrFromUDPAddr(ua *net.UDPAddr) Addr { + return Addr{IP: ua.IP, Port: uint16(ua.Port), Orig: ua} +} + +// Valid reports whether the addr is valid. +func (a Addr) Valid() bool { + return len(a.IP) > 0 && a.Port > 0 +} + +// Equal reports whether a is equal to o. +func (a Addr) Equal(o Addr) bool { + return a.Port == o.Port && a.IP.Equal(o.IP) +} + +// UDPAddr converts itself to *net.UDPAddr. +func (a Addr) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{IP: a.IP, Port: int(a.Port)} +} + +var _ net.Addr = Addr{} + +// Network implements the interface net.Addr#Network. +func (a Addr) Network() string { + return "krpc" +} + +func (a Addr) String() string { + if a.Port == 0 { + return a.IP.String() + } + return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10)) +} + +// WriteBinary is the same as MarshalBinary, but writes the result into w +// instead of returning. +func (a Addr) WriteBinary(w io.Writer) (n int, err error) { + if n, err = w.Write(a.IP); err == nil { + if err = binary.Write(w, binary.BigEndian, a.Port); err == nil { + n += 2 + } + } + return +} + +var ( + _ encoding.BinaryMarshaler = new(Addr) + _ encoding.BinaryUnmarshaler = new(Addr) +) + +// MarshalBinary implements the interface encoding.BinaryMarshaler, +func (a Addr) MarshalBinary() (data []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(18) + if _, err = a.WriteBinary(buf); err == nil { + data = buf.Bytes() + } + return +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (a *Addr) UnmarshalBinary(data []byte) error { + _len := len(data) - 2 + switch _len { + case net.IPv4len, net.IPv6len: + default: + return errors.New("invalid compact ip-address/port info") + } + + a.IP = make(net.IP, _len) + copy(a.IP, data[:_len]) + a.Port = binary.BigEndian.Uint16(data[_len:]) + return nil +} + +var ( + _ bencode.Marshaler = new(Addr) + _ bencode.Unmarshaler = new(Addr) +) + +// MarshalBencode implements the interface bencode.Marshaler. +func (a Addr) MarshalBencode() (b []byte, err error) { + if b, err = a.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (a *Addr) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = a.UnmarshalBinary(data) + } + return +} + +// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +var ( + _ bencode.Marshaler = new(CompactIPv4Addrs) + _ bencode.Unmarshaler = new(CompactIPv4Addrs) + _ encoding.BinaryMarshaler = new(CompactIPv4Addrs) + _ encoding.BinaryUnmarshaler = new(CompactIPv4Addrs) +) + +// CompactIPv4Addrs is a set of IPv4 Addrs. +type CompactIPv4Addrs []Addr + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (cas CompactIPv4Addrs) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(6 * len(cas)) + for _, addr := range cas { + if addr.IP = addr.IP.To4(); len(addr.IP) == 0 { + continue + } + + if n, err := addr.WriteBinary(buf); err != nil { + return nil, err + } else if n != 6 { + panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (cas *CompactIPv4Addrs) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%6 != 0 { + return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len) + } + + addrs := make(CompactIPv4Addrs, 0, _len/6) + for i := 0; i < _len; i += 6 { + var addr Addr + if err = addr.UnmarshalBinary(b[i : i+6]); err != nil { + return + } + addrs = append(addrs, addr) + } + + *cas = addrs + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cas CompactIPv4Addrs) MarshalBencode() (b []byte, err error) { + if b, err = cas.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cas *CompactIPv4Addrs) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = cas.UnmarshalBinary(data) + } + return +} + +// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +var ( + _ bencode.Marshaler = new(CompactIPv6Addrs) + _ bencode.Unmarshaler = new(CompactIPv6Addrs) + _ encoding.BinaryMarshaler = new(CompactIPv6Addrs) + _ encoding.BinaryUnmarshaler = new(CompactIPv6Addrs) +) + +// CompactIPv6Addrs is a set of IPv6 Addrs. +type CompactIPv6Addrs []Addr + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (cas CompactIPv6Addrs) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(18 * len(cas)) + for _, addr := range cas { + if addr.IP = addr.IP.To4(); len(addr.IP) == 0 { + continue + } + + if n, err := addr.WriteBinary(buf); err != nil { + return nil, err + } else if n != 18 { + panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (cas *CompactIPv6Addrs) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%18 != 0 { + return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len) + } + + addrs := make(CompactIPv6Addrs, 0, _len/18) + for i := 0; i < _len; i += 18 { + var addr Addr + if err = addr.UnmarshalBinary(b[i : i+18]); err != nil { + return + } + addrs = append(addrs, addr) + } + + *cas = addrs + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cas CompactIPv6Addrs) MarshalBencode() (b []byte, err error) { + if b, err = cas.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cas *CompactIPv6Addrs) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = cas.UnmarshalBinary(data) + } + return +} diff --git a/krpc/addr_test.go b/krpc/addr_test.go new file mode 100644 index 0000000..a830a0c --- /dev/null +++ b/krpc/addr_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package krpc + +import ( + "bytes" + "net" + "testing" + + "github.com/xgfone/bt/bencode" +) + +func TestAddr(t *testing.T) { + addrs := []Addr{ + {IP: net.ParseIP("172.16.1.1").To4(), Port: 123}, + {IP: net.ParseIP("192.168.1.1").To4(), Port: 456}, + } + expect := "l6:\xac\x10\x01\x01\x00\x7b6:\xc0\xa8\x01\x01\x01\xc8e" + + buf := new(bytes.Buffer) + if err := bencode.NewEncoder(buf).Encode(addrs); err != nil { + t.Error(err) + } else if result := buf.String(); result != expect { + t.Errorf("expect %s, but got %x\n", expect, result) + } + + var raddrs []Addr + if err := bencode.DecodeString(expect, &raddrs); err != nil { + t.Error(err) + } else if len(raddrs) != len(addrs) { + t.Errorf("expect addrs length %d, but got %d\n", len(addrs), len(raddrs)) + } else { + for i, addr := range addrs { + if !addr.Equal(raddrs[i]) { + t.Errorf("%d: expect addr %v, but got %v\n", i, addr, raddrs[i]) + } + } + } +} diff --git a/krpc/message.go b/krpc/message.go index 5585262..1de9c43 100644 --- a/krpc/message.go +++ b/krpc/message.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -272,14 +272,14 @@ type ResponseResult struct { // of the requested ipv4 target. // // find_node - Nodes CompactIPv4Node `bencode:"nodes,omitempty"` // BEP 5 + Nodes CompactIPv4Nodes `bencode:"nodes,omitempty"` // BEP 5 // Nodes6 is a string containing the compact node information for the list // of the ipv6 target node, or the K(8) closest good nodes in routing table // of the requested ipv6 target. // // find_node - Nodes6 CompactIPv6Node `bencode:"nodes6,omitempty"` // BEP 32 + Nodes6 CompactIPv6Nodes `bencode:"nodes6,omitempty"` // BEP 32 // Token is used for future "announce_peer". // @@ -288,157 +288,14 @@ type ResponseResult struct { // Values is a list of the torrent peers. // + // For each element, in general, it is a compact IP-address/port info and + // may be decoded to Addr, for example, + // + // addrs := make([]Addr, len(values)) + // for i, v := range values { + // addrs[i].UnmarshalBinary([]byte(v)) + // } + // // get_peers - Values CompactAddresses `bencode:"values,omitempty"` // BEP 5 -} - -/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// CompactAddresses represents a group of the compact addresses. -type CompactAddresses []metainfo.Address - -// MarshalBinary implements the interface binary.BinaryMarshaler. -func (cas CompactAddresses) MarshalBinary() ([]byte, error) { - ss := make([]string, len(cas)) - for i, addr := range cas { - data, err := addr.MarshalBinary() - if err != nil { - return nil, err - } - ss[i] = string(data) - } - return bencode.EncodeBytes(ss) -} - -// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. -func (cas *CompactAddresses) UnmarshalBinary(b []byte) (err error) { - var ss []string - if err = bencode.DecodeBytes(b, &ss); err == nil { - addrs := make(CompactAddresses, len(ss)) - for i, s := range ss { - var addr metainfo.Address - if err = addr.UnmarshalBinary([]byte(s)); err != nil { - return - } - addrs[i] = addr - } - } - - return -} - -/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// CompactIPv4Node is a set of IPv4 Nodes. -type CompactIPv4Node []Node - -// MarshalBinary implements the interface binary.BinaryMarshaler. -func (cn CompactIPv4Node) MarshalBinary() ([]byte, error) { - buf := bytes.NewBuffer(nil) - buf.Grow(26 * len(cn)) - for _, ni := range cn { - if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 { - continue - } - if n, err := ni.WriteBinary(buf); err != nil { - return nil, err - } else if n != 26 { - panic(fmt.Errorf("CompactIPv4NodeInfo: the invalid NodeInfo length '%d'", n)) - } - } - return buf.Bytes(), nil -} - -// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. -func (cn *CompactIPv4Node) UnmarshalBinary(b []byte) (err error) { - _len := len(b) - if _len%26 != 0 { - return fmt.Errorf("CompactIPv4NodeInfo: invalid bytes length '%d'", _len) - } - - nis := make([]Node, 0, _len/26) - for i := 0; i < _len; i += 26 { - var ni Node - if err = ni.UnmarshalBinary(b[i : i+26]); err != nil { - return - } - nis = append(nis, ni) - } - - *cn = nis - return -} - -// MarshalBencode implements the interface bencode.Marshaler. -func (cn CompactIPv4Node) MarshalBencode() (b []byte, err error) { - if b, err = cn.MarshalBinary(); err == nil { - b, err = bencode.EncodeBytes(b) - } - return -} - -// UnmarshalBencode implements the interface bencode.Unmarshaler. -func (cn *CompactIPv4Node) UnmarshalBencode(b []byte) (err error) { - var s string - if err = bencode.DecodeBytes(b, &s); err == nil { - err = cn.UnmarshalBinary([]byte(s)) - } - return -} - -/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// CompactIPv6Node is a set of IPv6 Nodes. -type CompactIPv6Node []Node - -// MarshalBinary implements the interface binary.BinaryMarshaler. -func (cn CompactIPv6Node) MarshalBinary() ([]byte, error) { - buf := bytes.NewBuffer(nil) - buf.Grow(38 * len(cn)) - for _, ni := range cn { - ni.Addr.IP = ni.Addr.IP.To16() - if n, err := ni.WriteBinary(buf); err != nil { - return nil, err - } else if n != 38 { - panic(fmt.Errorf("CompactIPv6NodeInfo: the invalid NodeInfo length '%d'", n)) - } - } - return buf.Bytes(), nil -} - -// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. -func (cn *CompactIPv6Node) UnmarshalBinary(b []byte) (err error) { - _len := len(b) - if _len%38 != 0 { - return fmt.Errorf("CompactIPv6NodeInfo: invalid bytes length '%d'", _len) - } - - nis := make([]Node, 0, _len/38) - for i := 0; i < _len; i += 38 { - var ni Node - if err = ni.UnmarshalBinary(b[i : i+38]); err != nil { - return - } - nis = append(nis, ni) - } - - *cn = nis - return -} - -// MarshalBencode implements the interface bencode.Marshaler. -func (cn CompactIPv6Node) MarshalBencode() (b []byte, err error) { - if b, err = cn.MarshalBinary(); err == nil { - b, err = bencode.EncodeBytes(b) - } - return -} - -// UnmarshalBencode implements the interface bencode.Unmarshaler. -func (cn *CompactIPv6Node) UnmarshalBencode(b []byte) (err error) { - var s string - if err = bencode.DecodeBytes(b, &s); err == nil { - err = cn.UnmarshalBinary([]byte(s)) - } - return + Values []string `bencode:"values,omitempty"` // BEP 5 } diff --git a/krpc/node.go b/krpc/node.go index ce83ddc..7d2f11f 100644 --- a/krpc/node.go +++ b/krpc/node.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,29 +16,23 @@ package krpc import ( "bytes" + "encoding" "fmt" "io" - "net" + "github.com/xgfone/bt/bencode" "github.com/xgfone/bt/metainfo" ) // Node represents a node information. type Node struct { ID metainfo.Hash - Addr metainfo.Address + Addr Addr } // NewNode returns a new Node. -func NewNode(id metainfo.Hash, ip net.IP, port int) Node { - return Node{ID: id, Addr: metainfo.NewAddress(ip, uint16(port))} -} - -// NewNodeByUDPAddr returns a new Node with the id and the UDP address. -func NewNodeByUDPAddr(id metainfo.Hash, addr *net.UDPAddr) (n Node) { - n.ID = id - n.Addr.FromUDPAddr(addr) - return +func NewNode(id metainfo.Hash, addr Addr) Node { + return Node{ID: id, Addr: addr} } func (n Node) String() string { @@ -63,17 +57,28 @@ func (n Node) WriteBinary(w io.Writer) (m int, err error) { return } -// MarshalBinary implements the interface binary.BinaryMarshaler. +var ( + _ encoding.BinaryMarshaler = new(Node) + _ encoding.BinaryUnmarshaler = new(Node) +) + +// MarshalBinary implements the interface encoding.BinaryMarshaler, +// which implements "Compact node info". +// +// See http://bittorrent.org/beps/bep_0005.html. func (n Node) MarshalBinary() (data []byte, err error) { buf := bytes.NewBuffer(nil) - buf.Grow(48) + buf.Grow(40) if _, err = n.WriteBinary(buf); err == nil { data = buf.Bytes() } return } -// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler, +// which implements "Compact node info". +// +// See http://bittorrent.org/beps/bep_0005.html. func (n *Node) UnmarshalBinary(b []byte) error { if len(b) < 26 { return io.ErrShortBuffer @@ -82,3 +87,133 @@ func (n *Node) UnmarshalBinary(b []byte) error { copy(n.ID[:], b[:20]) return n.Addr.UnmarshalBinary(b[20:]) } + +// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +var ( + _ bencode.Marshaler = new(CompactIPv4Nodes) + _ bencode.Unmarshaler = new(CompactIPv4Nodes) + _ encoding.BinaryMarshaler = new(CompactIPv4Nodes) + _ encoding.BinaryUnmarshaler = new(CompactIPv4Nodes) +) + +// CompactIPv4Nodes is a set of IPv4 Nodes. +type CompactIPv4Nodes []Node + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (cns CompactIPv4Nodes) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(26 * len(cns)) + for _, ni := range cns { + if ni.Addr.IP = ni.Addr.IP.To4(); len(ni.Addr.IP) == 0 { + continue + } + if n, err := ni.WriteBinary(buf); err != nil { + return nil, err + } else if n != 26 { + panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (cns *CompactIPv4Nodes) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%26 != 0 { + return fmt.Errorf("CompactIPv4Nodes: invalid node info length '%d'", _len) + } + + nis := make([]Node, 0, _len/26) + for i := 0; i < _len; i += 26 { + var ni Node + if err = ni.UnmarshalBinary(b[i : i+26]); err != nil { + return + } + nis = append(nis, ni) + } + + *cns = nis + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cns CompactIPv4Nodes) MarshalBencode() (b []byte, err error) { + if b, err = cns.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cns *CompactIPv4Nodes) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = cns.UnmarshalBinary(data) + } + return +} + +// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +var ( + _ bencode.Marshaler = new(CompactIPv6Nodes) + _ bencode.Unmarshaler = new(CompactIPv6Nodes) + _ encoding.BinaryMarshaler = new(CompactIPv6Nodes) + _ encoding.BinaryUnmarshaler = new(CompactIPv6Nodes) +) + +// CompactIPv6Nodes is a set of IPv6 Nodes. +type CompactIPv6Nodes []Node + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (cns CompactIPv6Nodes) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(38 * len(cns)) + for _, ni := range cns { + ni.Addr.IP = ni.Addr.IP.To16() + if n, err := ni.WriteBinary(buf); err != nil { + return nil, err + } else if n != 38 { + panic(fmt.Errorf("CompactIPv6Nodes: the invalid node info length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (cns *CompactIPv6Nodes) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%38 != 0 { + return fmt.Errorf("CompactIPv6Nodes: invalid node info length '%d'", _len) + } + + nis := make([]Node, 0, _len/38) + for i := 0; i < _len; i += 38 { + var ni Node + if err = ni.UnmarshalBinary(b[i : i+38]); err != nil { + return + } + nis = append(nis, ni) + } + + *cns = nis + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cns CompactIPv6Nodes) MarshalBencode() (b []byte, err error) { + if b, err = cns.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cns *CompactIPv6Nodes) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = cns.UnmarshalBinary(data) + } + return +} diff --git a/metainfo/addr_compact.go b/metainfo/addr_compact.go new file mode 100644 index 0000000..e13e672 --- /dev/null +++ b/metainfo/addr_compact.go @@ -0,0 +1,270 @@ +// Copyright 2023 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "bytes" + "encoding" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + + "github.com/xgfone/bt/bencode" +) + +// CompactAddr represents an address based on ip and port, +// which implements "Compact IP-address/port info". +// +// See http://bittorrent.org/beps/bep_0005.html. +type CompactAddr struct { + IP net.IP // For IPv4, its length must be 4. + Port uint16 +} + +// NewCompactAddr returns a new compact Addr with ip and port. +func NewCompactAddr(ip net.IP, port uint16) CompactAddr { + return CompactAddr{IP: ip, Port: port} +} + +// NewCompactAddrFromUDPAddr converts *net.UDPAddr to a new CompactAddr. +func NewCompactAddrFromUDPAddr(ua *net.UDPAddr) CompactAddr { + return CompactAddr{IP: ua.IP, Port: uint16(ua.Port)} +} + +// Valid reports whether the addr is valid. +func (a CompactAddr) Valid() bool { + return len(a.IP) > 0 && a.Port > 0 +} + +// Equal reports whether a is equal to o. +func (a CompactAddr) Equal(o CompactAddr) bool { + return a.Port == o.Port && a.IP.Equal(o.IP) +} + +// UDPAddr converts itself to *net.Addr. +func (a CompactAddr) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{IP: a.IP, Port: int(a.Port)} +} + +var _ net.Addr = CompactAddr{} + +// Network implements the interface net.Addr#Network. +func (a CompactAddr) Network() string { + return "krpc" +} + +func (a CompactAddr) String() string { + if a.Port == 0 { + return a.IP.String() + } + return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10)) +} + +// WriteBinary is the same as MarshalBinary, but writes the result into w +// instead of returning. +func (a CompactAddr) WriteBinary(w io.Writer) (n int, err error) { + if n, err = w.Write(a.IP); err == nil { + if err = binary.Write(w, binary.BigEndian, a.Port); err == nil { + n += 2 + } + } + return +} + +var ( + _ encoding.BinaryMarshaler = new(CompactAddr) + _ encoding.BinaryUnmarshaler = new(CompactAddr) +) + +// MarshalBinary implements the interface encoding.BinaryMarshaler, +func (a CompactAddr) MarshalBinary() (data []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(18) + if _, err = a.WriteBinary(buf); err == nil { + data = buf.Bytes() + } + return +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (a *CompactAddr) UnmarshalBinary(data []byte) error { + _len := len(data) - 2 + switch _len { + case net.IPv4len, net.IPv6len: + default: + return errors.New("invalid compact ip-address/port info") + } + + a.IP = make(net.IP, _len) + copy(a.IP, data[:_len]) + a.Port = binary.BigEndian.Uint16(data[_len:]) + return nil +} + +var ( + _ bencode.Marshaler = new(CompactAddr) + _ bencode.Unmarshaler = new(CompactAddr) +) + +// MarshalBencode implements the interface bencode.Marshaler. +func (a CompactAddr) MarshalBencode() (b []byte, err error) { + if b, err = a.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (a *CompactAddr) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = a.UnmarshalBinary(data) + } + return +} + +// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +var ( + _ bencode.Marshaler = new(CompactIPv4Addrs) + _ bencode.Unmarshaler = new(CompactIPv4Addrs) + _ encoding.BinaryMarshaler = new(CompactIPv4Addrs) + _ encoding.BinaryUnmarshaler = new(CompactIPv4Addrs) +) + +// CompactIPv4Addrs is a set of IPv4 Addrs. +type CompactIPv4Addrs []CompactAddr + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (cas CompactIPv4Addrs) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(6 * len(cas)) + for _, addr := range cas { + if addr.IP = addr.IP.To4(); len(addr.IP) == 0 { + continue + } + + if n, err := addr.WriteBinary(buf); err != nil { + return nil, err + } else if n != 6 { + panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (cas *CompactIPv4Addrs) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%6 != 0 { + return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len) + } + + addrs := make(CompactIPv4Addrs, 0, _len/6) + for i := 0; i < _len; i += 6 { + var addr CompactAddr + if err = addr.UnmarshalBinary(b[i : i+6]); err != nil { + return + } + addrs = append(addrs, addr) + } + + *cas = addrs + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cas CompactIPv4Addrs) MarshalBencode() (b []byte, err error) { + if b, err = cas.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cas *CompactIPv4Addrs) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = cas.UnmarshalBinary(data) + } + return +} + +// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +var ( + _ bencode.Marshaler = new(CompactIPv6Addrs) + _ bencode.Unmarshaler = new(CompactIPv6Addrs) + _ encoding.BinaryMarshaler = new(CompactIPv6Addrs) + _ encoding.BinaryUnmarshaler = new(CompactIPv6Addrs) +) + +// CompactIPv6Addrs is a set of IPv6 Addrs. +type CompactIPv6Addrs []CompactAddr + +// MarshalBinary implements the interface encoding.BinaryMarshaler. +func (cas CompactIPv6Addrs) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(nil) + buf.Grow(18 * len(cas)) + for _, addr := range cas { + addr.IP = addr.IP.To16() + if n, err := addr.WriteBinary(buf); err != nil { + return nil, err + } else if n != 18 { + panic(fmt.Errorf("CompactIPv4Nodes: the invalid node info length '%d'", n)) + } + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements the interface encoding.BinaryUnmarshaler. +func (cas *CompactIPv6Addrs) UnmarshalBinary(b []byte) (err error) { + _len := len(b) + if _len%18 != 0 { + return fmt.Errorf("CompactIPv4Addrs: invalid addr info length '%d'", _len) + } + + addrs := make(CompactIPv6Addrs, 0, _len/18) + for i := 0; i < _len; i += 18 { + var addr CompactAddr + if err = addr.UnmarshalBinary(b[i : i+18]); err != nil { + return + } + addrs = append(addrs, addr) + } + + *cas = addrs + return +} + +// MarshalBencode implements the interface bencode.Marshaler. +func (cas CompactIPv6Addrs) MarshalBencode() (b []byte, err error) { + if b, err = cas.MarshalBinary(); err == nil { + b, err = bencode.EncodeBytes(b) + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (cas *CompactIPv6Addrs) UnmarshalBencode(b []byte) (err error) { + var data []byte + if err = bencode.DecodeBytes(b, &data); err == nil { + err = cas.UnmarshalBinary(data) + } + return +} diff --git a/metainfo/addr_compact_test.go b/metainfo/addr_compact_test.go new file mode 100644 index 0000000..7b58197 --- /dev/null +++ b/metainfo/addr_compact_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "bytes" + "net" + "testing" + + "github.com/xgfone/bt/bencode" +) + +func TestCompactAddr(t *testing.T) { + addrs := []CompactAddr{ + {IP: net.ParseIP("172.16.1.1").To4(), Port: 123}, + {IP: net.ParseIP("192.168.1.1").To4(), Port: 456}, + } + expect := "l6:\xac\x10\x01\x01\x00\x7b6:\xc0\xa8\x01\x01\x01\xc8e" + + buf := new(bytes.Buffer) + if err := bencode.NewEncoder(buf).Encode(addrs); err != nil { + t.Error(err) + } else if result := buf.String(); result != expect { + t.Errorf("expect %s, but got %x\n", expect, result) + } + + var raddrs []CompactAddr + if err := bencode.DecodeString(expect, &raddrs); err != nil { + t.Error(err) + } else if len(raddrs) != len(addrs) { + t.Errorf("expect addrs length %d, but got %d\n", len(addrs), len(raddrs)) + } else { + for i, addr := range addrs { + if !addr.Equal(raddrs[i]) { + t.Errorf("%d: expect addr %v, but got %v\n", i, addr, raddrs[i]) + } + } + } +} diff --git a/metainfo/addr_host.go b/metainfo/addr_host.go new file mode 100644 index 0000000..7f5edd0 --- /dev/null +++ b/metainfo/addr_host.go @@ -0,0 +1,115 @@ +// Copyright 2020 xgfone +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metainfo + +import ( + "bytes" + "fmt" + "net" + "strconv" + + "github.com/xgfone/bt/bencode" +) + +// HostAddr represents an address based on host and port. +type HostAddr struct { + Host string + Port uint16 +} + +// NewHostAddr returns a new host Addr. +func NewHostAddr(host string, port uint16) HostAddr { + return HostAddr{Host: host, Port: port} +} + +// ParseHostAddr parses a string s to Addr. +func ParseHostAddr(s string) (HostAddr, error) { + host, port, err := net.SplitHostPort(s) + if err != nil { + return HostAddr{}, err + } + + _port, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return HostAddr{}, err + } + + return NewHostAddr(host, uint16(_port)), nil +} + +func (a HostAddr) String() string { + if a.Port == 0 { + return a.Host + } + return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10)) +} + +// Equal reports whether a is equal to o. +func (a HostAddr) Equal(o HostAddr) bool { + return a.Port == o.Port && a.Host == o.Host +} + +var ( + _ bencode.Marshaler = new(HostAddr) + _ bencode.Unmarshaler = new(HostAddr) +) + +// MarshalBencode implements the interface bencode.Marshaler. +func (a HostAddr) MarshalBencode() (b []byte, err error) { + buf := bytes.NewBuffer(nil) + buf.Grow(64) + err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port}) + if err == nil { + b = buf.Bytes() + } + return +} + +// UnmarshalBencode implements the interface bencode.Unmarshaler. +func (a *HostAddr) UnmarshalBencode(b []byte) (err error) { + var iface interface{} + if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil { + return + } + + switch v := iface.(type) { + case string: + *a, err = ParseHostAddr(v) + + case []interface{}: + err = a.decode(v) + + default: + err = fmt.Errorf("unsupported type: %T", iface) + } + + return +} + +func (a *HostAddr) decode(vs []interface{}) (err error) { + defer func() { + switch e := recover().(type) { + case nil: + case error: + err = e + default: + err = fmt.Errorf("%v", e) + } + }() + + a.Host = vs[0].(string) + a.Port = uint16(vs[1].(int64)) + return +} diff --git a/metainfo/address_test.go b/metainfo/addr_host_test.go similarity index 51% rename from metainfo/address_test.go rename to metainfo/addr_host_test.go index 929794e..7467df2 100644 --- a/metainfo/address_test.go +++ b/metainfo/addr_host_test.go @@ -16,33 +16,33 @@ package metainfo import ( "testing" + + "github.com/xgfone/bt/bencode" ) func TestAddress(t *testing.T) { - var addr1 Address - if err := addr1.FromString("1.2.3.4:1234"); err != nil { + addrs := []HostAddr{ + {Host: "1.2.3.4", Port: 123}, + {Host: "www.example.com", Port: 456}, + } + expect := `ll7:1.2.3.4i123eel15:www.example.comi456eee` + + if result, err := bencode.EncodeString(addrs); err != nil { t.Error(err) - return + } else if result != expect { + t.Errorf("expect %s, but got %s\n", expect, result) } - data, err := addr1.MarshalBencode() - if err != nil { + var raddrs []HostAddr + if err := bencode.DecodeString(expect, &raddrs); err != nil { t.Error(err) - return - } else if s := string(data); s != `l7:1.2.3.4i1234ee` { - t.Errorf(`expected 'l7:1.2.3.4i1234ee', but got '%s'`, s) - } - - var addr2 Address - if err = addr2.UnmarshalBencode(data); err != nil { - t.Error(err) - } else if addr2.String() != `1.2.3.4:1234` { - t.Errorf("expected '1.2.3.4:1234', but got '%s'", addr2) - } - - if data, err = addr2.MarshalBinary(); err != nil { - t.Error(err) - } else if s := string(data); s != "\x01\x02\x03\x04\x04\xd2" { - t.Errorf(`expected '\x01\x02\x03\x04\x04\xd2', but got '%#x'`, s) + } else if len(raddrs) != len(addrs) { + t.Errorf("expect addrs length %d, but got %d\n", len(addrs), len(raddrs)) + } else { + for i, addr := range addrs { + if !addr.Equal(raddrs[i]) { + t.Errorf("%d: expect %v, but got %v\n", i, addr, raddrs[i]) + } + } } } diff --git a/metainfo/address.go b/metainfo/address.go deleted file mode 100644 index d564c41..0000000 --- a/metainfo/address.go +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright 2020 xgfone -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metainfo - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "net" - "strconv" - - "github.com/xgfone/bt/bencode" -) - -// ErrInvalidAddr is returned when the compact address is invalid. -var ErrInvalidAddr = fmt.Errorf("invalid compact information of ip and port") - -// Address represents a client/server listening on a UDP port implementing -// the DHT protocol. -type Address struct { - IP net.IP // For IPv4, its length must be 4. - Port uint16 -} - -// NewAddress returns a new Address. -func NewAddress(ip net.IP, port uint16) Address { - if ipv4 := ip.To4(); len(ipv4) > 0 { - ip = ipv4 - } - return Address{IP: ip, Port: port} -} - -// NewAddressFromString returns a new Address by the address string. -func NewAddressFromString(s string) (addr Address, err error) { - err = addr.FromString(s) - return -} - -// NewAddressesFromString returns a list of Addresses by the address string. -func NewAddressesFromString(s string) (addrs []Address, err error) { - shost, sport, err := net.SplitHostPort(s) - if err != nil { - return nil, fmt.Errorf("invalid address '%s': %s", s, err) - } - - var port uint16 - if sport != "" { - v, err := strconv.ParseUint(sport, 10, 16) - if err != nil { - return nil, fmt.Errorf("invalid address '%s': %s", s, err) - } - port = uint16(v) - } - - ips, err := net.LookupIP(shost) - if err != nil { - return nil, fmt.Errorf("fail to lookup the domain '%s': %s", shost, err) - } - - addrs = make([]Address, len(ips)) - for i, ip := range ips { - if ipv4 := ip.To4(); len(ipv4) != 0 { - addrs[i] = Address{IP: ipv4, Port: port} - } else { - addrs[i] = Address{IP: ip, Port: port} - } - } - - return -} - -// FromString parses and sets the ip from the string addr. -func (a *Address) FromString(addr string) (err error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return fmt.Errorf("invalid address '%s': %s", addr, err) - } - - if port != "" { - v, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return fmt.Errorf("invalid address '%s': %s", addr, err) - } - a.Port = uint16(v) - } - - ips, err := net.LookupIP(host) - if err != nil { - return fmt.Errorf("fail to lookup the domain '%s': %s", host, err) - } else if len(ips) == 0 { - return fmt.Errorf("the domain '%s' has no ips", host) - } - - a.IP = ips[0] - if ip := a.IP.To4(); len(ip) > 0 { - a.IP = ip - } - - return -} - -// FromUDPAddr sets the ip from net.UDPAddr. -func (a *Address) FromUDPAddr(ua *net.UDPAddr) { - a.Port = uint16(ua.Port) - a.IP = ua.IP - if ipv4 := a.IP.To4(); len(ipv4) != 0 { - a.IP = ipv4 - } -} - -// UDPAddr creates a new net.UDPAddr. -func (a Address) UDPAddr() *net.UDPAddr { - return &net.UDPAddr{ - IP: a.IP, - Port: int(a.Port), - } -} - -func (a Address) String() string { - if a.Port == 0 { - return a.IP.String() - } - return net.JoinHostPort(a.IP.String(), strconv.FormatUint(uint64(a.Port), 10)) -} - -// Equal reports whether n is equal to o, which is equal to -// n.HasIPAndPort(o.IP, o.Port) -func (a Address) Equal(o Address) bool { - return a.Port == o.Port && a.IP.Equal(o.IP) -} - -// HasIPAndPort reports whether the current node has the ip and the port. -func (a Address) HasIPAndPort(ip net.IP, port uint16) bool { - return port == a.Port && a.IP.Equal(ip) -} - -// WriteBinary is the same as MarshalBinary, but writes the result into w -// instead of returning. -func (a Address) WriteBinary(w io.Writer) (m int, err error) { - if m, err = w.Write(a.IP); err == nil { - if err = binary.Write(w, binary.BigEndian, a.Port); err == nil { - m += 2 - } - } - return -} - -// UnmarshalBinary implements the interface binary.BinaryUnmarshaler. -func (a *Address) UnmarshalBinary(b []byte) (err error) { - _len := len(b) - 2 - switch _len { - case net.IPv4len, net.IPv6len: - default: - return ErrInvalidAddr - } - - a.IP = make(net.IP, _len) - copy(a.IP, b[:_len]) - a.Port = binary.BigEndian.Uint16(b[_len:]) - return -} - -// MarshalBinary implements the interface binary.BinaryMarshaler. -func (a Address) MarshalBinary() (data []byte, err error) { - buf := bytes.NewBuffer(nil) - buf.Grow(20) - if _, err = a.WriteBinary(buf); err == nil { - data = buf.Bytes() - } - return -} - -func (a *Address) decode(vs []interface{}) (err error) { - defer func() { - switch e := recover().(type) { - case nil: - case error: - err = e - default: - err = fmt.Errorf("%v", e) - } - }() - - host := vs[0].(string) - if a.IP = net.ParseIP(host); len(a.IP) == 0 { - return ErrInvalidAddr - } else if ip := a.IP.To4(); len(ip) != 0 { - a.IP = ip - } - - a.Port = uint16(vs[1].(int64)) - return -} - -// UnmarshalBencode implements the interface bencode.Unmarshaler. -func (a *Address) UnmarshalBencode(b []byte) (err error) { - var iface interface{} - if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil { - return - } - - switch v := iface.(type) { - case string: - err = a.FromString(v) - case []interface{}: - err = a.decode(v) - default: - err = fmt.Errorf("unsupported type: %T", iface) - } - - return -} - -// MarshalBencode implements the interface bencode.Marshaler. -func (a Address) MarshalBencode() (b []byte, err error) { - buf := bytes.NewBuffer(nil) - buf.Grow(32) - err = bencode.NewEncoder(buf).Encode([]interface{}{a.IP.String(), a.Port}) - if err == nil { - b = buf.Bytes() - } - return -} - -/// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// HostAddress is the same as the Address, but the host part may be -// either a domain or a ip. -type HostAddress struct { - Host string - Port uint16 -} - -// NewHostAddress returns a new host addrress. -func NewHostAddress(host string, port uint16) HostAddress { - return HostAddress{Host: host, Port: port} -} - -// NewHostAddressFromString returns a new host address by the string. -func NewHostAddressFromString(s string) (addr HostAddress, err error) { - err = addr.FromString(s) - return -} - -// FromString parses and sets the host from the string addr. -func (a *HostAddress) FromString(addr string) (err error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return fmt.Errorf("invalid address '%s': %s", addr, err) - } else if host == "" { - return fmt.Errorf("invalid address '%s': missing host", addr) - } - - if port != "" { - v, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return fmt.Errorf("invalid address '%s': %s", addr, err) - } - a.Port = uint16(v) - } - - a.Host = host - return -} - -func (a HostAddress) String() string { - if a.Port == 0 { - return a.Host - } - return net.JoinHostPort(a.Host, strconv.FormatUint(uint64(a.Port), 10)) -} - -// Addresses parses the host address to a list of Addresses. -func (a HostAddress) Addresses() (addrs []Address, err error) { - if ip := net.ParseIP(a.Host); len(ip) != 0 { - return []Address{NewAddress(ip, a.Port)}, nil - } - - ips, err := net.LookupIP(a.Host) - if err != nil { - err = fmt.Errorf("fail to lookup the domain '%s': %s", a.Host, err) - } else { - addrs = make([]Address, len(ips)) - for i, ip := range ips { - addrs[i] = NewAddress(ip, a.Port) - } - } - - return -} - -// Equal reports whether a is equal to o. -func (a HostAddress) Equal(o HostAddress) bool { - return a.Port == o.Port && a.Host == o.Host -} - -func (a *HostAddress) decode(vs []interface{}) (err error) { - defer func() { - switch e := recover().(type) { - case nil: - case error: - err = e - default: - err = fmt.Errorf("%v", e) - } - }() - - a.Host = vs[0].(string) - a.Port = uint16(vs[1].(int64)) - return -} - -// UnmarshalBencode implements the interface bencode.Unmarshaler. -func (a *HostAddress) UnmarshalBencode(b []byte) (err error) { - var iface interface{} - if err = bencode.NewDecoder(bytes.NewBuffer(b)).Decode(&iface); err != nil { - return - } - - switch v := iface.(type) { - case string: - err = a.FromString(v) - case []interface{}: - err = a.decode(v) - default: - err = fmt.Errorf("unsupported type: %T", iface) - } - - return -} - -// MarshalBencode implements the interface bencode.Marshaler. -func (a HostAddress) MarshalBencode() (b []byte, err error) { - buf := bytes.NewBuffer(nil) - buf.Grow(32) - err = bencode.NewEncoder(buf).Encode([]interface{}{a.Host, a.Port}) - if err == nil { - b = buf.Bytes() - } - return -} diff --git a/metainfo/file.go b/metainfo/file.go index be22aff..3ed50df 100644 --- a/metainfo/file.go +++ b/metainfo/file.go @@ -14,9 +14,7 @@ package metainfo -import ( - "path/filepath" -) +import "path/filepath" // File represents a file in the multi-file case. type File struct { diff --git a/metainfo/infohash.go b/metainfo/infohash.go index 5eab121..6977f81 100644 --- a/metainfo/infohash.go +++ b/metainfo/infohash.go @@ -145,10 +145,12 @@ func (h *Hash) FromString(s string) (err error) { copy(h[:], bs) } default: - err = fmt.Errorf("hash string has bad length: %d", len(s)) + hasher := sha1.New() + hasher.Write([]byte(s)) + copy(h[:], hasher.Sum(nil)) } - return nil + return } // FromHexString resets the info hash from the hex string. diff --git a/metainfo/magnet.go b/metainfo/magnet.go index 6fe1810..8cf802e 100644 --- a/metainfo/magnet.go +++ b/metainfo/magnet.go @@ -25,16 +25,15 @@ const xtPrefix = "urn:btih:" // Peers returns the list of the addresses of the peers. // // See BEP 9 -func (m Magnet) Peers() (peers []HostAddress, err error) { +func (m Magnet) Peers() (peers []HostAddr, err error) { vs := m.Params["x.pe"] - peers = make([]HostAddress, 0, len(vs)) + peers = make([]HostAddr, 0, len(vs)) for _, v := range vs { if v != "" { - var addr HostAddress - if err = addr.FromString(v); err != nil { - return + addr, err := ParseHostAddr(v) + if err != nil { + return nil, err } - peers = append(peers, addr) } } diff --git a/metainfo/metainfo.go b/metainfo/metainfo.go index 67c8614..d4a4d17 100644 --- a/metainfo/metainfo.go +++ b/metainfo/metainfo.go @@ -21,7 +21,7 @@ import ( "strings" "github.com/xgfone/bt/bencode" - "github.com/xgfone/bt/utils" + "github.com/xgfone/bt/internal/helper" ) // Bytes is the []byte type. @@ -35,7 +35,7 @@ func (al AnnounceList) Unique() (announces []string) { announces = make([]string, 0, len(al)) for _, tier := range al { for _, v := range tier { - if v != "" && !utils.InStringSlice(announces, v) { + if v != "" && !helper.ContainsString(announces, v) { announces = append(announces, v) } } @@ -93,11 +93,11 @@ func (us *URLList) UnmarshalBencode(b []byte) (err error) { // MetaInfo represents the .torrent file. type MetaInfo struct { - InfoBytes Bytes `bencode:"info"` // BEP 3 - Announce string `bencode:"announce,omitempty"` // BEP 3 - AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12 - Nodes []HostAddress `bencode:"nodes,omitempty"` // BEP 5 - URLList URLList `bencode:"url-list,omitempty"` // BEP 19 + InfoBytes Bytes `bencode:"info"` // BEP 3 + Announce string `bencode:"announce,omitempty"` // BEP 3 + AnnounceList AnnounceList `bencode:"announce-list,omitempty"` // BEP 12 + Nodes []HostAddr `bencode:"nodes,omitempty"` // BEP 5 + URLList URLList `bencode:"url-list,omitempty"` // BEP 19 // Where's this specified? // Mentioned at https://wiki.theory.org/index.php/BitTorrentSpecification. diff --git a/metainfo/piece.go b/metainfo/piece.go index 3f1520e..a2519bd 100644 --- a/metainfo/piece.go +++ b/metainfo/piece.go @@ -21,7 +21,7 @@ import ( "io" "sort" - "github.com/xgfone/bt/utils" + "github.com/xgfone/bt/internal/helper" ) // Predefine some sizes of the pieces. @@ -69,7 +69,7 @@ func GeneratePieces(r io.Reader, pieceLength int64) (hs Hashes, err error) { buf := make([]byte, pieceLength) for { h := sha1.New() - written, err := utils.CopyNBuffer(h, r, pieceLength, buf) + written, err := helper.CopyNBuffer(h, r, pieceLength, buf) if written > 0 { hs = append(hs, NewHash(h.Sum(nil))) } @@ -92,7 +92,7 @@ func writeFiles(w io.Writer, files []File, open func(File) (io.ReadCloser, error return fmt.Errorf("error opening %s: %s", file, err) } - n, err := utils.CopyNBuffer(w, r, file.Length, buf) + n, err := helper.CopyNBuffer(w, r, file.Length, buf) r.Close() if n != file.Length { diff --git a/peerprotocol/extension.go b/peerprotocol/extension.go index 7ab53ff..d70c0b6 100644 --- a/peerprotocol/extension.go +++ b/peerprotocol/extension.go @@ -86,7 +86,7 @@ type ExtendedHandshakeMsg struct { // M is the type of map[ExtendedMessageName]ExtendedMessageID. M map[string]uint8 `bencode:"m"` // BEP 10 V string `bencode:"v,omitempty"` // BEP 10 - Reqq int `bencode:"reqq,omitempty"` // BEP 10 + Reqq uint64 `bencode:"reqq,omitempty"` // BEP 10. The default in in libtorrent is 250. // Port is the local client port, which is redundant and no need // for the receiving side of the connection to send this. @@ -95,7 +95,7 @@ type ExtendedHandshakeMsg struct { IPv4 CompactIP `bencode:"ipv4,omitempty"` // BEP 10 YourIP CompactIP `bencode:"yourip,omitempty"` // BEP 10 - MetadataSize int `bencode:"metadata_size,omitempty"` // BEP 9 + MetadataSize uint64 `bencode:"metadata_size,omitempty"` // BEP 9 } // Decode decodes the extended handshake message from b. @@ -118,7 +118,7 @@ type UtMetadataExtendedMsg struct { Piece int `bencode:"piece"` // BEP 9 // They are only used by "data" type - TotalSize int `bencode:"total_size,omitempty"` // BEP 9 + TotalSize uint64 `bencode:"total_size,omitempty"` // BEP 9 Data []byte `bencode:"-"` } @@ -139,10 +139,9 @@ func (um UtMetadataExtendedMsg) EncodeToPayload(buf *bytes.Buffer) (err error) { // EncodeToBytes is equal to // -// buf := new(bytes.Buffer) -// err = um.EncodeToPayload(buf) -// return buf.Bytes(), err -// +// buf := new(bytes.Buffer) +// err = um.EncodeToPayload(buf) +// return buf.Bytes(), err func (um UtMetadataExtendedMsg) EncodeToBytes() (b []byte, err error) { buf := bytes.NewBuffer(make([]byte, 0, 128)) if err = um.EncodeToPayload(buf); err == nil { diff --git a/peerprotocol/handshake.go b/peerprotocol/handshake.go index 63f4e2c..0a515cf 100644 --- a/peerprotocol/handshake.go +++ b/peerprotocol/handshake.go @@ -44,12 +44,12 @@ func (eb ExtensionBits) String() string { } // Set sets the bit to 1, that's, to set it to be on. -func (eb ExtensionBits) Set(bit uint) { +func (eb *ExtensionBits) Set(bit uint) { eb[7-bit/8] |= 1 << (bit % 8) } // Unset sets the bit to 0, that's, to set it to be off. -func (eb ExtensionBits) Unset(bit uint) { +func (eb *ExtensionBits) Unset(bit uint) { eb[7-bit/8] &^= 1 << (bit % 8) } diff --git a/peerprotocol/message.go b/peerprotocol/message.go index 1fba09c..55a4f38 100644 --- a/peerprotocol/message.go +++ b/peerprotocol/message.go @@ -335,7 +335,9 @@ func (m Message) Encode(buf *bytes.Buffer) (err error) { if !m.Keepalive { if err = buf.WriteByte(byte(m.Type)); err != nil { return - } else if err = m.marshalBinaryType(buf); err != nil { + } + + if err = m.marshalBinaryType(buf); err != nil { return } @@ -376,7 +378,7 @@ func (m Message) marshalBinaryType(buf *bytes.Buffer) (err error) { } _, err = buf.Write(m.Piece) case MTypeExtended: - if err = buf.WriteByte(byte(m.ExtendedID)); err != nil { + if err = buf.WriteByte(byte(m.ExtendedID)); err == nil { _, err = buf.Write(m.ExtendedPayload) } case MTypePort: diff --git a/peerprotocol/protocol.go b/peerprotocol/protocol.go index 5696820..8366e70 100644 --- a/peerprotocol/protocol.go +++ b/peerprotocol/protocol.go @@ -14,9 +14,7 @@ package peerprotocol -import ( - "fmt" -) +import "fmt" // ProtocolHeader is the BT protocal prefix. // diff --git a/peerprotocol/server.go b/peerprotocol/server.go index b624439..3eff272 100644 --- a/peerprotocol/server.go +++ b/peerprotocol/server.go @@ -65,9 +65,9 @@ type Config struct { HandleMessage func(pc *PeerConn, msg Message, handler Handler) error } -func (c *Config) set(conf ...Config) { - if len(conf) > 0 { - *c = conf[0] +func (c *Config) set(conf *Config) { + if conf != nil { + *c = *conf } if c.MaxLength == 0 { @@ -93,33 +93,35 @@ type Server struct { } // NewServerByListen returns a new Server by listening on the address. -func NewServerByListen(network, address string, id metainfo.Hash, h Handler, - c ...Config) (*Server, error) { +func NewServerByListen(network, address string, id metainfo.Hash, h Handler, c *Config) (*Server, error) { ln, err := net.Listen(network, address) if err != nil { return nil, err } - return NewServer(ln, id, h, c...), nil + return NewServer(ln, id, h, c), nil } // NewServer returns a new Server. -func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c ...Config) *Server { +func NewServer(ln net.Listener, id metainfo.Hash, h Handler, c *Config) *Server { if id.IsZero() { panic("the peer node id must not be empty") } var conf Config - conf.set(c...) + conf.set(c) return &Server{Listener: ln, ID: id, Handler: h, Config: conf} } // Run starts the peer protocol server. func (s *Server) Run() { - s.Config.set() + s.Config.set(nil) for { conn, err := s.Listener.Accept() if err != nil { - s.Config.ErrorLog("fail to accept new connection: %s", err) + if !strings.Contains(err.Error(), "closed") { + s.Config.ErrorLog("fail to accept new connection: %s", err) + } + return } go s.handleConn(conn) } diff --git a/tracker/get_peers.go b/tracker/get_peers.go index d6c1889..06d648d 100644 --- a/tracker/get_peers.go +++ b/tracker/get_peers.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ type GetPeersResult struct { // GetPeers gets the peers from the trackers. // // Notice: the returned chan will be closed when all the requests end. -func GetPeers(ctx context.Context, id, infohash metainfo.Hash, trackers []string) []GetPeersResult { +func GetPeers(ctx context.Context, nodeid, infohash metainfo.Hash, trackers []string) []GetPeersResult { if len(trackers) == 0 { return nil } @@ -65,7 +65,7 @@ func GetPeers(ctx context.Context, id, infohash metainfo.Hash, trackers []string for i := 0; i < wlen; i++ { go func() { for tracker := range reqs { - resp, err := getPeers(ctx, wg, tracker, id, infohash) + resp, err := getPeers(ctx, wg, tracker, nodeid, infohash) lock.Lock() results = append(results, GetPeersResult{ Tracker: tracker, @@ -87,7 +87,7 @@ func getPeers(ctx context.Context, wg *sync.WaitGroup, tracker string, nodeID, infoHash metainfo.Hash) (resp AnnounceResponse, err error) { defer wg.Done() - client, err := NewClient(tracker, ClientConfig{ID: nodeID}) + client, err := NewClient(tracker, nodeID, nil) if err == nil { resp, err = client.Announce(ctx, AnnounceRequest{InfoHash: infoHash}) } diff --git a/tracker/httptracker/http.go b/tracker/httptracker/http.go index 1f7d61e..1bf9cfa 100644 --- a/tracker/httptracker/http.go +++ b/tracker/httptracker/http.go @@ -101,7 +101,7 @@ type AnnounceRequest struct { // ToQuery converts the Request to URL Query. func (r AnnounceRequest) ToQuery() (vs url.Values) { - vs = make(url.Values, 9) + vs = make(url.Values, 10) vs.Set("info_hash", r.InfoHash.BytesString()) vs.Set("peer_id", r.PeerID.BytesString()) vs.Set("uploaded", strconv.FormatInt(r.Uploaded, 10)) @@ -278,11 +278,15 @@ type Client struct { // // scrapeURL may be empty, which will replace the "announce" in announceURL // with "scrape" to generate the scrapeURL. -func NewClient(announceURL, scrapeURL string) *Client { +func NewClient(id metainfo.Hash, announceURL, scrapeURL string) *Client { if scrapeURL == "" { scrapeURL = strings.Replace(announceURL, "announce", "scrape", -1) } - id := metainfo.NewRandomHash() + + if id.IsZero() { + id = metainfo.NewRandomHash() + } + return &Client{AnnounceURL: announceURL, ScrapeURL: scrapeURL, ID: id} } @@ -310,7 +314,7 @@ func (t *Client) send(c context.Context, u string, vs url.Values, r interface{}) resp, err = t.Client.Do(req) } - if resp.Body != nil { + if resp != nil { defer resp.Body.Close() } @@ -323,14 +327,11 @@ func (t *Client) send(c context.Context, u string, vs url.Values, r interface{}) // Announce sends a Announce request to the tracker. func (t *Client) Announce(c context.Context, req AnnounceRequest) (resp AnnounceResponse, err error) { - if req.PeerID.IsZero() { - if t.ID.IsZero() { - req.PeerID = metainfo.NewRandomHash() - } else { - req.PeerID = t.ID - } + if req.InfoHash.IsZero() { + panic("infohash is ZERO") } + req.PeerID = t.ID err = t.send(c, t.AnnounceURL, req.ToQuery(), &resp) return } diff --git a/tracker/httptracker/http_peer.go b/tracker/httptracker/http_peer.go index 2da9cd6..db8f32c 100644 --- a/tracker/httptracker/http_peer.go +++ b/tracker/httptracker/http_peer.go @@ -15,8 +15,6 @@ package httptracker import ( - "bytes" - "encoding/binary" "errors" "net" @@ -24,34 +22,19 @@ import ( "github.com/xgfone/bt/metainfo" ) -var errInvalidPeer = errors.New("invalid peer information format") +var errInvalidPeer = errors.New("invalid bt peer information format") // Peer is a tracker peer. type Peer struct { - // ID is the peer's self-selected ID. - ID string `bencode:"peer id"` // BEP 3 - - // IP is the IP address or dns name. - IP string `bencode:"ip"` // BEP 3 - Port uint16 `bencode:"port"` // BEP 3 + ID string `bencode:"peer id"` // BEP 3, the peer's self-selected ID. + IP string `bencode:"ip"` // BEP 3, an IP address or dns name. + Port uint16 `bencode:"port"` // BEP 3 } -// Addresses returns the list of the addresses that the peer listens on. -func (p Peer) Addresses() (addrs []metainfo.Address, err error) { - if ip := net.ParseIP(p.IP); len(ip) != 0 { - return []metainfo.Address{{IP: ip, Port: p.Port}}, nil - } - - ips, err := net.LookupIP(p.IP) - if _len := len(ips); err == nil && len(ips) != 0 { - addrs = make([]metainfo.Address, _len) - for i, ip := range ips { - addrs[i] = metainfo.Address{IP: ip, Port: p.Port} - } - } - - return -} +var ( + _ bencode.Marshaler = new(Peers) + _ bencode.Unmarshaler = new(Peers) +) // Peers is a set of the peers. type Peers []Peer @@ -65,21 +48,17 @@ func (ps *Peers) UnmarshalBencode(b []byte) (err error) { switch vs := v.(type) { case string: // BEP 23 - _len := len(vs) - if _len%6 != 0 { - return metainfo.ErrInvalidAddr + var addrs metainfo.CompactIPv4Addrs + if err = addrs.UnmarshalBinary([]byte(vs)); err != nil { + return err } - peers := make(Peers, 0, _len/6) - for i := 0; i < _len; i += 6 { - var addr metainfo.Address - if err = addr.UnmarshalBinary([]byte(vs[i : i+6])); err != nil { - return - } - peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port}) + peers := make(Peers, len(addrs)) + for i, addr := range addrs { + peers[i] = Peer{IP: addr.IP.String(), Port: addr.Port} } - *ps = peers + case []interface{}: // BEP 3 peers := make(Peers, len(vs)) for i, p := range vs { @@ -106,6 +85,7 @@ func (ps *Peers) UnmarshalBencode(b []byte) (err error) { peers[i] = Peer{ID: pid, IP: ip, Port: uint16(port)} } *ps = peers + default: return errInvalidPeer } @@ -114,38 +94,32 @@ func (ps *Peers) UnmarshalBencode(b []byte) (err error) { // MarshalBencode implements the interface bencode.Marshaler. func (ps Peers) MarshalBencode() (b []byte, err error) { - for _, p := range ps { - if p.ID == "" { - return ps.marshalCompactBencode() // BEP 23 - } + // BEP 23 + if b, err = ps.marshalCompactBencode(); err == nil { + return } // BEP 3 - buf := bytes.NewBuffer(make([]byte, 0, 64*len(ps))) - buf.WriteByte('l') - for _, p := range ps { - if err = bencode.NewEncoder(buf).Encode(p); err != nil { - return - } - } - buf.WriteByte('e') - b = buf.Bytes() - return + return bencode.EncodeBytes([]Peer(ps)) } func (ps Peers) marshalCompactBencode() (b []byte, err error) { - buf := bytes.NewBuffer(make([]byte, 0, 6*len(ps))) - for _, peer := range ps { - ip := net.ParseIP(peer.IP).To4() - if len(ip) == 0 { + addrs := make(metainfo.CompactIPv4Addrs, len(ps)) + for i, p := range ps { + ip := net.ParseIP(p.IP).To4() + if ip == nil { return nil, errInvalidPeer } - buf.Write(ip[:]) - binary.Write(buf, binary.BigEndian, peer.Port) + addrs[i] = metainfo.CompactAddr{IP: ip, Port: p.Port} } - return bencode.EncodeBytes(buf.Bytes()) + return addrs.MarshalBencode() } +var ( + _ bencode.Marshaler = new(Peers6) + _ bencode.Unmarshaler = new(Peers6) +) + // Peers6 is a set of the peers for IPv6 in the compact case. // // BEP 7 @@ -158,35 +132,29 @@ func (ps *Peers6) UnmarshalBencode(b []byte) (err error) { return } - _len := len(s) - if _len%18 != 0 { - return metainfo.ErrInvalidAddr + var addrs metainfo.CompactIPv6Addrs + if err = addrs.UnmarshalBinary([]byte(s)); err != nil { + return err } - peers := make(Peers6, 0, _len/18) - for i := 0; i < _len; i += 18 { - var addr metainfo.Address - if err = addr.UnmarshalBinary([]byte(s[i : i+18])); err != nil { - return - } - peers = append(peers, Peer{IP: addr.IP.String(), Port: addr.Port}) + peers := make(Peers6, len(addrs)) + for i, addr := range addrs { + peers[i] = Peer{IP: addr.IP.String(), Port: addr.Port} } - *ps = peers + return } // MarshalBencode implements the interface bencode.Marshaler. func (ps Peers6) MarshalBencode() (b []byte, err error) { - buf := bytes.NewBuffer(make([]byte, 0, 18*len(ps))) - for _, peer := range ps { - ip := net.ParseIP(peer.IP).To16() - if len(ip) == 0 { + addrs := make(metainfo.CompactIPv6Addrs, len(ps)) + for i, p := range ps { + ip := net.ParseIP(p.IP) + if ip == nil { return nil, errInvalidPeer } - - buf.Write(ip[:]) - binary.Write(buf, binary.BigEndian, peer.Port) + addrs[i] = metainfo.CompactAddr{IP: ip, Port: p.Port} } - return bencode.EncodeBytes(buf.Bytes()) + return addrs.MarshalBencode() } diff --git a/tracker/httptracker/http_peer_test.go b/tracker/httptracker/http_peer_test.go index fe91bd7..bff128c 100644 --- a/tracker/httptracker/http_peer_test.go +++ b/tracker/httptracker/http_peer_test.go @@ -21,8 +21,8 @@ import ( func TestPeers(t *testing.T) { peers := Peers{ - {ID: "123", IP: "1.1.1.1", Port: 80}, - {ID: "456", IP: "2.2.2.2", Port: 81}, + {IP: "1.1.1.1", Port: 80}, + {IP: "2.2.2.2", Port: 81}, } b, err := peers.MarshalBencode() @@ -71,5 +71,6 @@ func TestPeers6(t *testing.T) { t.Fatal(err) } else if !reflect.DeepEqual(ps, peers) { t.Errorf("%v != %v", ps, peers) + t.Error(string(b)) } } diff --git a/tracker/tracker.go b/tracker/tracker.go index 67b81b3..66c91d6 100644 --- a/tracker/tracker.go +++ b/tracker/tracker.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ type AnnounceRequest struct { IP net.IP // Optional Key int32 // Optional - NumWant int32 // Optional, BEP 15: -1 for default. But we use 0 as default. + NumWant int32 // Optional Port uint16 // Optional } @@ -89,7 +89,6 @@ func (ar AnnounceRequest) ToUDPAnnounceRequest() udptracker.AnnounceRequest { Left: ar.Left, Uploaded: ar.Uploaded, Event: ar.Event, - IP: ar.IP, Key: ar.Key, NumWant: ar.NumWant, Port: ar.Port, @@ -103,7 +102,7 @@ type AnnounceResponse struct { Interval uint32 Leechers uint32 Seeders uint32 - Addresses []metainfo.Address + Addresses []metainfo.HostAddr } // FromHTTPAnnounceResponse sets itself from r. @@ -111,14 +110,12 @@ func (ar *AnnounceResponse) FromHTTPAnnounceResponse(r httptracker.AnnounceRespo ar.Interval = r.Interval ar.Leechers = r.Incomplete ar.Seeders = r.Complete - ar.Addresses = make([]metainfo.Address, 0, len(r.Peers)+len(r.Peers6)) - for _, peer := range r.Peers { - addrs, _ := peer.Addresses() - ar.Addresses = append(ar.Addresses, addrs...) + ar.Addresses = make([]metainfo.HostAddr, 0, len(r.Peers)+len(r.Peers6)) + for _, p := range r.Peers { + ar.Addresses = append(ar.Addresses, metainfo.NewHostAddr(p.IP, p.Port)) } - for _, peer := range r.Peers6 { - addrs, _ := peer.Addresses() - ar.Addresses = append(ar.Addresses, addrs...) + for _, p := range r.Peers6 { + ar.Addresses = append(ar.Addresses, metainfo.NewHostAddr(p.IP, p.Port)) } } @@ -127,7 +124,11 @@ func (ar *AnnounceResponse) FromUDPAnnounceResponse(r udptracker.AnnounceRespons ar.Interval = r.Interval ar.Leechers = r.Leechers ar.Seeders = r.Seeders - ar.Addresses = r.Addresses + + ar.Addresses = make([]metainfo.HostAddr, len(r.Addresses)) + for i, a := range r.Addresses { + ar.Addresses[i] = metainfo.NewHostAddr(a.IP.String(), a.Port) + } } // ScrapeResponseResult is a commont Scrape response result. @@ -171,8 +172,7 @@ func (sr ScrapeResponse) FromHTTPScrapeResponse(r httptracker.ScrapeResponse) { } // FromUDPScrapeResponse sets itself from hs and r. -func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash, - r []udptracker.ScrapeResponse) { +func (sr ScrapeResponse) FromUDPScrapeResponse(hs []metainfo.Hash, r []udptracker.ScrapeResponse) { klen := len(hs) if _len := len(r); _len < klen { klen = _len @@ -195,48 +195,37 @@ type Client interface { Close() error } -// ClientConfig is used to configure the defalut client implementation. -type ClientConfig struct { - // The ID of the local client peer. - ID metainfo.Hash - - // The http client used only the tracker client is based on HTTP. - HTTPClient *http.Client -} - // NewClient returns a new Client. -func NewClient(connURL string, conf ...ClientConfig) (c Client, err error) { - var config ClientConfig - if len(conf) > 0 { - config = conf[0] - } - +// +// If id is ZERO, use a random hash instead. +// If client is nil, use http.DefaultClient instead for the http tracker. +func NewClient(connURL string, id metainfo.Hash, client *http.Client) (c Client, err error) { u, err := url.Parse(connURL) - if err == nil { - switch u.Scheme { - case "http", "https": - tracker := httptracker.NewClient(connURL, "") - if !config.ID.IsZero() { - tracker.ID = config.ID - } - c = &tclient{url: connURL, http: tracker} - - case "udp", "udp4", "udp6": - var utc *udptracker.Client - config := udptracker.ClientConfig{ID: config.ID} - utc, err = udptracker.NewClientByDial(u.Scheme, u.Host, config) - if err == nil { - var e []udptracker.Extension - if p := u.RequestURI(); p != "" { - e = []udptracker.Extension{udptracker.NewURLData([]byte(p))} - } - c = &tclient{url: connURL, exts: e, udp: utc} - } - default: - err = fmt.Errorf("unknown url scheme '%s'", u.Scheme) - } + if err != nil { + return } - return + + tclient := &tclient{url: connURL} + switch u.Scheme { + case "http", "https": + tclient.http = httptracker.NewClient(id, connURL, "") + tclient.http.Client = client + + case "udp", "udp4", "udp6": + tclient.udp, err = udptracker.NewClientByDial(u.Scheme, u.Host, id) + if err != nil { + return + } + + if p := u.RequestURI(); p != "" { + tclient.exts = []udptracker.Extension{udptracker.NewURLData([]byte(p))} + } + + default: + err = fmt.Errorf("unknown url scheme '%s'", u.Scheme) + } + + return tclient, nil } type tclient struct { diff --git a/tracker/tracker_test.go b/tracker/tracker_test.go index 13c0d3c..3c0d72e 100644 --- a/tracker/tracker_test.go +++ b/tracker/tracker_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,8 +29,7 @@ import ( type testHandler struct{} func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return } -func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) ( - r udptracker.AnnounceResponse, err error) { +func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest) (r udptracker.AnnounceResponse, err error) { if req.Port != 80 { err = errors.New("port is not 80") return @@ -51,12 +50,11 @@ func (testHandler) OnAnnounce(raddr *net.UDPAddr, req udptracker.AnnounceRequest Interval: 1, Leechers: 2, Seeders: 3, - Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8000}}, + Addresses: []metainfo.CompactAddr{{IP: net.ParseIP("127.0.0.1"), Port: 8000}}, } return } -func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ( - rs []udptracker.ScrapeResponse, err error) { +func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (rs []udptracker.ScrapeResponse, err error) { rs = make([]udptracker.ScrapeResponse, len(infohashes)) for i := range infohashes { rs[i] = udptracker.ScrapeResponse{ @@ -74,7 +72,7 @@ func ExampleClient() { if err != nil { log.Fatal(err) } - server := udptracker.NewServer(sconn, testHandler{}) + server := udptracker.NewServer(sconn, testHandler{}, 0) defer server.Close() go server.Run() @@ -82,14 +80,14 @@ func ExampleClient() { time.Sleep(time.Second) // Create a client and dial to the UDP tracker server. - client, err := NewClient("udp://127.0.0.1:8000/path?a=1&b=2") + client, err := NewClient("udp://127.0.0.1:8000/path?a=1&b=2", metainfo.Hash{}, nil) if err != nil { log.Fatal(err) } // Send the ANNOUNCE request to the UDP tracker server, // and get the ANNOUNCE response. - req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80} + req := AnnounceRequest{InfoHash: metainfo.NewRandomHash(), IP: net.ParseIP("127.0.0.1"), Port: 80} resp, err := client.Announce(context.Background(), req) if err != nil { log.Fatal(err) @@ -99,7 +97,7 @@ func ExampleClient() { fmt.Printf("Leechers: %d\n", resp.Leechers) fmt.Printf("Seeders: %d\n", resp.Seeders) for i, addr := range resp.Addresses { - fmt.Printf("Address[%d].IP: %s\n", i, addr.IP.String()) + fmt.Printf("Address[%d].IP: %s\n", i, addr.Host) fmt.Printf("Address[%d].Port: %d\n", i, addr.Port) } diff --git a/tracker/udptracker/udp.go b/tracker/udptracker/udp.go index 90fef57..a1f2ea6 100644 --- a/tracker/udptracker/udp.go +++ b/tracker/udptracker/udp.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ import ( "github.com/xgfone/bt/metainfo" ) +const maxBufSize = 2048 + // ProtocolID is magic constant for the udp tracker connection. // // BEP 15 @@ -55,16 +57,15 @@ type AnnounceRequest struct { Uploaded int64 Event uint32 - IP net.IP Key int32 - NumWant int32 // -1 for default + NumWant int32 // -1 for default and use -1 instead if 0 Port uint16 Exts []Extension // BEP 41 } // DecodeFrom decodes the request from b. -func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) { +func (r *AnnounceRequest) DecodeFrom(b []byte) { r.InfoHash = metainfo.NewHash(b[0:20]) r.PeerID = metainfo.NewHash(b[20:40]) r.Downloaded = int64(binary.BigEndian.Uint64(b[40:48])) @@ -72,21 +73,13 @@ func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) { r.Uploaded = int64(binary.BigEndian.Uint64(b[56:64])) r.Event = binary.BigEndian.Uint32(b[64:68]) - if ipv4 { - r.IP = make(net.IP, net.IPv4len) - copy(r.IP, b[68:72]) - b = b[72:] - } else { - r.IP = make(net.IP, net.IPv6len) - copy(r.IP, b[68:84]) - b = b[84:] - } + // ignore b[68:72] // 4 bytes - r.Key = int32(binary.BigEndian.Uint32(b[0:4])) - r.NumWant = int32(binary.BigEndian.Uint32(b[4:8])) - r.Port = binary.BigEndian.Uint16(b[8:10]) + r.Key = int32(binary.BigEndian.Uint32(b[72:76])) + r.NumWant = int32(binary.BigEndian.Uint32(b[76:80])) + r.Port = binary.BigEndian.Uint16(b[80:82]) - b = b[10:] + b = b[82:] for len(b) > 0 { var ext Extension parsed := ext.DecodeFrom(b) @@ -97,26 +90,25 @@ func (r *AnnounceRequest) DecodeFrom(b []byte, ipv4 bool) { // EncodeTo encodes the request to buf. func (r AnnounceRequest) EncodeTo(buf *bytes.Buffer) { - buf.Grow(82) - buf.Write(r.InfoHash[:]) - buf.Write(r.PeerID[:]) - - binary.Write(buf, binary.BigEndian, r.Downloaded) - binary.Write(buf, binary.BigEndian, r.Left) - binary.Write(buf, binary.BigEndian, r.Uploaded) - binary.Write(buf, binary.BigEndian, r.Event) - - if ip := r.IP.To4(); ip != nil { - buf.Write(ip[:]) - } else { - buf.Write(r.IP[:]) + if r.NumWant <= 0 { + r.NumWant = -1 } - binary.Write(buf, binary.BigEndian, r.Key) - binary.Write(buf, binary.BigEndian, r.NumWant) - binary.Write(buf, binary.BigEndian, r.Port) + buf.Grow(82) + buf.Write(r.InfoHash[:]) // 20: 16 - 36 + buf.Write(r.PeerID[:]) // 20: 36 - 56 - for _, ext := range r.Exts { + binary.Write(buf, binary.BigEndian, r.Downloaded) // 8: 56 - 64 + binary.Write(buf, binary.BigEndian, r.Left) // 8: 64 - 72 + binary.Write(buf, binary.BigEndian, r.Uploaded) // 8: 72 - 80 + binary.Write(buf, binary.BigEndian, r.Event) // 4: 80 - 84 + binary.Write(buf, binary.BigEndian, uint32(0)) // 4: 84 - 88 + + binary.Write(buf, binary.BigEndian, r.Key) // 4: 88 - 92 + binary.Write(buf, binary.BigEndian, r.NumWant) // 4: 92 - 96 + binary.Write(buf, binary.BigEndian, r.Port) // 2: 96 - 98 + + for _, ext := range r.Exts { // N: 98 - ext.EncodeTo(buf) } } @@ -128,32 +120,35 @@ type AnnounceResponse struct { Interval uint32 Leechers uint32 Seeders uint32 - Addresses []metainfo.Address + Addresses []metainfo.CompactAddr } // EncodeTo encodes the response to buf. -func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer) { +func (r AnnounceResponse) EncodeTo(buf *bytes.Buffer, ipv4 bool) { buf.Grow(12 + len(r.Addresses)*18) binary.Write(buf, binary.BigEndian, r.Interval) binary.Write(buf, binary.BigEndian, r.Leechers) binary.Write(buf, binary.BigEndian, r.Seeders) - for _, addr := range r.Addresses { - if ip := addr.IP.To4(); ip != nil { - buf.Write(ip[:]) + for i, addr := range r.Addresses { + if ipv4 { + addr.IP = addr.IP.To4() } else { - buf.Write(addr.IP[:]) + addr.IP = addr.IP.To16() } - binary.Write(buf, binary.BigEndian, addr.Port) + if len(addr.IP) == 0 { + panic(fmt.Errorf("invalid ip '%s'", r.Addresses[i].IP.String())) + } + addr.WriteBinary(buf) } } // DecodeFrom decodes the response from b. func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) { - r.Interval = binary.BigEndian.Uint32(b[:4]) - r.Leechers = binary.BigEndian.Uint32(b[4:8]) - r.Seeders = binary.BigEndian.Uint32(b[8:12]) + r.Interval = binary.BigEndian.Uint32(b[:4]) // 4: 8 - 12 + r.Leechers = binary.BigEndian.Uint32(b[4:8]) // 4: 12 - 16 + r.Seeders = binary.BigEndian.Uint32(b[8:12]) // 4: 16 - 20 - b = b[12:] + b = b[12:] // N*(6|18): 20 - iplen := net.IPv6len if ipv4 { iplen = net.IPv4len @@ -161,12 +156,13 @@ func (r *AnnounceResponse) DecodeFrom(b []byte, ipv4 bool) { _len := len(b) step := iplen + 2 - r.Addresses = make([]metainfo.Address, 0, _len/step) + r.Addresses = make([]metainfo.CompactAddr, 0, _len/step) for i := step; i <= _len; i += step { - ip := make(net.IP, iplen) - copy(ip, b[i-step:i-2]) - port := binary.BigEndian.Uint16(b[i-2 : i]) - r.Addresses = append(r.Addresses, metainfo.Address{IP: ip, Port: port}) + var addr metainfo.CompactAddr + if err := addr.UnmarshalBinary(b[i-step : i]); err != nil { + panic(err) + } + r.Addresses = append(r.Addresses, addr) } } diff --git a/tracker/udptracker/udp_client.go b/tracker/udptracker/udp_client.go index 2299007..c24d0f0 100644 --- a/tracker/udptracker/udp_client.go +++ b/tracker/udptracker/udp_client.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "io" "net" "strings" @@ -29,39 +30,27 @@ import ( ) // NewClientByDial returns a new Client by dialing. -func NewClientByDial(network, address string, c ...ClientConfig) (*Client, error) { +func NewClientByDial(network, address string, id metainfo.Hash) (*Client, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } - - return NewClient(conn.(*net.UDPConn), c...), nil + return NewClient(conn.(*net.UDPConn), id), nil } // NewClient returns a new Client. -func NewClient(conn *net.UDPConn, c ...ClientConfig) *Client { - var conf ClientConfig - conf.set(c...) +func NewClient(conn *net.UDPConn, id metainfo.Hash) *Client { ipv4 := strings.Contains(conn.LocalAddr().String(), ".") - return &Client{conn: conn, conf: conf, ipv4: ipv4} -} - -// ClientConfig is used to configure the Client. -type ClientConfig struct { - ID metainfo.Hash - MaxBufSize int // Default: 2048 -} - -func (c *ClientConfig) set(conf ...ClientConfig) { - if len(conf) > 0 { - *c = conf[0] + if id.IsZero() { + id = metainfo.NewRandomHash() } - if c.ID.IsZero() { - c.ID = metainfo.NewRandomHash() - } - if c.MaxBufSize <= 0 { - c.MaxBufSize = 2048 + return &Client{ + MaxBufSize: maxBufSize, + + conn: conn, + ipv4: ipv4, + id: id, } } @@ -72,12 +61,14 @@ func (c *ClientConfig) set(conf ...ClientConfig) { // // BEP 15 type Client struct { - ipv4 bool - conf ClientConfig + MaxBufSize int // Default: 2048 + + id metainfo.Hash conn *net.UDPConn last time.Time cid uint64 tid uint32 + ipv4 bool } // Close closes the UDP tracker client. @@ -140,21 +131,21 @@ func (utc *Client) connect(ctx context.Context) (err error) { } data = data[:n] - switch binary.BigEndian.Uint32(data[:4]) { + switch action := binary.BigEndian.Uint32(data[:4]); action { case ActionConnect: case ActionError: _, reason := utc.parseError(data[4:]) return errors.New(reason) default: - return errors.New("tracker response not connect action") + return fmt.Errorf("tracker response is not connect action: %d", action) } if n < 16 { return io.ErrShortBuffer } - if binary.BigEndian.Uint32(data[4:8]) != tid { - return errors.New("invalid transaction id") + if _tid := binary.BigEndian.Uint32(data[4:8]); _tid != tid { + return fmt.Errorf("expect transaction id %d, but got %d", tid, _tid) } utc.cid = binary.BigEndian.Uint64(data[8:16]) @@ -172,8 +163,7 @@ func (utc *Client) getConnectionID(ctx context.Context) (cid uint64, err error) return } -func (utc *Client) announce(ctx context.Context, req AnnounceRequest) ( - r AnnounceResponse, err error) { +func (utc *Client) announce(ctx context.Context, req AnnounceRequest) (r AnnounceResponse, err error) { cid, err := utc.getConnectionID(ctx) if err != nil { return @@ -181,16 +171,21 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) ( tid := utc.getTranID() buf := bytes.NewBuffer(make([]byte, 0, 110)) - binary.Write(buf, binary.BigEndian, cid) - binary.Write(buf, binary.BigEndian, ActionAnnounce) - binary.Write(buf, binary.BigEndian, tid) + binary.Write(buf, binary.BigEndian, cid) // 8: 0 - 8 + binary.Write(buf, binary.BigEndian, ActionAnnounce) // 4: 8 - 12 + binary.Write(buf, binary.BigEndian, tid) // 4: 12 - 16 req.EncodeTo(buf) b := buf.Bytes() if err = utc.send(b); err != nil { return } - data := make([]byte, utc.conf.MaxBufSize) + bufSize := utc.MaxBufSize + if bufSize <= 0 { + bufSize = maxBufSize + } + + data := make([]byte, bufSize) n, err := utc.readResp(ctx, data) if err != nil { return @@ -200,14 +195,14 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) ( } data = data[:n] - switch binary.BigEndian.Uint32(data[:4]) { + switch action := binary.BigEndian.Uint32(data[:4]); action { case ActionAnnounce: case ActionError: _, reason := utc.parseError(data[4:]) err = errors.New(reason) return default: - err = errors.New("tracker response not connect action") + err = fmt.Errorf("tracker response is not connect action: %d", action) return } @@ -216,8 +211,8 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) ( return } - if binary.BigEndian.Uint32(data[4:8]) != tid { - err = errors.New("invalid transaction id") + if _tid := binary.BigEndian.Uint32(data[4:8]); _tid != tid { + err = fmt.Errorf("expect transaction id %d, but got %d", tid, _tid) return } @@ -228,19 +223,20 @@ func (utc *Client) announce(ctx context.Context, req AnnounceRequest) ( // Announce sends a Announce request to the tracker. // // Notice: -// 1. if it does not connect to the UDP tracker server, it will connect to it, -// then send the ANNOUNCE request. -// 2. If returning an error, you should retry it. -// See http://www.bittorrent.org/beps/bep_0015.html#time-outs +// 1. if it does not connect to the UDP tracker server, it will connect to it, +// then send the ANNOUNCE request. +// 2. If returning an error, you should retry it. +// See http://www.bittorrent.org/beps/bep_0015.html#time-outs func (utc *Client) Announce(c context.Context, r AnnounceRequest) (AnnounceResponse, error) { - if r.PeerID.IsZero() { - r.PeerID = utc.conf.ID + if r.InfoHash.IsZero() { + panic("infohash is ZERO") } + + r.PeerID = utc.id return utc.announce(c, r) } -func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) ( - rs []ScrapeResponse, err error) { +func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) (rs []ScrapeResponse, err error) { cid, err := utc.getConnectionID(c) if err != nil { return @@ -248,17 +244,24 @@ func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) ( tid := utc.getTranID() buf := bytes.NewBuffer(make([]byte, 0, 16+len(ihs)*20)) - binary.Write(buf, binary.BigEndian, cid) - binary.Write(buf, binary.BigEndian, ActionScrape) - binary.Write(buf, binary.BigEndian, tid) - for _, h := range ihs { + binary.Write(buf, binary.BigEndian, cid) // 8: 0 - 8 + binary.Write(buf, binary.BigEndian, ActionScrape) // 4: 8 - 12 + binary.Write(buf, binary.BigEndian, tid) // 4: 12 - 16 + + for _, h := range ihs { // 20*N: 16 - buf.Write(h[:]) } + if err = utc.send(buf.Bytes()); err != nil { return } - data := make([]byte, utc.conf.MaxBufSize) + bufSize := utc.MaxBufSize + if bufSize <= 0 { + bufSize = maxBufSize + } + + data := make([]byte, bufSize) n, err := utc.readResp(c, data) if err != nil { return @@ -268,19 +271,19 @@ func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) ( } data = data[:n] - switch binary.BigEndian.Uint32(data[:4]) { + switch action := binary.BigEndian.Uint32(data[:4]); action { case ActionScrape: case ActionError: _, reason := utc.parseError(data[4:]) err = errors.New(reason) return default: - err = errors.New("tracker response not connect action") + err = fmt.Errorf("tracker response is not connect action %d", action) return } - if binary.BigEndian.Uint32(data[4:8]) != tid { - err = errors.New("invalid transaction id") + if _tid := binary.BigEndian.Uint32(data[4:8]); _tid != tid { + err = fmt.Errorf("expect transaction id %d, but got %d", tid, _tid) return } @@ -299,10 +302,10 @@ func (utc *Client) scrape(c context.Context, ihs []metainfo.Hash) ( // Scrape sends a Scrape request to the tracker. // // Notice: -// 1. if it does not connect to the UDP tracker server, it will connect to it, -// then send the ANNOUNCE request. -// 2. If returning an error, you should retry it. -// See http://www.bittorrent.org/beps/bep_0015.html#time-outs +// 1. if it does not connect to the UDP tracker server, it will connect to it, +// then send the ANNOUNCE request. +// 2. If returning an error, you should retry it. +// See http://www.bittorrent.org/beps/bep_0015.html#time-outs func (utc *Client) Scrape(c context.Context, hs []metainfo.Hash) ([]ScrapeResponse, error) { return utc.scrape(c, hs) } diff --git a/tracker/udptracker/udp_server.go b/tracker/udptracker/udp_server.go index 548a2f2..17ba815 100644 --- a/tracker/udptracker/udp_server.go +++ b/tracker/udptracker/udp_server.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,25 +45,14 @@ type wrappedPeerAddr struct { Time time.Time } -// ServerConfig is used to configure the Server. -type ServerConfig struct { - MaxBufSize int // Default: 2048 - ErrorLog func(format string, args ...interface{}) // Default: log.Printf -} - -func (c *ServerConfig) setDefault() { - if c.MaxBufSize <= 0 { - c.MaxBufSize = 2048 - } - if c.ErrorLog == nil { - c.ErrorLog = log.Printf - } -} +type buffer struct{ Data []byte } // Server is a tracker server based on UDP. type Server struct { + // Default: log.Printf + ErrorLog func(format string, args ...interface{}) + conn net.PacketConn - conf ServerConfig handler ServerHandler bufpool sync.Pool @@ -74,23 +63,20 @@ type Server struct { } // NewServer returns a new Server. -func NewServer(c net.PacketConn, h ServerHandler, sc ...ServerConfig) *Server { - var conf ServerConfig - if len(sc) > 0 { - conf = sc[0] +func NewServer(c net.PacketConn, h ServerHandler, bufSize int) *Server { + if bufSize <= 0 { + bufSize = maxBufSize } - conf.setDefault() - s := &Server{ - conf: conf, + return &Server{ conn: c, handler: h, exit: make(chan struct{}), conns: make(map[uint64]wrappedPeerAddr, 128), + bufpool: sync.Pool{New: func() interface{} { + return &buffer{Data: make([]byte, bufSize)} + }}, } - s.bufpool.New = func() interface{} { return make([]byte, conf.MaxBufSize) } - - return s } // Close closes the tracker server. @@ -126,11 +112,11 @@ func (uts *Server) cleanConnectionID(interval time.Duration) { func (uts *Server) Run() { go uts.cleanConnectionID(time.Minute * 2) for { - buf := uts.bufpool.Get().([]byte) - n, raddr, err := uts.conn.ReadFrom(buf) + buf := uts.bufpool.Get().(*buffer) + n, raddr, err := uts.conn.ReadFrom(buf.Data) if err != nil { if !strings.Contains(err.Error(), "closed") { - uts.conf.ErrorLog("failed to read udp tracker request: %s", err) + uts.errorf("failed to read udp tracker request: %s", err) } return } else if n < 16 { @@ -140,18 +126,26 @@ func (uts *Server) Run() { } } -func (uts *Server) handleRequest(raddr *net.UDPAddr, buf []byte, n int) { +func (uts *Server) errorf(format string, args ...interface{}) { + if uts.ErrorLog == nil { + log.Printf(format, args...) + } else { + uts.ErrorLog(format, args...) + } +} + +func (uts *Server) handleRequest(raddr *net.UDPAddr, buf *buffer, n int) { defer uts.bufpool.Put(buf) - uts.handlePacket(raddr, buf[:n]) + uts.handlePacket(raddr, buf.Data[:n]) } func (uts *Server) send(raddr *net.UDPAddr, b []byte) { n, err := uts.conn.WriteTo(b, raddr) if err != nil { - uts.conf.ErrorLog("fail to send the udp tracker response to '%s': %s", + uts.errorf("fail to send the udp tracker response to '%s': %s", raddr.String(), err) } else if n < len(b) { - uts.conf.ErrorLog("too short udp tracker response sent to '%s'", raddr.String()) + uts.errorf("too short udp tracker response sent to '%s'", raddr.String()) } } @@ -190,16 +184,14 @@ func (uts *Server) sendConnResp(raddr *net.UDPAddr, tid uint32, cid uint64) { uts.send(raddr, buf.Bytes()) } -func (uts *Server) sendAnnounceResp(raddr *net.UDPAddr, tid uint32, - resp AnnounceResponse) { +func (uts *Server) sendAnnounceResp(raddr *net.UDPAddr, tid uint32, resp AnnounceResponse) { buf := bytes.NewBuffer(make([]byte, 0, 8+12+len(resp.Addresses)*18)) encodeResponseHeader(buf, ActionAnnounce, tid) - resp.EncodeTo(buf) + resp.EncodeTo(buf, raddr.IP.To4() != nil) uts.send(raddr, buf.Bytes()) } -func (uts *Server) sendScrapResp(raddr *net.UDPAddr, tid uint32, - rs []ScrapeResponse) { +func (uts *Server) sendScrapResp(raddr *net.UDPAddr, tid uint32, rs []ScrapeResponse) { buf := bytes.NewBuffer(make([]byte, 0, 8+len(rs)*12)) encodeResponseHeader(buf, ActionScrape, tid) for _, r := range rs { @@ -209,9 +201,9 @@ func (uts *Server) sendScrapResp(raddr *net.UDPAddr, tid uint32, } func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) { - cid := binary.BigEndian.Uint64(b[:8]) - action := binary.BigEndian.Uint32(b[8:12]) - tid := binary.BigEndian.Uint32(b[12:16]) + cid := binary.BigEndian.Uint64(b[:8]) // 8: 0 - 8 + action := binary.BigEndian.Uint32(b[8:12]) // 4: 8 - 12 + tid := binary.BigEndian.Uint32(b[12:16]) // 4: 12 - 16 b = b[16:] // Handle the connection request. @@ -241,13 +233,13 @@ func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) { uts.sendError(raddr, tid, "invalid announce request") return } - req.DecodeFrom(b, true) + req.DecodeFrom(b) } else { // For ipv6 if len(b) < 94 { uts.sendError(raddr, tid, "invalid announce request") return } - req.DecodeFrom(b, false) + req.DecodeFrom(b) } resp, err := uts.handler.OnAnnounce(raddr, req) @@ -256,6 +248,7 @@ func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) { } else { uts.sendAnnounceResp(raddr, tid, resp) } + case ActionScrape: _len := len(b) infohashes := make([]metainfo.Hash, 0, _len/20) @@ -274,6 +267,7 @@ func (uts *Server) handlePacket(raddr *net.UDPAddr, b []byte) { } else { uts.sendScrapResp(raddr, tid, resps) } + default: uts.sendError(raddr, tid, "unkwnown action") } diff --git a/tracker/udptracker/udp_test.go b/tracker/udptracker/udp_test.go index d202875..0dbec74 100644 --- a/tracker/udptracker/udp_test.go +++ b/tracker/udptracker/udp_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 xgfone +// Copyright 2020~2023 xgfone // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,8 +28,7 @@ import ( type testHandler struct{} func (testHandler) OnConnect(raddr *net.UDPAddr) (err error) { return } -func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) ( - r AnnounceResponse, err error) { +func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) (r AnnounceResponse, err error) { if req.Port != 80 { err = errors.New("port is not 80") return @@ -50,12 +49,11 @@ func (testHandler) OnAnnounce(raddr *net.UDPAddr, req AnnounceRequest) ( Interval: 1, Leechers: 2, Seeders: 3, - Addresses: []metainfo.Address{{IP: net.ParseIP("127.0.0.1"), Port: 8001}}, + Addresses: []metainfo.CompactAddr{{IP: net.ParseIP("127.0.0.1"), Port: 8001}}, } return } -func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) ( - rs []ScrapeResponse, err error) { +func (testHandler) OnScrap(raddr *net.UDPAddr, infohashes []metainfo.Hash) (rs []ScrapeResponse, err error) { rs = make([]ScrapeResponse, len(infohashes)) for i := range infohashes { rs[i] = ScrapeResponse{ @@ -73,7 +71,7 @@ func ExampleClient() { if err != nil { log.Fatal(err) } - server := NewServer(sconn, testHandler{}) + server := NewServer(sconn, testHandler{}, 0) defer server.Close() go server.Run() @@ -81,7 +79,7 @@ func ExampleClient() { time.Sleep(time.Second) // Create a client and dial to the UDP tracker server. - client, err := NewClientByDial("udp4", "127.0.0.1:8001") + client, err := NewClientByDial("udp4", "127.0.0.1:8001", metainfo.Hash{}) if err != nil { log.Fatal(err) } @@ -89,7 +87,7 @@ func ExampleClient() { // Send the ANNOUNCE request to the UDP tracker server, // and get the ANNOUNCE response. exts := []Extension{NewURLData([]byte("data")), NewNop()} - req := AnnounceRequest{IP: net.ParseIP("127.0.0.1"), Port: 80, Exts: exts} + req := AnnounceRequest{InfoHash: metainfo.NewRandomHash(), Port: 80, Exts: exts} resp, err := client.Announce(context.Background(), req) if err != nil { log.Fatal(err) diff --git a/utils/bool.go b/utils/bool.go deleted file mode 100644 index 2bbcdd2..0000000 --- a/utils/bool.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2020 xgfone -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "sync/atomic" -) - -// Bool is used to implement a atomic bool. -type Bool struct { - v uint32 -} - -// NewBool returns a new Bool with the initialized value. -func NewBool(t bool) (b Bool) { - if t { - b.v = 1 - } - return -} - -// Get returns the bool value. -func (b *Bool) Get() bool { return atomic.LoadUint32(&b.v) == 1 } - -// SetTrue sets the bool value to true. -func (b *Bool) SetTrue() { atomic.StoreUint32(&b.v, 1) } - -// SetFalse sets the bool value to false. -func (b *Bool) SetFalse() { atomic.StoreUint32(&b.v, 0) } - -// Set sets the bool value to t. -func (b *Bool) Set(t bool) { - if t { - b.SetTrue() - } else { - b.SetFalse() - } -}