update peer protocol

This commit is contained in:
xgfone
2020-06-19 20:18:58 +08:00
parent 32cc9a6738
commit ada4692b65
4 changed files with 149 additions and 32 deletions

View File

@ -20,24 +20,57 @@ import (
"fmt"
"io"
"io/ioutil"
"sort"
)
var errMessageTooLong = fmt.Errorf("the peer message is too long")
// Pieces is used to represent the list of the indexes of the pieces.
type Pieces []uint32
// Sort sorts itself.
func (ps Pieces) Sort() { sort.Sort(ps) }
func (ps Pieces) Len() int { return len(ps) }
func (ps Pieces) Less(i, j int) bool { return ps[i] < ps[j] }
func (ps Pieces) Swap(i, j int) { ps[i], ps[j] = ps[j], ps[i] }
// Append appends the piece index, sorts and returns the new index list.
func (ps Pieces) Append(index uint32) Pieces {
for _, p := range ps {
if p == index {
return ps
}
}
pieces := append(ps, index)
sort.Sort(pieces)
return pieces
}
// Remove removes the given piece index from the list and returns new list.
func (ps Pieces) Remove(index uint32) Pieces {
for i, p := range ps {
if p == index {
copy(ps[i:], ps[i+1:])
return ps[:len(ps)-1]
}
}
return ps
}
// BitField represents the bit field of the pieces.
type BitField []uint8
// NewBitField returns a new BitField to hold the pieceNum pieces.
func NewBitField(pieceNum int) BitField {
return make(BitField, (pieceNum+7)/8)
}
// NewBitFieldTrue is the same as NewBitField, but set all bits to 1.
func NewBitFieldTrue(pieceNum int) BitField {
//
// If set is set to true, it will set all the bit fields to 1.
func NewBitField(pieceNum int, set ...bool) BitField {
_len := (pieceNum + 7) / 8
bf := make(BitField, _len)
for i := 0; i < _len; i++ {
bf[i] = 0xff
if len(set) > 0 && set[0] {
for i := 0; i < _len; i++ {
bf[i] = 0xff
}
}
return bf
}
@ -68,19 +101,50 @@ func (bf BitField) Bools() []bool {
return bs
}
// Sets returns the indexes of all the pieces that are set to 1.
func (bf BitField) Sets() (pieces Pieces) {
_len := len(bf) * 8
for i := 0; i < _len; i++ {
index := uint32(i)
if bf.IsSet(index) {
pieces = append(pieces, index)
}
}
return
}
// Unsets returns the indexes of all the pieces that are set to 0.
func (bf BitField) Unsets() (pieces Pieces) {
_len := len(bf) * 8
for i := 0; i < _len; i++ {
index := uint32(i)
if !bf.IsSet(index) {
pieces = append(pieces, index)
}
}
return
}
// Set sets the bit of the piece to 1 by its index.
func (bf BitField) Set(index uint32) {
bf[index/8] |= (1 << byte(7-index%8))
if i := int(index) / 8; i < len(bf) {
bf[i] |= (1 << byte(7-index%8))
}
}
// Unset sets the bit of the piece to 0 by its index.
func (bf BitField) Unset(index uint32) {
bf[index/8] &^= (1 << byte(7-index%8))
if i := int(index) / 8; i < len(bf) {
bf[i] &^= (1 << byte(7-index%8))
}
}
// IsSet reports whether the bit of the piece is set to 1.
func (bf BitField) IsSet(index uint32) bool {
return bf[index/8]&(1<<byte(7-index%8)) != 0
func (bf BitField) IsSet(index uint32) (set bool) {
if i := int(index) / 8; i < len(bf) {
set = bf[i]&(1<<byte(7-index%8)) != 0
}
return
}
// Message is the message used by the peer protocol, which contains
@ -176,7 +240,7 @@ func (m *Message) Decode(r io.Reader, maxLength uint32) (err error) {
var length uint32
if err = binary.Read(r, binary.BigEndian, &length); err != nil {
if err != io.EOF {
err = fmt.Errorf("error reading peer message message length: %s", err)
err = fmt.Errorf("reading length error: %s", err)
}
return
}

View File

@ -14,7 +14,9 @@
package peerprotocol
import "testing"
import (
"testing"
)
func TestBitField(t *testing.T) {
bf := NewBitFieldFromBools([]bool{
@ -51,9 +53,32 @@ func TestBitField(t *testing.T) {
t.Error(10)
}
bf = NewBitFieldTrue(16)
bf = NewBitField(16, true)
if !bf.IsSet(0) || !bf.IsSet(1) || !bf.IsSet(2) || !bf.IsSet(3) ||
!bf.IsSet(4) || !bf.IsSet(5) || !bf.IsSet(6) || !bf.IsSet(7) {
t.Error(bf)
}
}
func TestPieces(t *testing.T) {
ps := Pieces{2, 3, 4, 5}
ps = ps.Append(1)
if len(ps) != 5 || ps[0] != 1 {
t.Fatal(ps)
}
ps = ps.Append(5)
if len(ps) != 5 {
t.Fatal(ps)
}
ps = ps.Remove(6)
if len(ps) != 5 {
t.Fatal(ps)
}
ps = ps.Remove(3)
if len(ps) != 4 || ps[0] != 1 || ps[1] != 2 || ps[2] != 4 || ps[3] != 5 {
t.Fatal(ps)
}
}

View File

@ -99,8 +99,8 @@ func (NoopBep6Handler) Reject(*PeerConn, uint32, uint32, uint32) error { return
// the noop interface methods.
type NoopBep10Handler struct{}
// OnHandShake implements the interface Bep10Handler#OnHandShake.
func (NoopBep10Handler) OnHandShake(*PeerConn, ExtendedHandshakeMsg) error { return nil }
// OnExtHandShake implements the interface Bep10Handler#OnExtHandShake.
func (NoopBep10Handler) OnExtHandShake(*PeerConn) error { return nil }
// OnPayload implements the interface Bep10Handler#OnPayload.
func (NoopBep10Handler) OnPayload(*PeerConn, uint8, []byte) error { return nil }

View File

@ -27,10 +27,13 @@ import (
// Predefine some errors about extension support.
var (
ErrChoked = fmt.Errorf("choked")
ErrNotFirstMsg = fmt.Errorf("not the first message")
ErrNotSupportDHT = fmt.Errorf("not support DHT extension")
ErrNotSupportFast = fmt.Errorf("not support Fast extension")
ErrNotSupportExtended = fmt.Errorf("not support Extended extension")
ErrSecondExtHandshake = fmt.Errorf("second extended handshake")
ErrNoExtHandshake = fmt.Errorf("no extended handshake")
)
// Bep3Handler is used to handle the BEP 3 type message if Handler has also
@ -72,7 +75,7 @@ type Bep6Handler interface {
//
// Notice: the server must enable the Extended extension bit.
type Bep10Handler interface {
OnHandShake(conn *PeerConn, exthmsg ExtendedHandshakeMsg) error
OnExtHandShake(conn *PeerConn) error
OnPayload(conn *PeerConn, extid uint8, payload []byte) error
}
@ -128,6 +131,14 @@ type PeerConn struct {
// The default is 0, which represents no limit.
MaxLength uint32
Fasts Pieces // The list of the indexes of the FAST pieces.
Suggests Pieces // The list of the indexes of the SUGGEST pieces.
BitField BitField // The bit field of the piece indexes.
// ExtendedHandshakeMsg is the handshake message of the extension
// from the remote peer.
ExtendedHandshakeMsg ExtendedHandshakeMsg
// Data is used to store the context data associated with the connection.
Data interface{}
@ -137,7 +148,8 @@ type PeerConn struct {
// Optional.
OnWriteMsg func(pc *PeerConn, m Message) error
notFirstMsg bool
notFirstMsg bool
extHandshake bool
}
// NewPeerConn returns a new PeerConn.
@ -481,10 +493,13 @@ func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
case MTypeBitField:
if pc.notFirstMsg {
err = ErrNotFirstMsg
} else if h, ok := handler.(Bep3Handler); ok {
err = h.BitField(pc, msg.BitField)
} else {
err = handler.OnMessage(pc, msg)
pc.BitField = msg.BitField
if h, ok := handler.(Bep3Handler); ok {
err = h.BitField(pc, msg.BitField)
} else {
err = handler.OnMessage(pc, msg)
}
}
case MTypeRequest:
if h, ok := handler.(Bep3Handler); ok {
@ -519,10 +534,13 @@ func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
case MTypeSuggest:
if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.Suggest(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
pc.Suggests = pc.Suggests.Append(msg.Index)
if h, ok := handler.(Bep6Handler); ok {
err = h.Suggest(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
}
}
case MTypeHaveAll:
if pc.notFirstMsg {
@ -555,10 +573,13 @@ func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
case MTypeAllowedFast:
if !pc.ExtBits.IsSupportFast() {
err = ErrNotSupportFast
} else if h, ok := handler.(Bep6Handler); ok {
err = h.AllowedFast(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
pc.Fasts = pc.Fasts.Append(msg.Index)
if h, ok := handler.(Bep6Handler); ok {
err = h.AllowedFast(pc, msg.Index)
} else {
err = handler.OnMessage(pc, msg)
}
}
// BEP 10 - Extension Protocol
@ -585,12 +606,19 @@ func (pc *PeerConn) HandleMessage(msg Message, handler Handler) (err error) {
func (pc *PeerConn) handleExtMsg(h Bep10Handler, m Message) (err error) {
if m.ExtendedID == ExtendedIDHandshake {
var ehmsg ExtendedHandshakeMsg
if err = bencode.DecodeBytes(m.ExtendedPayload, &ehmsg); err == nil {
err = h.OnHandShake(pc, ehmsg)
if pc.extHandshake {
return ErrSecondExtHandshake
}
} else {
pc.extHandshake = true
err = bencode.DecodeBytes(m.ExtendedPayload, &pc.ExtendedHandshakeMsg)
if err == nil {
err = h.OnExtHandShake(pc)
}
} else if pc.extHandshake {
err = h.OnPayload(pc, m.ExtendedID, m.ExtendedPayload)
} else {
err = ErrNoExtHandshake
}
return