Use a lock to protect conn.{closed,outgoing}

Unfortunately, I don't think there's a good way to implement net.Conn
semantics on top of channels. The Close and SendMessage methods should
gracefully fail without panicking if the connection is already closed.
Using only channels leads to race conditions.

We could remove the lock if Close and SendMessage are only called from a
single goroutine. However that's not the case right now.

Closes: https://todo.sr.ht/~emersion/soju/55
This commit is contained in:
Simon Ser 2020-04-30 10:35:02 +02:00
parent e7e4311160
commit e9cebb6fe3
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48

31
conn.go
View File

@ -3,6 +3,7 @@ package soju
import ( import (
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
@ -24,8 +25,10 @@ type conn struct {
irc *irc.Conn irc *irc.Conn
srv *Server srv *Server
logger Logger logger Logger
lock sync.Mutex
outgoing chan<- *irc.Message outgoing chan<- *irc.Message
closed chan struct{} closed bool
} }
func newConn(srv *Server, netConn net.Conn, logger Logger) *conn { func newConn(srv *Server, netConn net.Conn, logger Logger) *conn {
@ -38,7 +41,6 @@ func newConn(srv *Server, netConn net.Conn, logger Logger) *conn {
srv: srv, srv: srv,
outgoing: outgoing, outgoing: outgoing,
logger: logger, logger: logger,
closed: make(chan struct{}),
} }
go func() { go func() {
@ -68,20 +70,21 @@ func newConn(srv *Server, netConn net.Conn, logger Logger) *conn {
} }
func (c *conn) isClosed() bool { func (c *conn) isClosed() bool {
select { c.lock.Lock()
case <-c.closed: defer c.lock.Unlock()
return true return c.closed
default:
return false
}
} }
// Close closes the connection. It is safe to call from any goroutine. // Close closes the connection. It is safe to call from any goroutine.
func (c *conn) Close() error { func (c *conn) Close() error {
if c.isClosed() { c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return fmt.Errorf("connection already closed") return fmt.Errorf("connection already closed")
} }
close(c.closed)
c.closed = true
close(c.outgoing) close(c.outgoing)
return nil return nil
} }
@ -101,8 +104,14 @@ func (c *conn) ReadMessage() (*irc.Message, error) {
// SendMessage queues a new outgoing message. It is safe to call from any // SendMessage queues a new outgoing message. It is safe to call from any
// goroutine. // goroutine.
//
// If the connection is closed before the message is sent, SendMessage silently
// drops the message.
func (c *conn) SendMessage(msg *irc.Message) { func (c *conn) SendMessage(msg *irc.Message) {
if c.isClosed() { c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return return
} }
c.outgoing <- msg c.outgoing <- msg