Add upstream labeled-response capability support
This commit is contained in:
parent
d0917f0fa1
commit
e19f8aaba4
1
irc.go
1
irc.go
@ -198,4 +198,5 @@ type batch struct {
|
||||
Type string
|
||||
Params []string
|
||||
Outer *batch // if not-nil, this batch is nested in Outer
|
||||
Label string
|
||||
}
|
||||
|
54
upstream.go
54
upstream.go
@ -3,6 +3,7 @@ package soju
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -52,6 +53,8 @@ type upstreamConn struct {
|
||||
batches map[string]batch
|
||||
|
||||
tagsSupported bool
|
||||
labelsSupported bool
|
||||
nextLabelId uint64
|
||||
|
||||
saslClient sasl.Client
|
||||
saslStarted bool
|
||||
@ -127,6 +130,15 @@ func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) forEachDownstreamById(id uint64, f func(*downstreamConn)) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
if id != 0 && id != dc.id {
|
||||
return
|
||||
}
|
||||
f(dc)
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
||||
ch, ok := uc.channels[name]
|
||||
if !ok {
|
||||
@ -152,6 +164,11 @@ func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership,
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
var label string
|
||||
if l, ok := msg.GetTag("label"); ok {
|
||||
label = l
|
||||
}
|
||||
|
||||
var msgBatch *batch
|
||||
if batchName, ok := msg.GetTag("batch"); ok {
|
||||
b, ok := uc.batches[batchName]
|
||||
@ -159,6 +176,21 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
|
||||
}
|
||||
msgBatch = &b
|
||||
if label == "" {
|
||||
label = msgBatch.Label
|
||||
}
|
||||
}
|
||||
|
||||
var downstreamId uint64 = 0
|
||||
if label != "" {
|
||||
var labelOffset uint64
|
||||
n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamId, &labelOffset)
|
||||
if err == nil && n < 2 {
|
||||
err = errors.New("not enough arguments")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
|
||||
}
|
||||
}
|
||||
|
||||
switch msg.Command {
|
||||
@ -204,7 +236,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
}
|
||||
|
||||
requestCaps := make([]string, 0, 16)
|
||||
for _, c := range []string{"message-tags", "batch"} {
|
||||
for _, c := range []string{"message-tags", "batch", "labeled-response"} {
|
||||
if _, ok := uc.caps[c]; ok {
|
||||
requestCaps = append(requestCaps, c)
|
||||
}
|
||||
@ -431,10 +463,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
if err := parseMessageParams(msg, nil, &batchType); err != nil {
|
||||
return err
|
||||
}
|
||||
label := label
|
||||
if label == "" && msgBatch != nil {
|
||||
label = msgBatch.Label
|
||||
}
|
||||
uc.batches[tag] = batch{
|
||||
Type: batchType,
|
||||
Params: msg.Params[2:],
|
||||
Outer: msgBatch,
|
||||
Label: label,
|
||||
}
|
||||
} else if strings.HasPrefix(tag, "-") {
|
||||
tag = tag[1:]
|
||||
@ -949,6 +986,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
})
|
||||
case "TAGMSG":
|
||||
// TODO: relay to downstream connections that accept message-tags
|
||||
case "ACK":
|
||||
// Ignore
|
||||
case irc.RPL_YOURHOST, irc.RPL_CREATED:
|
||||
// Ignore
|
||||
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
|
||||
@ -1047,6 +1086,8 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
|
||||
})
|
||||
case "message-tags":
|
||||
uc.tagsSupported = ok
|
||||
case "labeled-response":
|
||||
uc.labelsSupported = ok
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1073,3 +1114,14 @@ func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
|
||||
func (uc *upstreamConn) SendMessage(msg *irc.Message) {
|
||||
uc.outgoing <- msg
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) SendMessageLabeled(dc *downstreamConn, msg *irc.Message) {
|
||||
if uc.labelsSupported {
|
||||
if msg.Tags == nil {
|
||||
msg.Tags = make(map[string]irc.TagValue)
|
||||
}
|
||||
msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", dc.id, uc.nextLabelId))
|
||||
uc.nextLabelId++
|
||||
}
|
||||
uc.SendMessage(msg)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user