downstream: ignore nickname during connection registration

Just force-set the nickname and completely disregard what the client
sets during connection registration. Clients must discover their
effective nickname via RPL_WELCOME.
This commit is contained in:
Simon Ser 2022-03-21 14:37:45 +01:00
parent 2ac9bd9c94
commit bed50c10ce
1 changed files with 47 additions and 35 deletions

View File

@ -286,6 +286,12 @@ type downstreamSASL struct {
pendingResp bytes.Buffer pendingResp bytes.Buffer
} }
type downstreamRegistration struct {
nick string
username string
password string
}
type downstreamConn struct { type downstreamConn struct {
conn conn
@ -295,16 +301,16 @@ type downstreamConn struct {
user *user user *user
nick string nick string
nickCM string nickCM string
rawUsername string
networkName string networkName string
clientName string clientName string
realname string realname string
hostname string hostname string
account string // RPL_LOGGEDIN/OUT state account string // RPL_LOGGEDIN/OUT state
password string // empty after authentication
network *network // can be nil network *network // can be nil
isMultiUpstream bool isMultiUpstream bool
registration *downstreamRegistration // nil after RPL_WELCOME
negotiatingCaps bool negotiatingCaps bool
capVersion int capVersion int
caps capRegistry caps capRegistry
@ -320,12 +326,13 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
options := connOptions{Logger: logger} options := connOptions{Logger: logger}
dc := &downstreamConn{ dc := &downstreamConn{
conn: *newConn(srv, ic, &options), conn: *newConn(srv, ic, &options),
id: id, id: id,
nick: "*", nick: "*",
nickCM: "*", nickCM: "*",
caps: newCapRegistry(), caps: newCapRegistry(),
monitored: newCasemapMap(0), monitored: newCasemapMap(0),
registration: new(downstreamRegistration),
} }
dc.monitored.SetCasemapping(casemapASCII) dc.monitored.SetCasemapping(casemapASCII)
dc.hostname = remoteAddr dc.hostname = remoteAddr
@ -702,31 +709,15 @@ func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) e
func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error { func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error {
switch msg.Command { switch msg.Command {
case "NICK": case "NICK":
var nick string if err := parseMessageParams(msg, &dc.registration.nick); err != nil {
if err := parseMessageParams(msg, &nick); err != nil {
return err return err
} }
if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
return ircError{&irc.Message{
Command: irc.ERR_ERRONEUSNICKNAME,
Params: []string{dc.nick, nick, "Nickname contains illegal characters"},
}}
}
nickCM := casemapASCII(nick)
if nickCM == serviceNickCM {
return ircError{&irc.Message{
Command: irc.ERR_NICKNAMEINUSE,
Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
}}
}
dc.nick = nick
dc.nickCM = nickCM
case "USER": case "USER":
if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { if err := parseMessageParams(msg, &dc.registration.username, nil, nil, nil); err != nil {
return err return err
} }
case "PASS": case "PASS":
if err := parseMessageParams(msg, &dc.password); err != nil { if err := parseMessageParams(msg, &dc.registration.password); err != nil {
return err return err
} }
case "CAP": case "CAP":
@ -805,7 +796,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
dc.logger.Printf("unhandled message: %v", msg) dc.logger.Printf("unhandled message: %v", msg)
return newUnknownCommandError(msg.Command) return newUnknownCommandError(msg.Command)
} }
if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps { if dc.registration.nick != "" && dc.registration.username != "" && !dc.negotiatingCaps {
return dc.register(ctx) return dc.register(ctx)
} }
return nil return nil
@ -1269,8 +1260,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
}) })
} }
password := dc.password password := dc.registration.password
dc.password = "" dc.registration.password = ""
if dc.user == nil { if dc.user == nil {
if password == "" { if password == "" {
if dc.caps.IsEnabled("sasl") { if dc.caps.IsEnabled("sasl") {
@ -1286,8 +1277,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
} }
} }
if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil { if err := dc.authenticate(ctx, dc.registration.username, password); err != nil {
dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err) dc.logger.Printf("PASS authentication error for user %q: %v", dc.registration.username, err)
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH, Command: irc.ERR_PASSWDMISMATCH,
Params: []string{dc.nick, authErrorReason(err)}, Params: []string{dc.nick, authErrorReason(err)},
@ -1295,7 +1286,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
} }
} }
_, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername) _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username)
if dc.clientName == "" { if dc.clientName == "" {
dc.clientName = fallbackClientName dc.clientName = fallbackClientName
} else if fallbackClientName != "" && dc.clientName != fallbackClientName { } else if fallbackClientName != "" && dc.clientName != fallbackClientName {
@ -1314,6 +1305,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
} }
dc.registered = true dc.registered = true
dc.registration = nil
dc.logger.Printf("registration complete for user %q", dc.user.Username) dc.logger.Printf("registration complete for user %q", dc.user.Username)
return nil return nil
} }
@ -1342,7 +1334,19 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
// Some clients only allow specifying the nickname (and use the // Some clients only allow specifying the nickname (and use the
// nickname as a username too). Strip the network name from the // nickname as a username too). Strip the network name from the
// nickname when auto-saving networks. // nickname when auto-saving networks.
nick, _, _ := unmarshalUsername(dc.nick) nick, _, _ := unmarshalUsername(dc.registration.nick)
if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
return ircError{&irc.Message{
Command: irc.ERR_ERRONEUSNICKNAME,
Params: []string{dc.nick, dc.registration.nick, "Nickname contains illegal characters"},
}}
}
if casemapASCII(nick) == serviceNickCM {
return ircError{&irc.Message{
Command: irc.ERR_NICKNAMEINUSE,
Params: []string{dc.nick, dc.registration.nick, "Nickname reserved for bouncer service"},
}}
}
dc.logger.Printf("auto-saving network %q", dc.networkName) dc.logger.Printf("auto-saving network %q", dc.networkName)
var err error var err error
@ -1388,6 +1392,15 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
dc.updateSupportedCaps() dc.updateSupportedCaps()
if uc := dc.upstream(); uc != nil {
dc.nick = uc.nick
} else if dc.network != nil {
dc.nick = GetNick(&dc.user.User, &dc.network.Network)
} else {
dc.nick = dc.user.Username
}
dc.nickCM = casemapASCII(dc.nick)
isupport := []string{ isupport := []string{
fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit), fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit),
"CASEMAPPING=ascii", "CASEMAPPING=ascii",
@ -1452,7 +1465,6 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
}) })
} }
dc.updateNick()
dc.updateRealname() dc.updateRealname()
dc.updateAccount() dc.updateAccount()