diff --git a/downstream.go b/downstream.go index f3e0fca..1496ce2 100644 --- a/downstream.go +++ b/downstream.go @@ -71,6 +71,8 @@ type downstreamConn struct { password string // empty after authentication network *network // can be nil + ringConsumers map[*network]*RingConsumer + negociatingCaps bool capVersion int caps map[string]bool @@ -83,15 +85,16 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn { dc := &downstreamConn{ - id: id, - net: netConn, - irc: irc.NewConn(netConn), - srv: srv, - logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, - outgoing: make(chan *irc.Message, 64), - closed: make(chan struct{}), - caps: make(map[string]bool), - ourMessages: make(map[*irc.Message]struct{}), + id: id, + net: netConn, + irc: irc.NewConn(netConn), + srv: srv, + logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, + outgoing: make(chan *irc.Message, 64), + closed: make(chan struct{}), + ringConsumers: make(map[*network]*RingConsumer), + caps: make(map[string]bool), + ourMessages: make(map[*irc.Message]struct{}), } dc.hostname = netConn.RemoteAddr().String() if host, _, err := net.SplitHostPort(dc.hostname); err == nil { @@ -722,9 +725,7 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) { var seqPtr *uint64 if loadHistory { - net.lock.Lock() seq, ok := net.history[dc.clientName] - net.lock.Unlock() if ok { seqPtr = &seq } @@ -735,79 +736,63 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) { serverTimeEnabled := dc.caps["server-time"] consumer, ch := net.ring.NewConsumer(seqPtr) + + if _, ok := dc.ringConsumers[net]; ok { + panic("network has been added twice") + } + dc.ringConsumers[net] = consumer + go func() { - for { - var closed bool - select { - case _, ok := <-ch: - if !ok { - closed = true + for range ch { + uc := net.upstream() + if uc == nil { + dc.logger.Printf("ignoring messages for upstream %q: upstream is disconnected", net.Addr) + continue + } + + for { + msg := consumer.Peek() + if msg == nil { break } - uc := net.upstream() - if uc == nil { - dc.logger.Printf("ignoring messages for upstream %q: upstream is disconnected", net.Addr) - break + dc.lock.Lock() + _, ours := dc.ourMessages[msg] + delete(dc.ourMessages, msg) + dc.lock.Unlock() + if ours { + // The message comes from our connection, don't echo it + // back + consumer.Consume() + continue } - for { - msg := consumer.Peek() - if msg == nil { - break - } + msg = msg.Copy() + switch msg.Command { + case "PRIVMSG": + msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix) + msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) + default: + panic("expected to consume a PRIVMSG message") + } - dc.lock.Lock() - _, ours := dc.ourMessages[msg] - delete(dc.ourMessages, msg) - dc.lock.Unlock() - if ours { - // The message comes from our connection, don't echo it - // back - consumer.Consume() - continue - } - - msg = msg.Copy() - switch msg.Command { - case "PRIVMSG": - msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix) - msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) - default: - panic("expected to consume a PRIVMSG message") - } - - if !msgTagsEnabled { - for name := range msg.Tags { - supported := false - switch name { - case "time": - supported = serverTimeEnabled - } - if !supported { - delete(msg.Tags, name) - } + if !msgTagsEnabled { + for name := range msg.Tags { + supported := false + switch name { + case "time": + supported = serverTimeEnabled + } + if !supported { + delete(msg.Tags, name) } } - - dc.SendMessage(msg) - consumer.Consume() } - case <-dc.closed: - closed = true - } - if closed { - break + + dc.SendMessage(msg) + consumer.Consume() } } - - // TODO: close the consumer from the user goroutine, so we don't need - // that net.history lock - seq := consumer.Close() - - net.lock.Lock() - net.history[dc.clientName] = seq - net.lock.Unlock() }() } diff --git a/user.go b/user.go index d335e52..af4d97f 100644 --- a/user.go +++ b/user.go @@ -41,9 +41,10 @@ type network struct { ring *Ring stopped chan struct{} - lock sync.Mutex - conn *upstreamConn history map[string]uint64 + + lock sync.Mutex + conn *upstreamConn } func newNetwork(user *user, record *Network) *network { @@ -235,6 +236,12 @@ func (u *user) run() { }) case eventDownstreamDisconnected: dc := e.dc + + for net, rc := range dc.ringConsumers { + seq := rc.Close() + net.history[dc.clientName] = seq + } + for i := range u.downstreamConns { if u.downstreamConns[i] == dc { u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)