diff --git a/downstream.go b/downstream.go index a4ed984..7d0e5c8 100644 --- a/downstream.go +++ b/downstream.go @@ -1079,6 +1079,21 @@ func (dc *downstreamConn) updateSupportedCaps() { dc.unsetSupportedCap("sasl") } + if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { + // Strip "before-connect", because we require downstreams to be fully + // connected before attempting account registration. + values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") + for i, v := range values { + if v == "before-connect" { + values = append(values[:i], values[i+1:]...) + break + } + } + dc.setSupportedCap("draft/account-registration", strings.Join(values, ",")) + } else { + dc.unsetSupportedCap("draft/account-registration") + } + if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { dc.setSupportedCap("draft/event-playback", "") } else { @@ -2408,7 +2423,6 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. uc := dc.upstream() if uc == nil || !uc.caps["sasl"] { return ircError{&irc.Message{ - Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, Params: []string{dc.nick, "Upstream network authentication not supported"}, }} @@ -2436,6 +2450,23 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{"PLAIN"}, }) } + case "REGISTER", "VERIFY": + // Check number of params here, since we'll use that to save the + // credentials on command success + if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) { + return newNeedMoreParamsError(msg.Command) + } + + uc := dc.upstream() + if uc == nil || !uc.caps["draft/account-registration"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, + }} + } + + uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0]) + uc.enqueueCommand(dc, msg) case "MONITOR": // MONITOR is unsupported in multi-upstream mode uc := dc.upstream() diff --git a/upstream.go b/upstream.go index a25b3ed..6cc02fc 100644 --- a/upstream.go +++ b/upstream.go @@ -35,7 +35,8 @@ var permanentUpstreamCaps = map[string]bool{ "server-time": true, "setname": true, - "draft/extended-monitor": true, + "draft/account-registration": true, + "draft/extended-monitor": true, } type registrationError string @@ -300,6 +301,12 @@ func (uc *upstreamConn) endPendingCommands() { Command: irc.ERR_SASLABORTED, Params: []string{dc.nick, "SASL authentication aborted"}, }) + case "REGISTER", "VERIFY": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "FAIL", + Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, + }) default: panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) } @@ -318,7 +325,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) { func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { switch msg.Command { - case "LIST", "WHO", "AUTHENTICATE": + case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY": // Supported default: panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) @@ -633,6 +640,21 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{"END"}, }) } + case "REGISTER", "VERIFY": + if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil { + if msg.Command == "REGISTER" { + var account, password string + if err := parseMessageParams(msg, nil, &account); err != nil { + return err + } + if err := parseMessageParams(cmd, nil, nil, &password); err != nil { + return err + } + uc.network.autoSaveSASLPlain(context.TODO(), account, password) + } + + dc.SendMessage(msg) + } case irc.RPL_WELCOME: uc.registered = true uc.logger.Printf("connection registered") @@ -1569,11 +1591,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - if command == "LIST" || command == "WHO" { - dc, _ := uc.dequeueCommand(command) - if dc != nil && downstreamID == 0 { - downstreamID = dc.id - } + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id } uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { @@ -1583,6 +1602,19 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{dc.nick, command, reason}, }) }) + case "FAIL": + var command string + if err := parseMessageParams(msg, &command); err != nil { + return err + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(msg) + }) case "ACK": // Ignore case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: