From 657e25b25c47511ac8a1bb042d0bc9481a813e3b Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 6 Jun 2022 09:58:39 +0200 Subject: [PATCH] Make casemapMap more type-safe In addition to a type-safe getter, also define type-safe setters and iterators. References: https://lists.sr.ht/~emersion/soju-dev/patches/32777 --- downstream.go | 49 +++++++++--------- irc.go | 134 ++++++++++++++++++++++++++++++++++++-------------- service.go | 10 ++-- upstream.go | 79 ++++++++++++++--------------- user.go | 48 +++++++++--------- 5 files changed, 183 insertions(+), 137 deletions(-) diff --git a/downstream.go b/downstream.go index 81b4a66..bed7616 100644 --- a/downstream.go +++ b/downstream.go @@ -1592,14 +1592,13 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { } dc.forEachUpstream(func(uc *upstreamConn) { - for _, entry := range uc.channels.innerMap { - ch := entry.value.(*upstreamChannel) + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { if !ch.complete { - continue + return } - record := uc.network.channels.Value(ch.Name) + record := uc.network.channels.Get(ch.Name) if record != nil && record.Detached { - continue + return } dc.SendMessage(&irc.Message{ @@ -1609,7 +1608,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { }) forwardChannel(ctx, dc, ch) - } + }) }) dc.forEachNetwork(func(net *network) { @@ -1667,7 +1666,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t return } - ch := net.channels.Value(target) + ch := net.channels.Get(target) ctx, cancel := context.WithTimeout(ctx, backlogTimeout) defer cancel() @@ -1938,7 +1937,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) } - ch := uc.network.channels.Value(upstreamName) + ch := uc.network.channels.Get(upstreamName) if ch != nil { // Don't clear the channel key if there's one set // TODO: add a way to unset the channel key @@ -1951,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Name: upstreamName, Key: key, } - uc.network.channels.SetValue(upstreamName, ch) + uc.network.channels.Set(upstreamName, ch) } if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) @@ -1975,7 +1974,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if strings.EqualFold(reason, "detach") { - ch := uc.network.channels.Value(upstreamName) + ch := uc.network.channels.Get(upstreamName) if ch != nil { uc.network.detach(ch) } else { @@ -1983,7 +1982,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Name: name, Detached: true, } - uc.network.channels.SetValue(upstreamName, ch) + uc.network.channels.Set(upstreamName, ch) } if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) @@ -2119,7 +2118,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: params, }) } else { - ch := uc.channels.Value(upstreamName) + ch := uc.channels.Get(upstreamName) if ch == nil { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHCHANNEL, @@ -2168,7 +2167,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{upstreamName, topic}, }) } else { // getting topic - ch := uc.channels.Value(upstreamName) + ch := uc.channels.Get(upstreamName) if ch == nil { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHCHANNEL, @@ -2223,7 +2222,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } - ch := uc.channels.Value(upstreamName) + ch := uc.channels.Get(upstreamName) if ch != nil { sendNames(dc, ch) } else { @@ -2677,7 +2676,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. for _, target := range strings.Split(targets, ",") { if subcommand == "+" { // Hard limit, just to avoid having downstreams fill our map - if len(dc.monitored.innerMap) >= 1000 { + if dc.monitored.Len() >= 1000 { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_MONLISTFULL, @@ -2686,7 +2685,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. continue } - dc.monitored.SetValue(target, nil) + dc.monitored.set(target, nil) if uc.network.casemap(target) == serviceNickCM { // BouncerServ is never tired @@ -2700,7 +2699,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if uc.monitored.Has(target) { cmd := irc.RPL_MONOFFLINE - if online := uc.monitored.Value(target); online { + if online := uc.monitored.Get(target); online { cmd = irc.RPL_MONONLINE } @@ -2711,7 +2710,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) } } else { - dc.monitored.Delete(target) + dc.monitored.Del(target) } } uc.updateMonitor() @@ -2721,7 +2720,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. uc.updateMonitor() case "L": // list // TODO: be less lazy and pack the list - for _, entry := range dc.monitored.innerMap { + for _, entry := range dc.monitored.m { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_MONLIST, @@ -2735,11 +2734,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) case "S": // status // TODO: be less lazy and pack the lists - for _, entry := range dc.monitored.innerMap { + for _, entry := range dc.monitored.m { target := entry.originalKey cmd := irc.RPL_MONOFFLINE - if online := uc.monitored.Value(target); online { + if online := uc.monitored.Get(target); online { cmd = irc.RPL_MONONLINE } @@ -2872,7 +2871,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) { for _, target := range targets { - if ch := network.channels.Value(target.Name); ch != nil && ch.Detached { + if ch := network.channels.Get(target.Name); ch != nil && ch.Detached { continue } @@ -3329,12 +3328,10 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) { downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) var members []string - for _, entry := range ch.Members.innerMap { - nick := entry.originalKey - memberships := entry.value.(*xirc.MembershipSet) + ch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) { s := formatMemberPrefix(*memberships, dc) + dc.marshalEntity(ch.conn.network, nick) members = append(members, s) - } + }) msgs := xirc.GenerateNamesReply(dc.srv.prefix(), dc.nick, downstreamName, ch.Status, members) for _, msg := range msgs { diff --git a/irc.go b/irc.go index 6399122..9fd6b76 100644 --- a/irc.go +++ b/irc.go @@ -111,7 +111,7 @@ outer: return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) } member := arguments[nextArgument] - m := ch.Members.Value(member) + m := ch.Members.Get(member) if m != nil { if plusMinus == '+' { m.Add(ch.conn.availableMemberships, membership) @@ -304,8 +304,8 @@ func partialCasemap(higher casemapping, name string) string { } type casemapMap struct { - innerMap map[string]casemapEntry - casemap casemapping + m map[string]casemapEntry + casemap casemapping } type casemapEntry struct { @@ -315,95 +315,153 @@ type casemapEntry struct { func newCasemapMap() casemapMap { return casemapMap{ - innerMap: make(map[string]casemapEntry), - casemap: casemapNone, + m: make(map[string]casemapEntry), + casemap: casemapNone, } } func (cm *casemapMap) Has(name string) bool { - _, ok := cm.innerMap[cm.casemap(name)] + _, ok := cm.m[cm.casemap(name)] return ok } func (cm *casemapMap) Len() int { - return len(cm.innerMap) + return len(cm.m) } -func (cm *casemapMap) SetValue(name string, value interface{}) { - nameCM := cm.casemap(name) - entry, ok := cm.innerMap[nameCM] +func (cm *casemapMap) get(name string) interface{} { + entry, ok := cm.m[cm.casemap(name)] if !ok { - cm.innerMap[nameCM] = casemapEntry{ + return nil + } + return entry.value +} + +func (cm *casemapMap) set(name string, value interface{}) { + nameCM := cm.casemap(name) + entry, ok := cm.m[nameCM] + if !ok { + cm.m[nameCM] = casemapEntry{ originalKey: name, value: value, } return } entry.value = value - cm.innerMap[nameCM] = entry + cm.m[nameCM] = entry } -func (cm *casemapMap) Delete(name string) { - delete(cm.innerMap, cm.casemap(name)) +func (cm *casemapMap) Del(name string) { + delete(cm.m, cm.casemap(name)) } func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { cm.casemap = newCasemap - newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) - for _, entry := range cm.innerMap { - newInnerMap[cm.casemap(entry.originalKey)] = entry + m := make(map[string]casemapEntry, len(cm.m)) + for _, entry := range cm.m { + m[cm.casemap(entry.originalKey)] = entry } - cm.innerMap = newInnerMap + cm.m = m } type upstreamChannelCasemapMap struct{ casemapMap } -func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { - entry, ok := cm.innerMap[cm.casemap(name)] - if !ok { +func (cm *upstreamChannelCasemapMap) Get(name string) *upstreamChannel { + if v := cm.get(name); v == nil { return nil + } else { + return v.(*upstreamChannel) + } +} + +func (cm *upstreamChannelCasemapMap) Set(name string, uch *upstreamChannel) { + cm.set(name, uch) +} + +func (cm *upstreamChannelCasemapMap) ForEach(f func(string, *upstreamChannel)) { + for _, entry := range cm.m { + f(entry.originalKey, entry.value.(*upstreamChannel)) } - return entry.value.(*upstreamChannel) } type channelCasemapMap struct{ casemapMap } -func (cm *channelCasemapMap) Value(name string) *database.Channel { - entry, ok := cm.innerMap[cm.casemap(name)] - if !ok { +func (cm *channelCasemapMap) Get(name string) *database.Channel { + if v := cm.get(name); v == nil { return nil + } else { + return v.(*database.Channel) + } +} + +func (cm *channelCasemapMap) Set(name string, ch *database.Channel) { + cm.set(name, ch) +} + +func (cm *channelCasemapMap) ForEach(f func(string, *database.Channel)) { + for _, entry := range cm.m { + f(entry.originalKey, entry.value.(*database.Channel)) } - return entry.value.(*database.Channel) } type membershipsCasemapMap struct{ casemapMap } -func (cm *membershipsCasemapMap) Value(name string) *xirc.MembershipSet { - entry, ok := cm.innerMap[cm.casemap(name)] - if !ok { +func (cm *membershipsCasemapMap) Get(name string) *xirc.MembershipSet { + if v := cm.get(name); v == nil { return nil + } else { + return v.(*xirc.MembershipSet) + } +} + +func (cm *membershipsCasemapMap) Set(name string, ms *xirc.MembershipSet) { + cm.set(name, ms) +} + +func (cm *membershipsCasemapMap) ForEach(f func(string, *xirc.MembershipSet)) { + for _, entry := range cm.m { + f(entry.originalKey, entry.value.(*xirc.MembershipSet)) } - return entry.value.(*xirc.MembershipSet) } type deliveredCasemapMap struct{ casemapMap } -func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { - entry, ok := cm.innerMap[cm.casemap(name)] - if !ok { +func (cm *deliveredCasemapMap) Get(name string) deliveredClientMap { + if v := cm.get(name); v == nil { return nil + } else { + return v.(deliveredClientMap) + } +} + +func (cm *deliveredCasemapMap) Set(name string, m deliveredClientMap) { + cm.set(name, m) +} + +func (cm *deliveredCasemapMap) ForEach(f func(string, deliveredClientMap)) { + for _, entry := range cm.m { + f(entry.originalKey, entry.value.(deliveredClientMap)) } - return entry.value.(deliveredClientMap) } type monitorCasemapMap struct{ casemapMap } -func (cm *monitorCasemapMap) Value(name string) (online bool) { - entry, ok := cm.innerMap[cm.casemap(name)] - if !ok { +func (cm *monitorCasemapMap) Get(name string) (online bool) { + if v := cm.get(name); v == nil { return false + } else { + return v.(bool) + } +} + +func (cm *monitorCasemapMap) Set(name string, online bool) { + cm.set(name, online) +} + +func (cm *monitorCasemapMap) ForEach(f func(name string, online bool)) { + for _, entry := range cm.m { + f(entry.originalKey, entry.value.(bool)) } - return entry.value.(bool) } func isWordBoundary(r rune) bool { diff --git a/service.go b/service.go index 9eaf3be..d1c944c 100644 --- a/service.go +++ b/service.go @@ -974,9 +974,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params sendNetwork := func(net *network) { var channels []*database.Channel - for _, entry := range net.channels.innerMap { - channels = append(channels, entry.value.(*database.Channel)) - } + net.channels.ForEach(func(_ string, ch *database.Channel) { + channels = append(channels, ch) + }) sort.Slice(channels, func(i, j int) bool { return strings.ReplaceAll(channels[i].Name, "#", "") < @@ -986,7 +986,7 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params for _, ch := range channels { var uch *upstreamChannel if net.conn != nil { - uch = net.conn.channels.Value(ch.Name) + uch = net.conn.channels.Get(ch.Name) } name := ch.Name @@ -1109,7 +1109,7 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params return fmt.Errorf("unknown channel %q", name) } - ch := uc.network.channels.Value(upstreamName) + ch := uc.network.channels.Get(upstreamName) if ch == nil { return fmt.Errorf("unknown channel %q", name) } diff --git a/upstream.go b/upstream.go index db04acc..c708bc5 100644 --- a/upstream.go +++ b/upstream.go @@ -292,7 +292,7 @@ func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn { } func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { - ch := uc.channels.Value(name) + ch := uc.channels.Get(name) if ch == nil { return nil, fmt.Errorf("unknown channel %q", name) } @@ -513,7 +513,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err self := uc.isOurNick(msg.Prefix.Name) - ch := uc.network.channels.Value(target) + ch := uc.network.channels.Get(target) if ch != nil && msg.Command != "TAGMSG" && !self { if ch.Detached { uc.handleDetachedMessage(ctx, ch, msg) @@ -757,11 +757,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.network.channels.Len() > 0 { var channels, keys []string - for _, entry := range uc.network.channels.innerMap { - ch := entry.value.(*database.Channel) + uc.network.channels.ForEach(func(_ string, ch *database.Channel) { channels = append(channels, ch.Name) keys = append(keys, ch.Key) - } + }) for _, msg := range xirc.GenerateJoin(channels, keys) { uc.SendMessage(ctx, msg) @@ -918,15 +917,14 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.nickCM = uc.network.casemap(uc.nick) } - for _, entry := range uc.channels.innerMap { - ch := entry.value.(*upstreamChannel) - memberships := ch.Members.Value(msg.Prefix.Name) + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { + memberships := ch.Members.Get(msg.Prefix.Name) if memberships != nil { - ch.Members.Delete(msg.Prefix.Name) - ch.Members.SetValue(newNick, memberships) + ch.Members.Del(msg.Prefix.Name) + ch.Members.Set(newNick, memberships) uc.appendLog(ch.Name, msg) } - } + }) if !me { uc.forEachDownstream(func(dc *downstreamConn) { @@ -995,7 +993,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.logger.Printf("joined channel %q", ch) members := membershipsCasemapMap{newCasemapMap()} members.casemap = uc.network.casemap - uc.channels.SetValue(ch, &upstreamChannel{ + uc.channels.Set(ch, &upstreamChannel{ Name: ch, conn: uc, Members: members, @@ -1011,7 +1009,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if err != nil { return err } - ch.Members.SetValue(msg.Prefix.Name, &xirc.MembershipSet{}) + ch.Members.Set(msg.Prefix.Name, &xirc.MembershipSet{}) } chMsg := msg.Copy() @@ -1027,9 +1025,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err for _, ch := range strings.Split(channels, ",") { if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("parted channel %q", ch) - uch := uc.channels.Value(ch) - if uch != nil { - uc.channels.Delete(ch) + if uch := uc.channels.Get(ch); uch != nil { + uc.channels.Del(ch) uch.updateAutoDetach(0) } } else { @@ -1037,7 +1034,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if err != nil { return err } - ch.Members.Delete(msg.Prefix.Name) + ch.Members.Del(msg.Prefix.Name) } chMsg := msg.Copy() @@ -1052,13 +1049,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.isOurNick(user) { uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) - uc.channels.Delete(channel) + uc.channels.Del(channel) } else { ch, err := uc.getChannel(channel) if err != nil { return err } - ch.Members.Delete(user) + ch.Members.Del(user) } uc.produce(channel, msg, 0) @@ -1067,14 +1064,12 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.logger.Printf("quit") } - for _, entry := range uc.channels.innerMap { - ch := entry.value.(*upstreamChannel) + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { if ch.Members.Has(msg.Prefix.Name) { - ch.Members.Delete(msg.Prefix.Name) - + ch.Members.Del(msg.Prefix.Name) uc.appendLog(ch.Name, msg) } - } + }) if msg.Prefix.Name != uc.nick { uc.forEachDownstream(func(dc *downstreamConn) { @@ -1147,7 +1142,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.appendLog(ch.Name, msg) - c := uc.network.channels.Value(name) + c := uc.network.channels.Get(name) if c == nil || !c.Detached { uc.forEachDownstream(func(dc *downstreamConn) { params := make([]string, len(msg.Params)) @@ -1211,7 +1206,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - c := uc.network.channels.Value(channel) + c := uc.network.channels.Get(channel) if firstMode && (c == nil || !c.Detached) { modeStr, modeParams := ch.modes.Format() @@ -1240,7 +1235,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err firstCreationTime := ch.creationTime == "" ch.creationTime = creationTime - c := uc.network.channels.Value(channel) + c := uc.network.channels.Get(channel) if firstCreationTime && (c == nil || !c.Detached) { uc.forEachDownstream(func(dc *downstreamConn) { dc.SendMessage(&irc.Message{ @@ -1269,7 +1264,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } ch.TopicTime = time.Unix(sec, 0) - c := uc.network.channels.Value(channel) + c := uc.network.channels.Get(channel) if firstTopicWhoTime && (c == nil || !c.Detached) { uc.forEachDownstream(func(dc *downstreamConn) { topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho) @@ -1322,7 +1317,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - ch := uc.channels.Value(name) + ch := uc.channels.Get(name) if ch == nil { // NAMES on a channel we have not joined, forward to downstream uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { @@ -1351,7 +1346,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err for _, s := range splitSpace(members) { memberships, nick := uc.parseMembershipPrefix(s) - ch.Members.SetValue(nick, memberships) + ch.Members.Set(nick, &memberships) } case irc.RPL_ENDOFNAMES: var name string @@ -1359,7 +1354,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - ch := uc.channels.Value(name) + ch := uc.channels.Get(name) if ch == nil { // NAMES on a channel we have not joined, forward to downstream uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { @@ -1379,7 +1374,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } ch.complete = true - c := uc.network.channels.Value(name) + c := uc.network.channels.Get(name) if c == nil || !c.Detached { uc.forEachDownstream(func(dc *downstreamConn) { forwardChannel(ctx, dc, ch) @@ -1542,7 +1537,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err online := msg.Command == irc.RPL_MONONLINE for _, target := range targets { prefix := irc.ParsePrefix(target) - uc.monitored.SetValue(prefix.Name, online) + uc.monitored.Set(prefix.Name, online) } // Check if the nick we want is now free @@ -2112,7 +2107,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64 } // Don't forward messages if it's a detached channel - ch := uc.network.channels.Value(target) + ch := uc.network.channels.Get(target) detached := ch != nil && ch.Detached uc.forEachDownstream(func(dc *downstreamConn) { @@ -2148,11 +2143,11 @@ func (uc *upstreamConn) updateAway() { } func (uc *upstreamConn) updateChannelAutoDetach(name string) { - uch := uc.channels.Value(name) + uch := uc.channels.Get(name) if uch == nil { return } - ch := uc.network.channels.Value(name) + ch := uc.network.channels.Get(name) if ch == nil || ch.Detached { return } @@ -2170,7 +2165,7 @@ func (uc *upstreamConn) updateMonitor() { var addList []string seen := make(map[string]struct{}) uc.forEachDownstream(func(dc *downstreamConn) { - for _, entry := range dc.monitored.innerMap { + for _, entry := range dc.monitored.m { targetCM := uc.network.casemap(entry.originalKey) if targetCM == serviceNickCM { continue @@ -2195,13 +2190,13 @@ func (uc *upstreamConn) updateMonitor() { removeAll := true var removeList []string - for targetCM, entry := range uc.monitored.innerMap { - if _, ok := seen[targetCM]; ok { + uc.monitored.ForEach(func(nick string, online bool) { + if _, ok := seen[uc.network.casemap(nick)]; ok { removeAll = false } else { - removeList = append(removeList, entry.originalKey) + removeList = append(removeList, nick) } - } + }) // TODO: better handle the case where len(uc.monitored) + len(addList) // exceeds the limit, probably by immediately sending ERR_MONLISTFULL? @@ -2221,6 +2216,6 @@ func (uc *upstreamConn) updateMonitor() { } for _, target := range removeList { - uc.monitored.Delete(target) + uc.monitored.Del(target) } } diff --git a/user.go b/user.go index 73221bd..cac5312 100644 --- a/user.go +++ b/user.go @@ -85,11 +85,11 @@ func newDeliveredStore() deliveredStore { } func (ds deliveredStore) HasTarget(target string) bool { - return ds.m.Value(target) != nil + return ds.m.Get(target) != nil } func (ds deliveredStore) LoadID(target, clientName string) string { - clients := ds.m.Value(target) + clients := ds.m.Get(target) if clients == nil { return "" } @@ -97,28 +97,27 @@ func (ds deliveredStore) LoadID(target, clientName string) string { } func (ds deliveredStore) StoreID(target, clientName, msgID string) { - clients := ds.m.Value(target) + clients := ds.m.Get(target) if clients == nil { clients = make(deliveredClientMap) - ds.m.SetValue(target, clients) + ds.m.Set(target, clients) } clients[clientName] = msgID } func (ds deliveredStore) ForEachTarget(f func(target string)) { - for _, entry := range ds.m.innerMap { - f(entry.originalKey) - } + ds.m.ForEach(func(name string, _ deliveredClientMap) { + f(name) + }) } func (ds deliveredStore) ForEachClient(f func(clientName string)) { clients := make(map[string]struct{}) - for _, entry := range ds.m.innerMap { - delivered := entry.value.(deliveredClientMap) + ds.m.ForEach(func(name string, delivered deliveredClientMap) { for clientName := range delivered { clients[clientName] = struct{}{} } - } + }) for clientName := range clients { f(clientName) @@ -144,7 +143,7 @@ func newNetwork(user *user, record *database.Network, channels []database.Channe m := channelCasemapMap{newCasemapMap()} for _, ch := range channels { ch := ch - m.SetValue(ch.Name, &ch) + m.Set(ch.Name, &ch) } return &network{ @@ -300,7 +299,7 @@ func (net *network) detach(ch *database.Channel) { } if net.conn != nil { - uch := net.conn.channels.Value(ch.Name) + uch := net.conn.channels.Get(ch.Name) if uch != nil { uch.updateAutoDetach(0) } @@ -328,7 +327,7 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) { var uch *upstreamChannel if net.conn != nil { - uch = net.conn.channels.Value(ch.Name) + uch = net.conn.channels.Get(ch.Name) net.conn.updateChannelAutoDetach(ch.Name) } @@ -351,12 +350,12 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) { } func (net *network) deleteChannel(ctx context.Context, name string) error { - ch := net.channels.Value(name) + ch := net.channels.Get(name) if ch == nil { return fmt.Errorf("unknown channel %q", name) } if net.conn != nil { - uch := net.conn.channels.Value(ch.Name) + uch := net.conn.channels.Get(ch.Name) if uch != nil { uch.updateAutoDetach(0) } @@ -365,7 +364,7 @@ func (net *network) deleteChannel(ctx context.Context, name string) error { if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { return err } - net.channels.Delete(name) + net.channels.Del(name) return nil } @@ -375,10 +374,9 @@ func (net *network) updateCasemapping(newCasemap casemapping) { net.delivered.m.SetCasemapping(newCasemap) if uc := net.conn; uc != nil { uc.channels.SetCasemapping(newCasemap) - for _, entry := range uc.channels.innerMap { - uch := entry.value.(*upstreamChannel) + uc.channels.ForEach(func(_ string, uch *upstreamChannel) { uch.Members.SetCasemapping(newCasemap) - } + }) uc.monitored.SetCasemapping(newCasemap) } net.forEachDownstream(func(dc *downstreamConn) { @@ -623,7 +621,7 @@ func (u *user) run() { } case eventChannelDetach: uc, name := e.uc, e.name - c := uc.network.channels.Value(name) + c := uc.network.channels.Get(name) if c == nil || c.Detached { continue } @@ -746,10 +744,9 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { uc.abortPendingCommands() - for _, entry := range uc.channels.innerMap { - uch := entry.value.(*upstreamChannel) + uc.channels.ForEach(func(_ string, uch *upstreamChannel) { uch.updateAutoDetach(0) - } + }) uc.forEachDownstream(func(dc *downstreamConn) { dc.updateSupportedCaps() @@ -924,10 +921,9 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne // Most network changes require us to re-connect to the upstream server channels := make([]database.Channel, 0, network.channels.Len()) - for _, entry := range network.channels.innerMap { - ch := entry.value.(*database.Channel) + network.channels.ForEach(func(_ string, ch *database.Channel) { channels = append(channels, *ch) - } + }) updatedNetwork := newNetwork(u, record, channels)