downstream: process BOUNCER BIND in downstreamConn.welcome

This allows clients to send BOUNCER BIND before SASL auth, or to
use BOUNCER BIND with PASS.
This commit is contained in:
Simon Ser 2022-03-21 15:02:54 +01:00
parent 2c691d012d
commit b3425ba1a3

View File

@ -289,7 +289,10 @@ type downstreamSASL struct {
type downstreamRegistration struct {
nick string
username string
password string
password string // from PASS
networkName string
networkID int64
}
type downstreamConn struct {
@ -301,7 +304,6 @@ type downstreamConn struct {
user *user
nick string
nickCM string
networkName string
clientName string
realname string
hostname string
@ -776,21 +778,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
return err
}
var match *network
for _, net := range dc.user.networks {
if net.ID == id {
match = net
break
}
}
if match == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"},
}}
}
dc.networkName = match.GetName()
dc.registration.networkID = id
}
default:
dc.logger.Printf("unhandled message: %v", msg)
@ -1243,7 +1231,7 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
return fmt.Errorf("user not active")
}
dc.clientName = clientName
dc.networkName = networkName
dc.registration.networkName = networkName
return nil
}
@ -1295,9 +1283,10 @@ func (dc *downstreamConn) register(ctx context.Context) error {
Params: []string{dc.nick, "Client name mismatch in usernames"},
}}
}
if dc.networkName == "" {
dc.networkName = fallbackNetworkName
} else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName {
if dc.registration.networkName == "" {
dc.registration.networkName = fallbackNetworkName
} else if fallbackNetworkName != "" && dc.registration.networkName != fallbackNetworkName {
return ircError{&irc.Message{
Command: irc.ERR_ERRONEUSNICKNAME,
Params: []string{dc.nick, "Network name mismatch in usernames"},
@ -1305,31 +1294,41 @@ func (dc *downstreamConn) register(ctx context.Context) error {
}
dc.registered = true
dc.registration = nil
dc.logger.Printf("registration complete for user %q", dc.user.Username)
return nil
}
func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
if dc.networkName == "*" {
if id := dc.registration.networkID; id != 0 {
network := dc.user.getNetworkByID(id)
if network == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"BOUNCER", "INVALID_NETID", fmt.Sprintf("%v", id), "Unknown network ID"},
}}
}
dc.network = network
return nil
}
if dc.registration.networkName == "*" {
if !dc.srv.Config().MultiUpstream {
return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{dc.nick, fmt.Sprintf("Multi-upstream mode is disabled on this server")},
}}
}
dc.networkName = ""
dc.isMultiUpstream = true
return nil
}
if dc.networkName == "" {
if dc.registration.networkName == "" {
return nil
}
network := dc.user.getNetwork(dc.networkName)
network := dc.user.getNetwork(dc.registration.networkName)
if network == nil {
addr := dc.networkName
addr := dc.registration.networkName
if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697"
}
@ -1339,7 +1338,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)},
Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.registration.networkName)},
}}
}
@ -1360,10 +1359,10 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
}}
}
dc.logger.Printf("auto-saving network %q", dc.networkName)
dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
var err error
network, err = dc.user.createNetwork(ctx, &Network{
Addr: dc.networkName,
Addr: dc.registration.networkName,
Nick: nick,
Enabled: true,
})
@ -1391,6 +1390,8 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
return err
}
dc.registration = nil
dc.updateSupportedCaps()
if uc := dc.upstream(); uc != nil {