From 08bb06c164e98dd7b2b16a2d93074c7fbf3938fa Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Fri, 27 Mar 2020 19:17:58 +0100 Subject: [PATCH] Nuke user.lock Split user.register into two functions, one to make sure the user is authenticated, the other to send our current state. This allows to get rid of data races by doing the second part in the user goroutine. Closes: https://todo.sr.ht/~emersion/soju/22 --- downstream.go | 106 +++++++++++++++++++++++++++++--------------------- user.go | 21 +++------- 2 files changed, 67 insertions(+), 60 deletions(-) diff --git a/downstream.go b/downstream.go index 8207c41..e7fb6d4 100644 --- a/downstream.go +++ b/downstream.go @@ -71,6 +71,7 @@ type downstreamConn struct { nick string username string rawUsername string + networkName string realname string hostname string password string // empty after authentication @@ -582,42 +583,6 @@ func unmarshalUsername(rawUsername string) (username, network string) { return username, network } -func (dc *downstreamConn) setNetwork(networkName string) error { - if networkName == "" { - return nil - } - - network := dc.user.getNetwork(networkName) - if network == nil { - addr := networkName - if !strings.ContainsRune(addr, ':') { - addr = addr + ":6697" - } - - dc.logger.Printf("trying to connect to new network %q", addr) - if err := sanityCheckServer(addr); err != nil { - dc.logger.Printf("failed to connect to %q: %v", addr, err) - return ircError{&irc.Message{ - Command: irc.ERR_PASSWDMISMATCH, - Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)}, - }} - } - - dc.logger.Printf("auto-saving network %q", networkName) - var err error - network, err = dc.user.createNetwork(&Network{ - Addr: networkName, - Nick: dc.nick, - }) - if err != nil { - return err - } - } - - dc.network = network - return nil -} - func (dc *downstreamConn) authenticate(username, password string) error { username, networkName := unmarshalUsername(username) @@ -634,31 +599,82 @@ func (dc *downstreamConn) authenticate(username, password string) error { } dc.user = u - - return dc.setNetwork(networkName) + dc.networkName = networkName + return nil } func (dc *downstreamConn) register() error { + if dc.registered { + return fmt.Errorf("tried to register twice") + } + password := dc.password dc.password = "" if dc.user == nil { if err := dc.authenticate(dc.rawUsername, password); err != nil { return err } - } else if dc.network == nil { - _, networkName := unmarshalUsername(dc.rawUsername) - if err := dc.setNetwork(networkName); err != nil { - return err - } + } + + if dc.networkName == "" { + _, dc.networkName = unmarshalUsername(dc.rawUsername) } dc.registered = true dc.username = dc.user.Username dc.logger.Printf("registration complete for user %q", dc.username) + return nil +} + +func (dc *downstreamConn) loadNetwork() error { + if dc.networkName == "" { + return nil + } + + network := dc.user.getNetwork(dc.networkName) + if network == nil { + addr := dc.networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + }} + } + + dc.logger.Printf("auto-saving network %q", dc.networkName) + var err error + network, err = dc.user.createNetwork(&Network{ + Addr: dc.networkName, + Nick: dc.nick, + }) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) welcome() error { + if dc.user == nil || !dc.registered { + panic("tried to welcome an unregistered connection") + } + + // TODO: doing this might take some time. We should do it in dc.register + // instead, but we'll potentially be adding a new network and this must be + // done in the user goroutine. + if err := dc.loadNetwork(); err != nil { + return err + } - dc.user.lock.Lock() firstDownstream := len(dc.user.downstreamConns) == 0 - dc.user.lock.Unlock() dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), diff --git a/user.go b/user.go index 384e719..804a5e9 100644 --- a/user.go +++ b/user.go @@ -91,7 +91,6 @@ type user struct { events chan event - lock sync.Mutex networks []*network downstreamConns []*downstreamConn } @@ -105,15 +104,12 @@ func newUser(srv *Server, record *User) *user { } func (u *user) forEachNetwork(f func(*network)) { - u.lock.Lock() for _, network := range u.networks { f(network) } - u.lock.Unlock() } func (u *user) forEachUpstream(f func(uc *upstreamConn)) { - u.lock.Lock() for _, network := range u.networks { uc := network.upstream() if uc == nil || !uc.registered || uc.closed { @@ -121,15 +117,12 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) { } f(uc) } - u.lock.Unlock() } func (u *user) forEachDownstream(f func(dc *downstreamConn)) { - u.lock.Lock() for _, dc := range u.downstreamConns { f(dc) } - u.lock.Unlock() } func (u *user) getNetwork(name string) *network { @@ -148,14 +141,12 @@ func (u *user) run() { return } - u.lock.Lock() for _, record := range networks { network := newNetwork(u, &record) u.networks = append(u.networks, network) go network.run() } - u.lock.Unlock() for e := range u.events { switch e := e.(type) { @@ -170,19 +161,21 @@ func (u *user) run() { } case eventDownstreamConnected: dc := e.dc - u.lock.Lock() + + if err := dc.welcome(); err != nil { + dc.logger.Printf("failed to handle new registered connection: %v", err) + break + } + u.downstreamConns = append(u.downstreamConns, dc) - u.lock.Unlock() case eventDownstreamDisconnected: dc := e.dc - u.lock.Lock() for i := range u.downstreamConns { if u.downstreamConns[i] == dc { u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) break } } - u.lock.Unlock() case eventDownstreamMessage: msg, dc := e.msg, e.dc if dc.isClosed() { @@ -220,9 +213,7 @@ func (u *user) createNetwork(net *Network) (*network, error) { } }) - u.lock.Lock() u.networks = append(u.networks, network) - u.lock.Unlock() go network.run() return network, nil