diff --git a/conn.go b/conn.go index 89f0811..f13038d 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,7 @@ package soju import ( "fmt" "net" + "sync" "time" "gopkg.in/irc.v3" @@ -20,12 +21,14 @@ func setKeepAlive(c net.Conn) error { } type conn struct { - net net.Conn - irc *irc.Conn - srv *Server - logger Logger + net net.Conn + irc *irc.Conn + srv *Server + logger Logger + + lock sync.Mutex outgoing chan<- *irc.Message - closed chan struct{} + closed bool } func newConn(srv *Server, netConn net.Conn, logger Logger) *conn { @@ -38,7 +41,6 @@ func newConn(srv *Server, netConn net.Conn, logger Logger) *conn { srv: srv, outgoing: outgoing, logger: logger, - closed: make(chan struct{}), } go func() { @@ -68,20 +70,21 @@ func newConn(srv *Server, netConn net.Conn, logger Logger) *conn { } func (c *conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } + c.lock.Lock() + defer c.lock.Unlock() + return c.closed } // Close closes the connection. It is safe to call from any goroutine. func (c *conn) Close() error { - if c.isClosed() { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { return fmt.Errorf("connection already closed") } - close(c.closed) + + c.closed = true close(c.outgoing) return nil } @@ -101,8 +104,14 @@ func (c *conn) ReadMessage() (*irc.Message, error) { // SendMessage queues a new outgoing message. It is safe to call from any // goroutine. +// +// If the connection is closed before the message is sent, SendMessage silently +// drops the message. func (c *conn) SendMessage(msg *irc.Message) { - if c.isClosed() { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { return } c.outgoing <- msg