diff --git a/downstream.go b/downstream.go index ceda5fd..6be292e 100644 --- a/downstream.go +++ b/downstream.go @@ -628,6 +628,9 @@ func (dc *downstreamConn) handleMessage(msg *irc.Message) error { } func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { + ctx, cancel := context.WithTimeout(context.TODO(), handleDownstreamMessageTimeout) + defer cancel() + switch msg.Command { case "NICK": var nick string @@ -697,7 +700,10 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { switch mech { case "PLAIN": dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { - return dc.authenticate(username, password) + // TODO: we can't use the command context here, because it + // gets cancelled once the command handler returns. SASL + // might take multiple AUTHENTICATE commands to complete. + return dc.authenticate(context.TODO(), username, password) })) default: return ircError{&irc.Message{ @@ -805,7 +811,7 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { return newUnknownCommandError(msg.Command) } if dc.rawUsername != "" && dc.nick != "" && !dc.negotiatingCaps { - return dc.register() + return dc.register(ctx) } return nil } @@ -1068,10 +1074,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) { return username, client, network } -func (dc *downstreamConn) authenticate(username, password string) error { +func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { username, clientName, networkName := unmarshalUsername(username) - u, err := dc.srv.db.GetUser(context.TODO(), username) + u, err := dc.srv.db.GetUser(ctx, username) if err != nil { dc.logger.Printf("failed authentication for %q: user not found: %v", username, err) return errAuthFailed @@ -1098,7 +1104,7 @@ func (dc *downstreamConn) authenticate(username, password string) error { return nil } -func (dc *downstreamConn) register() error { +func (dc *downstreamConn) register(ctx context.Context) error { if dc.registered { return fmt.Errorf("tried to register twice") } @@ -1106,7 +1112,7 @@ func (dc *downstreamConn) register() error { password := dc.password dc.password = "" if dc.user == nil { - if err := dc.authenticate(dc.rawUsername, password); err != nil { + if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil { return err } }