Add context to downstreamConn.handleMessageUnregistered

This commit is contained in:
Simon Ser 2021-11-17 12:29:23 +01:00
parent 06ce0b8da9
commit e459dcdb76

View File

@ -628,6 +628,9 @@ func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
} }
func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
ctx, cancel := context.WithTimeout(context.TODO(), handleDownstreamMessageTimeout)
defer cancel()
switch msg.Command { switch msg.Command {
case "NICK": case "NICK":
var nick string var nick string
@ -697,7 +700,10 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
switch mech { switch mech {
case "PLAIN": case "PLAIN":
dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
return dc.authenticate(username, password) // TODO: we can't use the command context here, because it
// gets cancelled once the command handler returns. SASL
// might take multiple AUTHENTICATE commands to complete.
return dc.authenticate(context.TODO(), username, password)
})) }))
default: default:
return ircError{&irc.Message{ return ircError{&irc.Message{
@ -805,7 +811,7 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
return newUnknownCommandError(msg.Command) return newUnknownCommandError(msg.Command)
} }
if dc.rawUsername != "" && dc.nick != "" && !dc.negotiatingCaps { if dc.rawUsername != "" && dc.nick != "" && !dc.negotiatingCaps {
return dc.register() return dc.register(ctx)
} }
return nil return nil
} }
@ -1068,10 +1074,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
return username, client, network return username, client, network
} }
func (dc *downstreamConn) authenticate(username, password string) error { func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
username, clientName, networkName := unmarshalUsername(username) username, clientName, networkName := unmarshalUsername(username)
u, err := dc.srv.db.GetUser(context.TODO(), username) u, err := dc.srv.db.GetUser(ctx, username)
if err != nil { if err != nil {
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err) dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
return errAuthFailed return errAuthFailed
@ -1098,7 +1104,7 @@ func (dc *downstreamConn) authenticate(username, password string) error {
return nil return nil
} }
func (dc *downstreamConn) register() error { func (dc *downstreamConn) register(ctx context.Context) error {
if dc.registered { if dc.registered {
return fmt.Errorf("tried to register twice") return fmt.Errorf("tried to register twice")
} }
@ -1106,7 +1112,7 @@ func (dc *downstreamConn) register() error {
password := dc.password password := dc.password
dc.password = "" dc.password = ""
if dc.user == nil { if dc.user == nil {
if err := dc.authenticate(dc.rawUsername, password); err != nil { if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil {
return err return err
} }
} }