downstream: process BOUNCER BIND in downstreamConn.welcome

This allows clients to send BOUNCER BIND before SASL auth, or to
use BOUNCER BIND with PASS.
This commit is contained in:
Simon Ser 2022-03-21 15:02:54 +01:00
parent 2c691d012d
commit b3425ba1a3

View File

@ -289,7 +289,10 @@ type downstreamSASL struct {
type downstreamRegistration struct { type downstreamRegistration struct {
nick string nick string
username string username string
password string password string // from PASS
networkName string
networkID int64
} }
type downstreamConn struct { type downstreamConn struct {
@ -301,7 +304,6 @@ type downstreamConn struct {
user *user user *user
nick string nick string
nickCM string nickCM string
networkName string
clientName string clientName string
realname string realname string
hostname string hostname string
@ -776,21 +778,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
return err return err
} }
var match *network dc.registration.networkID = id
for _, net := range dc.user.networks {
if net.ID == id {
match = net
break
}
}
if match == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"},
}}
}
dc.networkName = match.GetName()
} }
default: default:
dc.logger.Printf("unhandled message: %v", msg) dc.logger.Printf("unhandled message: %v", msg)
@ -1243,7 +1231,7 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
return fmt.Errorf("user not active") return fmt.Errorf("user not active")
} }
dc.clientName = clientName dc.clientName = clientName
dc.networkName = networkName dc.registration.networkName = networkName
return nil return nil
} }
@ -1295,9 +1283,10 @@ func (dc *downstreamConn) register(ctx context.Context) error {
Params: []string{dc.nick, "Client name mismatch in usernames"}, Params: []string{dc.nick, "Client name mismatch in usernames"},
}} }}
} }
if dc.networkName == "" {
dc.networkName = fallbackNetworkName if dc.registration.networkName == "" {
} else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName { dc.registration.networkName = fallbackNetworkName
} else if fallbackNetworkName != "" && dc.registration.networkName != fallbackNetworkName {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_ERRONEUSNICKNAME, Command: irc.ERR_ERRONEUSNICKNAME,
Params: []string{dc.nick, "Network name mismatch in usernames"}, Params: []string{dc.nick, "Network name mismatch in usernames"},
@ -1305,31 +1294,41 @@ func (dc *downstreamConn) register(ctx context.Context) error {
} }
dc.registered = true dc.registered = true
dc.registration = nil
dc.logger.Printf("registration complete for user %q", dc.user.Username) dc.logger.Printf("registration complete for user %q", dc.user.Username)
return nil return nil
} }
func (dc *downstreamConn) loadNetwork(ctx context.Context) error { func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
if dc.networkName == "*" { if id := dc.registration.networkID; id != 0 {
network := dc.user.getNetworkByID(id)
if network == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"BOUNCER", "INVALID_NETID", fmt.Sprintf("%v", id), "Unknown network ID"},
}}
}
dc.network = network
return nil
}
if dc.registration.networkName == "*" {
if !dc.srv.Config().MultiUpstream { if !dc.srv.Config().MultiUpstream {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH, Command: irc.ERR_PASSWDMISMATCH,
Params: []string{dc.nick, fmt.Sprintf("Multi-upstream mode is disabled on this server")}, Params: []string{dc.nick, fmt.Sprintf("Multi-upstream mode is disabled on this server")},
}} }}
} }
dc.networkName = ""
dc.isMultiUpstream = true dc.isMultiUpstream = true
return nil return nil
} }
if dc.networkName == "" { if dc.registration.networkName == "" {
return nil return nil
} }
network := dc.user.getNetwork(dc.networkName) network := dc.user.getNetwork(dc.registration.networkName)
if network == nil { if network == nil {
addr := dc.networkName addr := dc.registration.networkName
if !strings.ContainsRune(addr, ':') { if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697" addr = addr + ":6697"
} }
@ -1339,7 +1338,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
dc.logger.Printf("failed to connect to %q: %v", addr, err) dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH, Command: irc.ERR_PASSWDMISMATCH,
Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)}, Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.registration.networkName)},
}} }}
} }
@ -1360,10 +1359,10 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
}} }}
} }
dc.logger.Printf("auto-saving network %q", dc.networkName) dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
var err error var err error
network, err = dc.user.createNetwork(ctx, &Network{ network, err = dc.user.createNetwork(ctx, &Network{
Addr: dc.networkName, Addr: dc.registration.networkName,
Nick: nick, Nick: nick,
Enabled: true, Enabled: true,
}) })
@ -1391,6 +1390,8 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
return err return err
} }
dc.registration = nil
dc.updateSupportedCaps() dc.updateSupportedCaps()
if uc := dc.upstream(); uc != nil { if uc := dc.upstream(); uc != nil {