diff --git a/upstream.go b/upstream.go index 5c6c9e4..c6e34bc 100644 --- a/upstream.go +++ b/upstream.go @@ -36,6 +36,7 @@ type upstreamConn struct { srv *Server user *user outgoing chan<- *irc.Message + closed chan struct{} serverName string availableUserModes string @@ -47,7 +48,6 @@ type upstreamConn struct { nick string username string realname string - closed bool modes userModes channels map[string]*upstreamChannel caps map[string]string @@ -95,12 +95,21 @@ func connectToUpstream(network *network) (*upstreamConn, error) { } go func() { - for msg := range outgoing { - if uc.srv.Debug { - uc.logger.Printf("sent: %v", msg) + for { + var closed bool + select { + case msg := <-outgoing: + if uc.srv.Debug { + uc.logger.Printf("sent: %v", msg) + } + if err := uc.irc.WriteMessage(msg); err != nil { + uc.logger.Printf("failed to write message: %v", err) + } + case <-uc.closed: + closed = true } - if err := uc.irc.WriteMessage(msg); err != nil { - uc.logger.Printf("failed to write message: %v", err) + if closed { + break } } if err := uc.net.Close(); err != nil { @@ -113,12 +122,20 @@ func connectToUpstream(network *network) (*upstreamConn, error) { return uc, nil } +func (uc *upstreamConn) isClosed() bool { + select { + case <-uc.closed: + return true + default: + return false + } +} + func (uc *upstreamConn) Close() error { - if uc.closed { + if uc.isClosed() { return fmt.Errorf("upstream connection already closed") } - close(uc.outgoing) - uc.closed = true + close(uc.closed) return nil } diff --git a/user.go b/user.go index 804a5e9..8c19079 100644 --- a/user.go +++ b/user.go @@ -64,6 +64,9 @@ func (net *network) run() { uc.register() + // TODO: wait for the connection to be registered before adding it to + // net, otherwise messages might be sent to it while still being + // unauthenticated net.lock.Lock() net.conn = uc net.lock.Unlock() @@ -112,7 +115,7 @@ func (u *user) forEachNetwork(f func(*network)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) { for _, network := range u.networks { uc := network.upstream() - if uc == nil || !uc.registered || uc.closed { + if uc == nil || !uc.registered { continue } f(uc) @@ -152,7 +155,7 @@ func (u *user) run() { switch e := e.(type) { case eventUpstreamMessage: msg, uc := e.msg, e.uc - if uc.closed { + if uc.isClosed() { uc.logger.Printf("ignoring message on closed connection: %v", msg) break }