Use capRegistry for upstreamConn

This commit is contained in:
Simon Ser 2022-03-14 19:24:39 +01:00
parent 74fd506fef
commit 6e094b1099
2 changed files with 28 additions and 31 deletions

View File

@ -1090,7 +1090,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
} }
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
for cap, supported := range supportedCaps { for cap, supported := range supportedCaps {
supportedCaps[cap] = supported && uc.caps[cap] supportedCaps[cap] = supported && uc.caps.IsEnabled(cap)
} }
}) })
@ -1108,10 +1108,10 @@ func (dc *downstreamConn) updateSupportedCaps() {
dc.unsetSupportedCap("sasl") dc.unsetSupportedCap("sasl")
} }
if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { if uc := dc.upstream(); uc != nil && uc.caps.IsEnabled("draft/account-registration") {
// Strip "before-connect", because we require downstreams to be fully // Strip "before-connect", because we require downstreams to be fully
// connected before attempting account registration. // connected before attempting account registration.
values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") values := strings.Split(uc.caps.Available["draft/account-registration"], ",")
for i, v := range values { for i, v := range values {
if v == "before-connect" { if v == "before-connect" {
values = append(values[:i], values[i+1:]...) values = append(values[:i], values[i+1:]...)
@ -1742,7 +1742,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.forEachNetwork(func(n *network) { dc.forEachNetwork(func(n *network) {
// We only need to call updateNetwork for upstreams that don't // We only need to call updateNetwork for upstreams that don't
// support setname // support setname
if uc := n.conn; uc != nil && uc.caps["setname"] { if uc := n.conn; uc != nil && uc.caps.IsEnabled("setname") {
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "SETNAME", Command: "SETNAME",
Params: []string{realname}, Params: []string{realname},
@ -2443,7 +2443,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if err != nil { if err != nil {
return err return err
} }
if _, ok := uc.caps["message-tags"]; !ok { if !uc.caps.IsEnabled("message-tags") {
continue continue
} }
@ -2500,7 +2500,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// Post-connection-registration AUTHENTICATE is unsupported in // Post-connection-registration AUTHENTICATE is unsupported in
// multi-upstream mode, or if the upstream doesn't support SASL // multi-upstream mode, or if the upstream doesn't support SASL
uc := dc.upstream() uc := dc.upstream()
if uc == nil || !uc.caps["sasl"] { if uc == nil || !uc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_SASLFAIL, Command: irc.ERR_SASLFAIL,
Params: []string{dc.nick, "Upstream network authentication not supported"}, Params: []string{dc.nick, "Upstream network authentication not supported"},
@ -2537,7 +2537,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
uc := dc.upstream() uc := dc.upstream()
if uc == nil || !uc.caps["draft/account-registration"] { if uc == nil || !uc.caps.IsEnabled("draft/account-registration") {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},

View File

@ -124,8 +124,7 @@ type upstreamConn struct {
realname string realname string
modes userModes modes userModes
channels upstreamChannelCasemapMap channels upstreamChannelCasemapMap
supportedCaps map[string]string caps capRegistry
caps map[string]bool
batches map[string]batch batches map[string]batch
away bool away bool
account string account string
@ -241,8 +240,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
network: network, network: network,
user: network.user, user: network.user,
channels: upstreamChannelCasemapMap{newCasemapMap(0)}, channels: upstreamChannelCasemapMap{newCasemapMap(0)},
supportedCaps: make(map[string]string), caps: newCapRegistry(),
caps: make(map[string]bool),
batches: make(map[string]batch), batches: make(map[string]batch),
availableChannelTypes: stdChannelTypes, availableChannelTypes: stdChannelTypes,
availableChannelModes: stdChannelModes, availableChannelModes: stdChannelModes,
@ -563,8 +561,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
caps := strings.Fields(subParams[0]) caps := strings.Fields(subParams[0])
for _, c := range caps { for _, c := range caps {
delete(uc.supportedCaps, c) uc.caps.Del(c)
delete(uc.caps, c)
} }
if uc.registered { if uc.registered {
@ -1824,14 +1821,14 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
if len(kv) == 2 { if len(kv) == 2 {
v = kv[1] v = kv[1]
} }
uc.supportedCaps[k] = v uc.caps.Available[k] = v
} }
} }
func (uc *upstreamConn) requestCaps() { func (uc *upstreamConn) requestCaps() {
var requestCaps []string var requestCaps []string
for c := range permanentUpstreamCaps { for c := range permanentUpstreamCaps {
if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { if uc.caps.IsAvailable(c) && !uc.caps.IsEnabled(c) {
requestCaps = append(requestCaps, c) requestCaps = append(requestCaps, c)
} }
} }
@ -1847,7 +1844,7 @@ func (uc *upstreamConn) requestCaps() {
} }
func (uc *upstreamConn) supportsSASL(mech string) bool { func (uc *upstreamConn) supportsSASL(mech string) bool {
v, ok := uc.supportedCaps["sasl"] v, ok := uc.caps.Available["sasl"]
if !ok { if !ok {
return false return false
} }
@ -1873,7 +1870,7 @@ func (uc *upstreamConn) requestSASL() bool {
} }
func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
uc.caps[name] = ok uc.caps.SetEnabled(name, ok)
switch name { switch name {
case "sasl": case "sasl":
@ -1998,7 +1995,7 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
} }
func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
if !uc.caps["message-tags"] { if !uc.caps.IsEnabled("message-tags") {
msg = msg.Copy() msg = msg.Copy()
msg.Tags = nil msg.Tags = nil
} }
@ -2008,7 +2005,7 @@ func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
} }
func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
if uc.caps["labeled-response"] { if uc.caps.IsEnabled("labeled-response") {
if msg.Tags == nil { if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue) msg.Tags = make(map[string]irc.TagValue)
} }