Add upstream labeled-response capability support

This commit is contained in:
delthas 2020-03-23 03:21:43 +01:00 committed by Simon Ser
parent d0917f0fa1
commit e19f8aaba4
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 55 additions and 2 deletions

1
irc.go
View File

@ -198,4 +198,5 @@ type batch struct {
Type string Type string
Params []string Params []string
Outer *batch // if not-nil, this batch is nested in Outer Outer *batch // if not-nil, this batch is nested in Outer
Label string
} }

View File

@ -3,6 +3,7 @@ package soju
import ( import (
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -51,7 +52,9 @@ type upstreamConn struct {
caps map[string]string caps map[string]string
batches map[string]batch batches map[string]batch
tagsSupported bool tagsSupported bool
labelsSupported bool
nextLabelId uint64
saslClient sasl.Client saslClient sasl.Client
saslStarted bool saslStarted bool
@ -127,6 +130,15 @@ func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
}) })
} }
func (uc *upstreamConn) forEachDownstreamById(id uint64, f func(*downstreamConn)) {
uc.forEachDownstream(func(dc *downstreamConn) {
if id != 0 && id != dc.id {
return
}
f(dc)
})
}
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch, ok := uc.channels[name] ch, ok := uc.channels[name]
if !ok { if !ok {
@ -152,6 +164,11 @@ func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership,
} }
func (uc *upstreamConn) handleMessage(msg *irc.Message) error { func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
var label string
if l, ok := msg.GetTag("label"); ok {
label = l
}
var msgBatch *batch var msgBatch *batch
if batchName, ok := msg.GetTag("batch"); ok { if batchName, ok := msg.GetTag("batch"); ok {
b, ok := uc.batches[batchName] b, ok := uc.batches[batchName]
@ -159,6 +176,21 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName) return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
} }
msgBatch = &b msgBatch = &b
if label == "" {
label = msgBatch.Label
}
}
var downstreamId uint64 = 0
if label != "" {
var labelOffset uint64
n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamId, &labelOffset)
if err == nil && n < 2 {
err = errors.New("not enough arguments")
}
if err != nil {
return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
}
} }
switch msg.Command { switch msg.Command {
@ -204,7 +236,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
requestCaps := make([]string, 0, 16) requestCaps := make([]string, 0, 16)
for _, c := range []string{"message-tags", "batch"} { for _, c := range []string{"message-tags", "batch", "labeled-response"} {
if _, ok := uc.caps[c]; ok { if _, ok := uc.caps[c]; ok {
requestCaps = append(requestCaps, c) requestCaps = append(requestCaps, c)
} }
@ -431,10 +463,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if err := parseMessageParams(msg, nil, &batchType); err != nil { if err := parseMessageParams(msg, nil, &batchType); err != nil {
return err return err
} }
label := label
if label == "" && msgBatch != nil {
label = msgBatch.Label
}
uc.batches[tag] = batch{ uc.batches[tag] = batch{
Type: batchType, Type: batchType,
Params: msg.Params[2:], Params: msg.Params[2:],
Outer: msgBatch, Outer: msgBatch,
Label: label,
} }
} else if strings.HasPrefix(tag, "-") { } else if strings.HasPrefix(tag, "-") {
tag = tag[1:] tag = tag[1:]
@ -949,6 +986,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}) })
case "TAGMSG": case "TAGMSG":
// TODO: relay to downstream connections that accept message-tags // TODO: relay to downstream connections that accept message-tags
case "ACK":
// Ignore
case irc.RPL_YOURHOST, irc.RPL_CREATED: case irc.RPL_YOURHOST, irc.RPL_CREATED:
// Ignore // Ignore
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
@ -1047,6 +1086,8 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
}) })
case "message-tags": case "message-tags":
uc.tagsSupported = ok uc.tagsSupported = ok
case "labeled-response":
uc.labelsSupported = ok
} }
return nil return nil
} }
@ -1073,3 +1114,14 @@ func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
func (uc *upstreamConn) SendMessage(msg *irc.Message) { func (uc *upstreamConn) SendMessage(msg *irc.Message) {
uc.outgoing <- msg uc.outgoing <- msg
} }
func (uc *upstreamConn) SendMessageLabeled(dc *downstreamConn, msg *irc.Message) {
if uc.labelsSupported {
if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue)
}
msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", dc.id, uc.nextLabelId))
uc.nextLabelId++
}
uc.SendMessage(msg)
}