diff --git a/downstream.go b/downstream.go index 228502e..4b33240 100644 --- a/downstream.go +++ b/downstream.go @@ -115,12 +115,21 @@ func (c *downstreamConn) Close() error { if c.closed { return fmt.Errorf("downstream connection already closed") } - if err := c.net.Close(); err != nil { - return err + + if u := c.user; u != nil { + u.lock.Lock() + for i := range u.downstreamConns { + if u.downstreamConns[i] == c { + u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + } + } + u.lock.Unlock() } + close(c.messages) c.closed = true - return nil + + return c.net.Close() } func (c *downstreamConn) handleMessage(msg *irc.Message) error { @@ -182,6 +191,10 @@ func (c *downstreamConn) register() error { c.registered = true c.user = u + u.lock.Lock() + u.downstreamConns = append(u.downstreamConns, c) + u.lock.Unlock() + c.messages <- &irc.Message{ Prefix: c.srv.prefix(), Command: irc.RPL_WELCOME, diff --git a/server.go b/server.go index ec30a38..3b90fb8 100644 --- a/server.go +++ b/server.go @@ -35,8 +35,9 @@ type user struct { username string srv *Server - lock sync.Mutex - upstreamConns []*upstreamConn + lock sync.Mutex + upstreamConns []*upstreamConn + downstreamConns []*downstreamConn } func (u *user) forEachUpstream(f func(uc *upstreamConn)) { @@ -50,6 +51,14 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) { u.lock.Unlock() } +func (u *user) forEachDownstream(f func(dc *downstreamConn)) { + u.lock.Lock() + for _, dc := range u.downstreamConns { + f(dc) + } + u.lock.Unlock() +} + type Upstream struct { Addr string Nick string diff --git a/upstream.go b/upstream.go index b5200c0..d87a04e 100644 --- a/upstream.go +++ b/upstream.go @@ -125,11 +125,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { return err } - c.srv.lock.Lock() - for _, dc := range c.srv.downstreamConns { + c.user.forEachDownstream(func(dc *downstreamConn) { dc.messages <- msg - } - c.srv.lock.Unlock() + }) } case "NOTICE": c.logger.Print(msg) @@ -174,11 +172,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { } } - c.srv.lock.Lock() - for _, dc := range c.srv.downstreamConns { + c.user.forEachDownstream(func(dc *downstreamConn) { dc.messages <- msg - } - c.srv.lock.Unlock() + }) case "PART": if len(msg.Params) < 1 { return newNeedMoreParamsError(msg.Command) @@ -197,11 +193,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { } } - c.srv.lock.Lock() - for _, dc := range c.srv.downstreamConns { + c.user.forEachDownstream(func(dc *downstreamConn) { dc.messages <- msg - } - c.srv.lock.Unlock() + }) case irc.RPL_TOPIC, irc.RPL_NOTOPIC: if len(msg.Params) < 3 { return newNeedMoreParamsError(msg.Command) @@ -275,17 +269,13 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { } ch.complete = true - c.srv.lock.Lock() - for _, dc := range c.srv.downstreamConns { + c.user.forEachDownstream(func(dc *downstreamConn) { forwardChannel(dc, ch) - } - c.srv.lock.Unlock() + }) case "PRIVMSG": - c.srv.lock.Lock() - for _, dc := range c.srv.downstreamConns { + c.user.forEachDownstream(func(dc *downstreamConn) { dc.messages <- msg - } - c.srv.lock.Unlock() + }) case irc.RPL_YOURHOST, irc.RPL_CREATED: // Ignore case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: