diff --git a/downstream.go b/downstream.go index 5796277..4f898cc 100644 --- a/downstream.go +++ b/downstream.go @@ -51,20 +51,14 @@ var errAuthFailed = ircError{&irc.Message{ Params: []string{"*", "Invalid username or password"}, }} -type ringMessage struct { - consumer *RingConsumer - upstreamConn *upstreamConn -} - type downstreamConn struct { - id uint64 - net net.Conn - irc *irc.Conn - srv *Server - logger Logger - outgoing chan *irc.Message - ringMessages chan ringMessage - closed chan struct{} + id uint64 + net net.Conn + irc *irc.Conn + srv *Server + logger Logger + outgoing chan *irc.Message + closed chan struct{} registered bool user *user @@ -89,16 +83,15 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn { dc := &downstreamConn{ - id: id, - net: netConn, - irc: irc.NewConn(netConn), - srv: srv, - logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, - outgoing: make(chan *irc.Message, 64), - ringMessages: make(chan ringMessage), - closed: make(chan struct{}), - caps: make(map[string]bool), - ourMessages: make(map[*irc.Message]struct{}), + id: id, + net: netConn, + irc: irc.NewConn(netConn), + srv: srv, + logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, + outgoing: make(chan *irc.Message, 64), + closed: make(chan struct{}), + caps: make(map[string]bool), + ourMessages: make(map[*irc.Message]struct{}), } dc.hostname = netConn.RemoteAddr().String() if host, _, err := net.SplitHostPort(dc.hostname); err == nil { @@ -257,42 +250,6 @@ func (dc *downstreamConn) writeMessages() error { dc.logger.Printf("sent: %v", msg) } err = dc.irc.WriteMessage(msg) - case ringMessage := <-dc.ringMessages: - consumer, uc := ringMessage.consumer, ringMessage.upstreamConn - for { - msg := consumer.Peek() - if msg == nil { - break - } - - dc.lock.Lock() - _, ours := dc.ourMessages[msg] - delete(dc.ourMessages, msg) - dc.lock.Unlock() - if ours { - // The message comes from our connection, don't echo it - // back - consumer.Consume() - continue - } - - msg = msg.Copy() - switch msg.Command { - case "PRIVMSG": - msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix) - msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) - default: - panic("expected to consume a PRIVMSG message") - } - if dc.srv.Debug { - dc.logger.Printf("sent: %v", msg) - } - err = dc.irc.WriteMessage(msg) - if err != nil { - break - } - consumer.Consume() - } case <-dc.closed: closed = true } @@ -774,7 +731,36 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) { dc.logger.Printf("ignoring messages for upstream %q: upstream is disconnected", net.Addr) break } - dc.ringMessages <- ringMessage{consumer, uc} + + for { + msg := consumer.Peek() + if msg == nil { + break + } + + dc.lock.Lock() + _, ours := dc.ourMessages[msg] + delete(dc.ourMessages, msg) + dc.lock.Unlock() + if ours { + // The message comes from our connection, don't echo it + // back + consumer.Consume() + continue + } + + msg = msg.Copy() + switch msg.Command { + case "PRIVMSG": + msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix) + msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) + default: + panic("expected to consume a PRIVMSG message") + } + + dc.SendMessage(msg) + consumer.Consume() + } case <-dc.closed: closed = true }