From 3919ee2036975e5e4a237af695e22926df04285b Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 16 Mar 2020 12:44:59 +0100 Subject: [PATCH] Per-user dispatcher goroutine This allows message handlers to read upstream/downstream connection information without causing any race condition. References: https://todo.sr.ht/~emersion/soju/1 --- downstream.go | 35 +++++++++++++++++++++++------------ server.go | 8 ++++++-- upstream.go | 6 ++---- user.go | 43 ++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/downstream.go b/downstream.go index 4d28ab6..12ddc6b 100644 --- a/downstream.go +++ b/downstream.go @@ -191,7 +191,7 @@ func (dc *downstreamConn) isClosed() bool { } } -func (dc *downstreamConn) readMessages() error { +func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error { dc.logger.Printf("new connection") for { @@ -206,17 +206,7 @@ func (dc *downstreamConn) readMessages() error { dc.logger.Printf("received: %v", msg) } - err = dc.handleMessage(msg) - if ircErr, ok := err.(ircError); ok { - ircErr.Message.Prefix = dc.srv.prefix() - dc.SendMessage(ircErr.Message) - } else if err != nil { - return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err) - } - - if dc.isClosed() { - return nil - } + ch <- downstreamIncomingMessage{msg, dc} } return nil @@ -484,6 +474,27 @@ func (dc *downstreamConn) register() error { return nil } +func (dc *downstreamConn) runUntilRegistered() error { + for !dc.registered { + msg, err := dc.irc.ReadMessage() + if err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + err = dc.handleMessage(msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) + } + } + + return nil +} + func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { switch msg.Command { case "USER": diff --git a/server.go b/server.go index 3fbe096..62d30da 100644 --- a/server.go +++ b/server.go @@ -114,8 +114,12 @@ func (s *Server) Serve(ln net.Listener) error { s.downstreamConns = append(s.downstreamConns, dc) s.lock.Unlock() - if err := dc.readMessages(); err != nil { - dc.logger.Printf("failed to handle messages: %v", err) + if err := dc.runUntilRegistered(); err != nil { + dc.logger.Print(err) + } else { + if err := dc.readMessages(dc.user.downstreamIncoming); err != nil { + dc.logger.Print(err) + } } dc.Close() diff --git a/upstream.go b/upstream.go index 2cc6ba8..7fc2a32 100644 --- a/upstream.go +++ b/upstream.go @@ -659,7 +659,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { return nil } -func (uc *upstreamConn) readMessages() error { +func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error { for { msg, err := uc.irc.ReadMessage() if err == io.EOF { @@ -672,9 +672,7 @@ func (uc *upstreamConn) readMessages() error { uc.logger.Printf("received: %v", msg) } - if err := uc.handleMessage(msg); err != nil { - uc.logger.Printf("failed to handle message %q: %v", msg, err) - } + ch <- upstreamIncomingMessage{msg, uc} } return nil diff --git a/user.go b/user.go index 457958c..980ddfc 100644 --- a/user.go +++ b/user.go @@ -3,8 +3,20 @@ package soju import ( "sync" "time" + + "gopkg.in/irc.v3" ) +type upstreamIncomingMessage struct { + msg *irc.Message + uc *upstreamConn +} + +type downstreamIncomingMessage struct { + msg *irc.Message + dc *downstreamConn +} + type network struct { Network user *user @@ -40,7 +52,7 @@ func (net *network) run() { net.conn = uc net.user.lock.Unlock() - if err := uc.readMessages(); err != nil { + if err := uc.readMessages(net.user.upstreamIncoming); err != nil { uc.logger.Printf("failed to handle messages: %v", err) } uc.Close() @@ -55,6 +67,9 @@ type user struct { User srv *Server + upstreamIncoming chan upstreamIncomingMessage + downstreamIncoming chan downstreamIncomingMessage + lock sync.Mutex networks []*network downstreamConns []*downstreamConn @@ -62,8 +77,10 @@ type user struct { func newUser(srv *Server, record *User) *user { return &user{ - User: *record, - srv: srv, + User: *record, + srv: srv, + upstreamIncoming: make(chan upstreamIncomingMessage, 64), + downstreamIncoming: make(chan downstreamIncomingMessage, 64), } } @@ -119,6 +136,26 @@ func (u *user) run() { go network.run() } u.lock.Unlock() + + for { + select { + case upstreamMsg := <-u.upstreamIncoming: + msg, uc := upstreamMsg.msg, upstreamMsg.uc + if err := uc.handleMessage(msg); err != nil { + uc.logger.Printf("failed to handle message %q: %v", msg, err) + } + case downstreamMsg := <-u.downstreamIncoming: + msg, dc := downstreamMsg.msg, downstreamMsg.dc + err := dc.handleMessage(msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + dc.logger.Printf("failed to handle message %q: %v", msg, err) + dc.Close() + } + } + } } func (u *user) createNetwork(addr, nick string) (*network, error) {