From 6e094b1099dc3b1f6e40232524ef9566a5f724cd Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 14 Mar 2022 19:24:39 +0100 Subject: [PATCH] Use capRegistry for upstreamConn --- downstream.go | 14 +++++++------- upstream.go | 45 +++++++++++++++++++++------------------------ 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/downstream.go b/downstream.go index 7a41f96..09262d9 100644 --- a/downstream.go +++ b/downstream.go @@ -1090,7 +1090,7 @@ func (dc *downstreamConn) updateSupportedCaps() { } dc.forEachUpstream(func(uc *upstreamConn) { for cap, supported := range supportedCaps { - supportedCaps[cap] = supported && uc.caps[cap] + supportedCaps[cap] = supported && uc.caps.IsEnabled(cap) } }) @@ -1108,10 +1108,10 @@ func (dc *downstreamConn) updateSupportedCaps() { dc.unsetSupportedCap("sasl") } - if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { + if uc := dc.upstream(); uc != nil && uc.caps.IsEnabled("draft/account-registration") { // Strip "before-connect", because we require downstreams to be fully // connected before attempting account registration. - values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") + values := strings.Split(uc.caps.Available["draft/account-registration"], ",") for i, v := range values { if v == "before-connect" { values = append(values[:i], values[i+1:]...) @@ -1742,7 +1742,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. dc.forEachNetwork(func(n *network) { // We only need to call updateNetwork for upstreams that don't // support setname - if uc := n.conn; uc != nil && uc.caps["setname"] { + if uc := n.conn; uc != nil && uc.caps.IsEnabled("setname") { uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "SETNAME", Params: []string{realname}, @@ -2443,7 +2443,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if err != nil { return err } - if _, ok := uc.caps["message-tags"]; !ok { + if !uc.caps.IsEnabled("message-tags") { continue } @@ -2500,7 +2500,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // Post-connection-registration AUTHENTICATE is unsupported in // multi-upstream mode, or if the upstream doesn't support SASL uc := dc.upstream() - if uc == nil || !uc.caps["sasl"] { + if uc == nil || !uc.caps.IsEnabled("sasl") { return ircError{&irc.Message{ Command: irc.ERR_SASLFAIL, Params: []string{dc.nick, "Upstream network authentication not supported"}, @@ -2537,7 +2537,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc := dc.upstream() - if uc == nil || !uc.caps["draft/account-registration"] { + if uc == nil || !uc.caps.IsEnabled("draft/account-registration") { return ircError{&irc.Message{ Command: "FAIL", Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, diff --git a/upstream.go b/upstream.go index 5d70a32..1a8e066 100644 --- a/upstream.go +++ b/upstream.go @@ -117,20 +117,19 @@ type upstreamConn struct { availableMemberships []membership isupport map[string]*string - registered bool - nick string - nickCM string - username string - realname string - modes userModes - channels upstreamChannelCasemapMap - supportedCaps map[string]string - caps map[string]bool - batches map[string]batch - away bool - account string - nextLabelID uint64 - monitored monitorCasemapMap + registered bool + nick string + nickCM string + username string + realname string + modes userModes + channels upstreamChannelCasemapMap + caps capRegistry + batches map[string]batch + away bool + account string + nextLabelID uint64 + monitored monitorCasemapMap saslClient sasl.Client saslStarted bool @@ -241,8 +240,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er network: network, user: network.user, channels: upstreamChannelCasemapMap{newCasemapMap(0)}, - supportedCaps: make(map[string]string), - caps: make(map[string]bool), + caps: newCapRegistry(), batches: make(map[string]batch), availableChannelTypes: stdChannelTypes, availableChannelModes: stdChannelModes, @@ -563,8 +561,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err caps := strings.Fields(subParams[0]) for _, c := range caps { - delete(uc.supportedCaps, c) - delete(uc.caps, c) + uc.caps.Del(c) } if uc.registered { @@ -1824,14 +1821,14 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) { if len(kv) == 2 { v = kv[1] } - uc.supportedCaps[k] = v + uc.caps.Available[k] = v } } func (uc *upstreamConn) requestCaps() { var requestCaps []string for c := range permanentUpstreamCaps { - if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { + if uc.caps.IsAvailable(c) && !uc.caps.IsEnabled(c) { requestCaps = append(requestCaps, c) } } @@ -1847,7 +1844,7 @@ func (uc *upstreamConn) requestCaps() { } func (uc *upstreamConn) supportsSASL(mech string) bool { - v, ok := uc.supportedCaps["sasl"] + v, ok := uc.caps.Available["sasl"] if !ok { return false } @@ -1873,7 +1870,7 @@ func (uc *upstreamConn) requestSASL() bool { } func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { - uc.caps[name] = ok + uc.caps.SetEnabled(name, ok) switch name { case "sasl": @@ -1998,7 +1995,7 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error { } func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { - if !uc.caps["message-tags"] { + if !uc.caps.IsEnabled("message-tags") { msg = msg.Copy() msg.Tags = nil } @@ -2008,7 +2005,7 @@ func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { } func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { - if uc.caps["labeled-response"] { + if uc.caps.IsEnabled("labeled-response") { if msg.Tags == nil { msg.Tags = make(map[string]irc.TagValue) }