diff --git a/downstream.go b/downstream.go index d106c3d..3166b1a 100644 --- a/downstream.go +++ b/downstream.go @@ -660,7 +660,9 @@ func (dc *downstreamConn) register() error { dc.username = dc.user.Username dc.logger.Printf("registration complete for user %q", dc.username) - firstDownstream := dc.user.addDownstream(dc) + dc.user.lock.Lock() + firstDownstream := len(dc.user.downstreamConns) == 0 + dc.user.lock.Unlock() dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), diff --git a/server.go b/server.go index 31589a9..3666c53 100644 --- a/server.go +++ b/server.go @@ -119,6 +119,7 @@ func (s *Server) Serve(ln net.Listener) error { if err := dc.runUntilRegistered(); err != nil { dc.logger.Print(err) } else { + dc.user.events <- eventDownstreamConnected{dc} if err := dc.readMessages(dc.user.events); err != nil { dc.logger.Print(err) } diff --git a/user.go b/user.go index 0d2b253..11f972c 100644 --- a/user.go +++ b/user.go @@ -19,6 +19,10 @@ type eventDownstreamMessage struct { dc *downstreamConn } +type eventDownstreamConnected struct { + dc *downstreamConn +} + type network struct { Network user *user @@ -160,6 +164,11 @@ func (u *user) run() { if err := uc.handleMessage(msg); err != nil { uc.logger.Printf("failed to handle message %q: %v", msg, err) } + case eventDownstreamConnected: + dc := e.dc + u.lock.Lock() + u.downstreamConns = append(u.downstreamConns, dc) + u.lock.Unlock() case eventDownstreamMessage: msg, dc := e.msg, e.dc if dc.isClosed() { @@ -180,14 +189,6 @@ func (u *user) run() { } } -func (u *user) addDownstream(dc *downstreamConn) (first bool) { - u.lock.Lock() - first = len(dc.user.downstreamConns) == 0 - u.downstreamConns = append(u.downstreamConns, dc) - u.lock.Unlock() - return first -} - func (u *user) removeDownstream(dc *downstreamConn) { u.lock.Lock() for i := range u.downstreamConns {