diff --git a/downstream.go b/downstream.go index eded38b..11dd036 100644 --- a/downstream.go +++ b/downstream.go @@ -244,6 +244,11 @@ var passthroughIsupport = map[string]bool{ "WHOX": true, } +type downstreamSASL struct { + server sasl.Server + plainUsername, plainPassword string +} + type downstreamConn struct { conn @@ -267,12 +272,11 @@ type downstreamConn struct { capVersion int supportedCaps map[string]string caps map[string]bool + sasl *downstreamSASL lastBatchRef uint64 monitored casemapMap - - saslServer sasl.Server } func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { @@ -686,102 +690,28 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir return err } case "AUTHENTICATE": - if !dc.caps["sasl"] { - return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_SASLFAIL, - Params: []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, - }} - } - if len(msg.Params) == 0 { - return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_SASLFAIL, - Params: []string{"*", "Missing AUTHENTICATE argument"}, - }} - } - - var resp []byte - if msg.Params[0] == "*" { - dc.saslServer = nil - return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_SASLABORTED, - Params: []string{"*", "SASL authentication aborted"}, - }} - } else if dc.saslServer == nil { - mech := strings.ToUpper(msg.Params[0]) - switch mech { - case "PLAIN": - dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { - // TODO: we can't use the command context here, because it - // gets cancelled once the command handler returns. SASL - // might take multiple AUTHENTICATE commands to complete. - return dc.authenticate(context.TODO(), username, password) - })) - default: - return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_SASLFAIL, - Params: []string{"*", fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, - }} - } - } else if msg.Params[0] == "+" { - resp = nil - } else { - // TODO: multi-line messages - var err error - resp, err = base64.StdEncoding.DecodeString(msg.Params[0]) - if err != nil { - dc.saslServer = nil - return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_SASLFAIL, - Params: []string{"*", "Invalid base64-encoded response"}, - }} - } - } - - challenge, done, err := dc.saslServer.Next(resp) + credentials, err := dc.handleAuthenticateCommand(msg) if err != nil { - dc.saslServer = nil - if ircErr, ok := err.(ircError); ok && ircErr.Message.Command == irc.ERR_PASSWDMISMATCH { - return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_SASLFAIL, - Params: []string{"*", ircErr.Message.Params[1]}, - }} - } - dc.SendMessage(&irc.Message{ + return err + } else if credentials == nil { + break + } + + if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { + dc.logger.Printf("SASL authentication error: %v", err) + dc.endSASL(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, - Params: []string{"*", "SASL error"}, - }) - return fmt.Errorf("SASL authentication failed: %v", err) - } else if done { - dc.saslServer = nil - // Technically we should send RPL_LOGGEDIN here. However we use - // RPL_LOGGEDIN to mirror the upstream connection status. Let's see - // how many clients that breaks. See: - // https://github.com/ircv3/ircv3-specifications/pull/476 - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.RPL_SASLSUCCESS, - Params: []string{dc.nick, "SASL authentication successful"}, - }) - } else { - challengeStr := "+" - if len(challenge) > 0 { - challengeStr = base64.StdEncoding.EncodeToString(challenge) - } - - // TODO: multi-line messages - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: "AUTHENTICATE", - Params: []string{challengeStr}, + Params: []string{"Authentication failed"}, }) + break } + + // Technically we should send RPL_LOGGEDIN here. However we use + // RPL_LOGGEDIN to mirror the upstream connection status. Let's + // see how many clients that breaks. See: + // https://github.com/ircv3/ircv3-specifications/pull/476 + dc.endSASL(nil) case "BOUNCER": var subcommand string if err := parseMessageParams(msg, &subcommand); err != nil { @@ -951,6 +881,107 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { return nil } +func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) { + defer func() { + if err != nil { + dc.sasl = nil + } + }() + + if !dc.caps["sasl"] { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, + }} + } + if len(msg.Params) == 0 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{"*", "Missing AUTHENTICATE argument"}, + }} + } + if msg.Params[0] == "*" { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{"*", "SASL authentication aborted"}, + }} + } + + var resp []byte + if dc.sasl == nil { + mech := strings.ToUpper(msg.Params[0]) + var server sasl.Server + switch mech { + case "PLAIN": + server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { + dc.sasl.plainUsername = username + dc.sasl.plainPassword = password + return nil + })) + default: + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{"*", fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, + }} + } + + dc.sasl = &downstreamSASL{server: server} + } else { + // TODO: multi-line messages + if msg.Params[0] == "+" { + resp = nil + } else if resp, err = base64.StdEncoding.DecodeString(msg.Params[0]); err != nil { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{"*", "Invalid base64-encoded response"}, + }} + } + } + + challenge, done, err := dc.sasl.server.Next(resp) + if err != nil { + return nil, err + } else if done { + return dc.sasl, nil + } else { + challengeStr := "+" + if len(challenge) > 0 { + challengeStr = base64.StdEncoding.EncodeToString(challenge) + } + + // TODO: multi-line messages + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "AUTHENTICATE", + Params: []string{challengeStr}, + }) + return nil, nil + } +} + +func (dc *downstreamConn) endSASL(msg *irc.Message) { + if dc.sasl == nil { + return + } + + dc.sasl = nil + + if msg != nil { + dc.SendMessage(msg) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_SASLSUCCESS, + Params: []string{dc.nick, "SASL authentication successful"}, + }) + } +} + func (dc *downstreamConn) setSupportedCap(name, value string) { prevValue, hasPrev := dc.supportedCaps[name] changed := !hasPrev || prevValue != value @@ -1141,9 +1172,8 @@ func (dc *downstreamConn) register(ctx context.Context) error { return fmt.Errorf("tried to register twice") } - if dc.saslServer != nil { - dc.saslServer = nil - dc.SendMessage(&irc.Message{ + if dc.sasl != nil { + dc.endSASL(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLABORTED, Params: []string{"*", "SASL authentication aborted"}, @@ -2330,6 +2360,40 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Command: "INVITE", Params: []string{upstreamUser, upstreamChannel}, }) + case "AUTHENTICATE": + // Post-connection-registration AUTHENTICATE is unsupported in + // multi-upstream mode, or if the upstream doesn't support SASL + uc := dc.upstream() + if uc == nil || !uc.caps["sasl"] { + return ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Upstream network authentication not supported"}, + }} + } + + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } + + if credentials != nil { + if uc.saslClient != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Another authentication attempt is already in progress"}, + }) + return nil + } + + uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) + uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) + uc.enqueueCommand(dc, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"PLAIN"}, + }) + } case "MONITOR": // MONITOR is unsupported in multi-upstream mode uc := dc.upstream() @@ -2700,23 +2764,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) { username, password, ok := parseNickServCredentials(text, uc.nick) - if !ok { - return - } - - // User may have e.g. EXTERNAL mechanism configured. We do not want to - // automatically erase the key pair or any other credentials. - if uc.network.SASL.Mechanism != "" && uc.network.SASL.Mechanism != "PLAIN" { - return - } - - dc.logger.Printf("auto-saving NickServ credentials with username %q", username) - n := uc.network - n.SASL.Mechanism = "PLAIN" - n.SASL.Plain.Username = username - n.SASL.Plain.Password = password - if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil { - dc.logger.Printf("failed to save NickServ credentials: %v", err) + if ok { + uc.network.autoSaveSASLPlain(ctx, username, password) } } diff --git a/upstream.go b/upstream.go index 294dd2b..5d1decf 100644 --- a/upstream.go +++ b/upstream.go @@ -31,6 +31,7 @@ var permanentUpstreamCaps = map[string]bool{ "labeled-response": true, "message-tags": true, "multi-prefix": true, + "sasl": true, "server-time": true, "setname": true, @@ -293,6 +294,12 @@ func (uc *upstreamConn) endPendingCommands() { Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, mask, "End of /WHO"}, }) + case "AUTHENTICATE": + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) default: panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) } @@ -311,7 +318,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) { func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { switch msg.Command { - case "LIST", "WHO": + case "LIST", "WHO", "AUTHENTICATE": // Supported default: panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) @@ -612,10 +619,20 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.saslClient = nil uc.saslStarted = false - uc.SendMessage(&irc.Message{ - Command: "CAP", - Params: []string{"END"}, - }) + if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { + if msg.Command == irc.RPL_SASLSUCCESS { + uc.network.autoSaveSASLPlain(context.TODO(), dc.sasl.plainUsername, dc.sasl.plainPassword) + } + + dc.endSASL(msg) + } + + if !uc.registered { + uc.SendMessage(&irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + } case irc.RPL_WELCOME: uc.registered = true uc.logger.Printf("connection registered") @@ -1704,10 +1721,6 @@ func (uc *upstreamConn) requestCaps() { } } - if uc.requestSASL() && !uc.caps["sasl"] { - requestCaps = append(requestCaps, "sasl") - } - if len(requestCaps) == 0 { return } @@ -1749,6 +1762,9 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { switch name { case "sasl": + if !uc.requestSASL() { + return nil + } if !ok { uc.logger.Printf("server refused to acknowledge the SASL capability") return nil diff --git a/user.go b/user.go index 2d139da..f94cd56 100644 --- a/user.go +++ b/user.go @@ -404,6 +404,22 @@ func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) boo return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) } +func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { + // User may have e.g. EXTERNAL mechanism configured. We do not want to + // automatically erase the key pair or any other credentials. + if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" { + return + } + + net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username) + net.SASL.Mechanism = "PLAIN" + net.SASL.Plain.Username = username + net.SASL.Plain.Password = password + if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil { + net.logger.Printf("failed to save SASL PLAIN credentials: %v", err) + } +} + type user struct { User srv *Server