diff --git a/downstream.go b/downstream.go index 23a99fa..f1ed174 100644 --- a/downstream.go +++ b/downstream.go @@ -228,6 +228,7 @@ var passthroughIsupport = map[string]bool{ "MAXLIST": true, "MAXTARGETS": true, "MODES": true, + "MONITOR": true, "NAMELEN": true, "NETWORK": true, "NICKLEN": true, @@ -264,6 +265,8 @@ type downstreamConn struct { lastBatchRef uint64 + monitored casemapMap + saslServer sasl.Server } @@ -276,6 +279,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { id: id, supportedCaps: make(map[string]string), caps: make(map[string]bool), + monitored: newCasemapMap(0), } dc.hostname = remoteAddr if host, _, err := net.SplitHostPort(dc.hostname); err == nil { @@ -2253,6 +2257,89 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Command: "INVITE", Params: []string{upstreamUser, upstreamChannel}, }) + case "MONITOR": + // MONITOR is unsupported in multi-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "+", "-": + var targets string + if err := parseMessageParams(msg, nil, &targets); err != nil { + return err + } + for _, target := range strings.Split(targets, ",") { + if subcommand == "+" { + // Hard limit, just to avoid having downstreams fill our map + if len(dc.monitored.innerMap) >= 1000 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_MONLISTFULL, + Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"}, + }) + continue + } + + dc.monitored.SetValue(target, nil) + + if uc.monitored.Has(target) { + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } else { + dc.monitored.Delete(target) + } + } + uc.updateMonitor() + case "C": // clear + dc.monitored = newCasemapMap(0) + uc.updateMonitor() + case "L": // list + // TODO: be less lazy and pack the list + for _, entry := range dc.monitored.innerMap { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MONLIST, + Params: []string{dc.nick, entry.originalKey}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFMONLIST, + Params: []string{dc.nick, "End of MONITOR list"}, + }) + case "S": // status + // TODO: be less lazy and pack the lists + for _, entry := range dc.monitored.innerMap { + target := entry.originalKey + + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } case "CHATHISTORY": var subcommand string if err := parseMessageParams(msg, &subcommand); err != nil { diff --git a/irc.go b/irc.go index ae3ff47..5355f1a 100644 --- a/irc.go +++ b/irc.go @@ -408,6 +408,36 @@ func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message { return msgs } +func generateMonitor(subcmd string, targets []string) []*irc.Message { + maxLength := maxMessageLength - len("MONITOR "+subcmd+" ") + + var msgs []*irc.Message + var buf []string + n := 0 + for _, target := range targets { + if n+len(target)+1 > maxLength { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + buf = buf[:0] + n = 0 + } + + buf = append(buf, target) + n += len(target) + 1 + } + + if len(buf) > 0 { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + } + + return msgs +} + type joinSorter struct { channels []string keys []string @@ -634,6 +664,16 @@ func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { return entry.value.(deliveredClientMap) } +type monitorCasemapMap struct{ casemapMap } + +func (cm *monitorCasemapMap) Value(name string) (online bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return false + } + return entry.value.(bool) +} + func isWordBoundary(r rune) bool { switch r { case '-', '_', '|': diff --git a/upstream.go b/upstream.go index ae9843f..ee8a1c6 100644 --- a/upstream.go +++ b/upstream.go @@ -103,6 +103,7 @@ type upstreamConn struct { away bool account string nextLabelID uint64 + monitored monitorCasemapMap saslClient sasl.Client saslStarted bool @@ -209,6 +210,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) { availableMemberships: stdMemberships, isupport: make(map[string]*string), pendingCmds: make(map[string][]pendingUpstreamCommand), + monitored: monitorCasemapMap{newCasemapMap(0)}, } return uc, nil } @@ -1413,6 +1415,49 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, }) }) + case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: + var targetsStr string + if err := parseMessageParams(msg, nil, &targetsStr); err != nil { + return err + } + targets := strings.Split(targetsStr, ",") + + online := msg.Command == irc.RPL_MONONLINE + for _, target := range targets { + prefix := irc.ParsePrefix(target) + uc.monitored.SetValue(prefix.Name, online) + } + + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if dc.monitored.Has(prefix.Name) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, target}, + }) + } + } + }) + case irc.ERR_MONLISTFULL: + var limit, targetsStr string + if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil { + return err + } + + targets := strings.Split(targetsStr, ",") + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + if dc.monitored.Has(target) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, limit, target}, + }) + } + } + }) case irc.RPL_AWAY: var nick, reason string if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { @@ -1912,3 +1957,52 @@ func (uc *upstreamConn) updateChannelAutoDetach(name string) { } uch.updateAutoDetach(ch.DetachAfter) } + +func (uc *upstreamConn) updateMonitor() { + add := make(map[string]struct{}) + var addList []string + seen := make(map[string]struct{}) + uc.forEachDownstream(func(dc *downstreamConn) { + for targetCM := range dc.monitored.innerMap { + if !uc.monitored.Has(targetCM) { + if _, ok := add[targetCM]; !ok { + addList = append(addList, targetCM) + } + add[targetCM] = struct{}{} + } else { + seen[targetCM] = struct{}{} + } + } + }) + + removeAll := true + var removeList []string + for targetCM, entry := range uc.monitored.innerMap { + if _, ok := seen[targetCM]; ok { + removeAll = false + } else { + removeList = append(removeList, entry.originalKey) + } + } + + // TODO: better handle the case where len(uc.monitored) + len(addList) + // exceeds the limit, probably by immediately sending ERR_MONLISTFULL? + + if removeAll && len(addList) == 0 && len(removeList) > 0 { + // Optimization when the last MONITOR-aware downstream disconnects + uc.SendMessage(&irc.Message{ + Command: "MONITOR", + Params: []string{"C"}, + }) + } else { + msgs := generateMonitor("-", removeList) + msgs = append(msgs, generateMonitor("+", addList)...) + for _, msg := range msgs { + uc.SendMessage(msg) + } + } + + for _, target := range removeList { + uc.monitored.Delete(target) + } +} diff --git a/user.go b/user.go index 0de7d6d..276d527 100644 --- a/user.go +++ b/user.go @@ -342,13 +342,17 @@ func (net *network) updateCasemapping(newCasemap casemapping) { net.casemap = newCasemap net.channels.SetCasemapping(newCasemap) net.delivered.m.SetCasemapping(newCasemap) - if net.conn != nil { - net.conn.channels.SetCasemapping(newCasemap) - for _, entry := range net.conn.channels.innerMap { + if uc := net.conn; uc != nil { + uc.channels.SetCasemapping(newCasemap) + for _, entry := range uc.channels.innerMap { uch := entry.value.(*upstreamChannel) uch.Members.SetCasemapping(newCasemap) } + uc.monitored.SetCasemapping(newCasemap) } + net.forEachDownstream(func(dc *downstreamConn) { + dc.monitored.SetCasemapping(newCasemap) + }) } func (net *network) storeClientDeliveryReceipts(clientName string) { @@ -519,6 +523,7 @@ func (u *user) run() { uc.network.conn = uc uc.updateAway() + uc.updateMonitor() netIDStr := fmt.Sprintf("%v", uc.network.ID) uc.forEachDownstream(func(dc *downstreamConn) { @@ -588,6 +593,10 @@ func (u *user) run() { case eventDownstreamConnected: dc := e.dc + if dc.network != nil { + dc.monitored.SetCasemapping(dc.network.casemap) + } + if err := dc.welcome(); err != nil { dc.logger.Printf("failed to handle new registered connection: %v", err) break @@ -620,6 +629,7 @@ func (u *user) run() { u.forEachUpstream(func(uc *upstreamConn) { uc.updateAway() + uc.updateMonitor() }) case eventDownstreamMessage: msg, dc := e.msg, e.dc