diff --git a/downstream.go b/downstream.go index 2a3336d..f02e384 100644 --- a/downstream.go +++ b/downstream.go @@ -361,12 +361,7 @@ func (dc *downstreamConn) ackMsgID(id string) { return } - delivered := network.delivered.Value(entity) - if delivered == nil { - return - } - - delivered[dc.clientName] = id + network.delivered.StoreID(entity, dc.clientName, id) } func (dc *downstreamConn) sendPing(msgID string) { @@ -997,24 +992,27 @@ func (dc *downstreamConn) welcome() error { } }) if firstClient { - dc.sendNetworkBacklog(net) + net.delivered.ForEachTarget(func(target string) { + dc.sendTargetBacklog(net, target) + }) } // Fast-forward history to last message - for targetCM, entry := range net.delivered.innerMap { - delivered := entry.value.(deliveredClientMap) - ch := net.channels.Value(targetCM) + net.delivered.ForEachTarget(func(target string) { + ch := net.channels.Value(target) if ch != nil && ch.Detached { - continue + return } + targetCM := net.casemap(target) lastID, err := dc.user.msgStore.LastMsgID(net, targetCM, time.Now()) if err != nil { dc.logger.Printf("failed to get last message ID: %v", err) - continue + return } - delivered[dc.clientName] = lastID - } + + net.delivered.StoreID(target, dc.clientName, lastID) + }) }) return nil @@ -1034,13 +1032,6 @@ func (dc *downstreamConn) messageSupportsHistory(msg *irc.Message) bool { return false } -func (dc *downstreamConn) sendNetworkBacklog(net *network) { - for _, entry := range net.delivered.innerMap { - target := entry.originalKey - dc.sendTargetBacklog(net, target) - } -} - func (dc *downstreamConn) sendTargetBacklog(net *network, target string) { if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { return @@ -1048,12 +1039,9 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target string) { if ch := net.channels.Value(target); ch != nil && ch.Detached { return } - delivered := net.delivered.Value(target) - if delivered == nil { - return - } - lastDelivered, ok := delivered[dc.clientName] - if !ok { + + lastDelivered := net.delivered.LoadID(target, dc.clientName) + if lastDelivered == "" { return } diff --git a/upstream.go b/upstream.go index 409ab0b..82bd3b6 100644 --- a/upstream.go +++ b/upstream.go @@ -1740,9 +1740,9 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string return "" } - delivered := uc.network.delivered.Value(entity) entityCM := uc.network.casemap(entity) - if delivered == nil { + + if !uc.network.delivered.HasTarget(entity) { // This is the first message we receive from this target. Save the last // message ID in delivery receipts, so that we can send the new message // in the backlog if an offline client reconnects. @@ -1752,11 +1752,8 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string return "" } - delivered = make(deliveredClientMap) - uc.network.delivered.SetValue(entity, delivered) - for clientName, _ := range uc.user.clients { - delivered[clientName] = lastID + uc.network.delivered.StoreID(entity, clientName, lastID) } } diff --git a/user.go b/user.go index 4a637c5..4b7fcbf 100644 --- a/user.go +++ b/user.go @@ -57,6 +57,41 @@ type eventStop struct{} type deliveredClientMap map[string]string // client name -> msg ID +type deliveredStore struct { + m deliveredCasemapMap +} + +func newDeliveredStore() deliveredStore { + return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}} +} + +func (ds deliveredStore) HasTarget(target string) bool { + return ds.m.Value(target) != nil +} + +func (ds deliveredStore) LoadID(target, clientName string) string { + clients := ds.m.Value(target) + if clients == nil { + return "" + } + return clients[clientName] +} + +func (ds deliveredStore) StoreID(target, clientName, msgID string) { + clients := ds.m.Value(target) + if clients == nil { + clients = make(deliveredClientMap) + ds.m.SetValue(target, clients) + } + clients[clientName] = msgID +} + +func (ds deliveredStore) ForEachTarget(f func(target string)) { + for _, entry := range ds.m.innerMap { + f(entry.originalKey) + } +} + type network struct { Network user *user @@ -64,7 +99,7 @@ type network struct { conn *upstreamConn channels channelCasemapMap - delivered deliveredCasemapMap + delivered deliveredStore lastError error casemap casemapping } @@ -81,7 +116,7 @@ func newNetwork(user *user, record *Network, channels []Channel) *network { user: user, stopped: make(chan struct{}), channels: m, - delivered: deliveredCasemapMap{newCasemapMap(0)}, + delivered: newDeliveredStore(), casemap: casemapRFC1459, } } @@ -253,7 +288,7 @@ func (net *network) deleteChannel(name string) error { func (net *network) updateCasemapping(newCasemap casemapping) { net.casemap = newCasemap net.channels.SetCasemapping(newCasemap) - net.delivered.SetCasemapping(newCasemap) + net.delivered.m.SetCasemapping(newCasemap) if net.conn != nil { net.conn.channels.SetCasemapping(newCasemap) for _, entry := range net.conn.channels.innerMap {