diff --git a/downstream.go b/downstream.go index 3221360..d945805 100644 --- a/downstream.go +++ b/downstream.go @@ -53,12 +53,12 @@ var errAuthFailed = ircError{&irc.Message{ // permanentDownstreamCaps is the list of always-supported downstream // capabilities. var permanentDownstreamCaps = map[string]string{ - "batch": "", - "cap-notify": "", + "batch": "", + "cap-notify": "", "echo-message": "", "message-tags": "", - "sasl": "PLAIN", - "server-time": "", + "sasl": "PLAIN", + "server-time": "", } type downstreamConn struct { @@ -88,10 +88,10 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn { logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())} dc := &downstreamConn{ - conn: *newConn(srv, netConn, logger), - id: id, + conn: *newConn(srv, netConn, logger), + id: id, supportedCaps: make(map[string]string), - caps: make(map[string]bool), + caps: make(map[string]bool), } dc.hostname = netConn.RemoteAddr().String() if host, _, err := net.SplitHostPort(dc.hostname); err == nil { @@ -458,7 +458,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { caps := make([]string, 0, len(dc.supportedCaps)) for k, v := range dc.supportedCaps { if dc.capVersion >= 302 && v != "" { - caps = append(caps, k + "=" + v) + caps = append(caps, k+"="+v) } else { caps = append(caps, k) } @@ -595,6 +595,19 @@ func (dc *downstreamConn) unsetSupportedCap(name string) { }) } +func (dc *downstreamConn) updateSupportedCaps() { + awayNotifySupported := true + dc.forEachUpstream(func(uc *upstreamConn) { + awayNotifySupported = awayNotifySupported && uc.awayNotifySupported + }) + + if awayNotifySupported { + dc.setSupportedCap("away-notify", "") + } else { + dc.unsetSupportedCap("away-notify") + } +} + func sanityCheckServer(addr string) error { dialer := net.Dialer{Timeout: 30 * time.Second} conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil) diff --git a/upstream.go b/upstream.go index 7167359..8a5f9a6 100644 --- a/upstream.go +++ b/upstream.go @@ -46,13 +46,14 @@ type upstreamConn struct { realname string modes userModes channels map[string]*upstreamChannel - caps map[string]string + caps map[string]string // available capabilities batches map[string]batch away bool - tagsSupported bool - labelsSupported bool - nextLabelID uint64 + tagsSupported bool + awayNotifySupported bool + labelsSupported bool + nextLabelID uint64 saslClient sasl.Client saslStarted bool @@ -317,7 +318,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } requestCaps := make([]string, 0, 16) - for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time"} { + for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time", "away-notify"} { if _, ok := uc.caps[c]; ok { requestCaps = append(requestCaps, c) } @@ -450,6 +451,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.registered = true uc.logger.Printf("connection registered") + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + for _, ch := range uc.network.channels { params := []string{ch.Name} if ch.Key != "" { @@ -1148,6 +1153,21 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason}, }) }) + case "AWAY": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if !dc.caps["away-notify"] { + return + } + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "AWAY", + Params: msg.Params, + }) + }) case "TAGMSG": // TODO: relay to downstream connections that accept message-tags case "ACK": @@ -1262,7 +1282,6 @@ func (uc *upstreamConn) requestSASL() bool { } func (uc *upstreamConn) handleCapAck(name string, ok bool) error { - auth := &uc.network.SASL switch name { case "sasl": if !ok { @@ -1270,6 +1289,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { return nil } + auth := &uc.network.SASL switch auth.Mechanism { case "PLAIN": uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) @@ -1286,6 +1306,8 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { uc.tagsSupported = ok case "labeled-response": uc.labelsSupported = ok + case "away-notify": + uc.awayNotifySupported = ok case "batch", "server-time": // Nothing to do default: diff --git a/user.go b/user.go index d504796..508128c 100644 --- a/user.go +++ b/user.go @@ -255,6 +255,7 @@ func (u *user) run() { uc.updateAway() uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) }) uc.network.lastError = nil @@ -271,6 +272,10 @@ func (u *user) run() { uc.endPendingLISTs(true) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + if uc.network.lastError == nil { uc.forEachDownstream(func(dc *downstreamConn) { sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) @@ -314,6 +319,8 @@ func (u *user) run() { u.forEachUpstream(func(uc *upstreamConn) { uc.updateAway() }) + + dc.updateSupportedCaps() case eventDownstreamDisconnected: dc := e.dc