From 3397965dea809d8f563dec1fc8e662d1ef91c98b Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 1 Jul 2020 17:02:37 +0200 Subject: [PATCH] Add RemoteAddr to ircConn interface --- conn.go | 20 ++++++++++++++++++-- downstream.go | 3 ++- server.go | 8 ++++---- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index d1cca65..af384a7 100644 --- a/conn.go +++ b/conn.go @@ -20,6 +20,7 @@ type ircConn interface { Close() error SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error + RemoteAddr() net.Addr } func newNetIRCConn(c net.Conn) ircConn { @@ -33,10 +34,11 @@ func newNetIRCConn(c net.Conn) ircConn { type websocketIRCConn struct { conn *websocket.Conn readDeadline, writeDeadline time.Time + remoteAddr string } -func newWebsocketIRCConn(c *websocket.Conn) ircConn { - return websocketIRCConn{conn: c} +func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn { + return websocketIRCConn{conn: c, remoteAddr: remoteAddr} } func (wic websocketIRCConn) ReadMessage() (*irc.Message, error) { @@ -83,6 +85,20 @@ func (wic websocketIRCConn) SetWriteDeadline(t time.Time) error { return nil } +func (wic websocketIRCConn) RemoteAddr() net.Addr { + return websocketAddr(wic.remoteAddr) +} + +type websocketAddr string + +func (websocketAddr) Network() string { + return "ws" +} + +func (wa websocketAddr) String() string { + return string(wa) +} + type conn struct { conn ircConn srv *Server diff --git a/downstream.go b/downstream.go index 007e96c..9cca6c0 100644 --- a/downstream.go +++ b/downstream.go @@ -99,7 +99,8 @@ type downstreamConn struct { saslServer sasl.Server } -func newDownstreamConn(srv *Server, ic ircConn, remoteAddr string, id uint64) *downstreamConn { +func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { + remoteAddr := ic.RemoteAddr().String() logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} dc := &downstreamConn{ conn: *newConn(srv, ic, logger), diff --git a/server.go b/server.go index 0f26fb8..91a269c 100644 --- a/server.go +++ b/server.go @@ -117,9 +117,9 @@ func (s *Server) getUser(name string) *user { var lastDownstreamID uint64 = 0 -func (s *Server) handle(ic ircConn, remoteAddr string) { +func (s *Server) handle(ic ircConn) { id := atomic.AddUint64(&lastDownstreamID, 1) - dc := newDownstreamConn(s, ic, remoteAddr, id) + dc := newDownstreamConn(s, ic, id) if err := dc.runUntilRegistered(); err != nil { dc.logger.Print(err) } else { @@ -139,7 +139,7 @@ func (s *Server) Serve(ln net.Listener) error { return fmt.Errorf("failed to accept connection: %v", err) } - go s.handle(newNetIRCConn(conn), conn.RemoteAddr().String()) + go s.handle(newNetIRCConn(conn)) } } @@ -168,5 +168,5 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { remoteAddr = net.JoinHostPort(forwardedHost, forwardedPort) } - s.handle(newWebsocketIRCConn(conn), remoteAddr) + s.handle(newWebsocketIRCConn(conn, remoteAddr)) }