diff --git a/downstream.go b/downstream.go index 7844b49..0552e4e 100644 --- a/downstream.go +++ b/downstream.go @@ -304,14 +304,7 @@ func (dc *downstreamConn) Close() error { } if u := dc.user; u != nil { - u.lock.Lock() - for i := range u.downstreamConns { - if u.downstreamConns[i] == dc { - u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) - break - } - } - u.lock.Unlock() + u.removeDownstream(dc) } close(dc.closed) @@ -660,10 +653,7 @@ func (dc *downstreamConn) register() error { dc.username = dc.user.Username dc.logger.Printf("registration complete for user %q", dc.username) - dc.user.lock.Lock() - firstDownstream := len(dc.user.downstreamConns) == 0 - dc.user.downstreamConns = append(dc.user.downstreamConns, dc) - dc.user.lock.Unlock() + firstDownstream := dc.user.addDownstream(dc) dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), @@ -707,6 +697,7 @@ func (dc *downstreamConn) register() error { historyName := dc.rawUsername + // TODO: need to take dc.network into account here var seqPtr *uint64 if firstDownstream { uc.network.lock.Lock() @@ -717,6 +708,7 @@ func (dc *downstreamConn) register() error { } } + // TODO: we need to create a consumer when adding networks on-the-fly consumer, ch := uc.ring.NewConsumer(seqPtr) go func() { for { @@ -734,6 +726,7 @@ func (dc *downstreamConn) register() error { seq := consumer.Close() + // TODO: need to take dc.network into account here dc.user.lock.Lock() lastDownstream := len(dc.user.downstreamConns) == 0 dc.user.lock.Unlock() diff --git a/user.go b/user.go index 629e0a1..245664c 100644 --- a/user.go +++ b/user.go @@ -176,6 +176,25 @@ func (u *user) run() { } } +func (u *user) addDownstream(dc *downstreamConn) (first bool) { + u.lock.Lock() + first = len(dc.user.downstreamConns) == 0 + u.downstreamConns = append(u.downstreamConns, dc) + u.lock.Unlock() + return first +} + +func (u *user) removeDownstream(dc *downstreamConn) { + u.lock.Lock() + for i := range u.downstreamConns { + if u.downstreamConns[i] == dc { + u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + break + } + } + u.lock.Unlock() +} + func (u *user) createNetwork(net *Network) (*network, error) { network := newNetwork(u, net) err := u.srv.db.StoreNetwork(u.Username, &network.Network)