diff --git a/filter.go b/filter.go index 0f067cb..1d8eb4e 100644 --- a/filter.go +++ b/filter.go @@ -36,7 +36,7 @@ func (c *ConnFilter) Read(b []byte) (n int, err error) { // NewConnFilter creates a new ConnFilter that replaces occurrences of target strings with replacement strings in the data read from the connection. // It returns an error if the lengths of target and replacement slices are not equal. -func NewConnFilter(parentConn net.Conn, targets []string, replacements []string) (net.Conn, error) { +func NewConnFilter(parentConn net.Conn, targets, replacements []string) (net.Conn, error) { if len(targets) != len(replacements) { return nil, ErrInvalidFilter } diff --git a/functionFilter.go b/functionFilter.go new file mode 100644 index 0000000..57e2544 --- /dev/null +++ b/functionFilter.go @@ -0,0 +1,56 @@ +package filter + +import ( + "errors" + "net" +) + +var ErrInvalidFunctionFilter = errors.New("invalid Function filter") + +func noopReadFilter(b []byte) ([]byte, error) { + return b, nil +} + +func noopWriteFilter(b []byte) ([]byte, error) { + return b, nil +} + +type FunctionConnFilter struct { + net.Conn + ReadFilter func(b []byte) ([]byte, error) + WriteFilter func(b []byte) ([]byte, error) +} + +var ex net.Conn = &FunctionConnFilter{} + +// Write modifies the bytes according to c.Filter and writes the result to the underlying connection +func (c *FunctionConnFilter) Write(b []byte) (n int, err error) { + b2, err := c.WriteFilter(b) + if err != nil { + return len(b), err + } + return c.Conn.Write(b2) +} + +// Read reads data from the underlying connection and modifies the bytes according to c.Filter +func (c *FunctionConnFilter) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + return n, err + } + b2, err := c.ReadFilter(b) + if err != nil { + return n, err + } + copy(b, b2) + return len(b2), nil +} + +// NewFunctionConnFilter creates a new FunctionConnFilter that has the powerful ability to rewrite any byte that comes across the net.Conn with a user-defined function. By default a no-op function. +func NewFunctionConnFilter(parentConn net.Conn, Function string) (net.Conn, error) { + return &FunctionConnFilter{ + Conn: parentConn, + ReadFilter: noopReadFilter, + WriteFilter: noopWriteFilter, + }, nil +} diff --git a/http/httpinspector_test.go b/http/httpinspector_test.go index 75a1ea1..590ea99 100644 --- a/http/httpinspector_test.go +++ b/http/httpinspector_test.go @@ -52,7 +52,6 @@ func TestInspector(t *testing.T) { t.Error("request modification not applied") } }) - } // Mock implementations for testing diff --git a/irc/example/example.go b/irc/example/example.go new file mode 100644 index 0000000..b8978f7 --- /dev/null +++ b/irc/example/example.go @@ -0,0 +1,44 @@ +package main + +import ( + "log" + "net" + + ircinspector "github.com/go-i2p/go-connfilter/irc" +) + +func main() { + listener, err := net.Listen("tcp", ":6667") + if err != nil { + log.Fatal(err) + } + + inspector := ircinspector.New(listener, ircinspector.Config{ + OnMessage: func(msg *ircinspector.Message) error { + log.Printf("Received message: %s", msg.Raw) + return nil + }, + OnNumeric: func(numeric int, msg *ircinspector.Message) error { + log.Printf("Received numeric response: %d", numeric) + return nil + }, + }) + + inspector.AddFilter(ircinspector.Filter{ + Command: "PRIVMSG", + Channel: "#mychannel", + Callback: func(msg *ircinspector.Message) error { + msg.Trailing = "[modified] " + msg.Trailing + return nil + }, + }) + + for { + conn, err := inspector.Accept() + if err != nil { + log.Printf("Accept error: %v", err) + continue + } + go ircinspector.HandleConnection(conn) + } +} diff --git a/irc/ircinspector.go b/irc/ircinspector.go new file mode 100644 index 0000000..4f063b3 --- /dev/null +++ b/irc/ircinspector.go @@ -0,0 +1,178 @@ +package ircinspector + +import ( + "bufio" + "fmt" + "log" + "net" + "strings" +) + +type defaultLogger struct{} + +// Debug implements Logger. +func (d *defaultLogger) Debug(format string, args ...interface{}) { + log.Printf("DBG:"+format, args) +} + +// Error implements Logger. +func (d *defaultLogger) Error(format string, args ...interface{}) { + log.Printf("ERR:"+format, args) +} + +// New creates a new IRC inspector wrapping an existing listener +func New(listener net.Listener, config Config) *Inspector { + if config.Logger == nil { + config.Logger = &defaultLogger{} + } + + return &Inspector{ + listener: listener, + config: config, + filters: make([]Filter, 0), + } +} + +// Accept implements net.Listener Accept method +func (i *Inspector) Accept() (net.Conn, error) { + conn, err := i.listener.Accept() + if err != nil { + return nil, err + } + + return &ircConn{ + Conn: conn, + inspector: i, + }, nil +} + +// Close implements net.Listener Close method +func (i *Inspector) Close() error { + return i.listener.Close() +} + +// Addr implements net.Listener Addr method +func (i *Inspector) Addr() net.Addr { + return i.listener.Addr() +} + +type ircConn struct { + net.Conn + inspector *Inspector + reader *bufio.Reader + writer *bufio.Writer +} + +func (c *ircConn) Read(b []byte) (n int, err error) { + if c.reader == nil { + c.reader = bufio.NewReader(c.Conn) + } + + line, err := c.reader.ReadString('\n') + if err != nil { + return 0, err + } + + msg, err := parseMessage(line) + if err != nil { + c.inspector.config.Logger.Error("parse error: %v", err) + copy(b, line) + return len(line), nil + } + + if err := c.inspector.processMessage(msg); err != nil { + c.inspector.config.Logger.Error("process error: %v", err) + } + + modified := msg.String() + copy(b, modified) + return len(modified), nil +} + +func (c *ircConn) Write(b []byte) (n int, err error) { + if c.writer == nil { + c.writer = bufio.NewWriter(c.Conn) + } + + msg, err := parseMessage(string(b)) + if err != nil { + return c.writer.Write(b) + } + + if err := c.inspector.processMessage(msg); err != nil { + c.inspector.config.Logger.Error("process error: %v", err) + } + + return c.writer.Write([]byte(msg.String())) +} + +func parseMessage(raw string) (*Message, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, fmt.Errorf("empty message") + } + + msg := &Message{Raw: raw} + + if raw[0] == ':' { + parts := strings.SplitN(raw[1:], " ", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid message format") + } + msg.Prefix = parts[0] + raw = parts[1] + } + + parts := strings.SplitN(raw, " :", 2) + if len(parts) > 1 { + msg.Trailing = parts[1] + } + + words := strings.Fields(parts[0]) + if len(words) == 0 { + return nil, fmt.Errorf("no command found") + } + + msg.Command = words[0] + if len(words) > 1 { + msg.Params = words[1:] + } + + return msg, nil +} + +func (i *Inspector) AddFilter(filter Filter) { + i.mu.Lock() + defer i.mu.Unlock() + i.filters = append(i.filters, filter) +} + +func (i *Inspector) processMessage(msg *Message) error { + i.mu.RLock() + defer i.mu.RUnlock() + + // Process global message handler + if i.config.OnMessage != nil { + if err := i.config.OnMessage(msg); err != nil { + return err + } + } + + // Process numeric responses + if numeric, err := parseNumeric(msg.Command); err == nil && i.config.OnNumeric != nil { + if err := i.config.OnNumeric(numeric, msg); err != nil { + return err + } + } + + // Process filters + for _, filter := range i.filters { + if matchesFilter(msg, filter) { + if err := filter.Callback(msg); err != nil { + return err + } + } + } + + return nil +} diff --git a/irc/types.go b/irc/types.go new file mode 100644 index 0000000..424d332 --- /dev/null +++ b/irc/types.go @@ -0,0 +1,44 @@ +package ircinspector + +import ( + "net" + "sync" +) + +// Message represents a parsed IRC message +type Message struct { + Raw string + Prefix string + Command string + Params []string + Trailing string +} + +// Filter defines criteria for message filtering +type Filter struct { + Command string + Channel string + Prefix string + Callback func(*Message) error +} + +// Config contains inspector configuration +type Config struct { + OnMessage func(*Message) error + OnNumeric func(int, *Message) error + Logger Logger +} + +// Logger interface for customizable logging +type Logger interface { + Debug(format string, args ...interface{}) + Error(format string, args ...interface{}) +} + +// Inspector implements the net.Listener interface with IRC inspection +type Inspector struct { + listener net.Listener + config Config + filters []Filter + mu sync.RWMutex +} diff --git a/regexFilter.go b/regexFilter.go new file mode 100644 index 0000000..b471b0c --- /dev/null +++ b/regexFilter.go @@ -0,0 +1,34 @@ +package filter + +import ( + "bytes" + "errors" + "net" +) + +var ErrInvalidRegexFilter = errors.New("invalid regex filter") + +type RegexConnFilter struct { + net.Conn + match string +} + +// Read reads data from the underlying connection and replaces all occurrences of target regex +// with empty strings. The modified data is then copied back to the provided buffer. +func (c *RegexConnFilter) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + return n, err + } + // Replace all occurrences of regex `match` with nothing + var buffer bytes.Buffer + return buffer.Len(), nil +} + +// NewRegexConnFilter creates a new RegexConnFilter that replaces occurrences of target regex with empty strings in the data read from the connection. +// It returns an error if the lengths of target and replacement slices are not equal. +func NewRegexConnFilter(parentConn net.Conn, regex string) (net.Conn, error) { + return &RegexConnFilter{ + Conn: parentConn, + }, nil +}