From 2a569c3b27e1fadfa27f45cd591985fab45cb603 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 30 Apr 2020 15:27:41 +0200 Subject: [PATCH] Add upstream cap-notify support --- upstream.go | 91 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 27 deletions(-) diff --git a/upstream.go b/upstream.go index 56ca018..cf49c57 100644 --- a/upstream.go +++ b/upstream.go @@ -298,40 +298,16 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { if len(subParams) < 1 { return newNeedMoreParamsError(msg.Command) } - caps := strings.Fields(subParams[len(subParams)-1]) + caps := subParams[len(subParams)-1] more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*" - for _, s := range caps { - kv := strings.SplitN(s, "=", 2) - k := strings.ToLower(kv[0]) - var v string - if len(kv) == 2 { - v = kv[1] - } - uc.supportedCaps[k] = v - } + uc.handleSupportedCaps(caps) if more { break // wait to receive all capabilities } - requestCaps := make([]string, 0, 16) - for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time", "away-notify"} { - if _, ok := uc.supportedCaps[c]; ok { - requestCaps = append(requestCaps, c) - } - } - - if uc.requestSASL() { - requestCaps = append(requestCaps, "sasl") - } - - if len(requestCaps) > 0 { - uc.SendMessage(&irc.Message{ - Command: "CAP", - Params: []string{"REQ", strings.Join(requestCaps, " ")}, - }) - } + uc.requestCaps() if uc.requestSASL() { break // we'll send CAP END after authentication is completed @@ -359,6 +335,34 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{"END"}, }) } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + case "NEW": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + uc.handleSupportedCaps(subParams[0]) + uc.requestCaps() + case "DEL": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, c := range caps { + delete(uc.supportedCaps, c) + delete(uc.caps, c) + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } default: uc.logger.Printf("unhandled message: %v", msg) } @@ -1190,6 +1194,39 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return nil } +func (uc *upstreamConn) handleSupportedCaps(capsStr string) { + caps := strings.Fields(capsStr) + for _, s := range caps { + kv := strings.SplitN(s, "=", 2) + k := strings.ToLower(kv[0]) + var v string + if len(kv) == 2 { + v = kv[1] + } + uc.supportedCaps[k] = v + } +} + +func (uc *upstreamConn) requestCaps() { + var requestCaps []string + for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time", "away-notify"} { + if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { + requestCaps = append(requestCaps, c) + } + } + + if uc.requestSASL() && !uc.caps["sasl"] { + requestCaps = append(requestCaps, "sasl") + } + + if len(requestCaps) > 0 { + uc.SendMessage(&irc.Message{ + Command: "CAP", + Params: []string{"REQ", strings.Join(requestCaps, " ")}, + }) + } +} + func splitSpace(s string) []string { return strings.FieldsFunc(s, func(r rune) bool { return r == ' '