diff --git a/common/config_test.go b/common/config_test.go new file mode 100644 index 00000000..cfc74b46 --- /dev/null +++ b/common/config_test.go @@ -0,0 +1,160 @@ +package common + +import ( + "testing" +) + +func TestSetSAMAddress_Cases(t *testing.T) { + tests := []struct { + name string + addr string + wantHost string + wantPort int + }{ + { + name: "empty address uses defaults", + addr: "", + wantHost: "127.0.0.1", + wantPort: 7656, + }, + { + name: "valid host:port", + addr: "192.168.1.1:7000", + wantHost: "192.168.1.1", + wantPort: 7000, + }, + { + name: "invalid port uses default", + addr: "localhost:99999", + wantHost: "localhost", + wantPort: 0, + }, + { + name: "just IP address", + addr: "192.168.1.1", + wantHost: "192.168.1.1", + wantPort: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &I2PConfig{} + cfg.SetSAMAddress(tt.addr) + + if cfg.SamHost != tt.wantHost { + t.Errorf("SetSAMAddress() host = %v, want %v", cfg.SamHost, tt.wantHost) + } + if cfg.SamPort != tt.wantPort { + t.Errorf("SetSAMAddress() port = %v, want %v", cfg.SamPort, tt.wantPort) + } + }) + } +} + +func TestID_Generation(t *testing.T) { + cfg := &I2PConfig{} + + // Test when TunName is empty + id1 := cfg.ID() + if len(cfg.TunName) != 12 { + t.Errorf("ID() generated name length = %v, want 12", len(cfg.TunName)) + } + + // Verify format + if id1[:3] != "ID=" { + t.Errorf("ID() format incorrect, got %v, want prefix 'ID='", id1) + } + + // Test with preset TunName + cfg.TunName = "testtunnel" + id2 := cfg.ID() + if id2 != "ID=testtunnel" { + t.Errorf("ID() = %v, want ID=testtunnel", id2) + } +} + +func TestSam_AddressFormatting(t *testing.T) { + tests := []struct { + name string + host string + port int + want string + }{ + { + name: "default values", + host: "", + port: 0, + want: "127.0.0.1:7656", + }, + { + name: "custom host and port", + host: "localhost", + port: 7000, + want: "localhost:7000", + }, + { + name: "only custom host", + host: "192.168.1.1", + port: 0, + want: "192.168.1.1:7656", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &I2PConfig{ + SamHost: tt.host, + SamPort: tt.port, + } + got := cfg.Sam() + if got != tt.want { + t.Errorf("Sam() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLeaseSetSettings_Formatting(t *testing.T) { + tests := []struct { + name string + cfg I2PConfig + wantKey string + wantPrivKey string + wantSignKey string + }{ + { + name: "empty settings", + cfg: I2PConfig{}, + wantKey: "", + wantPrivKey: "", + wantSignKey: "", + }, + { + name: "all settings populated", + cfg: I2PConfig{ + LeaseSetKey: "testkey", + LeaseSetPrivateKey: "privkey", + LeaseSetPrivateSigningKey: "signkey", + }, + wantKey: " i2cp.leaseSetKey=testkey ", + wantPrivKey: " i2cp.leaseSetPrivateKey=privkey ", + wantSignKey: " i2cp.leaseSetPrivateSigningKey=signkey ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key, priv, sign := tt.cfg.LeaseSetSettings() + if key != tt.wantKey { + t.Errorf("LeaseSetSettings() key = %v, want %v", key, tt.wantKey) + } + if priv != tt.wantPrivKey { + t.Errorf("LeaseSetSettings() private key = %v, want %v", priv, tt.wantPrivKey) + } + if sign != tt.wantSignKey { + t.Errorf("LeaseSetSettings() signing key = %v, want %v", sign, tt.wantSignKey) + } + }) + } +} diff --git a/common/emit-options_test.go b/common/emit-options_test.go new file mode 100644 index 00000000..34b9bf49 --- /dev/null +++ b/common/emit-options_test.go @@ -0,0 +1,65 @@ +package common + +import "testing" + +func TestSetInQuantity(t *testing.T) { + tests := []struct { + name string + input int + wantErr bool + }{ + {"valid min", 1, false}, + {"valid max", 16, false}, + {"valid middle", 8, false}, + {"invalid zero", 0, true}, + {"invalid negative", -1, true}, + {"invalid too large", 17, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + emit := &SAMEmit{I2PConfig: I2PConfig{}} + err := SetInQuantity(tt.input)(emit) + + if (err != nil) != tt.wantErr { + t.Errorf("SetInQuantity() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && emit.I2PConfig.InQuantity != tt.input { + t.Errorf("SetInQuantity() = %v, want %v", emit.I2PConfig.InQuantity, tt.input) + } + }) + } +} + +func TestSetOutQuantity(t *testing.T) { + tests := []struct { + name string + input int + wantErr bool + }{ + {"valid min", 1, false}, + {"valid max", 16, false}, + {"valid middle", 8, false}, + {"invalid zero", 0, true}, + {"invalid negative", -1, true}, + {"invalid too large", 17, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + emit := &SAMEmit{I2PConfig: I2PConfig{}} + err := SetOutQuantity(tt.input)(emit) + + if (err != nil) != tt.wantErr { + t.Errorf("SetOutQuantity() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && emit.I2PConfig.OutQuantity != tt.input { + t.Errorf("SetOutQuantity() = %v, want %v", emit.I2PConfig.OutQuantity, tt.input) + } + }) + } +} diff --git a/common/sam3.go b/common/sam3.go index 2b2c09f8..0454789e 100644 --- a/common/sam3.go +++ b/common/sam3.go @@ -6,37 +6,6 @@ import ( "strings" ) -func NewSAM(address string) (*SAM, error) { - logger := log.WithField("address", address) - logger.Debug("Creating new SAM instance") - - conn, err := connectToSAM(address) - if err != nil { - return nil, err - } - defer func() { - if err != nil { - conn.Close() - } - }() - - s := &SAM{ - Conn: conn, - } - - if err = sendHelloAndValidate(conn, s); err != nil { - return nil, err - } - - s.SAMEmit.I2PConfig.SetSAMAddress(address) - - if s.SAMResolver, err = NewSAMResolver(s); err != nil { - return nil, fmt.Errorf("failed to create SAM resolver: %w", err) - } - - return s, nil -} - func connectToSAM(address string) (net.Conn, error) { conn, err := net.Dial("tcp", address) if err != nil { diff --git a/common/util.go b/common/util.go index eab441d5..6cf98192 100644 --- a/common/util.go +++ b/common/util.go @@ -37,6 +37,40 @@ func SplitHostPort(hostport string) (string, string, error) { return host, port, nil } +func ExtractPairString(input, value string) string { + log.WithFields(logrus.Fields{"input": input, "value": value}).Debug("ExtractPairString called") + parts := strings.Split(input, " ") + for _, part := range parts { + log.WithField("part", part).Debug("Checking part") + if strings.HasPrefix(part, value) { + kv := strings.SplitN(part, "=", 2) + if len(kv) == 2 { + log.WithFields(logrus.Fields{"key": kv[0], "value": kv[1]}).Debug("Pair extracted") + return kv[1] + } + } + } + log.WithFields(logrus.Fields{"input": input, "value": value}).Debug("No pair found") + return "" +} + +func ExtractPairInt(input, value string) int { + rv, err := strconv.Atoi(ExtractPairString(input, value)) + if err != nil { + log.WithFields(logrus.Fields{"input": input, "value": value}).Debug("No pair found") + return 0 + } + log.WithField("result", rv).Debug("Pair extracted and converted to int") + return rv +} + +func ExtractDest(input string) string { + log.WithField("input", input).Debug("ExtractDest called") + dest := strings.Split(input, " ")[0] + log.WithField("dest", dest).Debug("Destination extracted") + return strings.Split(input, " ")[0] +} + var randSource = rand.NewSource(time.Now().UnixNano()) var randGen = rand.New(randSource) diff --git a/common/util_test.go b/common/util_test.go new file mode 100644 index 00000000..24460080 --- /dev/null +++ b/common/util_test.go @@ -0,0 +1,182 @@ +package common + +import ( + "testing" +) + +func TestExtractDest_Cases(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "simple destination", + input: "dest1234 STYLE=3", + want: "dest1234", + }, + { + name: "empty input", + input: "", + want: "", + }, + { + name: "single word", + input: "destination", + want: "destination", + }, + { + name: "multiple spaces", + input: "dest123 STYLE=3 KEY=value", + want: "dest123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractDest(tt.input) + if got != tt.want { + t.Errorf("ExtractDest(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestExtractPairString_Cases(t *testing.T) { + tests := []struct { + name string + input string + key string + want string + }{ + { + name: "simple key-value", + input: "KEY=value", + key: "KEY", + want: "value", + }, + { + name: "no matching key", + input: "OTHER=value", + key: "KEY", + want: "", + }, + { + name: "multiple pairs", + input: "FIRST=1 KEY=value LAST=3", + key: "KEY", + want: "value", + }, + { + name: "empty input", + input: "", + key: "KEY", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractPairString(tt.input, tt.key) + if got != tt.want { + t.Errorf("ExtractPairString(%q, %q) = %q, want %q", + tt.input, tt.key, got, tt.want) + } + }) + } +} + +func TestExtractPairInt_Cases(t *testing.T) { + tests := []struct { + name string + input string + key string + want int + }{ + { + name: "valid integer", + input: "NUM=123", + key: "NUM", + want: 123, + }, + { + name: "invalid integer", + input: "NUM=abc", + key: "NUM", + want: 0, + }, + { + name: "no matching key", + input: "OTHER=123", + key: "NUM", + want: 0, + }, + { + name: "empty input", + input: "", + key: "NUM", + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractPairInt(tt.input, tt.key) + if got != tt.want { + t.Errorf("ExtractPairInt(%q, %q) = %d, want %d", + tt.input, tt.key, got, tt.want) + } + }) + } +} + +func TestSplitHostPort_Cases(t *testing.T) { + tests := []struct { + name string + input string + wantHost string + wantPort string + wantErr bool + }{ + { + name: "valid hostport", + input: "localhost:1234", + wantHost: "localhost", + wantPort: "1234", + wantErr: false, + }, + { + name: "missing port", + input: "localhost", + wantHost: "localhost", + wantPort: "0", + wantErr: false, + }, + { + name: "empty input", + input: "", + wantHost: "", + wantPort: "0", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host, port, err := SplitHostPort(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("SplitHostPort(%q) error = %v, wantErr %v", + tt.input, err, tt.wantErr) + return + } + if host != tt.wantHost { + t.Errorf("SplitHostPort(%q) host = %q, want %q", + tt.input, host, tt.wantHost) + } + if port != tt.wantPort { + t.Errorf("SplitHostPort(%q) port = %q, want %q", + tt.input, port, tt.wantPort) + } + }) + } +} diff --git a/go.mod b/go.mod index a12ab78d..882bff4b 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,12 @@ go 1.23.5 require ( github.com/go-i2p/i2pkeys v0.33.92 github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.7.0 ) -require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/go.sum b/go.sum index 538ff81f..bb9f975a 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,7 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/stream/listener.go b/stream/listener.go index b5cfdb6f..09e7022c 100644 --- a/stream/listener.go +++ b/stream/listener.go @@ -5,7 +5,6 @@ import ( "errors" "io" "net" - "strconv" "strings" "github.com/sirupsen/logrus" @@ -38,39 +37,6 @@ func (l *StreamListener) Accept() (net.Conn, error) { return l.AcceptI2P() } -func ExtractPairString(input, value string) string { - log.WithFields(logrus.Fields{"input": input, "value": value}).Debug("ExtractPairString called") - parts := strings.Split(input, " ") - for _, part := range parts { - if strings.HasPrefix(part, value) { - kv := strings.SplitN(input, "=", 2) - if len(kv) == 2 { - log.WithFields(logrus.Fields{"key": kv[0], "value": kv[1]}).Debug("Pair extracted") - return kv[1] - } - } - } - log.WithFields(logrus.Fields{"input": input, "value": value}).Debug("No pair found") - return "" -} - -func ExtractPairInt(input, value string) int { - rv, err := strconv.Atoi(ExtractPairString(input, value)) - if err != nil { - log.WithFields(logrus.Fields{"input": input, "value": value}).Debug("No pair found") - return 0 - } - log.WithField("result", rv).Debug("Pair extracted and converted to int") - return rv -} - -func ExtractDest(input string) string { - log.WithField("input", input).Debug("ExtractDest called") - dest := strings.Split(input, " ")[0] - log.WithField("dest", dest).Debug("Destination extracted") - return strings.Split(input, " ")[0] -} - // accept a new inbound connection func (l *StreamListener) AcceptI2P() (*StreamConn, error) { log.Debug("StreamListener.AcceptI2P() called") @@ -100,9 +66,9 @@ func (l *StreamListener) AcceptI2P() (*StreamConn, error) { // we gud read destination line destline, err := rd.ReadString(10) if err == nil { - dest := ExtractDest(destline) - l.session.Fromport = ExtractPairString(destline, "FROM_PORT") - l.session.Toport = ExtractPairString(destline, "TO_PORT") + dest := common.ExtractDest(destline) + l.session.Fromport = common.ExtractPairString(destline, "FROM_PORT") + l.session.Toport = common.ExtractPairString(destline, "TO_PORT") // return wrapped connection dest = strings.Trim(dest, "\n") log.WithFields(logrus.Fields{