Add support for the upstream echo-message capability

This adds support for upstream echo-message. This capability is
enabled when the upstream supports labeled-response.

When it is enabled, we don't echo downstream messages in the downstream
handler, but rather wait for the upstream to echo it, to produce it to
downstreams.

When it is disabled, we keep the same behaviour as before: produce the
message to all downstreams as soon as it is received from the
downstream.

In other words, the main functional difference is that when the upstream
supports labeled-response, the client will now receive an echo for its
messages when the server acknowledges them, rather than when soju acks
them.

Additionally, uc.produce was refactored to take an ID rather than a
downstream.
This commit is contained in:
delthas 2022-04-10 18:05:12 +02:00 committed by Simon Ser
parent 12577c10bb
commit abe5291b62
2 changed files with 49 additions and 33 deletions

View File

@ -2527,6 +2527,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: upstreamParams, Params: upstreamParams,
}) })
// If the upstream supports echo message, we'll produce the message
// when it is echoed from the upstream.
// Otherwise, produce/log it here because it's the last time we'll see it.
if !uc.caps.IsEnabled("echo-message") {
echoParams := []string{upstreamName} echoParams := []string{upstreamName}
if msg.Command != "TAGMSG" { if msg.Command != "TAGMSG" {
echoParams = append(echoParams, text) echoParams = append(echoParams, text)
@ -2547,7 +2551,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Command: msg.Command, Command: msg.Command,
Params: echoParams, Params: echoParams,
} }
uc.produce(upstreamName, echoMsg, dc) uc.produce(upstreamName, echoMsg, dc.id)
}
uc.updateChannelAutoDetach(upstreamName) uc.updateChannelAutoDetach(upstreamName)
} }

View File

@ -484,15 +484,17 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
uc.produce("", msg, nil) uc.produce("", msg, 0)
} else { // regular user message } else { // regular user message
target := entity target := entity
if uc.isOurNick(target) { if uc.isOurNick(target) {
target = msg.Prefix.Name target = msg.Prefix.Name
} }
self := uc.isOurNick(msg.Prefix.Name)
ch := uc.network.channels.Value(target) ch := uc.network.channels.Value(target)
if ch != nil && msg.Command != "TAGMSG" { if ch != nil && msg.Command != "TAGMSG" && !self {
if ch.Detached { if ch.Detached {
uc.handleDetachedMessage(ctx, ch, msg) uc.handleDetachedMessage(ctx, ch, msg)
} }
@ -503,7 +505,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
} }
uc.produce(target, msg, nil) uc.produce(target, msg, downstreamID)
} }
case "CAP": case "CAP":
var subCmd string var subCmd string
@ -526,7 +528,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
break // wait to receive all capabilities break // wait to receive all capabilities
} }
uc.requestCaps(ctx) uc.updateCaps(ctx)
if uc.requestSASL() { if uc.requestSASL() {
break // we'll send CAP END after authentication is completed break // we'll send CAP END after authentication is completed
@ -563,7 +565,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return newNeedMoreParamsError(msg.Command) return newNeedMoreParamsError(msg.Command)
} }
uc.handleSupportedCaps(subParams[0]) uc.handleSupportedCaps(subParams[0])
uc.requestCaps(ctx) uc.updateCaps(ctx)
case "DEL": case "DEL":
if len(subParams) < 1 { if len(subParams) < 1 {
return newNeedMoreParamsError(msg.Command) return newNeedMoreParamsError(msg.Command)
@ -1011,7 +1013,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
chMsg := msg.Copy() chMsg := msg.Copy()
chMsg.Params[0] = ch chMsg.Params[0] = ch
uc.produce(ch, chMsg, nil) uc.produce(ch, chMsg, 0)
} }
case "PART": case "PART":
var channels string var channels string
@ -1037,7 +1039,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
chMsg := msg.Copy() chMsg := msg.Copy()
chMsg.Params[0] = ch chMsg.Params[0] = ch
uc.produce(ch, chMsg, nil) uc.produce(ch, chMsg, 0)
} }
case "KICK": case "KICK":
var channel, user string var channel, user string
@ -1056,7 +1058,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
ch.Members.Delete(user) ch.Members.Delete(user)
} }
uc.produce(channel, msg, nil) uc.produce(channel, msg, 0)
case "QUIT": case "QUIT":
if uc.isOurNick(msg.Prefix.Name) { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("quit") uc.logger.Printf("quit")
@ -1106,7 +1108,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} else { } else {
ch.Topic = "" ch.Topic = ""
} }
uc.produce(ch.Name, msg, nil) uc.produce(ch.Name, msg, 0)
case "MODE": case "MODE":
var name, modeStr string var name, modeStr string
if err := parseMessageParams(msg, &name, &modeStr); err != nil { if err := parseMessageParams(msg, &name, &modeStr); err != nil {
@ -1851,7 +1853,7 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
} }
} }
func (uc *upstreamConn) requestCaps(ctx context.Context) { func (uc *upstreamConn) updateCaps(ctx context.Context) {
var requestCaps []string var requestCaps []string
for c := range permanentUpstreamCaps { for c := range permanentUpstreamCaps {
if uc.caps.IsAvailable(c) && !uc.caps.IsEnabled(c) { if uc.caps.IsAvailable(c) && !uc.caps.IsEnabled(c) {
@ -1859,6 +1861,13 @@ func (uc *upstreamConn) requestCaps(ctx context.Context) {
} }
} }
echoMessage := uc.caps.IsAvailable("labeled-response")
if !uc.caps.IsEnabled("echo-message") && echoMessage {
requestCaps = append(requestCaps, "echo-message")
} else if uc.caps.IsEnabled("echo-message") && !echoMessage {
requestCaps = append(requestCaps, "-echo-message")
}
if len(requestCaps) == 0 { if len(requestCaps) == 0 {
return return
} }
@ -1924,6 +1933,7 @@ func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool)
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{auth.Mechanism}, Params: []string{auth.Mechanism},
}) })
case "echo-message":
default: default:
if permanentUpstreamCaps[name] { if permanentUpstreamCaps[name] {
break break
@ -2089,9 +2099,10 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
// produce appends a message to the logs and forwards it to connected downstream // produce appends a message to the logs and forwards it to connected downstream
// connections. // connections.
// //
// If origin is not nil and origin doesn't support echo-message, the message is // originID is the id of the downstream (origin) that sent the message. If it is not 0
// forwarded to all connections except origin. // and origin doesn't support echo-message, the message is forwarded to all
func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) { // connections except origin.
func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64) {
var msgID string var msgID string
if target != "" { if target != "" {
msgID = uc.appendLog(target, msg) msgID = uc.appendLog(target, msg)
@ -2102,7 +2113,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
detached := ch != nil && ch.Detached detached := ch != nil && ch.Detached
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
if !detached && (dc != origin || dc.caps.IsEnabled("echo-message")) { if !detached && (dc.id != originID || dc.caps.IsEnabled("echo-message")) {
dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
} else { } else {
dc.advanceMessageWithID(msg, msgID) dc.advanceMessageWithID(msg, msgID)