Make casemapMap more type-safe

In addition to a type-safe getter, also define type-safe setters
and iterators.

References: https://lists.sr.ht/~emersion/soju-dev/patches/32777
This commit is contained in:
Simon Ser 2022-06-06 09:58:39 +02:00
parent c8f9728ff6
commit 657e25b25c
5 changed files with 183 additions and 137 deletions

View File

@ -1592,14 +1592,13 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
} }
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
for _, entry := range uc.channels.innerMap { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
ch := entry.value.(*upstreamChannel)
if !ch.complete { if !ch.complete {
continue return
} }
record := uc.network.channels.Value(ch.Name) record := uc.network.channels.Get(ch.Name)
if record != nil && record.Detached { if record != nil && record.Detached {
continue return
} }
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -1609,7 +1608,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
}) })
forwardChannel(ctx, dc, ch) forwardChannel(ctx, dc, ch)
} })
}) })
dc.forEachNetwork(func(net *network) { dc.forEachNetwork(func(net *network) {
@ -1667,7 +1666,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
return return
} }
ch := net.channels.Value(target) ch := net.channels.Get(target)
ctx, cancel := context.WithTimeout(ctx, backlogTimeout) ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
defer cancel() defer cancel()
@ -1938,7 +1937,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}) })
} }
ch := uc.network.channels.Value(upstreamName) ch := uc.network.channels.Get(upstreamName)
if ch != nil { if ch != nil {
// Don't clear the channel key if there's one set // Don't clear the channel key if there's one set
// TODO: add a way to unset the channel key // TODO: add a way to unset the channel key
@ -1951,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Name: upstreamName, Name: upstreamName,
Key: key, Key: key,
} }
uc.network.channels.SetValue(upstreamName, ch) uc.network.channels.Set(upstreamName, ch)
} }
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -1975,7 +1974,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
if strings.EqualFold(reason, "detach") { if strings.EqualFold(reason, "detach") {
ch := uc.network.channels.Value(upstreamName) ch := uc.network.channels.Get(upstreamName)
if ch != nil { if ch != nil {
uc.network.detach(ch) uc.network.detach(ch)
} else { } else {
@ -1983,7 +1982,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Name: name, Name: name,
Detached: true, Detached: true,
} }
uc.network.channels.SetValue(upstreamName, ch) uc.network.channels.Set(upstreamName, ch)
} }
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -2119,7 +2118,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: params, Params: params,
}) })
} else { } else {
ch := uc.channels.Value(upstreamName) ch := uc.channels.Get(upstreamName)
if ch == nil { if ch == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
@ -2168,7 +2167,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: []string{upstreamName, topic}, Params: []string{upstreamName, topic},
}) })
} else { // getting topic } else { // getting topic
ch := uc.channels.Value(upstreamName) ch := uc.channels.Get(upstreamName)
if ch == nil { if ch == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
@ -2223,7 +2222,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return err return err
} }
ch := uc.channels.Value(upstreamName) ch := uc.channels.Get(upstreamName)
if ch != nil { if ch != nil {
sendNames(dc, ch) sendNames(dc, ch)
} else { } else {
@ -2677,7 +2676,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
for _, target := range strings.Split(targets, ",") { for _, target := range strings.Split(targets, ",") {
if subcommand == "+" { if subcommand == "+" {
// Hard limit, just to avoid having downstreams fill our map // Hard limit, just to avoid having downstreams fill our map
if len(dc.monitored.innerMap) >= 1000 { if dc.monitored.Len() >= 1000 {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_MONLISTFULL, Command: irc.ERR_MONLISTFULL,
@ -2686,7 +2685,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue continue
} }
dc.monitored.SetValue(target, nil) dc.monitored.set(target, nil)
if uc.network.casemap(target) == serviceNickCM { if uc.network.casemap(target) == serviceNickCM {
// BouncerServ is never tired // BouncerServ is never tired
@ -2700,7 +2699,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if uc.monitored.Has(target) { if uc.monitored.Has(target) {
cmd := irc.RPL_MONOFFLINE cmd := irc.RPL_MONOFFLINE
if online := uc.monitored.Value(target); online { if online := uc.monitored.Get(target); online {
cmd = irc.RPL_MONONLINE cmd = irc.RPL_MONONLINE
} }
@ -2711,7 +2710,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}) })
} }
} else { } else {
dc.monitored.Delete(target) dc.monitored.Del(target)
} }
} }
uc.updateMonitor() uc.updateMonitor()
@ -2721,7 +2720,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
uc.updateMonitor() uc.updateMonitor()
case "L": // list case "L": // list
// TODO: be less lazy and pack the list // TODO: be less lazy and pack the list
for _, entry := range dc.monitored.innerMap { for _, entry := range dc.monitored.m {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_MONLIST, Command: irc.RPL_MONLIST,
@ -2735,11 +2734,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}) })
case "S": // status case "S": // status
// TODO: be less lazy and pack the lists // TODO: be less lazy and pack the lists
for _, entry := range dc.monitored.innerMap { for _, entry := range dc.monitored.m {
target := entry.originalKey target := entry.originalKey
cmd := irc.RPL_MONOFFLINE cmd := irc.RPL_MONOFFLINE
if online := uc.monitored.Value(target); online { if online := uc.monitored.Get(target); online {
cmd = irc.RPL_MONONLINE cmd = irc.RPL_MONONLINE
} }
@ -2872,7 +2871,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) { dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
for _, target := range targets { for _, target := range targets {
if ch := network.channels.Value(target.Name); ch != nil && ch.Detached { if ch := network.channels.Get(target.Name); ch != nil && ch.Detached {
continue continue
} }
@ -3329,12 +3328,10 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) {
downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
var members []string var members []string
for _, entry := range ch.Members.innerMap { ch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) {
nick := entry.originalKey
memberships := entry.value.(*xirc.MembershipSet)
s := formatMemberPrefix(*memberships, dc) + dc.marshalEntity(ch.conn.network, nick) s := formatMemberPrefix(*memberships, dc) + dc.marshalEntity(ch.conn.network, nick)
members = append(members, s) members = append(members, s)
} })
msgs := xirc.GenerateNamesReply(dc.srv.prefix(), dc.nick, downstreamName, ch.Status, members) msgs := xirc.GenerateNamesReply(dc.srv.prefix(), dc.nick, downstreamName, ch.Status, members)
for _, msg := range msgs { for _, msg := range msgs {

130
irc.go
View File

@ -111,7 +111,7 @@ outer:
return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
} }
member := arguments[nextArgument] member := arguments[nextArgument]
m := ch.Members.Value(member) m := ch.Members.Get(member)
if m != nil { if m != nil {
if plusMinus == '+' { if plusMinus == '+' {
m.Add(ch.conn.availableMemberships, membership) m.Add(ch.conn.availableMemberships, membership)
@ -304,7 +304,7 @@ func partialCasemap(higher casemapping, name string) string {
} }
type casemapMap struct { type casemapMap struct {
innerMap map[string]casemapEntry m map[string]casemapEntry
casemap casemapping casemap casemapping
} }
@ -315,95 +315,153 @@ type casemapEntry struct {
func newCasemapMap() casemapMap { func newCasemapMap() casemapMap {
return casemapMap{ return casemapMap{
innerMap: make(map[string]casemapEntry), m: make(map[string]casemapEntry),
casemap: casemapNone, casemap: casemapNone,
} }
} }
func (cm *casemapMap) Has(name string) bool { func (cm *casemapMap) Has(name string) bool {
_, ok := cm.innerMap[cm.casemap(name)] _, ok := cm.m[cm.casemap(name)]
return ok return ok
} }
func (cm *casemapMap) Len() int { func (cm *casemapMap) Len() int {
return len(cm.innerMap) return len(cm.m)
} }
func (cm *casemapMap) SetValue(name string, value interface{}) { func (cm *casemapMap) get(name string) interface{} {
nameCM := cm.casemap(name) entry, ok := cm.m[cm.casemap(name)]
entry, ok := cm.innerMap[nameCM]
if !ok { if !ok {
cm.innerMap[nameCM] = casemapEntry{ return nil
}
return entry.value
}
func (cm *casemapMap) set(name string, value interface{}) {
nameCM := cm.casemap(name)
entry, ok := cm.m[nameCM]
if !ok {
cm.m[nameCM] = casemapEntry{
originalKey: name, originalKey: name,
value: value, value: value,
} }
return return
} }
entry.value = value entry.value = value
cm.innerMap[nameCM] = entry cm.m[nameCM] = entry
} }
func (cm *casemapMap) Delete(name string) { func (cm *casemapMap) Del(name string) {
delete(cm.innerMap, cm.casemap(name)) delete(cm.m, cm.casemap(name))
} }
func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
cm.casemap = newCasemap cm.casemap = newCasemap
newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) m := make(map[string]casemapEntry, len(cm.m))
for _, entry := range cm.innerMap { for _, entry := range cm.m {
newInnerMap[cm.casemap(entry.originalKey)] = entry m[cm.casemap(entry.originalKey)] = entry
} }
cm.innerMap = newInnerMap cm.m = m
} }
type upstreamChannelCasemapMap struct{ casemapMap } type upstreamChannelCasemapMap struct{ casemapMap }
func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { func (cm *upstreamChannelCasemapMap) Get(name string) *upstreamChannel {
entry, ok := cm.innerMap[cm.casemap(name)] if v := cm.get(name); v == nil {
if !ok {
return nil return nil
} else {
return v.(*upstreamChannel)
}
}
func (cm *upstreamChannelCasemapMap) Set(name string, uch *upstreamChannel) {
cm.set(name, uch)
}
func (cm *upstreamChannelCasemapMap) ForEach(f func(string, *upstreamChannel)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*upstreamChannel))
} }
return entry.value.(*upstreamChannel)
} }
type channelCasemapMap struct{ casemapMap } type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Value(name string) *database.Channel { func (cm *channelCasemapMap) Get(name string) *database.Channel {
entry, ok := cm.innerMap[cm.casemap(name)] if v := cm.get(name); v == nil {
if !ok {
return nil return nil
} else {
return v.(*database.Channel)
}
}
func (cm *channelCasemapMap) Set(name string, ch *database.Channel) {
cm.set(name, ch)
}
func (cm *channelCasemapMap) ForEach(f func(string, *database.Channel)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*database.Channel))
} }
return entry.value.(*database.Channel)
} }
type membershipsCasemapMap struct{ casemapMap } type membershipsCasemapMap struct{ casemapMap }
func (cm *membershipsCasemapMap) Value(name string) *xirc.MembershipSet { func (cm *membershipsCasemapMap) Get(name string) *xirc.MembershipSet {
entry, ok := cm.innerMap[cm.casemap(name)] if v := cm.get(name); v == nil {
if !ok {
return nil return nil
} else {
return v.(*xirc.MembershipSet)
}
}
func (cm *membershipsCasemapMap) Set(name string, ms *xirc.MembershipSet) {
cm.set(name, ms)
}
func (cm *membershipsCasemapMap) ForEach(f func(string, *xirc.MembershipSet)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*xirc.MembershipSet))
} }
return entry.value.(*xirc.MembershipSet)
} }
type deliveredCasemapMap struct{ casemapMap } type deliveredCasemapMap struct{ casemapMap }
func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { func (cm *deliveredCasemapMap) Get(name string) deliveredClientMap {
entry, ok := cm.innerMap[cm.casemap(name)] if v := cm.get(name); v == nil {
if !ok {
return nil return nil
} else {
return v.(deliveredClientMap)
}
}
func (cm *deliveredCasemapMap) Set(name string, m deliveredClientMap) {
cm.set(name, m)
}
func (cm *deliveredCasemapMap) ForEach(f func(string, deliveredClientMap)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(deliveredClientMap))
} }
return entry.value.(deliveredClientMap)
} }
type monitorCasemapMap struct{ casemapMap } type monitorCasemapMap struct{ casemapMap }
func (cm *monitorCasemapMap) Value(name string) (online bool) { func (cm *monitorCasemapMap) Get(name string) (online bool) {
entry, ok := cm.innerMap[cm.casemap(name)] if v := cm.get(name); v == nil {
if !ok {
return false return false
} else {
return v.(bool)
}
}
func (cm *monitorCasemapMap) Set(name string, online bool) {
cm.set(name, online)
}
func (cm *monitorCasemapMap) ForEach(f func(name string, online bool)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(bool))
} }
return entry.value.(bool)
} }
func isWordBoundary(r rune) bool { func isWordBoundary(r rune) bool {

View File

@ -974,9 +974,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
sendNetwork := func(net *network) { sendNetwork := func(net *network) {
var channels []*database.Channel var channels []*database.Channel
for _, entry := range net.channels.innerMap { net.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, entry.value.(*database.Channel)) channels = append(channels, ch)
} })
sort.Slice(channels, func(i, j int) bool { sort.Slice(channels, func(i, j int) bool {
return strings.ReplaceAll(channels[i].Name, "#", "") < return strings.ReplaceAll(channels[i].Name, "#", "") <
@ -986,7 +986,7 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
for _, ch := range channels { for _, ch := range channels {
var uch *upstreamChannel var uch *upstreamChannel
if net.conn != nil { if net.conn != nil {
uch = net.conn.channels.Value(ch.Name) uch = net.conn.channels.Get(ch.Name)
} }
name := ch.Name name := ch.Name
@ -1109,7 +1109,7 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }
ch := uc.network.channels.Value(upstreamName) ch := uc.network.channels.Get(upstreamName)
if ch == nil { if ch == nil {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }

View File

@ -292,7 +292,7 @@ func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
} }
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch := uc.channels.Value(name) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
return nil, fmt.Errorf("unknown channel %q", name) return nil, fmt.Errorf("unknown channel %q", name)
} }
@ -513,7 +513,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
self := uc.isOurNick(msg.Prefix.Name) self := uc.isOurNick(msg.Prefix.Name)
ch := uc.network.channels.Value(target) ch := uc.network.channels.Get(target)
if ch != nil && msg.Command != "TAGMSG" && !self { if ch != nil && msg.Command != "TAGMSG" && !self {
if ch.Detached { if ch.Detached {
uc.handleDetachedMessage(ctx, ch, msg) uc.handleDetachedMessage(ctx, ch, msg)
@ -757,11 +757,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.network.channels.Len() > 0 { if uc.network.channels.Len() > 0 {
var channels, keys []string var channels, keys []string
for _, entry := range uc.network.channels.innerMap { uc.network.channels.ForEach(func(_ string, ch *database.Channel) {
ch := entry.value.(*database.Channel)
channels = append(channels, ch.Name) channels = append(channels, ch.Name)
keys = append(keys, ch.Key) keys = append(keys, ch.Key)
} })
for _, msg := range xirc.GenerateJoin(channels, keys) { for _, msg := range xirc.GenerateJoin(channels, keys) {
uc.SendMessage(ctx, msg) uc.SendMessage(ctx, msg)
@ -918,15 +917,14 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.nickCM = uc.network.casemap(uc.nick) uc.nickCM = uc.network.casemap(uc.nick)
} }
for _, entry := range uc.channels.innerMap { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
ch := entry.value.(*upstreamChannel) memberships := ch.Members.Get(msg.Prefix.Name)
memberships := ch.Members.Value(msg.Prefix.Name)
if memberships != nil { if memberships != nil {
ch.Members.Delete(msg.Prefix.Name) ch.Members.Del(msg.Prefix.Name)
ch.Members.SetValue(newNick, memberships) ch.Members.Set(newNick, memberships)
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
} }
} })
if !me { if !me {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
@ -995,7 +993,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("joined channel %q", ch) uc.logger.Printf("joined channel %q", ch)
members := membershipsCasemapMap{newCasemapMap()} members := membershipsCasemapMap{newCasemapMap()}
members.casemap = uc.network.casemap members.casemap = uc.network.casemap
uc.channels.SetValue(ch, &upstreamChannel{ uc.channels.Set(ch, &upstreamChannel{
Name: ch, Name: ch,
conn: uc, conn: uc,
Members: members, Members: members,
@ -1011,7 +1009,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if err != nil { if err != nil {
return err return err
} }
ch.Members.SetValue(msg.Prefix.Name, &xirc.MembershipSet{}) ch.Members.Set(msg.Prefix.Name, &xirc.MembershipSet{})
} }
chMsg := msg.Copy() chMsg := msg.Copy()
@ -1027,9 +1025,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
for _, ch := range strings.Split(channels, ",") { for _, ch := range strings.Split(channels, ",") {
if uc.isOurNick(msg.Prefix.Name) { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("parted channel %q", ch) uc.logger.Printf("parted channel %q", ch)
uch := uc.channels.Value(ch) if uch := uc.channels.Get(ch); uch != nil {
if uch != nil { uc.channels.Del(ch)
uc.channels.Delete(ch)
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
} else { } else {
@ -1037,7 +1034,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if err != nil { if err != nil {
return err return err
} }
ch.Members.Delete(msg.Prefix.Name) ch.Members.Del(msg.Prefix.Name)
} }
chMsg := msg.Copy() chMsg := msg.Copy()
@ -1052,13 +1049,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.isOurNick(user) { if uc.isOurNick(user) {
uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
uc.channels.Delete(channel) uc.channels.Del(channel)
} else { } else {
ch, err := uc.getChannel(channel) ch, err := uc.getChannel(channel)
if err != nil { if err != nil {
return err return err
} }
ch.Members.Delete(user) ch.Members.Del(user)
} }
uc.produce(channel, msg, 0) uc.produce(channel, msg, 0)
@ -1067,14 +1064,12 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("quit") uc.logger.Printf("quit")
} }
for _, entry := range uc.channels.innerMap { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
ch := entry.value.(*upstreamChannel)
if ch.Members.Has(msg.Prefix.Name) { if ch.Members.Has(msg.Prefix.Name) {
ch.Members.Delete(msg.Prefix.Name) ch.Members.Del(msg.Prefix.Name)
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
} }
} })
if msg.Prefix.Name != uc.nick { if msg.Prefix.Name != uc.nick {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
@ -1147,7 +1142,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
c := uc.network.channels.Value(name) c := uc.network.channels.Get(name)
if c == nil || !c.Detached { if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
params := make([]string, len(msg.Params)) params := make([]string, len(msg.Params))
@ -1211,7 +1206,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err return err
} }
c := uc.network.channels.Value(channel) c := uc.network.channels.Get(channel)
if firstMode && (c == nil || !c.Detached) { if firstMode && (c == nil || !c.Detached) {
modeStr, modeParams := ch.modes.Format() modeStr, modeParams := ch.modes.Format()
@ -1240,7 +1235,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
firstCreationTime := ch.creationTime == "" firstCreationTime := ch.creationTime == ""
ch.creationTime = creationTime ch.creationTime = creationTime
c := uc.network.channels.Value(channel) c := uc.network.channels.Get(channel)
if firstCreationTime && (c == nil || !c.Detached) { if firstCreationTime && (c == nil || !c.Detached) {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -1269,7 +1264,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
ch.TopicTime = time.Unix(sec, 0) ch.TopicTime = time.Unix(sec, 0)
c := uc.network.channels.Value(channel) c := uc.network.channels.Get(channel)
if firstTopicWhoTime && (c == nil || !c.Detached) { if firstTopicWhoTime && (c == nil || !c.Detached) {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho) topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
@ -1322,7 +1317,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err return err
} }
ch := uc.channels.Value(name) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
@ -1351,7 +1346,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
for _, s := range splitSpace(members) { for _, s := range splitSpace(members) {
memberships, nick := uc.parseMembershipPrefix(s) memberships, nick := uc.parseMembershipPrefix(s)
ch.Members.SetValue(nick, memberships) ch.Members.Set(nick, &memberships)
} }
case irc.RPL_ENDOFNAMES: case irc.RPL_ENDOFNAMES:
var name string var name string
@ -1359,7 +1354,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err return err
} }
ch := uc.channels.Value(name) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
@ -1379,7 +1374,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
ch.complete = true ch.complete = true
c := uc.network.channels.Value(name) c := uc.network.channels.Get(name)
if c == nil || !c.Detached { if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(ctx, dc, ch) forwardChannel(ctx, dc, ch)
@ -1542,7 +1537,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
online := msg.Command == irc.RPL_MONONLINE online := msg.Command == irc.RPL_MONONLINE
for _, target := range targets { for _, target := range targets {
prefix := irc.ParsePrefix(target) prefix := irc.ParsePrefix(target)
uc.monitored.SetValue(prefix.Name, online) uc.monitored.Set(prefix.Name, online)
} }
// Check if the nick we want is now free // Check if the nick we want is now free
@ -2112,7 +2107,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64
} }
// Don't forward messages if it's a detached channel // Don't forward messages if it's a detached channel
ch := uc.network.channels.Value(target) ch := uc.network.channels.Get(target)
detached := ch != nil && ch.Detached detached := ch != nil && ch.Detached
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
@ -2148,11 +2143,11 @@ func (uc *upstreamConn) updateAway() {
} }
func (uc *upstreamConn) updateChannelAutoDetach(name string) { func (uc *upstreamConn) updateChannelAutoDetach(name string) {
uch := uc.channels.Value(name) uch := uc.channels.Get(name)
if uch == nil { if uch == nil {
return return
} }
ch := uc.network.channels.Value(name) ch := uc.network.channels.Get(name)
if ch == nil || ch.Detached { if ch == nil || ch.Detached {
return return
} }
@ -2170,7 +2165,7 @@ func (uc *upstreamConn) updateMonitor() {
var addList []string var addList []string
seen := make(map[string]struct{}) seen := make(map[string]struct{})
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
for _, entry := range dc.monitored.innerMap { for _, entry := range dc.monitored.m {
targetCM := uc.network.casemap(entry.originalKey) targetCM := uc.network.casemap(entry.originalKey)
if targetCM == serviceNickCM { if targetCM == serviceNickCM {
continue continue
@ -2195,13 +2190,13 @@ func (uc *upstreamConn) updateMonitor() {
removeAll := true removeAll := true
var removeList []string var removeList []string
for targetCM, entry := range uc.monitored.innerMap { uc.monitored.ForEach(func(nick string, online bool) {
if _, ok := seen[targetCM]; ok { if _, ok := seen[uc.network.casemap(nick)]; ok {
removeAll = false removeAll = false
} else { } else {
removeList = append(removeList, entry.originalKey) removeList = append(removeList, nick)
}
} }
})
// TODO: better handle the case where len(uc.monitored) + len(addList) // TODO: better handle the case where len(uc.monitored) + len(addList)
// exceeds the limit, probably by immediately sending ERR_MONLISTFULL? // exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
@ -2221,6 +2216,6 @@ func (uc *upstreamConn) updateMonitor() {
} }
for _, target := range removeList { for _, target := range removeList {
uc.monitored.Delete(target) uc.monitored.Del(target)
} }
} }

48
user.go
View File

@ -85,11 +85,11 @@ func newDeliveredStore() deliveredStore {
} }
func (ds deliveredStore) HasTarget(target string) bool { func (ds deliveredStore) HasTarget(target string) bool {
return ds.m.Value(target) != nil return ds.m.Get(target) != nil
} }
func (ds deliveredStore) LoadID(target, clientName string) string { func (ds deliveredStore) LoadID(target, clientName string) string {
clients := ds.m.Value(target) clients := ds.m.Get(target)
if clients == nil { if clients == nil {
return "" return ""
} }
@ -97,28 +97,27 @@ func (ds deliveredStore) LoadID(target, clientName string) string {
} }
func (ds deliveredStore) StoreID(target, clientName, msgID string) { func (ds deliveredStore) StoreID(target, clientName, msgID string) {
clients := ds.m.Value(target) clients := ds.m.Get(target)
if clients == nil { if clients == nil {
clients = make(deliveredClientMap) clients = make(deliveredClientMap)
ds.m.SetValue(target, clients) ds.m.Set(target, clients)
} }
clients[clientName] = msgID clients[clientName] = msgID
} }
func (ds deliveredStore) ForEachTarget(f func(target string)) { func (ds deliveredStore) ForEachTarget(f func(target string)) {
for _, entry := range ds.m.innerMap { ds.m.ForEach(func(name string, _ deliveredClientMap) {
f(entry.originalKey) f(name)
} })
} }
func (ds deliveredStore) ForEachClient(f func(clientName string)) { func (ds deliveredStore) ForEachClient(f func(clientName string)) {
clients := make(map[string]struct{}) clients := make(map[string]struct{})
for _, entry := range ds.m.innerMap { ds.m.ForEach(func(name string, delivered deliveredClientMap) {
delivered := entry.value.(deliveredClientMap)
for clientName := range delivered { for clientName := range delivered {
clients[clientName] = struct{}{} clients[clientName] = struct{}{}
} }
} })
for clientName := range clients { for clientName := range clients {
f(clientName) f(clientName)
@ -144,7 +143,7 @@ func newNetwork(user *user, record *database.Network, channels []database.Channe
m := channelCasemapMap{newCasemapMap()} m := channelCasemapMap{newCasemapMap()}
for _, ch := range channels { for _, ch := range channels {
ch := ch ch := ch
m.SetValue(ch.Name, &ch) m.Set(ch.Name, &ch)
} }
return &network{ return &network{
@ -300,7 +299,7 @@ func (net *network) detach(ch *database.Channel) {
} }
if net.conn != nil { if net.conn != nil {
uch := net.conn.channels.Value(ch.Name) uch := net.conn.channels.Get(ch.Name)
if uch != nil { if uch != nil {
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
@ -328,7 +327,7 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
var uch *upstreamChannel var uch *upstreamChannel
if net.conn != nil { if net.conn != nil {
uch = net.conn.channels.Value(ch.Name) uch = net.conn.channels.Get(ch.Name)
net.conn.updateChannelAutoDetach(ch.Name) net.conn.updateChannelAutoDetach(ch.Name)
} }
@ -351,12 +350,12 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
} }
func (net *network) deleteChannel(ctx context.Context, name string) error { func (net *network) deleteChannel(ctx context.Context, name string) error {
ch := net.channels.Value(name) ch := net.channels.Get(name)
if ch == nil { if ch == nil {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }
if net.conn != nil { if net.conn != nil {
uch := net.conn.channels.Value(ch.Name) uch := net.conn.channels.Get(ch.Name)
if uch != nil { if uch != nil {
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
@ -365,7 +364,7 @@ func (net *network) deleteChannel(ctx context.Context, name string) error {
if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
return err return err
} }
net.channels.Delete(name) net.channels.Del(name)
return nil return nil
} }
@ -375,10 +374,9 @@ func (net *network) updateCasemapping(newCasemap casemapping) {
net.delivered.m.SetCasemapping(newCasemap) net.delivered.m.SetCasemapping(newCasemap)
if uc := net.conn; uc != nil { if uc := net.conn; uc != nil {
uc.channels.SetCasemapping(newCasemap) uc.channels.SetCasemapping(newCasemap)
for _, entry := range uc.channels.innerMap { uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
uch := entry.value.(*upstreamChannel)
uch.Members.SetCasemapping(newCasemap) uch.Members.SetCasemapping(newCasemap)
} })
uc.monitored.SetCasemapping(newCasemap) uc.monitored.SetCasemapping(newCasemap)
} }
net.forEachDownstream(func(dc *downstreamConn) { net.forEachDownstream(func(dc *downstreamConn) {
@ -623,7 +621,7 @@ func (u *user) run() {
} }
case eventChannelDetach: case eventChannelDetach:
uc, name := e.uc, e.name uc, name := e.uc, e.name
c := uc.network.channels.Value(name) c := uc.network.channels.Get(name)
if c == nil || c.Detached { if c == nil || c.Detached {
continue continue
} }
@ -746,10 +744,9 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.abortPendingCommands() uc.abortPendingCommands()
for _, entry := range uc.channels.innerMap { uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
uch := entry.value.(*upstreamChannel)
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} })
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps() dc.updateSupportedCaps()
@ -924,10 +921,9 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne
// Most network changes require us to re-connect to the upstream server // Most network changes require us to re-connect to the upstream server
channels := make([]database.Channel, 0, network.channels.Len()) channels := make([]database.Channel, 0, network.channels.Len())
for _, entry := range network.channels.innerMap { network.channels.ForEach(func(_ string, ch *database.Channel) {
ch := entry.value.(*database.Channel)
channels = append(channels, *ch) channels = append(channels, *ch)
} })
updatedNetwork := newNetwork(u, record, channels) updatedNetwork := newNetwork(u, record, channels)