diff --git a/downstream.go b/downstream.go index 991fadc..b26141c 100644 --- a/downstream.go +++ b/downstream.go @@ -289,7 +289,10 @@ type downstreamSASL struct { type downstreamRegistration struct { nick string username string - password string + password string // from PASS + + networkName string + networkID int64 } type downstreamConn struct { @@ -301,7 +304,6 @@ type downstreamConn struct { user *user nick string nickCM string - networkName string clientName string realname string hostname string @@ -776,21 +778,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir return err } - var match *network - for _, net := range dc.user.networks { - if net.ID == id { - match = net - break - } - } - if match == nil { - return ircError{&irc.Message{ - Command: "FAIL", - Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"}, - }} - } - - dc.networkName = match.GetName() + dc.registration.networkID = id } default: dc.logger.Printf("unhandled message: %v", msg) @@ -1243,7 +1231,7 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s return fmt.Errorf("user not active") } dc.clientName = clientName - dc.networkName = networkName + dc.registration.networkName = networkName return nil } @@ -1295,9 +1283,10 @@ func (dc *downstreamConn) register(ctx context.Context) error { Params: []string{dc.nick, "Client name mismatch in usernames"}, }} } - if dc.networkName == "" { - dc.networkName = fallbackNetworkName - } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName { + + if dc.registration.networkName == "" { + dc.registration.networkName = fallbackNetworkName + } else if fallbackNetworkName != "" && dc.registration.networkName != fallbackNetworkName { return ircError{&irc.Message{ Command: irc.ERR_ERRONEUSNICKNAME, Params: []string{dc.nick, "Network name mismatch in usernames"}, @@ -1305,31 +1294,41 @@ func (dc *downstreamConn) register(ctx context.Context) error { } dc.registered = true - dc.registration = nil dc.logger.Printf("registration complete for user %q", dc.user.Username) return nil } func (dc *downstreamConn) loadNetwork(ctx context.Context) error { - if dc.networkName == "*" { + if id := dc.registration.networkID; id != 0 { + network := dc.user.getNetworkByID(id) + if network == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", fmt.Sprintf("%v", id), "Unknown network ID"}, + }} + } + dc.network = network + return nil + } + + if dc.registration.networkName == "*" { if !dc.srv.Config().MultiUpstream { return ircError{&irc.Message{ Command: irc.ERR_PASSWDMISMATCH, Params: []string{dc.nick, fmt.Sprintf("Multi-upstream mode is disabled on this server")}, }} } - dc.networkName = "" dc.isMultiUpstream = true return nil } - if dc.networkName == "" { + if dc.registration.networkName == "" { return nil } - network := dc.user.getNetwork(dc.networkName) + network := dc.user.getNetwork(dc.registration.networkName) if network == nil { - addr := dc.networkName + addr := dc.registration.networkName if !strings.ContainsRune(addr, ':') { addr = addr + ":6697" } @@ -1339,7 +1338,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error { dc.logger.Printf("failed to connect to %q: %v", addr, err) return ircError{&irc.Message{ Command: irc.ERR_PASSWDMISMATCH, - Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.registration.networkName)}, }} } @@ -1360,10 +1359,10 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error { }} } - dc.logger.Printf("auto-saving network %q", dc.networkName) + dc.logger.Printf("auto-saving network %q", dc.registration.networkName) var err error network, err = dc.user.createNetwork(ctx, &Network{ - Addr: dc.networkName, + Addr: dc.registration.networkName, Nick: nick, Enabled: true, }) @@ -1391,6 +1390,8 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { return err } + dc.registration = nil + dc.updateSupportedCaps() if uc := dc.upstream(); uc != nil {