Use capRegistry for downstreamConn

This commit is contained in:
Simon Ser 2022-03-14 19:15:35 +01:00
parent 347a4979da
commit 74fd506fef
4 changed files with 58 additions and 64 deletions

View File

@ -19,7 +19,7 @@ func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel
sendTopic(dc, ch) sendTopic(dc, ch)
} }
if dc.caps["soju.im/read"] { if dc.caps.IsEnabled("soju.im/read") {
channelCM := ch.conn.network.casemap(ch.Name) channelCM := ch.conn.network.casemap(ch.Name)
r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM) r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM)
if err != nil { if err != nil {

View File

@ -307,8 +307,7 @@ type downstreamConn struct {
negotiatingCaps bool negotiatingCaps bool
capVersion int capVersion int
supportedCaps map[string]string caps capRegistry
caps map[string]bool
sasl *downstreamSASL sasl *downstreamSASL
lastBatchRef uint64 lastBatchRef uint64
@ -325,8 +324,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
id: id, id: id,
nick: "*", nick: "*",
nickCM: "*", nickCM: "*",
supportedCaps: make(map[string]string), caps: newCapRegistry(),
caps: make(map[string]bool),
monitored: newCasemapMap(0), monitored: newCasemapMap(0),
} }
dc.monitored.SetCasemapping(casemapASCII) dc.monitored.SetCasemapping(casemapASCII)
@ -335,14 +333,14 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
dc.hostname = host dc.hostname = host
} }
for k, v := range permanentDownstreamCaps { for k, v := range permanentDownstreamCaps {
dc.supportedCaps[k] = v dc.caps.Available[k] = v
} }
dc.supportedCaps["sasl"] = "PLAIN" dc.caps.Available["sasl"] = "PLAIN"
// TODO: this is racy, we should only enable chathistory after // TODO: this is racy, we should only enable chathistory after
// authentication and then check that user.msgStore implements // authentication and then check that user.msgStore implements
// chatHistoryMessageStore // chatHistoryMessageStore
if srv.Config().LogPath != "" { if srv.Config().LogPath != "" {
dc.supportedCaps["draft/chathistory"] = "" dc.caps.Available["draft/chathistory"] = ""
} }
return dc return dc
} }
@ -527,7 +525,7 @@ func (dc *downstreamConn) readMessages(ch chan<- event) error {
// //
// This can only called from the user goroutine. // This can only called from the user goroutine.
func (dc *downstreamConn) SendMessage(msg *irc.Message) { func (dc *downstreamConn) SendMessage(msg *irc.Message) {
if !dc.caps["message-tags"] { if !dc.caps.IsEnabled("message-tags") {
if msg.Command == "TAGMSG" { if msg.Command == "TAGMSG" {
return return
} }
@ -536,32 +534,32 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) {
supported := false supported := false
switch name { switch name {
case "time": case "time":
supported = dc.caps["server-time"] supported = dc.caps.IsEnabled("server-time")
case "account": case "account":
supported = dc.caps["account"] supported = dc.caps.IsEnabled("account")
} }
if !supported { if !supported {
delete(msg.Tags, name) delete(msg.Tags, name)
} }
} }
} }
if !dc.caps["batch"] && msg.Tags["batch"] != "" { if !dc.caps.IsEnabled("batch") && msg.Tags["batch"] != "" {
msg = msg.Copy() msg = msg.Copy()
delete(msg.Tags, "batch") delete(msg.Tags, "batch")
} }
if msg.Command == "JOIN" && !dc.caps["extended-join"] { if msg.Command == "JOIN" && !dc.caps.IsEnabled("extended-join") {
msg.Params = msg.Params[:1] msg.Params = msg.Params[:1]
} }
if msg.Command == "SETNAME" && !dc.caps["setname"] { if msg.Command == "SETNAME" && !dc.caps.IsEnabled("setname") {
return return
} }
if msg.Command == "AWAY" && !dc.caps["away-notify"] { if msg.Command == "AWAY" && !dc.caps.IsEnabled("away-notify") {
return return
} }
if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] { if msg.Command == "ACCOUNT" && !dc.caps.IsEnabled("account-notify") {
return return
} }
if msg.Command == "READ" && !dc.caps["soju.im/read"] { if msg.Command == "READ" && !dc.caps.IsEnabled("soju.im/read") {
return return
} }
@ -573,7 +571,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags,
dc.lastBatchRef++ dc.lastBatchRef++
ref := fmt.Sprintf("%v", dc.lastBatchRef) ref := fmt.Sprintf("%v", dc.lastBatchRef)
if dc.caps["batch"] { if dc.caps.IsEnabled("batch") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Tags: tags, Tags: tags,
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -584,7 +582,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags,
f(irc.TagValue(ref)) f(irc.TagValue(ref))
if dc.caps["batch"] { if dc.caps.IsEnabled("batch") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BATCH", Command: "BATCH",
@ -597,7 +595,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags,
func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
dc.SendMessage(msg) dc.SendMessage(msg)
if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps["draft/chathistory"] { if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") {
return return
} }
@ -608,7 +606,7 @@ func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
// sending a message. This is useful e.g. for self-messages when echo-message // sending a message. This is useful e.g. for self-messages when echo-message
// isn't enabled. // isn't enabled.
func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps["draft/chathistory"] { if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") {
return return
} }
@ -829,12 +827,12 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
// down the available capabilities when upstreams are // down the available capabilities when upstreams are
// known. // known.
for k, v := range needAllDownstreamCaps { for k, v := range needAllDownstreamCaps {
dc.supportedCaps[k] = v dc.caps.Available[k] = v
} }
} }
caps := make([]string, 0, len(dc.supportedCaps)) caps := make([]string, 0, len(dc.caps.Available))
for k, v := range dc.supportedCaps { for k, v := range dc.caps.Available {
if dc.capVersion >= 302 && v != "" { if dc.capVersion >= 302 && v != "" {
caps = append(caps, k+"="+v) caps = append(caps, k+"="+v)
} else { } else {
@ -851,7 +849,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
if dc.capVersion >= 302 { if dc.capVersion >= 302 {
// CAP version 302 implicitly enables cap-notify // CAP version 302 implicitly enables cap-notify
dc.caps["cap-notify"] = true dc.caps.SetEnabled("cap-notify", true)
} }
if !dc.registered { if !dc.registered {
@ -859,11 +857,9 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
} }
case "LIST": case "LIST":
var caps []string var caps []string
for name, enabled := range dc.caps { for name := range dc.caps.Enabled {
if enabled {
caps = append(caps, name) caps = append(caps, name)
} }
}
// TODO: multi-line replies // TODO: multi-line replies
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -889,12 +885,11 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
name = strings.TrimPrefix(name, "-") name = strings.TrimPrefix(name, "-")
} }
if enable == dc.caps[name] { if enable == dc.caps.IsEnabled(name) {
continue continue
} }
_, ok := dc.supportedCaps[name] if !dc.caps.IsAvailable(name) {
if !ok {
ack = false ack = false
break break
} }
@ -905,7 +900,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
break break
} }
dc.caps[name] = enable dc.caps.SetEnabled(name, enable)
} }
reply := "NAK" reply := "NAK"
@ -939,7 +934,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
} }
}() }()
if !dc.caps["sasl"] { if !dc.caps.IsEnabled("sasl") {
return nil, ircError{&irc.Message{ return nil, ircError{&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_SASLFAIL, Command: irc.ERR_SASLFAIL,
@ -1053,11 +1048,11 @@ func (dc *downstreamConn) endSASL(msg *irc.Message) {
} }
func (dc *downstreamConn) setSupportedCap(name, value string) { func (dc *downstreamConn) setSupportedCap(name, value string) {
prevValue, hasPrev := dc.supportedCaps[name] prevValue, hasPrev := dc.caps.Available[name]
changed := !hasPrev || prevValue != value changed := !hasPrev || prevValue != value
dc.supportedCaps[name] = value dc.caps.Available[name] = value
if !dc.caps["cap-notify"] || !changed { if !dc.caps.IsEnabled("cap-notify") || !changed {
return return
} }
@ -1074,11 +1069,10 @@ func (dc *downstreamConn) setSupportedCap(name, value string) {
} }
func (dc *downstreamConn) unsetSupportedCap(name string) { func (dc *downstreamConn) unsetSupportedCap(name string) {
_, hasPrev := dc.supportedCaps[name] hasPrev := dc.caps.IsAvailable(name)
delete(dc.supportedCaps, name) dc.caps.Del(name)
delete(dc.caps, name)
if !dc.caps["cap-notify"] || !hasPrev { if !dc.caps.IsEnabled("cap-notify") || !hasPrev {
return return
} }
@ -1149,7 +1143,7 @@ func (dc *downstreamConn) updateNick() {
} }
func (dc *downstreamConn) updateRealname() { func (dc *downstreamConn) updateRealname() {
if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] { if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps.IsEnabled("setname") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.prefix(), Prefix: dc.prefix(),
Command: "SETNAME", Command: "SETNAME",
@ -1169,7 +1163,7 @@ func (dc *downstreamConn) updateAccount() {
return return
} }
if dc.account == account || !dc.caps["sasl"] { if dc.account == account || !dc.caps.IsEnabled("sasl") {
return return
} }
@ -1272,7 +1266,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
dc.password = "" dc.password = ""
if dc.user == nil { if dc.user == nil {
if password == "" { if password == "" {
if dc.caps["sasl"] { if dc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"}, Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"},
@ -1374,7 +1368,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
return err return err
} }
if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream { if dc.network == nil && !dc.caps.IsEnabled("soju.im/bouncer-networks") && dc.srv.Config().MultiUpstream {
dc.isMultiUpstream = true dc.isMultiUpstream = true
} }
@ -1462,7 +1456,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
}) })
} }
if dc.caps["soju.im/bouncer-networks-notify"] { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
for _, network := range dc.user.networks { for _, network := range dc.user.networks {
idStr := fmt.Sprintf("%v", network.ID) idStr := fmt.Sprintf("%v", network.ID)
@ -1499,7 +1493,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
}) })
dc.forEachNetwork(func(net *network) { dc.forEachNetwork(func(net *network) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { if dc.caps.IsEnabled("draft/chathistory") || dc.user.msgStore == nil {
return return
} }
@ -1549,7 +1543,7 @@ func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
} }
func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) { func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { if dc.caps.IsEnabled("draft/chathistory") || dc.user.msgStore == nil {
return return
} }
@ -2375,7 +2369,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM { if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
if dc.caps["echo-message"] { if dc.caps.IsEnabled("echo-message") {
echoTags := tags.Copy() echoTags := tags.Copy()
echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -2737,7 +2731,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}} }}
} }
eventPlayback := dc.caps["draft/event-playback"] eventPlayback := dc.caps.IsEnabled("draft/event-playback")
var history []*irc.Message var history []*irc.Message
switch subcommand { switch subcommand {

View File

@ -1497,7 +1497,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
weAreInvited := uc.isOurNick(nick) weAreInvited := uc.isOurNick(nick)
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
if !weAreInvited && !dc.caps["invite-notify"] { if !weAreInvited && !dc.caps.IsEnabled("invite-notify") {
return return
} }
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -2079,7 +2079,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
detached := ch != nil && ch.Detached detached := ch != nil && ch.Detached
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
if !detached && (dc != origin || dc.caps["echo-message"]) { if !detached && (dc != origin || dc.caps.IsEnabled("echo-message")) {
dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
} else { } else {
dc.advanceMessageWithID(msg, msgID) dc.advanceMessageWithID(msg, msgID)

14
user.go
View File

@ -562,7 +562,7 @@ func (u *user) run() {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps() dc.updateSupportedCaps()
if !dc.caps["soju.im/bouncer-networks"] { if !dc.caps.IsEnabled("soju.im/bouncer-networks") {
sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
} }
@ -571,7 +571,7 @@ func (u *user) run() {
dc.updateAccount() dc.updateAccount()
}) })
u.forEachDownstream(func(dc *downstreamConn) { u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",
@ -751,7 +751,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
} }
u.forEachDownstream(func(dc *downstreamConn) { u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",
@ -762,7 +762,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
if uc.network.lastError == nil { if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
if !dc.caps["soju.im/bouncer-networks"] { if !dc.caps.IsEnabled("soju.im/bouncer-networks") {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
} }
}) })
@ -872,7 +872,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
idStr := fmt.Sprintf("%v", network.ID) idStr := fmt.Sprintf("%v", network.ID)
attrs := getNetworkAttrs(network) attrs := getNetworkAttrs(network)
u.forEachDownstream(func(dc *downstreamConn) { u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",
@ -953,7 +953,7 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er
idStr := fmt.Sprintf("%v", updatedNetwork.ID) idStr := fmt.Sprintf("%v", updatedNetwork.ID)
attrs := getNetworkAttrs(updatedNetwork) attrs := getNetworkAttrs(updatedNetwork)
u.forEachDownstream(func(dc *downstreamConn) { u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",
@ -979,7 +979,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
idStr := fmt.Sprintf("%v", network.ID) idStr := fmt.Sprintf("%v", network.ID)
u.forEachDownstream(func(dc *downstreamConn) { u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",