diff --git a/downstream.go b/downstream.go index 4ba6792..456f347 100644 --- a/downstream.go +++ b/downstream.go @@ -2079,16 +2079,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. options = msg.Params[1] } - optionsParts := strings.SplitN(options, "%", 2) - // TODO: add support for WHOX flags in optionsParts[0] - var fields, whoxToken string - if len(optionsParts) == 2 { - optionsParts := strings.SplitN(optionsParts[1], ",", 2) - fields = strings.ToLower(optionsParts[0]) - if len(optionsParts) == 2 && strings.Contains(fields, "t") { - whoxToken = optionsParts[1] - } - } + fields, whoxToken := xirc.ParseWHOXOptions(options) // TODO: support mixed bouncer/upstream WHO queries maskCM := casemapASCII(mask) @@ -2157,6 +2148,29 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } + // Check if we have the reply cached + if l, ok := uc.getCachedWHO(mask, fields); ok { + for _, uu := range l { + info := xirc.WHOXInfo{ + Token: whoxToken, + Username: uu.Username, + Hostname: uu.Hostname, + Server: uu.Server, + Nickname: uu.Nickname, + Flags: uu.Flags, + Account: uu.Account, + Realname: uu.Realname, + } + dc.SendMessage(xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + uc.enqueueCommand(dc, msg) case "WHOIS": if len(msg.Params) == 0 { diff --git a/irc.go b/irc.go index 464d9f8..6e10f6c 100644 --- a/irc.go +++ b/irc.go @@ -379,6 +379,20 @@ func (cm *upstreamChannelCasemapMap) ForEach(f func(*upstreamChannel)) { } } +type upstreamUserCasemapMap struct{ casemapMap } + +func (cm *upstreamUserCasemapMap) Get(name string) *upstreamUser { + if v := cm.get(name); v == nil { + return nil + } else { + return v.(*upstreamUser) + } +} + +func (cm *upstreamUserCasemapMap) Set(u *upstreamUser) { + cm.set(u.Nickname, u) +} + type channelCasemapMap struct{ casemapMap } func (cm *channelCasemapMap) Get(name string) *database.Channel { diff --git a/upstream.go b/upstream.go index 5e47dc7..0b43fb8 100644 --- a/upstream.go +++ b/upstream.go @@ -111,6 +111,69 @@ type upstreamBatch struct { Label string } +type upstreamUser struct { + Nickname string + Username string + Hostname string + Server string + Flags string + Account string + Realname string +} + +func (uu *upstreamUser) hasWHOXFields(fields string) bool { + for i := 0; i < len(fields); i++ { + ok := false + switch fields[i] { + case 'n': + ok = uu.Nickname != "" + case 'u': + ok = uu.Username != "" + case 'h': + ok = uu.Hostname != "" + case 's': + ok = uu.Server != "" + case 'f': + ok = uu.Flags != "" + case 'a': + ok = uu.Account != "" + case 'r': + ok = uu.Realname != "" + case 't', 'c', 'i', 'd', 'l', 'o': + // we return static values for those fields, so they are always available + ok = true + } + if !ok { + return false + } + } + return true +} + +func (uu *upstreamUser) updateFrom(update *upstreamUser) { + if update.Nickname != "" { + uu.Nickname = update.Nickname + } + if update.Username != "" { + uu.Username = update.Username + } + if update.Hostname != "" { + uu.Hostname = update.Hostname + } + if update.Server != "" { + uu.Server = update.Server + } + if update.Flags != "" { + uu.Flags = update.Flags + } + if update.Account != "" { + uu.Account = update.Account + } + if update.Realname != "" { + uu.Realname = update.Realname + } +} + type pendingUpstreamCommand struct { downstreamID uint64 msg *irc.Message @@ -138,6 +201,7 @@ type upstreamConn struct { hostname string modes userModes channels upstreamChannelCasemapMap + users upstreamUserCasemapMap caps xirc.CapRegistry batches map[string]upstreamBatch away bool @@ -263,6 +327,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er network: network, user: network.user, channels: upstreamChannelCasemapMap{newCasemapMap()}, + users: upstreamUserCasemapMap{newCasemapMap()}, caps: xirc.NewCapRegistry(), batches: make(map[string]upstreamBatch), serverPrefix: &irc.Prefix{Name: "*"}, @@ -973,6 +1038,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } }) + uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{ + Nickname: newNick, + }) + if !me { uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(msg) @@ -989,6 +1058,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } + uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{ + Realname: newRealname, + }) + // TODO: consider appending this message to logs if uc.isOurNick(msg.Prefix.Name) { @@ -1035,6 +1108,30 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } + uu := &upstreamUser{ + Username: msg.Prefix.User, + Hostname: msg.Prefix.Host, + } + if uc.caps.IsEnabled("away-notify") { + // we have enough info to build the user flags in a best-effort manner: + // - the H/G flag is set to Here first, will be replaced by Gone later if the user is AWAY + uu.Flags = "H" + // - the B (bot mode) flag is set if the JOIN comes from a bot + // note: we have no way to track the user bot mode after they have joined + // (we are not notified of the bot mode updates), but this is good enough. + if _, ok := msg.Tags["bot"]; ok { + if bot := uc.isupport["BOT"]; bot != nil { + uu.Flags += *bot + } + } + // TODO: add the server operator flag (`*`) if the message has an oper-tag + } + if len(msg.Params) > 2 { // extended-join + uu.Account = msg.Params[1] + uu.Realname = msg.Params[2] + } + uc.cacheUserInfo(msg.Prefix.Name, uu) + for _, ch := range strings.Split(channels, ",") { if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("joined channel %q", ch) @@ -1075,6 +1172,11 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uch := uc.channels.Get(ch); uch != nil { uc.channels.Del(ch) uch.updateAutoDetach(0) + uch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) { + if !uc.shouldCacheUserInfo(nick) { + uc.users.Del(nick) + } + }) } } else { ch, err := uc.getChannel(ch) @@ -1082,6 +1184,9 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } ch.Members.Del(msg.Prefix.Name) + if !uc.shouldCacheUserInfo(msg.Prefix.Name) { + uc.users.Del(msg.Prefix.Name) + } } chMsg := msg.Copy() @@ -1096,13 +1201,23 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.isOurNick(user) { uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) - uc.channels.Del(channel) + if uch := uc.channels.Get(channel); uch != nil { + uc.channels.Del(channel) + uch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) { + if !uc.shouldCacheUserInfo(nick) { + uc.users.Del(nick) + } + }) + } } else { ch, err := uc.getChannel(channel) if err != nil { return err } ch.Members.Del(user) + if !uc.shouldCacheUserInfo(user) { + uc.users.Del(user) + } } uc.produce(channel, msg, 0) @@ -1118,6 +1233,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } }) + uc.users.Del(msg.Prefix.Name) + if msg.Prefix.Name != uc.nick { uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(msg) @@ -1358,15 +1475,68 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err forwardChannel(ctx, dc, ch) }) } - case irc.RPL_WHOREPLY, xirc.RPL_WHOSPCRPL: + case irc.RPL_WHOREPLY: + var username, host, server, nick, flags, trailing string + if err := parseMessageParams(msg, nil, nil, &username, &host, &server, &nick, &flags, &trailing); err != nil { + return err + } + dc, cmd := uc.currentPendingCommand("WHO") if cmd == nil { - return fmt.Errorf("unexpected WHO reply %v: no matching pending WHO", msg.Command) + return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO") + } else if dc == nil { + return nil + } + + parts := strings.SplitN(trailing, " ", 2) + if len(parts) != 2 { + return fmt.Errorf("malformed RPL_WHOREPLY: failed to parse real name") + } + realname := parts[1] + + dc.SendMessage(msg) + + if uc.shouldCacheUserInfo(nick) { + uc.cacheUserInfo(nick, &upstreamUser{ + Username: username, + Hostname: host, + Server: server, + Nickname: nick, + Flags: flags, + Realname: realname, + }) + } + case xirc.RPL_WHOSPCRPL: + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO") } else if dc == nil { return nil } dc.SendMessage(msg) + + if len(cmd.Params) > 1 { + fields, _ := xirc.ParseWHOXOptions(cmd.Params[1]) + if strings.IndexByte(fields, 'n') < 0 { + return nil + } + info, err := xirc.ParseWHOXReply(msg, fields) + if err != nil { + return err + } + if uc.shouldCacheUserInfo(info.Nickname) { + uc.cacheUserInfo(info.Nickname, &upstreamUser{ + Nickname: info.Nickname, + Username: info.Username, + Hostname: info.Hostname, + Server: info.Server, + Flags: info.Flags, + Account: info.Account, + Realname: info.Realname, + }) + } + } case irc.RPL_ENDOFWHO: dc, cmd := uc.dequeueCommand("WHO") if cmd == nil { @@ -1490,7 +1660,32 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { dc.SendMessage(msg) }) - case "AWAY", "ACCOUNT": + case "AWAY": + // Update user flags, if we already have the flags cached + uu := uc.users.Get(msg.Prefix.Name) + if uu != nil && uu.Flags != "" { + flags := uu.Flags + if isAway := len(msg.Params) > 0; isAway { + flags = strings.ReplaceAll(flags, "H", "G") + } else { + flags = strings.ReplaceAll(flags, "G", "H") + } + uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{ + Flags: flags, + }) + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case "ACCOUNT": + var account string + if err := parseMessageParams(msg, &account); err != nil { + return err + } + uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{ + Account: account, + }) uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(msg) }) @@ -2066,6 +2261,9 @@ func (uc *upstreamConn) updateMonitor() { for _, target := range removeList { uc.monitored.Del(target) + if !uc.shouldCacheUserInfo(target) { + uc.users.Del(target) + } } } @@ -2132,3 +2330,87 @@ func (uc *upstreamConn) tryRegainNick(nick string) { }) uc.pendingRegainNick = wantNick } + +func (uc *upstreamConn) getCachedWHO(mask, fields string) (l []*upstreamUser, ok bool) { + // Non-extended WHO fields + if fields == "" { + fields = "cuhsnfdr" + } + + // Some extensions are required to keep our cached state in sync. We could + // require setname for 'r' and chghost for 'h'/'s', but servers usually + // implement a QUIT/JOIN fallback, so let's not bother. + + // TODO: Avoid storing fields we cannot keep up to date, instead of storing them + // then failing here. eg if we don't have account-notify, avoid storing the ACCOUNT + // in the first place. + if strings.IndexByte(fields, 'a') >= 0 && !uc.caps.IsEnabled("account-notify") { + return nil, false + } + if strings.IndexByte(fields, 'f') >= 0 && !uc.caps.IsEnabled("away-notify") { + return nil, false + } + + if uu := uc.users.Get(mask); uu != nil { + if uu.hasWHOXFields(fields) { + return []*upstreamUser{uu}, true + } + } else if uch := uc.channels.Get(mask); uch != nil { + l = make([]*upstreamUser, 0, uch.Members.Len()) + ok = true + uch.Members.ForEach(func(nick string, membershipSet *xirc.MembershipSet) { + if !ok { + return + } + uu := uc.users.Get(nick) + if uu == nil || !uu.hasWHOXFields(fields) { + ok = false + } else { + l = append(l, uu) + } + }) + if !ok { + return nil, false + } + return l, true + } + + return nil, false +} + +func (uc *upstreamConn) cacheUserInfo(nick string, info *upstreamUser) { + if nick == "" { + panic("cacheUserInfo called with empty nickname") + } + + uu := uc.users.Get(nick) + if uu == nil { + if info.Nickname != "" { + nick = info.Nickname + } else { + info.Nickname = nick + } + uc.users.Set(info) + } else { + uu.updateFrom(info) + if info.Nickname != "" && nick != info.Nickname { + uc.users.Del(nick) + uc.users.Set(uu) + } + } +} + +func (uc *upstreamConn) shouldCacheUserInfo(nick string) bool { + if uc.isOurNick(nick) { + return true + } + // keep the cached user info only if we MONITOR it, or we share a channel with them + if uc.monitored.Has(nick) { + return true + } + found := false + uc.channels.ForEach(func(ch *upstreamChannel) { + found = found || ch.Members.Has(nick) + }) + return found +} diff --git a/user.go b/user.go index 8b83c34..ad785cf 100644 --- a/user.go +++ b/user.go @@ -389,6 +389,7 @@ func (net *network) updateCasemapping(newCasemap casemapping) { uc.channels.ForEach(func(uch *upstreamChannel) { uch.Members.SetCasemapping(newCasemap) }) + uc.users.SetCasemapping(newCasemap) uc.monitored.SetCasemapping(newCasemap) } net.forEachDownstream(func(dc *downstreamConn) { diff --git a/xirc/whox.go b/xirc/whox.go index 4db8570..2e11813 100644 --- a/xirc/whox.go +++ b/xirc/whox.go @@ -2,6 +2,9 @@ package xirc import ( "gopkg.in/irc.v4" + + "fmt" + "strings" ) // whoxFields is the list of all WHOX field letters, by order of appearance in @@ -19,8 +22,8 @@ type WHOXInfo struct { Realname string } -func (info *WHOXInfo) get(field byte) string { - switch field { +func (info *WHOXInfo) get(k byte) string { + switch k { case 't': return info.Token case 'c': @@ -55,6 +58,27 @@ func (info *WHOXInfo) get(field byte) string { return "" } +func (info *WHOXInfo) set(k byte, v string) { + switch k { + case 't': + info.Token = v + case 'u': + info.Username = v + case 'h': + info.Hostname = v + case 's': + info.Server = v + case 'n': + info.Nickname = v + case 'f': + info.Flags = v + case 'a': + info.Account = v + case 'r': + info.Realname = v + } +} + func GenerateWHOXReply(prefix *irc.Prefix, nick, fields string, info *WHOXInfo) *irc.Message { if fields == "" { return &irc.Message{ @@ -83,3 +107,46 @@ func GenerateWHOXReply(prefix *irc.Prefix, nick, fields string, info *WHOXInfo) Params: append([]string{nick}, values...), } } + +func ParseWHOXOptions(options string) (fields, whoxToken string) { + optionsParts := strings.SplitN(options, "%", 2) + // TODO: add support for WHOX flags in optionsParts[0] + if len(optionsParts) == 2 { + optionsParts := strings.SplitN(optionsParts[1], ",", 2) + fields = strings.ToLower(optionsParts[0]) + if len(optionsParts) == 2 && strings.Contains(fields, "t") { + whoxToken = optionsParts[1] + } + } + return fields, whoxToken +} + +func ParseWHOXReply(msg *irc.Message, fields string) (*WHOXInfo, error) { + if msg.Command != RPL_WHOSPCRPL { + return nil, fmt.Errorf("invalid WHOX reply %q", msg.Command) + } else if len(msg.Params) == 0 { + return nil, fmt.Errorf("invalid RPL_WHOSPCRPL: no params") + } + + fieldSet := make(map[byte]bool) + for i := 0; i < len(fields); i++ { + fieldSet[fields[i]] = true + } + + var info WHOXInfo + values := msg.Params[1:] + for _, field := range whoxFields { + if !fieldSet[field] { + continue + } + + if len(values) == 0 { + return nil, fmt.Errorf("invalid RPL_WHOSPCRPL: missing value for field %q", string(field)) + } + + info.set(field, values[0]) + values = values[1:] + } + + return &info, nil +}