Add context to downstreamConn.handleMessageUnregistered
This commit is contained in:
parent
06ce0b8da9
commit
e459dcdb76
@ -628,6 +628,9 @@ func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.TODO(), handleDownstreamMessageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
switch msg.Command {
|
switch msg.Command {
|
||||||
case "NICK":
|
case "NICK":
|
||||||
var nick string
|
var nick string
|
||||||
@ -697,7 +700,10 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
|||||||
switch mech {
|
switch mech {
|
||||||
case "PLAIN":
|
case "PLAIN":
|
||||||
dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
|
dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
|
||||||
return dc.authenticate(username, password)
|
// TODO: we can't use the command context here, because it
|
||||||
|
// gets cancelled once the command handler returns. SASL
|
||||||
|
// might take multiple AUTHENTICATE commands to complete.
|
||||||
|
return dc.authenticate(context.TODO(), username, password)
|
||||||
}))
|
}))
|
||||||
default:
|
default:
|
||||||
return ircError{&irc.Message{
|
return ircError{&irc.Message{
|
||||||
@ -805,7 +811,7 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
|
|||||||
return newUnknownCommandError(msg.Command)
|
return newUnknownCommandError(msg.Command)
|
||||||
}
|
}
|
||||||
if dc.rawUsername != "" && dc.nick != "" && !dc.negotiatingCaps {
|
if dc.rawUsername != "" && dc.nick != "" && !dc.negotiatingCaps {
|
||||||
return dc.register()
|
return dc.register(ctx)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1068,10 +1074,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
|||||||
return username, client, network
|
return username, client, network
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) authenticate(username, password string) error {
|
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
|
||||||
username, clientName, networkName := unmarshalUsername(username)
|
username, clientName, networkName := unmarshalUsername(username)
|
||||||
|
|
||||||
u, err := dc.srv.db.GetUser(context.TODO(), username)
|
u, err := dc.srv.db.GetUser(ctx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
|
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
|
||||||
return errAuthFailed
|
return errAuthFailed
|
||||||
@ -1098,7 +1104,7 @@ func (dc *downstreamConn) authenticate(username, password string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) register() error {
|
func (dc *downstreamConn) register(ctx context.Context) error {
|
||||||
if dc.registered {
|
if dc.registered {
|
||||||
return fmt.Errorf("tried to register twice")
|
return fmt.Errorf("tried to register twice")
|
||||||
}
|
}
|
||||||
@ -1106,7 +1112,7 @@ func (dc *downstreamConn) register() error {
|
|||||||
password := dc.password
|
password := dc.password
|
||||||
dc.password = ""
|
dc.password = ""
|
||||||
if dc.user == nil {
|
if dc.user == nil {
|
||||||
if err := dc.authenticate(dc.rawUsername, password); err != nil {
|
if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user