Add support for away-notify

This makes use of cap-notify to dynamically advertise support for
away-notify. away-notify is advertised to downstream connections if all
upstreams support it.
This commit is contained in:
Simon Ser 2020-04-29 16:28:33 +02:00
parent 394f2853ad
commit 0c549d68c4
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 56 additions and 14 deletions

View File

@ -458,7 +458,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
caps := make([]string, 0, len(dc.supportedCaps)) caps := make([]string, 0, len(dc.supportedCaps))
for k, v := range dc.supportedCaps { for k, v := range dc.supportedCaps {
if dc.capVersion >= 302 && v != "" { if dc.capVersion >= 302 && v != "" {
caps = append(caps, k + "=" + v) caps = append(caps, k+"="+v)
} else { } else {
caps = append(caps, k) caps = append(caps, k)
} }
@ -595,6 +595,19 @@ func (dc *downstreamConn) unsetSupportedCap(name string) {
}) })
} }
func (dc *downstreamConn) updateSupportedCaps() {
awayNotifySupported := true
dc.forEachUpstream(func(uc *upstreamConn) {
awayNotifySupported = awayNotifySupported && uc.awayNotifySupported
})
if awayNotifySupported {
dc.setSupportedCap("away-notify", "")
} else {
dc.unsetSupportedCap("away-notify")
}
}
func sanityCheckServer(addr string) error { func sanityCheckServer(addr string) error {
dialer := net.Dialer{Timeout: 30 * time.Second} dialer := net.Dialer{Timeout: 30 * time.Second}
conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil) conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)

View File

@ -46,11 +46,12 @@ type upstreamConn struct {
realname string realname string
modes userModes modes userModes
channels map[string]*upstreamChannel channels map[string]*upstreamChannel
caps map[string]string caps map[string]string // available capabilities
batches map[string]batch batches map[string]batch
away bool away bool
tagsSupported bool tagsSupported bool
awayNotifySupported bool
labelsSupported bool labelsSupported bool
nextLabelID uint64 nextLabelID uint64
@ -317,7 +318,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
requestCaps := make([]string, 0, 16) requestCaps := make([]string, 0, 16)
for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time"} { for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time", "away-notify"} {
if _, ok := uc.caps[c]; ok { if _, ok := uc.caps[c]; ok {
requestCaps = append(requestCaps, c) requestCaps = append(requestCaps, c)
} }
@ -450,6 +451,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.registered = true uc.registered = true
uc.logger.Printf("connection registered") uc.logger.Printf("connection registered")
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
})
for _, ch := range uc.network.channels { for _, ch := range uc.network.channels {
params := []string{ch.Name} params := []string{ch.Name}
if ch.Key != "" { if ch.Key != "" {
@ -1148,6 +1153,21 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason}, Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason},
}) })
}) })
case "AWAY":
if msg.Prefix == nil {
return fmt.Errorf("expected a prefix")
}
uc.forEachDownstream(func(dc *downstreamConn) {
if !dc.caps["away-notify"] {
return
}
dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
Command: "AWAY",
Params: msg.Params,
})
})
case "TAGMSG": case "TAGMSG":
// TODO: relay to downstream connections that accept message-tags // TODO: relay to downstream connections that accept message-tags
case "ACK": case "ACK":
@ -1262,7 +1282,6 @@ func (uc *upstreamConn) requestSASL() bool {
} }
func (uc *upstreamConn) handleCapAck(name string, ok bool) error { func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
auth := &uc.network.SASL
switch name { switch name {
case "sasl": case "sasl":
if !ok { if !ok {
@ -1270,6 +1289,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
return nil return nil
} }
auth := &uc.network.SASL
switch auth.Mechanism { switch auth.Mechanism {
case "PLAIN": case "PLAIN":
uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
@ -1286,6 +1306,8 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
uc.tagsSupported = ok uc.tagsSupported = ok
case "labeled-response": case "labeled-response":
uc.labelsSupported = ok uc.labelsSupported = ok
case "away-notify":
uc.awayNotifySupported = ok
case "batch", "server-time": case "batch", "server-time":
// Nothing to do // Nothing to do
default: default:

View File

@ -255,6 +255,7 @@ func (u *user) run() {
uc.updateAway() uc.updateAway()
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
}) })
uc.network.lastError = nil uc.network.lastError = nil
@ -271,6 +272,10 @@ func (u *user) run() {
uc.endPendingLISTs(true) uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
})
if uc.network.lastError == nil { if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
@ -314,6 +319,8 @@ func (u *user) run() {
u.forEachUpstream(func(uc *upstreamConn) { u.forEachUpstream(func(uc *upstreamConn) {
uc.updateAway() uc.updateAway()
}) })
dc.updateSupportedCaps()
case eventDownstreamDisconnected: case eventDownstreamDisconnected:
dc := e.dc dc := e.dc