diff --git a/downstream.go b/downstream.go index 8207c41..e7fb6d4 100644 --- a/downstream.go +++ b/downstream.go @@ -71,6 +71,7 @@ type downstreamConn struct { nick string username string rawUsername string + networkName string realname string hostname string password string // empty after authentication @@ -582,42 +583,6 @@ func unmarshalUsername(rawUsername string) (username, network string) { return username, network } -func (dc *downstreamConn) setNetwork(networkName string) error { - if networkName == "" { - return nil - } - - network := dc.user.getNetwork(networkName) - if network == nil { - addr := networkName - if !strings.ContainsRune(addr, ':') { - addr = addr + ":6697" - } - - dc.logger.Printf("trying to connect to new network %q", addr) - if err := sanityCheckServer(addr); err != nil { - dc.logger.Printf("failed to connect to %q: %v", addr, err) - return ircError{&irc.Message{ - Command: irc.ERR_PASSWDMISMATCH, - Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)}, - }} - } - - dc.logger.Printf("auto-saving network %q", networkName) - var err error - network, err = dc.user.createNetwork(&Network{ - Addr: networkName, - Nick: dc.nick, - }) - if err != nil { - return err - } - } - - dc.network = network - return nil -} - func (dc *downstreamConn) authenticate(username, password string) error { username, networkName := unmarshalUsername(username) @@ -634,31 +599,82 @@ func (dc *downstreamConn) authenticate(username, password string) error { } dc.user = u - - return dc.setNetwork(networkName) + dc.networkName = networkName + return nil } func (dc *downstreamConn) register() error { + if dc.registered { + return fmt.Errorf("tried to register twice") + } + password := dc.password dc.password = "" if dc.user == nil { if err := dc.authenticate(dc.rawUsername, password); err != nil { return err } - } else if dc.network == nil { - _, networkName := unmarshalUsername(dc.rawUsername) - if err := dc.setNetwork(networkName); err != nil { - return err - } + } + + if dc.networkName == "" { + _, dc.networkName = unmarshalUsername(dc.rawUsername) } dc.registered = true dc.username = dc.user.Username dc.logger.Printf("registration complete for user %q", dc.username) + return nil +} + +func (dc *downstreamConn) loadNetwork() error { + if dc.networkName == "" { + return nil + } + + network := dc.user.getNetwork(dc.networkName) + if network == nil { + addr := dc.networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + }} + } + + dc.logger.Printf("auto-saving network %q", dc.networkName) + var err error + network, err = dc.user.createNetwork(&Network{ + Addr: dc.networkName, + Nick: dc.nick, + }) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) welcome() error { + if dc.user == nil || !dc.registered { + panic("tried to welcome an unregistered connection") + } + + // TODO: doing this might take some time. We should do it in dc.register + // instead, but we'll potentially be adding a new network and this must be + // done in the user goroutine. + if err := dc.loadNetwork(); err != nil { + return err + } - dc.user.lock.Lock() firstDownstream := len(dc.user.downstreamConns) == 0 - dc.user.lock.Unlock() dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), diff --git a/user.go b/user.go index 384e719..804a5e9 100644 --- a/user.go +++ b/user.go @@ -91,7 +91,6 @@ type user struct { events chan event - lock sync.Mutex networks []*network downstreamConns []*downstreamConn } @@ -105,15 +104,12 @@ func newUser(srv *Server, record *User) *user { } func (u *user) forEachNetwork(f func(*network)) { - u.lock.Lock() for _, network := range u.networks { f(network) } - u.lock.Unlock() } func (u *user) forEachUpstream(f func(uc *upstreamConn)) { - u.lock.Lock() for _, network := range u.networks { uc := network.upstream() if uc == nil || !uc.registered || uc.closed { @@ -121,15 +117,12 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) { } f(uc) } - u.lock.Unlock() } func (u *user) forEachDownstream(f func(dc *downstreamConn)) { - u.lock.Lock() for _, dc := range u.downstreamConns { f(dc) } - u.lock.Unlock() } func (u *user) getNetwork(name string) *network { @@ -148,14 +141,12 @@ func (u *user) run() { return } - u.lock.Lock() for _, record := range networks { network := newNetwork(u, &record) u.networks = append(u.networks, network) go network.run() } - u.lock.Unlock() for e := range u.events { switch e := e.(type) { @@ -170,19 +161,21 @@ func (u *user) run() { } case eventDownstreamConnected: dc := e.dc - u.lock.Lock() + + if err := dc.welcome(); err != nil { + dc.logger.Printf("failed to handle new registered connection: %v", err) + break + } + u.downstreamConns = append(u.downstreamConns, dc) - u.lock.Unlock() case eventDownstreamDisconnected: dc := e.dc - u.lock.Lock() for i := range u.downstreamConns { if u.downstreamConns[i] == dc { u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) break } } - u.lock.Unlock() case eventDownstreamMessage: msg, dc := e.msg, e.dc if dc.isClosed() { @@ -220,9 +213,7 @@ func (u *user) createNetwork(net *Network) (*network, error) { } }) - u.lock.Lock() u.networks = append(u.networks, network) - u.lock.Unlock() go network.run() return network, nil