Make casemapMap more type-safe
In addition to a type-safe getter, also define type-safe setters and iterators. References: https://lists.sr.ht/~emersion/soju-dev/patches/32777
This commit is contained in:
parent
c8f9728ff6
commit
657e25b25c
@ -1592,14 +1592,13 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
|
||||
}
|
||||
|
||||
dc.forEachUpstream(func(uc *upstreamConn) {
|
||||
for _, entry := range uc.channels.innerMap {
|
||||
ch := entry.value.(*upstreamChannel)
|
||||
uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
|
||||
if !ch.complete {
|
||||
continue
|
||||
return
|
||||
}
|
||||
record := uc.network.channels.Value(ch.Name)
|
||||
record := uc.network.channels.Get(ch.Name)
|
||||
if record != nil && record.Detached {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
dc.SendMessage(&irc.Message{
|
||||
@ -1609,7 +1608,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
|
||||
})
|
||||
|
||||
forwardChannel(ctx, dc, ch)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
dc.forEachNetwork(func(net *network) {
|
||||
@ -1667,7 +1666,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
|
||||
return
|
||||
}
|
||||
|
||||
ch := net.channels.Value(target)
|
||||
ch := net.channels.Get(target)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
|
||||
defer cancel()
|
||||
@ -1938,7 +1937,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
})
|
||||
}
|
||||
|
||||
ch := uc.network.channels.Value(upstreamName)
|
||||
ch := uc.network.channels.Get(upstreamName)
|
||||
if ch != nil {
|
||||
// Don't clear the channel key if there's one set
|
||||
// TODO: add a way to unset the channel key
|
||||
@ -1951,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
Name: upstreamName,
|
||||
Key: key,
|
||||
}
|
||||
uc.network.channels.SetValue(upstreamName, ch)
|
||||
uc.network.channels.Set(upstreamName, ch)
|
||||
}
|
||||
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
|
||||
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
|
||||
@ -1975,7 +1974,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
}
|
||||
|
||||
if strings.EqualFold(reason, "detach") {
|
||||
ch := uc.network.channels.Value(upstreamName)
|
||||
ch := uc.network.channels.Get(upstreamName)
|
||||
if ch != nil {
|
||||
uc.network.detach(ch)
|
||||
} else {
|
||||
@ -1983,7 +1982,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
Name: name,
|
||||
Detached: true,
|
||||
}
|
||||
uc.network.channels.SetValue(upstreamName, ch)
|
||||
uc.network.channels.Set(upstreamName, ch)
|
||||
}
|
||||
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
|
||||
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
|
||||
@ -2119,7 +2118,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
Params: params,
|
||||
})
|
||||
} else {
|
||||
ch := uc.channels.Value(upstreamName)
|
||||
ch := uc.channels.Get(upstreamName)
|
||||
if ch == nil {
|
||||
return ircError{&irc.Message{
|
||||
Command: irc.ERR_NOSUCHCHANNEL,
|
||||
@ -2168,7 +2167,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
Params: []string{upstreamName, topic},
|
||||
})
|
||||
} else { // getting topic
|
||||
ch := uc.channels.Value(upstreamName)
|
||||
ch := uc.channels.Get(upstreamName)
|
||||
if ch == nil {
|
||||
return ircError{&irc.Message{
|
||||
Command: irc.ERR_NOSUCHCHANNEL,
|
||||
@ -2223,7 +2222,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
return err
|
||||
}
|
||||
|
||||
ch := uc.channels.Value(upstreamName)
|
||||
ch := uc.channels.Get(upstreamName)
|
||||
if ch != nil {
|
||||
sendNames(dc, ch)
|
||||
} else {
|
||||
@ -2677,7 +2676,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
for _, target := range strings.Split(targets, ",") {
|
||||
if subcommand == "+" {
|
||||
// Hard limit, just to avoid having downstreams fill our map
|
||||
if len(dc.monitored.innerMap) >= 1000 {
|
||||
if dc.monitored.Len() >= 1000 {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
Command: irc.ERR_MONLISTFULL,
|
||||
@ -2686,7 +2685,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
continue
|
||||
}
|
||||
|
||||
dc.monitored.SetValue(target, nil)
|
||||
dc.monitored.set(target, nil)
|
||||
|
||||
if uc.network.casemap(target) == serviceNickCM {
|
||||
// BouncerServ is never tired
|
||||
@ -2700,7 +2699,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
|
||||
if uc.monitored.Has(target) {
|
||||
cmd := irc.RPL_MONOFFLINE
|
||||
if online := uc.monitored.Value(target); online {
|
||||
if online := uc.monitored.Get(target); online {
|
||||
cmd = irc.RPL_MONONLINE
|
||||
}
|
||||
|
||||
@ -2711,7 +2710,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
})
|
||||
}
|
||||
} else {
|
||||
dc.monitored.Delete(target)
|
||||
dc.monitored.Del(target)
|
||||
}
|
||||
}
|
||||
uc.updateMonitor()
|
||||
@ -2721,7 +2720,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
uc.updateMonitor()
|
||||
case "L": // list
|
||||
// TODO: be less lazy and pack the list
|
||||
for _, entry := range dc.monitored.innerMap {
|
||||
for _, entry := range dc.monitored.m {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
Command: irc.RPL_MONLIST,
|
||||
@ -2735,11 +2734,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
})
|
||||
case "S": // status
|
||||
// TODO: be less lazy and pack the lists
|
||||
for _, entry := range dc.monitored.innerMap {
|
||||
for _, entry := range dc.monitored.m {
|
||||
target := entry.originalKey
|
||||
|
||||
cmd := irc.RPL_MONOFFLINE
|
||||
if online := uc.monitored.Value(target); online {
|
||||
if online := uc.monitored.Get(target); online {
|
||||
cmd = irc.RPL_MONONLINE
|
||||
}
|
||||
|
||||
@ -2872,7 +2871,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
|
||||
dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
|
||||
for _, target := range targets {
|
||||
if ch := network.channels.Value(target.Name); ch != nil && ch.Detached {
|
||||
if ch := network.channels.Get(target.Name); ch != nil && ch.Detached {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -3329,12 +3328,10 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) {
|
||||
downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
|
||||
|
||||
var members []string
|
||||
for _, entry := range ch.Members.innerMap {
|
||||
nick := entry.originalKey
|
||||
memberships := entry.value.(*xirc.MembershipSet)
|
||||
ch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) {
|
||||
s := formatMemberPrefix(*memberships, dc) + dc.marshalEntity(ch.conn.network, nick)
|
||||
members = append(members, s)
|
||||
}
|
||||
})
|
||||
|
||||
msgs := xirc.GenerateNamesReply(dc.srv.prefix(), dc.nick, downstreamName, ch.Status, members)
|
||||
for _, msg := range msgs {
|
||||
|
130
irc.go
130
irc.go
@ -111,7 +111,7 @@ outer:
|
||||
return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
|
||||
}
|
||||
member := arguments[nextArgument]
|
||||
m := ch.Members.Value(member)
|
||||
m := ch.Members.Get(member)
|
||||
if m != nil {
|
||||
if plusMinus == '+' {
|
||||
m.Add(ch.conn.availableMemberships, membership)
|
||||
@ -304,7 +304,7 @@ func partialCasemap(higher casemapping, name string) string {
|
||||
}
|
||||
|
||||
type casemapMap struct {
|
||||
innerMap map[string]casemapEntry
|
||||
m map[string]casemapEntry
|
||||
casemap casemapping
|
||||
}
|
||||
|
||||
@ -315,95 +315,153 @@ type casemapEntry struct {
|
||||
|
||||
func newCasemapMap() casemapMap {
|
||||
return casemapMap{
|
||||
innerMap: make(map[string]casemapEntry),
|
||||
m: make(map[string]casemapEntry),
|
||||
casemap: casemapNone,
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *casemapMap) Has(name string) bool {
|
||||
_, ok := cm.innerMap[cm.casemap(name)]
|
||||
_, ok := cm.m[cm.casemap(name)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (cm *casemapMap) Len() int {
|
||||
return len(cm.innerMap)
|
||||
return len(cm.m)
|
||||
}
|
||||
|
||||
func (cm *casemapMap) SetValue(name string, value interface{}) {
|
||||
nameCM := cm.casemap(name)
|
||||
entry, ok := cm.innerMap[nameCM]
|
||||
func (cm *casemapMap) get(name string) interface{} {
|
||||
entry, ok := cm.m[cm.casemap(name)]
|
||||
if !ok {
|
||||
cm.innerMap[nameCM] = casemapEntry{
|
||||
return nil
|
||||
}
|
||||
return entry.value
|
||||
}
|
||||
|
||||
func (cm *casemapMap) set(name string, value interface{}) {
|
||||
nameCM := cm.casemap(name)
|
||||
entry, ok := cm.m[nameCM]
|
||||
if !ok {
|
||||
cm.m[nameCM] = casemapEntry{
|
||||
originalKey: name,
|
||||
value: value,
|
||||
}
|
||||
return
|
||||
}
|
||||
entry.value = value
|
||||
cm.innerMap[nameCM] = entry
|
||||
cm.m[nameCM] = entry
|
||||
}
|
||||
|
||||
func (cm *casemapMap) Delete(name string) {
|
||||
delete(cm.innerMap, cm.casemap(name))
|
||||
func (cm *casemapMap) Del(name string) {
|
||||
delete(cm.m, cm.casemap(name))
|
||||
}
|
||||
|
||||
func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
|
||||
cm.casemap = newCasemap
|
||||
newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
|
||||
for _, entry := range cm.innerMap {
|
||||
newInnerMap[cm.casemap(entry.originalKey)] = entry
|
||||
m := make(map[string]casemapEntry, len(cm.m))
|
||||
for _, entry := range cm.m {
|
||||
m[cm.casemap(entry.originalKey)] = entry
|
||||
}
|
||||
cm.innerMap = newInnerMap
|
||||
cm.m = m
|
||||
}
|
||||
|
||||
type upstreamChannelCasemapMap struct{ casemapMap }
|
||||
|
||||
func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
|
||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||
if !ok {
|
||||
func (cm *upstreamChannelCasemapMap) Get(name string) *upstreamChannel {
|
||||
if v := cm.get(name); v == nil {
|
||||
return nil
|
||||
} else {
|
||||
return v.(*upstreamChannel)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *upstreamChannelCasemapMap) Set(name string, uch *upstreamChannel) {
|
||||
cm.set(name, uch)
|
||||
}
|
||||
|
||||
func (cm *upstreamChannelCasemapMap) ForEach(f func(string, *upstreamChannel)) {
|
||||
for _, entry := range cm.m {
|
||||
f(entry.originalKey, entry.value.(*upstreamChannel))
|
||||
}
|
||||
return entry.value.(*upstreamChannel)
|
||||
}
|
||||
|
||||
type channelCasemapMap struct{ casemapMap }
|
||||
|
||||
func (cm *channelCasemapMap) Value(name string) *database.Channel {
|
||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||
if !ok {
|
||||
func (cm *channelCasemapMap) Get(name string) *database.Channel {
|
||||
if v := cm.get(name); v == nil {
|
||||
return nil
|
||||
} else {
|
||||
return v.(*database.Channel)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *channelCasemapMap) Set(name string, ch *database.Channel) {
|
||||
cm.set(name, ch)
|
||||
}
|
||||
|
||||
func (cm *channelCasemapMap) ForEach(f func(string, *database.Channel)) {
|
||||
for _, entry := range cm.m {
|
||||
f(entry.originalKey, entry.value.(*database.Channel))
|
||||
}
|
||||
return entry.value.(*database.Channel)
|
||||
}
|
||||
|
||||
type membershipsCasemapMap struct{ casemapMap }
|
||||
|
||||
func (cm *membershipsCasemapMap) Value(name string) *xirc.MembershipSet {
|
||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||
if !ok {
|
||||
func (cm *membershipsCasemapMap) Get(name string) *xirc.MembershipSet {
|
||||
if v := cm.get(name); v == nil {
|
||||
return nil
|
||||
} else {
|
||||
return v.(*xirc.MembershipSet)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *membershipsCasemapMap) Set(name string, ms *xirc.MembershipSet) {
|
||||
cm.set(name, ms)
|
||||
}
|
||||
|
||||
func (cm *membershipsCasemapMap) ForEach(f func(string, *xirc.MembershipSet)) {
|
||||
for _, entry := range cm.m {
|
||||
f(entry.originalKey, entry.value.(*xirc.MembershipSet))
|
||||
}
|
||||
return entry.value.(*xirc.MembershipSet)
|
||||
}
|
||||
|
||||
type deliveredCasemapMap struct{ casemapMap }
|
||||
|
||||
func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
|
||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||
if !ok {
|
||||
func (cm *deliveredCasemapMap) Get(name string) deliveredClientMap {
|
||||
if v := cm.get(name); v == nil {
|
||||
return nil
|
||||
} else {
|
||||
return v.(deliveredClientMap)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *deliveredCasemapMap) Set(name string, m deliveredClientMap) {
|
||||
cm.set(name, m)
|
||||
}
|
||||
|
||||
func (cm *deliveredCasemapMap) ForEach(f func(string, deliveredClientMap)) {
|
||||
for _, entry := range cm.m {
|
||||
f(entry.originalKey, entry.value.(deliveredClientMap))
|
||||
}
|
||||
return entry.value.(deliveredClientMap)
|
||||
}
|
||||
|
||||
type monitorCasemapMap struct{ casemapMap }
|
||||
|
||||
func (cm *monitorCasemapMap) Value(name string) (online bool) {
|
||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||
if !ok {
|
||||
func (cm *monitorCasemapMap) Get(name string) (online bool) {
|
||||
if v := cm.get(name); v == nil {
|
||||
return false
|
||||
} else {
|
||||
return v.(bool)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *monitorCasemapMap) Set(name string, online bool) {
|
||||
cm.set(name, online)
|
||||
}
|
||||
|
||||
func (cm *monitorCasemapMap) ForEach(f func(name string, online bool)) {
|
||||
for _, entry := range cm.m {
|
||||
f(entry.originalKey, entry.value.(bool))
|
||||
}
|
||||
return entry.value.(bool)
|
||||
}
|
||||
|
||||
func isWordBoundary(r rune) bool {
|
||||
|
10
service.go
10
service.go
@ -974,9 +974,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
|
||||
|
||||
sendNetwork := func(net *network) {
|
||||
var channels []*database.Channel
|
||||
for _, entry := range net.channels.innerMap {
|
||||
channels = append(channels, entry.value.(*database.Channel))
|
||||
}
|
||||
net.channels.ForEach(func(_ string, ch *database.Channel) {
|
||||
channels = append(channels, ch)
|
||||
})
|
||||
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return strings.ReplaceAll(channels[i].Name, "#", "") <
|
||||
@ -986,7 +986,7 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
|
||||
for _, ch := range channels {
|
||||
var uch *upstreamChannel
|
||||
if net.conn != nil {
|
||||
uch = net.conn.channels.Value(ch.Name)
|
||||
uch = net.conn.channels.Get(ch.Name)
|
||||
}
|
||||
|
||||
name := ch.Name
|
||||
@ -1109,7 +1109,7 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
|
||||
return fmt.Errorf("unknown channel %q", name)
|
||||
}
|
||||
|
||||
ch := uc.network.channels.Value(upstreamName)
|
||||
ch := uc.network.channels.Get(upstreamName)
|
||||
if ch == nil {
|
||||
return fmt.Errorf("unknown channel %q", name)
|
||||
}
|
||||
|
79
upstream.go
79
upstream.go
@ -292,7 +292,7 @@ func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
|
||||
ch := uc.channels.Value(name)
|
||||
ch := uc.channels.Get(name)
|
||||
if ch == nil {
|
||||
return nil, fmt.Errorf("unknown channel %q", name)
|
||||
}
|
||||
@ -513,7 +513,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
self := uc.isOurNick(msg.Prefix.Name)
|
||||
|
||||
ch := uc.network.channels.Value(target)
|
||||
ch := uc.network.channels.Get(target)
|
||||
if ch != nil && msg.Command != "TAGMSG" && !self {
|
||||
if ch.Detached {
|
||||
uc.handleDetachedMessage(ctx, ch, msg)
|
||||
@ -757,11 +757,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
if uc.network.channels.Len() > 0 {
|
||||
var channels, keys []string
|
||||
for _, entry := range uc.network.channels.innerMap {
|
||||
ch := entry.value.(*database.Channel)
|
||||
uc.network.channels.ForEach(func(_ string, ch *database.Channel) {
|
||||
channels = append(channels, ch.Name)
|
||||
keys = append(keys, ch.Key)
|
||||
}
|
||||
})
|
||||
|
||||
for _, msg := range xirc.GenerateJoin(channels, keys) {
|
||||
uc.SendMessage(ctx, msg)
|
||||
@ -918,15 +917,14 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
uc.nickCM = uc.network.casemap(uc.nick)
|
||||
}
|
||||
|
||||
for _, entry := range uc.channels.innerMap {
|
||||
ch := entry.value.(*upstreamChannel)
|
||||
memberships := ch.Members.Value(msg.Prefix.Name)
|
||||
uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
|
||||
memberships := ch.Members.Get(msg.Prefix.Name)
|
||||
if memberships != nil {
|
||||
ch.Members.Delete(msg.Prefix.Name)
|
||||
ch.Members.SetValue(newNick, memberships)
|
||||
ch.Members.Del(msg.Prefix.Name)
|
||||
ch.Members.Set(newNick, memberships)
|
||||
uc.appendLog(ch.Name, msg)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if !me {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
@ -995,7 +993,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
uc.logger.Printf("joined channel %q", ch)
|
||||
members := membershipsCasemapMap{newCasemapMap()}
|
||||
members.casemap = uc.network.casemap
|
||||
uc.channels.SetValue(ch, &upstreamChannel{
|
||||
uc.channels.Set(ch, &upstreamChannel{
|
||||
Name: ch,
|
||||
conn: uc,
|
||||
Members: members,
|
||||
@ -1011,7 +1009,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ch.Members.SetValue(msg.Prefix.Name, &xirc.MembershipSet{})
|
||||
ch.Members.Set(msg.Prefix.Name, &xirc.MembershipSet{})
|
||||
}
|
||||
|
||||
chMsg := msg.Copy()
|
||||
@ -1027,9 +1025,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
for _, ch := range strings.Split(channels, ",") {
|
||||
if uc.isOurNick(msg.Prefix.Name) {
|
||||
uc.logger.Printf("parted channel %q", ch)
|
||||
uch := uc.channels.Value(ch)
|
||||
if uch != nil {
|
||||
uc.channels.Delete(ch)
|
||||
if uch := uc.channels.Get(ch); uch != nil {
|
||||
uc.channels.Del(ch)
|
||||
uch.updateAutoDetach(0)
|
||||
}
|
||||
} else {
|
||||
@ -1037,7 +1034,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ch.Members.Delete(msg.Prefix.Name)
|
||||
ch.Members.Del(msg.Prefix.Name)
|
||||
}
|
||||
|
||||
chMsg := msg.Copy()
|
||||
@ -1052,13 +1049,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
if uc.isOurNick(user) {
|
||||
uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
|
||||
uc.channels.Delete(channel)
|
||||
uc.channels.Del(channel)
|
||||
} else {
|
||||
ch, err := uc.getChannel(channel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ch.Members.Delete(user)
|
||||
ch.Members.Del(user)
|
||||
}
|
||||
|
||||
uc.produce(channel, msg, 0)
|
||||
@ -1067,14 +1064,12 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
uc.logger.Printf("quit")
|
||||
}
|
||||
|
||||
for _, entry := range uc.channels.innerMap {
|
||||
ch := entry.value.(*upstreamChannel)
|
||||
uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
|
||||
if ch.Members.Has(msg.Prefix.Name) {
|
||||
ch.Members.Delete(msg.Prefix.Name)
|
||||
|
||||
ch.Members.Del(msg.Prefix.Name)
|
||||
uc.appendLog(ch.Name, msg)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if msg.Prefix.Name != uc.nick {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
@ -1147,7 +1142,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
uc.appendLog(ch.Name, msg)
|
||||
|
||||
c := uc.network.channels.Value(name)
|
||||
c := uc.network.channels.Get(name)
|
||||
if c == nil || !c.Detached {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
params := make([]string, len(msg.Params))
|
||||
@ -1211,7 +1206,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
return err
|
||||
}
|
||||
|
||||
c := uc.network.channels.Value(channel)
|
||||
c := uc.network.channels.Get(channel)
|
||||
if firstMode && (c == nil || !c.Detached) {
|
||||
modeStr, modeParams := ch.modes.Format()
|
||||
|
||||
@ -1240,7 +1235,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
firstCreationTime := ch.creationTime == ""
|
||||
ch.creationTime = creationTime
|
||||
|
||||
c := uc.network.channels.Value(channel)
|
||||
c := uc.network.channels.Get(channel)
|
||||
if firstCreationTime && (c == nil || !c.Detached) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.SendMessage(&irc.Message{
|
||||
@ -1269,7 +1264,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
ch.TopicTime = time.Unix(sec, 0)
|
||||
|
||||
c := uc.network.channels.Value(channel)
|
||||
c := uc.network.channels.Get(channel)
|
||||
if firstTopicWhoTime && (c == nil || !c.Detached) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
|
||||
@ -1322,7 +1317,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
return err
|
||||
}
|
||||
|
||||
ch := uc.channels.Value(name)
|
||||
ch := uc.channels.Get(name)
|
||||
if ch == nil {
|
||||
// NAMES on a channel we have not joined, forward to downstream
|
||||
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
|
||||
@ -1351,7 +1346,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
for _, s := range splitSpace(members) {
|
||||
memberships, nick := uc.parseMembershipPrefix(s)
|
||||
ch.Members.SetValue(nick, memberships)
|
||||
ch.Members.Set(nick, &memberships)
|
||||
}
|
||||
case irc.RPL_ENDOFNAMES:
|
||||
var name string
|
||||
@ -1359,7 +1354,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
return err
|
||||
}
|
||||
|
||||
ch := uc.channels.Value(name)
|
||||
ch := uc.channels.Get(name)
|
||||
if ch == nil {
|
||||
// NAMES on a channel we have not joined, forward to downstream
|
||||
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
|
||||
@ -1379,7 +1374,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
ch.complete = true
|
||||
|
||||
c := uc.network.channels.Value(name)
|
||||
c := uc.network.channels.Get(name)
|
||||
if c == nil || !c.Detached {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
forwardChannel(ctx, dc, ch)
|
||||
@ -1542,7 +1537,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
online := msg.Command == irc.RPL_MONONLINE
|
||||
for _, target := range targets {
|
||||
prefix := irc.ParsePrefix(target)
|
||||
uc.monitored.SetValue(prefix.Name, online)
|
||||
uc.monitored.Set(prefix.Name, online)
|
||||
}
|
||||
|
||||
// Check if the nick we want is now free
|
||||
@ -2112,7 +2107,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64
|
||||
}
|
||||
|
||||
// Don't forward messages if it's a detached channel
|
||||
ch := uc.network.channels.Value(target)
|
||||
ch := uc.network.channels.Get(target)
|
||||
detached := ch != nil && ch.Detached
|
||||
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
@ -2148,11 +2143,11 @@ func (uc *upstreamConn) updateAway() {
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) updateChannelAutoDetach(name string) {
|
||||
uch := uc.channels.Value(name)
|
||||
uch := uc.channels.Get(name)
|
||||
if uch == nil {
|
||||
return
|
||||
}
|
||||
ch := uc.network.channels.Value(name)
|
||||
ch := uc.network.channels.Get(name)
|
||||
if ch == nil || ch.Detached {
|
||||
return
|
||||
}
|
||||
@ -2170,7 +2165,7 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
var addList []string
|
||||
seen := make(map[string]struct{})
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
for _, entry := range dc.monitored.innerMap {
|
||||
for _, entry := range dc.monitored.m {
|
||||
targetCM := uc.network.casemap(entry.originalKey)
|
||||
if targetCM == serviceNickCM {
|
||||
continue
|
||||
@ -2195,13 +2190,13 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
|
||||
removeAll := true
|
||||
var removeList []string
|
||||
for targetCM, entry := range uc.monitored.innerMap {
|
||||
if _, ok := seen[targetCM]; ok {
|
||||
uc.monitored.ForEach(func(nick string, online bool) {
|
||||
if _, ok := seen[uc.network.casemap(nick)]; ok {
|
||||
removeAll = false
|
||||
} else {
|
||||
removeList = append(removeList, entry.originalKey)
|
||||
}
|
||||
removeList = append(removeList, nick)
|
||||
}
|
||||
})
|
||||
|
||||
// TODO: better handle the case where len(uc.monitored) + len(addList)
|
||||
// exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
|
||||
@ -2221,6 +2216,6 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
}
|
||||
|
||||
for _, target := range removeList {
|
||||
uc.monitored.Delete(target)
|
||||
uc.monitored.Del(target)
|
||||
}
|
||||
}
|
||||
|
48
user.go
48
user.go
@ -85,11 +85,11 @@ func newDeliveredStore() deliveredStore {
|
||||
}
|
||||
|
||||
func (ds deliveredStore) HasTarget(target string) bool {
|
||||
return ds.m.Value(target) != nil
|
||||
return ds.m.Get(target) != nil
|
||||
}
|
||||
|
||||
func (ds deliveredStore) LoadID(target, clientName string) string {
|
||||
clients := ds.m.Value(target)
|
||||
clients := ds.m.Get(target)
|
||||
if clients == nil {
|
||||
return ""
|
||||
}
|
||||
@ -97,28 +97,27 @@ func (ds deliveredStore) LoadID(target, clientName string) string {
|
||||
}
|
||||
|
||||
func (ds deliveredStore) StoreID(target, clientName, msgID string) {
|
||||
clients := ds.m.Value(target)
|
||||
clients := ds.m.Get(target)
|
||||
if clients == nil {
|
||||
clients = make(deliveredClientMap)
|
||||
ds.m.SetValue(target, clients)
|
||||
ds.m.Set(target, clients)
|
||||
}
|
||||
clients[clientName] = msgID
|
||||
}
|
||||
|
||||
func (ds deliveredStore) ForEachTarget(f func(target string)) {
|
||||
for _, entry := range ds.m.innerMap {
|
||||
f(entry.originalKey)
|
||||
}
|
||||
ds.m.ForEach(func(name string, _ deliveredClientMap) {
|
||||
f(name)
|
||||
})
|
||||
}
|
||||
|
||||
func (ds deliveredStore) ForEachClient(f func(clientName string)) {
|
||||
clients := make(map[string]struct{})
|
||||
for _, entry := range ds.m.innerMap {
|
||||
delivered := entry.value.(deliveredClientMap)
|
||||
ds.m.ForEach(func(name string, delivered deliveredClientMap) {
|
||||
for clientName := range delivered {
|
||||
clients[clientName] = struct{}{}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for clientName := range clients {
|
||||
f(clientName)
|
||||
@ -144,7 +143,7 @@ func newNetwork(user *user, record *database.Network, channels []database.Channe
|
||||
m := channelCasemapMap{newCasemapMap()}
|
||||
for _, ch := range channels {
|
||||
ch := ch
|
||||
m.SetValue(ch.Name, &ch)
|
||||
m.Set(ch.Name, &ch)
|
||||
}
|
||||
|
||||
return &network{
|
||||
@ -300,7 +299,7 @@ func (net *network) detach(ch *database.Channel) {
|
||||
}
|
||||
|
||||
if net.conn != nil {
|
||||
uch := net.conn.channels.Value(ch.Name)
|
||||
uch := net.conn.channels.Get(ch.Name)
|
||||
if uch != nil {
|
||||
uch.updateAutoDetach(0)
|
||||
}
|
||||
@ -328,7 +327,7 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
|
||||
|
||||
var uch *upstreamChannel
|
||||
if net.conn != nil {
|
||||
uch = net.conn.channels.Value(ch.Name)
|
||||
uch = net.conn.channels.Get(ch.Name)
|
||||
|
||||
net.conn.updateChannelAutoDetach(ch.Name)
|
||||
}
|
||||
@ -351,12 +350,12 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
|
||||
}
|
||||
|
||||
func (net *network) deleteChannel(ctx context.Context, name string) error {
|
||||
ch := net.channels.Value(name)
|
||||
ch := net.channels.Get(name)
|
||||
if ch == nil {
|
||||
return fmt.Errorf("unknown channel %q", name)
|
||||
}
|
||||
if net.conn != nil {
|
||||
uch := net.conn.channels.Value(ch.Name)
|
||||
uch := net.conn.channels.Get(ch.Name)
|
||||
if uch != nil {
|
||||
uch.updateAutoDetach(0)
|
||||
}
|
||||
@ -365,7 +364,7 @@ func (net *network) deleteChannel(ctx context.Context, name string) error {
|
||||
if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
net.channels.Delete(name)
|
||||
net.channels.Del(name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -375,10 +374,9 @@ func (net *network) updateCasemapping(newCasemap casemapping) {
|
||||
net.delivered.m.SetCasemapping(newCasemap)
|
||||
if uc := net.conn; uc != nil {
|
||||
uc.channels.SetCasemapping(newCasemap)
|
||||
for _, entry := range uc.channels.innerMap {
|
||||
uch := entry.value.(*upstreamChannel)
|
||||
uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
|
||||
uch.Members.SetCasemapping(newCasemap)
|
||||
}
|
||||
})
|
||||
uc.monitored.SetCasemapping(newCasemap)
|
||||
}
|
||||
net.forEachDownstream(func(dc *downstreamConn) {
|
||||
@ -623,7 +621,7 @@ func (u *user) run() {
|
||||
}
|
||||
case eventChannelDetach:
|
||||
uc, name := e.uc, e.name
|
||||
c := uc.network.channels.Value(name)
|
||||
c := uc.network.channels.Get(name)
|
||||
if c == nil || c.Detached {
|
||||
continue
|
||||
}
|
||||
@ -746,10 +744,9 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
|
||||
|
||||
uc.abortPendingCommands()
|
||||
|
||||
for _, entry := range uc.channels.innerMap {
|
||||
uch := entry.value.(*upstreamChannel)
|
||||
uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
|
||||
uch.updateAutoDetach(0)
|
||||
}
|
||||
})
|
||||
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.updateSupportedCaps()
|
||||
@ -924,10 +921,9 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne
|
||||
// Most network changes require us to re-connect to the upstream server
|
||||
|
||||
channels := make([]database.Channel, 0, network.channels.Len())
|
||||
for _, entry := range network.channels.innerMap {
|
||||
ch := entry.value.(*database.Channel)
|
||||
network.channels.ForEach(func(_ string, ch *database.Channel) {
|
||||
channels = append(channels, *ch)
|
||||
}
|
||||
})
|
||||
|
||||
updatedNetwork := newNetwork(u, record, channels)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user