Use capRegistry for upstreamConn

This commit is contained in:
Simon Ser 2022-03-14 19:24:39 +01:00
parent 74fd506fef
commit 6e094b1099
2 changed files with 28 additions and 31 deletions

View File

@ -1090,7 +1090,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
}
dc.forEachUpstream(func(uc *upstreamConn) {
for cap, supported := range supportedCaps {
supportedCaps[cap] = supported && uc.caps[cap]
supportedCaps[cap] = supported && uc.caps.IsEnabled(cap)
}
})
@ -1108,10 +1108,10 @@ func (dc *downstreamConn) updateSupportedCaps() {
dc.unsetSupportedCap("sasl")
}
if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] {
if uc := dc.upstream(); uc != nil && uc.caps.IsEnabled("draft/account-registration") {
// Strip "before-connect", because we require downstreams to be fully
// connected before attempting account registration.
values := strings.Split(uc.supportedCaps["draft/account-registration"], ",")
values := strings.Split(uc.caps.Available["draft/account-registration"], ",")
for i, v := range values {
if v == "before-connect" {
values = append(values[:i], values[i+1:]...)
@ -1742,7 +1742,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.forEachNetwork(func(n *network) {
// We only need to call updateNetwork for upstreams that don't
// support setname
if uc := n.conn; uc != nil && uc.caps["setname"] {
if uc := n.conn; uc != nil && uc.caps.IsEnabled("setname") {
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "SETNAME",
Params: []string{realname},
@ -2443,7 +2443,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if err != nil {
return err
}
if _, ok := uc.caps["message-tags"]; !ok {
if !uc.caps.IsEnabled("message-tags") {
continue
}
@ -2500,7 +2500,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// Post-connection-registration AUTHENTICATE is unsupported in
// multi-upstream mode, or if the upstream doesn't support SASL
uc := dc.upstream()
if uc == nil || !uc.caps["sasl"] {
if uc == nil || !uc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{
Command: irc.ERR_SASLFAIL,
Params: []string{dc.nick, "Upstream network authentication not supported"},
@ -2537,7 +2537,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
uc := dc.upstream()
if uc == nil || !uc.caps["draft/account-registration"] {
if uc == nil || !uc.caps.IsEnabled("draft/account-registration") {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},

View File

@ -124,8 +124,7 @@ type upstreamConn struct {
realname string
modes userModes
channels upstreamChannelCasemapMap
supportedCaps map[string]string
caps map[string]bool
caps capRegistry
batches map[string]batch
away bool
account string
@ -241,8 +240,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
network: network,
user: network.user,
channels: upstreamChannelCasemapMap{newCasemapMap(0)},
supportedCaps: make(map[string]string),
caps: make(map[string]bool),
caps: newCapRegistry(),
batches: make(map[string]batch),
availableChannelTypes: stdChannelTypes,
availableChannelModes: stdChannelModes,
@ -563,8 +561,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
caps := strings.Fields(subParams[0])
for _, c := range caps {
delete(uc.supportedCaps, c)
delete(uc.caps, c)
uc.caps.Del(c)
}
if uc.registered {
@ -1824,14 +1821,14 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
if len(kv) == 2 {
v = kv[1]
}
uc.supportedCaps[k] = v
uc.caps.Available[k] = v
}
}
func (uc *upstreamConn) requestCaps() {
var requestCaps []string
for c := range permanentUpstreamCaps {
if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
if uc.caps.IsAvailable(c) && !uc.caps.IsEnabled(c) {
requestCaps = append(requestCaps, c)
}
}
@ -1847,7 +1844,7 @@ func (uc *upstreamConn) requestCaps() {
}
func (uc *upstreamConn) supportsSASL(mech string) bool {
v, ok := uc.supportedCaps["sasl"]
v, ok := uc.caps.Available["sasl"]
if !ok {
return false
}
@ -1873,7 +1870,7 @@ func (uc *upstreamConn) requestSASL() bool {
}
func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
uc.caps[name] = ok
uc.caps.SetEnabled(name, ok)
switch name {
case "sasl":
@ -1998,7 +1995,7 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
}
func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
if !uc.caps["message-tags"] {
if !uc.caps.IsEnabled("message-tags") {
msg = msg.Copy()
msg.Tags = nil
}
@ -2008,7 +2005,7 @@ func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
}
func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
if uc.caps["labeled-response"] {
if uc.caps.IsEnabled("labeled-response") {
if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue)
}