Add upstreamConn.caps

Instead of adding one field per capability, let's just have a map, just
like downstreamConn.
This commit is contained in:
Simon Ser 2020-04-29 19:45:37 +02:00
parent 8445979956
commit c4655f1492
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 8 additions and 14 deletions

View File

@ -598,7 +598,7 @@ func (dc *downstreamConn) unsetSupportedCap(name string) {
func (dc *downstreamConn) updateSupportedCaps() { func (dc *downstreamConn) updateSupportedCaps() {
awayNotifySupported := true awayNotifySupported := true
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
awayNotifySupported = awayNotifySupported && uc.awayNotifySupported awayNotifySupported = awayNotifySupported && uc.caps["away-notify"]
}) })
if awayNotifySupported { if awayNotifySupported {

View File

@ -47,13 +47,10 @@ type upstreamConn struct {
modes userModes modes userModes
channels map[string]*upstreamChannel channels map[string]*upstreamChannel
supportedCaps map[string]string supportedCaps map[string]string
caps map[string]bool
batches map[string]batch batches map[string]batch
away bool away bool
nextLabelID uint64
tagsSupported bool
awayNotifySupported bool
labelsSupported bool
nextLabelID uint64
saslClient sasl.Client saslClient sasl.Client
saslStarted bool saslStarted bool
@ -111,6 +108,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
user: network.user, user: network.user,
channels: make(map[string]*upstreamChannel), channels: make(map[string]*upstreamChannel),
supportedCaps: make(map[string]string), supportedCaps: make(map[string]string),
caps: make(map[string]bool),
batches: make(map[string]batch), batches: make(map[string]batch),
availableChannelTypes: stdChannelTypes, availableChannelTypes: stdChannelTypes,
availableChannelModes: stdChannelModes, availableChannelModes: stdChannelModes,
@ -1282,6 +1280,8 @@ func (uc *upstreamConn) requestSASL() bool {
} }
func (uc *upstreamConn) handleCapAck(name string, ok bool) error { func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
uc.caps[name] = ok
switch name { switch name {
case "sasl": case "sasl":
if !ok { if !ok {
@ -1302,13 +1302,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{auth.Mechanism}, Params: []string{auth.Mechanism},
}) })
case "message-tags": case "message-tags", "labeled-response", "away-notify", "batch", "server-time":
uc.tagsSupported = ok
case "labeled-response":
uc.labelsSupported = ok
case "away-notify":
uc.awayNotifySupported = ok
case "batch", "server-time":
// Nothing to do // Nothing to do
default: default:
uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
@ -1332,7 +1326,7 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
} }
func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) { func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
if uc.labelsSupported { if uc.caps["labeled-response"] {
if msg.Tags == nil { if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue) msg.Tags = make(map[string]irc.TagValue)
} }