diff --git a/conn.go b/conn.go index 82ee4eb..576d308 100644 --- a/conn.go +++ b/conn.go @@ -172,6 +172,7 @@ type conn struct { lock sync.Mutex outgoing chan<- *irc.Message closed bool + closedCh chan struct{} } func newConn(srv *Server, ic ircConn, options *connOptions) *conn { @@ -181,6 +182,7 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn { srv: srv, outgoing: outgoing, logger: options.Logger, + closedCh: make(chan struct{}), } go func() { @@ -237,6 +239,7 @@ func (c *conn) Close() error { err := c.conn.Close() c.closed = true close(c.outgoing) + close(c.closedCh) return err } @@ -277,3 +280,28 @@ func (c *conn) RemoteAddr() net.Addr { func (c *conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } + +// NewContext returns a copy of the parent context with a new Done channel. The +// returned context's Done channel is closed when the connection is closed, +// when the returned cancel function is called, or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(parent) + + go func() { + defer cancel() + + select { + case <-ctx.Done(): + // The parent context has been cancelled, or the caller has called + // cancel() + case <-c.closedCh: + // The connection has been closed + } + }() + + return ctx, cancel +} diff --git a/downstream.go b/downstream.go index eb31a61..cce1305 100644 --- a/downstream.go +++ b/downstream.go @@ -615,7 +615,10 @@ func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Me } func (dc *downstreamConn) handleMessage(msg *irc.Message) error { - ctx, cancel := context.WithTimeout(context.TODO(), handleDownstreamMessageTimeout) + ctx, cancel := dc.conn.NewContext(context.TODO()) + defer cancel() + + ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) defer cancel() switch msg.Command {