From e19f8aaba4ea271fbac94e73f142a8bbc953d6e4 Mon Sep 17 00:00:00 2001 From: delthas Date: Mon, 23 Mar 2020 03:21:43 +0100 Subject: [PATCH] Add upstream labeled-response capability support --- irc.go | 1 + upstream.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/irc.go b/irc.go index 4623472..1ea6ed7 100644 --- a/irc.go +++ b/irc.go @@ -198,4 +198,5 @@ type batch struct { Type string Params []string Outer *batch // if not-nil, this batch is nested in Outer + Label string } diff --git a/upstream.go b/upstream.go index a84f438..49ccbaf 100644 --- a/upstream.go +++ b/upstream.go @@ -3,6 +3,7 @@ package soju import ( "crypto/tls" "encoding/base64" + "errors" "fmt" "io" "net" @@ -51,7 +52,9 @@ type upstreamConn struct { caps map[string]string batches map[string]batch - tagsSupported bool + tagsSupported bool + labelsSupported bool + nextLabelId uint64 saslClient sasl.Client saslStarted bool @@ -127,6 +130,15 @@ func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { }) } +func (uc *upstreamConn) forEachDownstreamById(id uint64, f func(*downstreamConn)) { + uc.forEachDownstream(func(dc *downstreamConn) { + if id != 0 && id != dc.id { + return + } + f(dc) + }) +} + func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { ch, ok := uc.channels[name] if !ok { @@ -152,6 +164,11 @@ func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, } func (uc *upstreamConn) handleMessage(msg *irc.Message) error { + var label string + if l, ok := msg.GetTag("label"); ok { + label = l + } + var msgBatch *batch if batchName, ok := msg.GetTag("batch"); ok { b, ok := uc.batches[batchName] @@ -159,6 +176,21 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName) } msgBatch = &b + if label == "" { + label = msgBatch.Label + } + } + + var downstreamId uint64 = 0 + if label != "" { + var labelOffset uint64 + n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamId, &labelOffset) + if err == nil && n < 2 { + err = errors.New("not enough arguments") + } + if err != nil { + return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + } } switch msg.Command { @@ -204,7 +236,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } requestCaps := make([]string, 0, 16) - for _, c := range []string{"message-tags", "batch"} { + for _, c := range []string{"message-tags", "batch", "labeled-response"} { if _, ok := uc.caps[c]; ok { requestCaps = append(requestCaps, c) } @@ -431,10 +463,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { if err := parseMessageParams(msg, nil, &batchType); err != nil { return err } + label := label + if label == "" && msgBatch != nil { + label = msgBatch.Label + } uc.batches[tag] = batch{ Type: batchType, Params: msg.Params[2:], Outer: msgBatch, + Label: label, } } else if strings.HasPrefix(tag, "-") { tag = tag[1:] @@ -949,6 +986,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { }) case "TAGMSG": // TODO: relay to downstream connections that accept message-tags + case "ACK": + // Ignore case irc.RPL_YOURHOST, irc.RPL_CREATED: // Ignore case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: @@ -1047,6 +1086,8 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { }) case "message-tags": uc.tagsSupported = ok + case "labeled-response": + uc.labelsSupported = ok } return nil } @@ -1073,3 +1114,14 @@ func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error { func (uc *upstreamConn) SendMessage(msg *irc.Message) { uc.outgoing <- msg } + +func (uc *upstreamConn) SendMessageLabeled(dc *downstreamConn, msg *irc.Message) { + if uc.labelsSupported { + if msg.Tags == nil { + msg.Tags = make(map[string]irc.TagValue) + } + msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", dc.id, uc.nextLabelId)) + uc.nextLabelId++ + } + uc.SendMessage(msg) +}