Use capRegistry for upstreamConn

This commit is contained in:
Simon Ser 2022-03-14 19:24:39 +01:00
parent 74fd506fef
commit 6e094b1099
2 changed files with 28 additions and 31 deletions

View File

@ -1090,7 +1090,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
} }
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
for cap, supported := range supportedCaps { for cap, supported := range supportedCaps {
supportedCaps[cap] = supported && uc.caps[cap] supportedCaps[cap] = supported && uc.caps.IsEnabled(cap)
} }
}) })
@ -1108,10 +1108,10 @@ func (dc *downstreamConn) updateSupportedCaps() {
dc.unsetSupportedCap("sasl") dc.unsetSupportedCap("sasl")
} }
if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { if uc := dc.upstream(); uc != nil && uc.caps.IsEnabled("draft/account-registration") {
// Strip "before-connect", because we require downstreams to be fully // Strip "before-connect", because we require downstreams to be fully
// connected before attempting account registration. // connected before attempting account registration.
values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") values := strings.Split(uc.caps.Available["draft/account-registration"], ",")
for i, v := range values { for i, v := range values {
if v == "before-connect" { if v == "before-connect" {
values = append(values[:i], values[i+1:]...) values = append(values[:i], values[i+1:]...)
@ -1742,7 +1742,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.forEachNetwork(func(n *network) { dc.forEachNetwork(func(n *network) {
// We only need to call updateNetwork for upstreams that don't // We only need to call updateNetwork for upstreams that don't
// support setname // support setname
if uc := n.conn; uc != nil && uc.caps["setname"] { if uc := n.conn; uc != nil && uc.caps.IsEnabled("setname") {
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "SETNAME", Command: "SETNAME",
Params: []string{realname}, Params: []string{realname},
@ -2443,7 +2443,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if err != nil { if err != nil {
return err return err
} }
if _, ok := uc.caps["message-tags"]; !ok { if !uc.caps.IsEnabled("message-tags") {
continue continue
} }
@ -2500,7 +2500,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// Post-connection-registration AUTHENTICATE is unsupported in // Post-connection-registration AUTHENTICATE is unsupported in
// multi-upstream mode, or if the upstream doesn't support SASL // multi-upstream mode, or if the upstream doesn't support SASL
uc := dc.upstream() uc := dc.upstream()
if uc == nil || !uc.caps["sasl"] { if uc == nil || !uc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_SASLFAIL, Command: irc.ERR_SASLFAIL,
Params: []string{dc.nick, "Upstream network authentication not supported"}, Params: []string{dc.nick, "Upstream network authentication not supported"},
@ -2537,7 +2537,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
uc := dc.upstream() uc := dc.upstream()
if uc == nil || !uc.caps["draft/account-registration"] { if uc == nil || !uc.caps.IsEnabled("draft/account-registration") {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},

View File

@ -117,20 +117,19 @@ type upstreamConn struct {
availableMemberships []membership availableMemberships []membership
isupport map[string]*string isupport map[string]*string
registered bool registered bool
nick string nick string
nickCM string nickCM string
username string username string
realname string realname string
modes userModes modes userModes
channels upstreamChannelCasemapMap channels upstreamChannelCasemapMap
supportedCaps map[string]string caps capRegistry
caps map[string]bool batches map[string]batch
batches map[string]batch away bool
away bool account string
account string nextLabelID uint64
nextLabelID uint64 monitored monitorCasemapMap
monitored monitorCasemapMap
saslClient sasl.Client saslClient sasl.Client
saslStarted bool saslStarted bool
@ -241,8 +240,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
network: network, network: network,
user: network.user, user: network.user,
channels: upstreamChannelCasemapMap{newCasemapMap(0)}, channels: upstreamChannelCasemapMap{newCasemapMap(0)},
supportedCaps: make(map[string]string), caps: newCapRegistry(),
caps: make(map[string]bool),
batches: make(map[string]batch), batches: make(map[string]batch),
availableChannelTypes: stdChannelTypes, availableChannelTypes: stdChannelTypes,
availableChannelModes: stdChannelModes, availableChannelModes: stdChannelModes,
@ -563,8 +561,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
caps := strings.Fields(subParams[0]) caps := strings.Fields(subParams[0])
for _, c := range caps { for _, c := range caps {
delete(uc.supportedCaps, c) uc.caps.Del(c)
delete(uc.caps, c)
} }
if uc.registered { if uc.registered {
@ -1824,14 +1821,14 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
if len(kv) == 2 { if len(kv) == 2 {
v = kv[1] v = kv[1]
} }
uc.supportedCaps[k] = v uc.caps.Available[k] = v
} }
} }
func (uc *upstreamConn) requestCaps() { func (uc *upstreamConn) requestCaps() {
var requestCaps []string var requestCaps []string
for c := range permanentUpstreamCaps { for c := range permanentUpstreamCaps {
if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { if uc.caps.IsAvailable(c) && !uc.caps.IsEnabled(c) {
requestCaps = append(requestCaps, c) requestCaps = append(requestCaps, c)
} }
} }
@ -1847,7 +1844,7 @@ func (uc *upstreamConn) requestCaps() {
} }
func (uc *upstreamConn) supportsSASL(mech string) bool { func (uc *upstreamConn) supportsSASL(mech string) bool {
v, ok := uc.supportedCaps["sasl"] v, ok := uc.caps.Available["sasl"]
if !ok { if !ok {
return false return false
} }
@ -1873,7 +1870,7 @@ func (uc *upstreamConn) requestSASL() bool {
} }
func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
uc.caps[name] = ok uc.caps.SetEnabled(name, ok)
switch name { switch name {
case "sasl": case "sasl":
@ -1998,7 +1995,7 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
} }
func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
if !uc.caps["message-tags"] { if !uc.caps.IsEnabled("message-tags") {
msg = msg.Copy() msg = msg.Copy()
msg.Tags = nil msg.Tags = nil
} }
@ -2008,7 +2005,7 @@ func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
} }
func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
if uc.caps["labeled-response"] { if uc.caps.IsEnabled("labeled-response") {
if msg.Tags == nil { if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue) msg.Tags = make(map[string]irc.TagValue)
} }