Consume ring messages outside of writer goroutine

This fixes out-of-order JOIN and PRIVMSG messages.

Closes: https://todo.sr.ht/~emersion/soju/36
This commit is contained in:
Simon Ser 2020-03-31 18:16:54 +02:00
parent 1023c2ebfc
commit d748ff269c
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48

View File

@ -51,20 +51,14 @@ var errAuthFailed = ircError{&irc.Message{
Params: []string{"*", "Invalid username or password"}, Params: []string{"*", "Invalid username or password"},
}} }}
type ringMessage struct {
consumer *RingConsumer
upstreamConn *upstreamConn
}
type downstreamConn struct { type downstreamConn struct {
id uint64 id uint64
net net.Conn net net.Conn
irc *irc.Conn irc *irc.Conn
srv *Server srv *Server
logger Logger logger Logger
outgoing chan *irc.Message outgoing chan *irc.Message
ringMessages chan ringMessage closed chan struct{}
closed chan struct{}
registered bool registered bool
user *user user *user
@ -89,16 +83,15 @@ type downstreamConn struct {
func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn { func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn {
dc := &downstreamConn{ dc := &downstreamConn{
id: id, id: id,
net: netConn, net: netConn,
irc: irc.NewConn(netConn), irc: irc.NewConn(netConn),
srv: srv, srv: srv,
logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
outgoing: make(chan *irc.Message, 64), outgoing: make(chan *irc.Message, 64),
ringMessages: make(chan ringMessage), closed: make(chan struct{}),
closed: make(chan struct{}), caps: make(map[string]bool),
caps: make(map[string]bool), ourMessages: make(map[*irc.Message]struct{}),
ourMessages: make(map[*irc.Message]struct{}),
} }
dc.hostname = netConn.RemoteAddr().String() dc.hostname = netConn.RemoteAddr().String()
if host, _, err := net.SplitHostPort(dc.hostname); err == nil { if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
@ -257,42 +250,6 @@ func (dc *downstreamConn) writeMessages() error {
dc.logger.Printf("sent: %v", msg) dc.logger.Printf("sent: %v", msg)
} }
err = dc.irc.WriteMessage(msg) err = dc.irc.WriteMessage(msg)
case ringMessage := <-dc.ringMessages:
consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
for {
msg := consumer.Peek()
if msg == nil {
break
}
dc.lock.Lock()
_, ours := dc.ourMessages[msg]
delete(dc.ourMessages, msg)
dc.lock.Unlock()
if ours {
// The message comes from our connection, don't echo it
// back
consumer.Consume()
continue
}
msg = msg.Copy()
switch msg.Command {
case "PRIVMSG":
msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix)
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0])
default:
panic("expected to consume a PRIVMSG message")
}
if dc.srv.Debug {
dc.logger.Printf("sent: %v", msg)
}
err = dc.irc.WriteMessage(msg)
if err != nil {
break
}
consumer.Consume()
}
case <-dc.closed: case <-dc.closed:
closed = true closed = true
} }
@ -774,7 +731,36 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) {
dc.logger.Printf("ignoring messages for upstream %q: upstream is disconnected", net.Addr) dc.logger.Printf("ignoring messages for upstream %q: upstream is disconnected", net.Addr)
break break
} }
dc.ringMessages <- ringMessage{consumer, uc}
for {
msg := consumer.Peek()
if msg == nil {
break
}
dc.lock.Lock()
_, ours := dc.ourMessages[msg]
delete(dc.ourMessages, msg)
dc.lock.Unlock()
if ours {
// The message comes from our connection, don't echo it
// back
consumer.Consume()
continue
}
msg = msg.Copy()
switch msg.Command {
case "PRIVMSG":
msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix)
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0])
default:
panic("expected to consume a PRIVMSG message")
}
dc.SendMessage(msg)
consumer.Consume()
}
case <-dc.closed: case <-dc.closed:
closed = true closed = true
} }