From 2a0696b6bb13f4666100c1c031e024b80df4ed00 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Fri, 3 Apr 2020 16:34:11 +0200 Subject: [PATCH] Introduce conn for common connection logic This centralizes the common upstream & downstream bits. --- conn.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++ downstream.go | 79 ++++-------------------------------- server.go | 13 ------ upstream.go | 75 +++------------------------------- 4 files changed, 122 insertions(+), 154 deletions(-) create mode 100644 conn.go diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..89f0811 --- /dev/null +++ b/conn.go @@ -0,0 +1,109 @@ +package soju + +import ( + "fmt" + "net" + "time" + + "gopkg.in/irc.v3" +) + +func setKeepAlive(c net.Conn) error { + tcpConn, ok := c.(*net.TCPConn) + if !ok { + return fmt.Errorf("cannot enable keep-alive on a non-TCP connection") + } + if err := tcpConn.SetKeepAlive(true); err != nil { + return err + } + return tcpConn.SetKeepAlivePeriod(keepAlivePeriod) +} + +type conn struct { + net net.Conn + irc *irc.Conn + srv *Server + logger Logger + outgoing chan<- *irc.Message + closed chan struct{} +} + +func newConn(srv *Server, netConn net.Conn, logger Logger) *conn { + setKeepAlive(netConn) + + outgoing := make(chan *irc.Message, 64) + c := &conn{ + net: netConn, + irc: irc.NewConn(netConn), + srv: srv, + outgoing: outgoing, + logger: logger, + closed: make(chan struct{}), + } + + go func() { + for msg := range outgoing { + if c.srv.Debug { + c.logger.Printf("sent: %v", msg) + } + c.net.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.irc.WriteMessage(msg); err != nil { + c.logger.Printf("failed to write message: %v", err) + break + } + } + if err := c.net.Close(); err != nil { + c.logger.Printf("failed to close connection: %v", err) + } else { + c.logger.Printf("connection closed") + } + // Drain the outgoing channel to prevent SendMessage from blocking + for range outgoing { + // This space is intentionally left blank + } + }() + + c.logger.Printf("new connection") + return c +} + +func (c *conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +// Close closes the connection. It is safe to call from any goroutine. +func (c *conn) Close() error { + if c.isClosed() { + return fmt.Errorf("connection already closed") + } + close(c.closed) + close(c.outgoing) + return nil +} + +func (c *conn) ReadMessage() (*irc.Message, error) { + msg, err := c.irc.ReadMessage() + if err != nil { + return nil, err + } + + if c.srv.Debug { + c.logger.Printf("received: %v", msg) + } + + return msg, nil +} + +// SendMessage queues a new outgoing message. It is safe to call from any +// goroutine. +func (c *conn) SendMessage(msg *irc.Message) { + if c.isClosed() { + return + } + c.outgoing <- msg +} diff --git a/downstream.go b/downstream.go index 96deb05..17b1d80 100644 --- a/downstream.go +++ b/downstream.go @@ -52,13 +52,9 @@ var errAuthFailed = ircError{&irc.Message{ }} type downstreamConn struct { - id uint64 - net net.Conn - irc *irc.Conn - srv *Server - logger Logger - outgoing chan<- *irc.Message - closed chan struct{} + conn + + id uint64 registered bool user *user @@ -84,15 +80,10 @@ type downstreamConn struct { } func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn { - outgoing := make(chan *irc.Message, 64) + logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())} dc := &downstreamConn{ + conn: *newConn(srv, netConn, logger), id: id, - net: netConn, - irc: irc.NewConn(netConn), - srv: srv, - logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, - outgoing: outgoing, - closed: make(chan struct{}), ringConsumers: make(map[*network]*RingConsumer), caps: make(map[string]bool), ourMessages: make(map[*irc.Message]struct{}), @@ -101,30 +92,6 @@ func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn if host, _, err := net.SplitHostPort(dc.hostname); err == nil { dc.hostname = host } - - go func() { - for msg := range outgoing { - if dc.srv.Debug { - dc.logger.Printf("sent: %v", msg) - } - dc.net.SetWriteDeadline(time.Now().Add(writeTimeout)) - if err := dc.irc.WriteMessage(msg); err != nil { - dc.logger.Printf("failed to write message: %v", err) - break - } - } - if err := dc.net.Close(); err != nil { - dc.logger.Printf("failed to close connection: %v", err) - } else { - dc.logger.Printf("connection closed") - } - // Drain the outgoing channel to prevent SendMessage from blocking - for range outgoing { - // This space is intentionally left blank - } - }() - - dc.logger.Printf("new connection") return dc } @@ -227,56 +194,24 @@ func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix } } -func (dc *downstreamConn) isClosed() bool { - select { - case <-dc.closed: - return true - default: - return false - } -} - func (dc *downstreamConn) readMessages(ch chan<- event) error { for { - msg, err := dc.irc.ReadMessage() + msg, err := dc.ReadMessage() if err == io.EOF { break } else if err != nil { return fmt.Errorf("failed to read IRC command: %v", err) } - if dc.srv.Debug { - dc.logger.Printf("received: %v", msg) - } - ch <- eventDownstreamMessage{msg, dc} } return nil } -func (dc *downstreamConn) writeMessages() error { - return nil -} - -// Close closes the connection. It is safe to call from any goroutine. -func (dc *downstreamConn) Close() error { - if dc.isClosed() { - return fmt.Errorf("downstream connection already closed") - } - close(dc.closed) - close(dc.outgoing) - return nil -} - -// SendMessage queues a new outgoing message. It is safe to call from any -// goroutine. func (dc *downstreamConn) SendMessage(msg *irc.Message) { - if dc.isClosed() { - return - } // TODO: strip tags if the client doesn't support them (see runNetwork) - dc.outgoing <- msg + dc.conn.SendMessage(msg) } func (dc *downstreamConn) handleMessage(msg *irc.Message) error { diff --git a/server.go b/server.go index 0389c26..135c4d2 100644 --- a/server.go +++ b/server.go @@ -16,17 +16,6 @@ var retryConnectMinDelay = time.Minute var connectTimeout = 15 * time.Second var writeTimeout = 10 * time.Second -func setKeepAlive(c net.Conn) error { - tcpConn, ok := c.(*net.TCPConn) - if !ok { - return fmt.Errorf("cannot enable keep-alive on a non-TCP connection") - } - if err := tcpConn.SetKeepAlive(true); err != nil { - return err - } - return tcpConn.SetKeepAlivePeriod(keepAlivePeriod) -} - type Logger interface { Print(v ...interface{}) Printf(format string, v ...interface{}) @@ -109,8 +98,6 @@ func (s *Server) Serve(ln net.Listener) error { return fmt.Errorf("failed to accept connection: %v", err) } - setKeepAlive(netConn) - dc := newDownstreamConn(s, netConn, nextDownstreamID) nextDownstreamID++ go func() { diff --git a/upstream.go b/upstream.go index 501d769..24d2173 100644 --- a/upstream.go +++ b/upstream.go @@ -31,14 +31,10 @@ type upstreamChannel struct { } type upstreamConn struct { - network *network - logger Logger - net net.Conn - irc *irc.Conn - srv *Server - user *user - outgoing chan<- *irc.Message - closed chan struct{} + conn + + network *network + user *user serverName string availableUserModes string @@ -90,18 +86,10 @@ func connectToUpstream(network *network) (*upstreamConn, error) { return nil, fmt.Errorf("failed to dial %q: %v", addr, err) } - setKeepAlive(netConn) - - outgoing := make(chan *irc.Message, 64) uc := &upstreamConn{ + conn: *newConn(network.user.srv, netConn, logger), network: network, - logger: logger, - net: netConn, - irc: irc.NewConn(netConn), - srv: network.user.srv, user: network.user, - outgoing: outgoing, - closed: make(chan struct{}), channels: make(map[string]*upstreamChannel), caps: make(map[string]string), batches: make(map[string]batch), @@ -112,50 +100,9 @@ func connectToUpstream(network *network) (*upstreamConn, error) { logs: make(map[string]entityLog), } - go func() { - for msg := range outgoing { - if uc.srv.Debug { - uc.logger.Printf("sent: %v", msg) - } - uc.net.SetWriteDeadline(time.Now().Add(writeTimeout)) - if err := uc.irc.WriteMessage(msg); err != nil { - uc.logger.Printf("failed to write message: %v", err) - break - } - } - if err := uc.net.Close(); err != nil { - uc.logger.Printf("failed to close connection: %v", err) - } else { - uc.logger.Printf("connection closed") - } - // Drain the outgoing channel to prevent SendMessage from blocking - for range outgoing { - // This space is intentionally left blank - } - }() - return uc, nil } -func (uc *upstreamConn) isClosed() bool { - select { - case <-uc.closed: - return true - default: - return false - } -} - -// Close closes the connection. It is safe to call from any goroutine. -func (uc *upstreamConn) Close() error { - if uc.isClosed() { - return fmt.Errorf("upstream connection already closed") - } - close(uc.closed) - close(uc.outgoing) - return nil -} - func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { uc.user.forEachDownstream(func(dc *downstreamConn) { if dc.network != nil && dc.network != uc.network { @@ -1409,29 +1356,19 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { func (uc *upstreamConn) readMessages(ch chan<- event) error { for { - msg, err := uc.irc.ReadMessage() + msg, err := uc.ReadMessage() if err == io.EOF { break } else if err != nil { return fmt.Errorf("failed to read IRC command: %v", err) } - if uc.srv.Debug { - uc.logger.Printf("received: %v", msg) - } - ch <- eventUpstreamMessage{msg, uc} } return nil } -// SendMessage queues a new outgoing message. It is safe to call from any -// goroutine. -func (uc *upstreamConn) SendMessage(msg *irc.Message) { - uc.outgoing <- msg -} - func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) { if uc.labelsSupported { if msg.Tags == nil {