From 74fd506fef9b796e0d75087fe3bb76815b8a308f Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 14 Mar 2022 19:15:35 +0100 Subject: [PATCH] Use capRegistry for downstreamConn --- bridge.go | 2 +- downstream.go | 102 ++++++++++++++++++++++++-------------------------- upstream.go | 4 +- user.go | 14 +++---- 4 files changed, 58 insertions(+), 64 deletions(-) diff --git a/bridge.go b/bridge.go index fda02ff..7553ac9 100644 --- a/bridge.go +++ b/bridge.go @@ -19,7 +19,7 @@ func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel sendTopic(dc, ch) } - if dc.caps["soju.im/read"] { + if dc.caps.IsEnabled("soju.im/read") { channelCM := ch.conn.network.casemap(ch.Name) r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM) if err != nil { diff --git a/downstream.go b/downstream.go index 0a2852f..7a41f96 100644 --- a/downstream.go +++ b/downstream.go @@ -307,8 +307,7 @@ type downstreamConn struct { negotiatingCaps bool capVersion int - supportedCaps map[string]string - caps map[string]bool + caps capRegistry sasl *downstreamSASL lastBatchRef uint64 @@ -321,13 +320,12 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} options := connOptions{Logger: logger} dc := &downstreamConn{ - conn: *newConn(srv, ic, &options), - id: id, - nick: "*", - nickCM: "*", - supportedCaps: make(map[string]string), - caps: make(map[string]bool), - monitored: newCasemapMap(0), + conn: *newConn(srv, ic, &options), + id: id, + nick: "*", + nickCM: "*", + caps: newCapRegistry(), + monitored: newCasemapMap(0), } dc.monitored.SetCasemapping(casemapASCII) dc.hostname = remoteAddr @@ -335,14 +333,14 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { dc.hostname = host } for k, v := range permanentDownstreamCaps { - dc.supportedCaps[k] = v + dc.caps.Available[k] = v } - dc.supportedCaps["sasl"] = "PLAIN" + dc.caps.Available["sasl"] = "PLAIN" // TODO: this is racy, we should only enable chathistory after // authentication and then check that user.msgStore implements // chatHistoryMessageStore if srv.Config().LogPath != "" { - dc.supportedCaps["draft/chathistory"] = "" + dc.caps.Available["draft/chathistory"] = "" } return dc } @@ -527,7 +525,7 @@ func (dc *downstreamConn) readMessages(ch chan<- event) error { // // This can only called from the user goroutine. func (dc *downstreamConn) SendMessage(msg *irc.Message) { - if !dc.caps["message-tags"] { + if !dc.caps.IsEnabled("message-tags") { if msg.Command == "TAGMSG" { return } @@ -536,32 +534,32 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) { supported := false switch name { case "time": - supported = dc.caps["server-time"] + supported = dc.caps.IsEnabled("server-time") case "account": - supported = dc.caps["account"] + supported = dc.caps.IsEnabled("account") } if !supported { delete(msg.Tags, name) } } } - if !dc.caps["batch"] && msg.Tags["batch"] != "" { + if !dc.caps.IsEnabled("batch") && msg.Tags["batch"] != "" { msg = msg.Copy() delete(msg.Tags, "batch") } - if msg.Command == "JOIN" && !dc.caps["extended-join"] { + if msg.Command == "JOIN" && !dc.caps.IsEnabled("extended-join") { msg.Params = msg.Params[:1] } - if msg.Command == "SETNAME" && !dc.caps["setname"] { + if msg.Command == "SETNAME" && !dc.caps.IsEnabled("setname") { return } - if msg.Command == "AWAY" && !dc.caps["away-notify"] { + if msg.Command == "AWAY" && !dc.caps.IsEnabled("away-notify") { return } - if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] { + if msg.Command == "ACCOUNT" && !dc.caps.IsEnabled("account-notify") { return } - if msg.Command == "READ" && !dc.caps["soju.im/read"] { + if msg.Command == "READ" && !dc.caps.IsEnabled("soju.im/read") { return } @@ -573,7 +571,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, dc.lastBatchRef++ ref := fmt.Sprintf("%v", dc.lastBatchRef) - if dc.caps["batch"] { + if dc.caps.IsEnabled("batch") { dc.SendMessage(&irc.Message{ Tags: tags, Prefix: dc.srv.prefix(), @@ -584,7 +582,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f(irc.TagValue(ref)) - if dc.caps["batch"] { + if dc.caps.IsEnabled("batch") { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "BATCH", @@ -597,7 +595,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { dc.SendMessage(msg) - if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps["draft/chathistory"] { + if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") { return } @@ -608,7 +606,7 @@ func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { // sending a message. This is useful e.g. for self-messages when echo-message // isn't enabled. func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { - if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps["draft/chathistory"] { + if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") { return } @@ -829,12 +827,12 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { // down the available capabilities when upstreams are // known. for k, v := range needAllDownstreamCaps { - dc.supportedCaps[k] = v + dc.caps.Available[k] = v } } - caps := make([]string, 0, len(dc.supportedCaps)) - for k, v := range dc.supportedCaps { + caps := make([]string, 0, len(dc.caps.Available)) + for k, v := range dc.caps.Available { if dc.capVersion >= 302 && v != "" { caps = append(caps, k+"="+v) } else { @@ -851,7 +849,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { if dc.capVersion >= 302 { // CAP version 302 implicitly enables cap-notify - dc.caps["cap-notify"] = true + dc.caps.SetEnabled("cap-notify", true) } if !dc.registered { @@ -859,10 +857,8 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { } case "LIST": var caps []string - for name, enabled := range dc.caps { - if enabled { - caps = append(caps, name) - } + for name := range dc.caps.Enabled { + caps = append(caps, name) } // TODO: multi-line replies @@ -889,12 +885,11 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { name = strings.TrimPrefix(name, "-") } - if enable == dc.caps[name] { + if enable == dc.caps.IsEnabled(name) { continue } - _, ok := dc.supportedCaps[name] - if !ok { + if !dc.caps.IsAvailable(name) { ack = false break } @@ -905,7 +900,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { break } - dc.caps[name] = enable + dc.caps.SetEnabled(name, enable) } reply := "NAK" @@ -939,7 +934,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d } }() - if !dc.caps["sasl"] { + if !dc.caps.IsEnabled("sasl") { return nil, ircError{&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, @@ -1053,11 +1048,11 @@ func (dc *downstreamConn) endSASL(msg *irc.Message) { } func (dc *downstreamConn) setSupportedCap(name, value string) { - prevValue, hasPrev := dc.supportedCaps[name] + prevValue, hasPrev := dc.caps.Available[name] changed := !hasPrev || prevValue != value - dc.supportedCaps[name] = value + dc.caps.Available[name] = value - if !dc.caps["cap-notify"] || !changed { + if !dc.caps.IsEnabled("cap-notify") || !changed { return } @@ -1074,11 +1069,10 @@ func (dc *downstreamConn) setSupportedCap(name, value string) { } func (dc *downstreamConn) unsetSupportedCap(name string) { - _, hasPrev := dc.supportedCaps[name] - delete(dc.supportedCaps, name) - delete(dc.caps, name) + hasPrev := dc.caps.IsAvailable(name) + dc.caps.Del(name) - if !dc.caps["cap-notify"] || !hasPrev { + if !dc.caps.IsEnabled("cap-notify") || !hasPrev { return } @@ -1149,7 +1143,7 @@ func (dc *downstreamConn) updateNick() { } func (dc *downstreamConn) updateRealname() { - if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] { + if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps.IsEnabled("setname") { dc.SendMessage(&irc.Message{ Prefix: dc.prefix(), Command: "SETNAME", @@ -1169,7 +1163,7 @@ func (dc *downstreamConn) updateAccount() { return } - if dc.account == account || !dc.caps["sasl"] { + if dc.account == account || !dc.caps.IsEnabled("sasl") { return } @@ -1272,7 +1266,7 @@ func (dc *downstreamConn) register(ctx context.Context) error { dc.password = "" if dc.user == nil { if password == "" { - if dc.caps["sasl"] { + if dc.caps.IsEnabled("sasl") { return ircError{&irc.Message{ Command: "FAIL", Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"}, @@ -1374,7 +1368,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { return err } - if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream { + if dc.network == nil && !dc.caps.IsEnabled("soju.im/bouncer-networks") && dc.srv.Config().MultiUpstream { dc.isMultiUpstream = true } @@ -1462,7 +1456,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { }) } - if dc.caps["soju.im/bouncer-networks-notify"] { + if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { for _, network := range dc.user.networks { idStr := fmt.Sprintf("%v", network.ID) @@ -1499,7 +1493,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { }) dc.forEachNetwork(func(net *network) { - if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + if dc.caps.IsEnabled("draft/chathistory") || dc.user.msgStore == nil { return } @@ -1549,7 +1543,7 @@ func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool { } func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) { - if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + if dc.caps.IsEnabled("draft/chathistory") || dc.user.msgStore == nil { return } @@ -2375,7 +2369,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM { - if dc.caps["echo-message"] { + if dc.caps.IsEnabled("echo-message") { echoTags := tags.Copy() echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) dc.SendMessage(&irc.Message{ @@ -2737,7 +2731,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - eventPlayback := dc.caps["draft/event-playback"] + eventPlayback := dc.caps.IsEnabled("draft/event-playback") var history []*irc.Message switch subcommand { diff --git a/upstream.go b/upstream.go index 9e76514..5d70a32 100644 --- a/upstream.go +++ b/upstream.go @@ -1497,7 +1497,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err weAreInvited := uc.isOurNick(nick) uc.forEachDownstream(func(dc *downstreamConn) { - if !weAreInvited && !dc.caps["invite-notify"] { + if !weAreInvited && !dc.caps.IsEnabled("invite-notify") { return } dc.SendMessage(&irc.Message{ @@ -2079,7 +2079,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr detached := ch != nil && ch.Detached uc.forEachDownstream(func(dc *downstreamConn) { - if !detached && (dc != origin || dc.caps["echo-message"]) { + if !detached && (dc != origin || dc.caps.IsEnabled("echo-message")) { dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) } else { dc.advanceMessageWithID(msg, msgID) diff --git a/user.go b/user.go index 0364ff3..b90a4f2 100644 --- a/user.go +++ b/user.go @@ -562,7 +562,7 @@ func (u *user) run() { uc.forEachDownstream(func(dc *downstreamConn) { dc.updateSupportedCaps() - if !dc.caps["soju.im/bouncer-networks"] { + if !dc.caps.IsEnabled("soju.im/bouncer-networks") { sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) } @@ -571,7 +571,7 @@ func (u *user) run() { dc.updateAccount() }) u.forEachDownstream(func(dc *downstreamConn) { - if dc.caps["soju.im/bouncer-networks-notify"] { + if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", @@ -751,7 +751,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { } u.forEachDownstream(func(dc *downstreamConn) { - if dc.caps["soju.im/bouncer-networks-notify"] { + if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", @@ -762,7 +762,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { if uc.network.lastError == nil { uc.forEachDownstream(func(dc *downstreamConn) { - if !dc.caps["soju.im/bouncer-networks"] { + if !dc.caps.IsEnabled("soju.im/bouncer-networks") { sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) } }) @@ -872,7 +872,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er idStr := fmt.Sprintf("%v", network.ID) attrs := getNetworkAttrs(network) u.forEachDownstream(func(dc *downstreamConn) { - if dc.caps["soju.im/bouncer-networks-notify"] { + if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", @@ -953,7 +953,7 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er idStr := fmt.Sprintf("%v", updatedNetwork.ID) attrs := getNetworkAttrs(updatedNetwork) u.forEachDownstream(func(dc *downstreamConn) { - if dc.caps["soju.im/bouncer-networks-notify"] { + if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", @@ -979,7 +979,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error { idStr := fmt.Sprintf("%v", network.ID) u.forEachDownstream(func(dc *downstreamConn) { - if dc.caps["soju.im/bouncer-networks-notify"] { + if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER",