From 1e4ff49472467e1e30c897608aeddb6921dc81c7 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 10 Feb 2021 18:16:08 +0100 Subject: [PATCH] Save delivery receipts in DB This avoids loosing history on restart for clients that don't support chathistory. Closes: https://todo.sr.ht/~emersion/soju/80 --- db.go | 90 +++++++++++++++++++++++++++++++++++++++++++++++++++++ upstream.go | 4 +-- user.go | 78 ++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 161 insertions(+), 11 deletions(-) diff --git a/db.go b/db.go index 85fdec9..e2c90a7 100644 --- a/db.go +++ b/db.go @@ -120,6 +120,13 @@ type Channel struct { DetachOn MessageFilter } +type DeliveryReceipt struct { + ID int64 + Target string // channel or nick + Client string + InternalMsgID string +} + const schema = ` CREATE TABLE User ( id INTEGER PRIMARY KEY, @@ -161,6 +168,16 @@ CREATE TABLE Channel ( FOREIGN KEY(network) REFERENCES Network(id), UNIQUE(network, name) ); + +CREATE TABLE DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target VARCHAR(255) NOT NULL, + client VARCHAR(255), + internal_msgid VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) +); ` var migrations = []string{ @@ -217,6 +234,17 @@ var migrations = []string{ ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0; ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0; `, + ` + CREATE TABLE DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target VARCHAR(255) NOT NULL, + client VARCHAR(255), + internal_msgid VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) + ); + `, } type DB struct { @@ -578,3 +606,65 @@ func (db *DB) DeleteChannel(id int64) error { _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id) return err } + +func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + rows, err := db.db.Query(`SELECT id, target, client, internal_msgid + FROM DeliveryReceipt + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + var client sql.NullString + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + rcpt.Client = client.String + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?", + networkID, toNullString(client)) + if err != nil { + return err + } + + for i := range receipts { + rcpt := &receipts[i] + + res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)", + networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID) + if err != nil { + return err + } + rcpt.ID, err = res.LastInsertId() + if err != nil { + return err + } + } + + return tx.Commit() +} diff --git a/upstream.go b/upstream.go index bf933e6..f083c0b 100644 --- a/upstream.go +++ b/upstream.go @@ -1752,9 +1752,9 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string return "" } - for clientName, _ := range uc.user.clientNames { + uc.network.delivered.ForEachClient(func(clientName string) { uc.network.delivered.StoreID(entity, clientName, lastID) - } + }) } msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg) diff --git a/user.go b/user.go index 7f9e203..490a3b0 100644 --- a/user.go +++ b/user.go @@ -92,6 +92,20 @@ func (ds deliveredStore) ForEachTarget(f func(target string)) { } } +func (ds deliveredStore) ForEachClient(f func(clientName string)) { + clients := make(map[string]struct{}) + for _, entry := range ds.m.innerMap { + delivered := entry.value.(deliveredClientMap) + for clientName := range delivered { + clients[clientName] = struct{}{} + } + } + + for clientName := range clients { + f(clientName) + } +} + type network struct { Network user *user @@ -298,6 +312,28 @@ func (net *network) updateCasemapping(newCasemap casemapping) { } } +func (net *network) storeClientDeliveryReceipts(clientName string) { + if !net.user.hasPersistentMsgStore() { + return + } + + var receipts []DeliveryReceipt + net.delivered.ForEachTarget(func(target string) { + msgID := net.delivered.LoadID(target, clientName) + if msgID == "" { + return + } + receipts = append(receipts, DeliveryReceipt{ + Target: target, + InternalMsgID: msgID, + }) + }) + + if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil { + net.user.srv.Logger.Printf("failed to store delivery receipts for user %q, client %q, network %q: %v", net.user.Username, clientName, net.GetName(), err) + } +} + type user struct { User srv *Server @@ -308,7 +344,6 @@ type user struct { networks []*network downstreamConns []*downstreamConn msgStore messageStore - clientNames map[string]struct{} // LIST commands in progress pendingLISTs []pendingLIST @@ -329,12 +364,11 @@ func newUser(srv *Server, record *User) *user { } return &user{ - User: *record, - srv: srv, - events: make(chan event, 64), - done: make(chan struct{}), - msgStore: msgStore, - clientNames: make(map[string]struct{}), + User: *record, + srv: srv, + events: make(chan event, 64), + done: make(chan struct{}), + msgStore: msgStore, } } @@ -407,6 +441,18 @@ func (u *user) run() { network := newNetwork(u, &record, channels) u.networks = append(u.networks, network) + if u.hasPersistentMsgStore() { + receipts, err := u.srv.db.ListDeliveryReceipts(record.ID) + if err != nil { + u.srv.Logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) + return + } + + for _, rcpt := range receipts { + network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID) + } + } + go network.run() } @@ -489,8 +535,6 @@ func (u *user) run() { u.forEachUpstream(func(uc *upstreamConn) { uc.updateAway() }) - - u.clientNames[dc.clientName] = struct{}{} case eventDownstreamDisconnected: dc := e.dc @@ -501,6 +545,10 @@ func (u *user) run() { } } + dc.forEachNetwork(func(net *network) { + net.storeClientDeliveryReceipts(dc.clientName) + }) + u.forEachUpstream(func(uc *upstreamConn) { uc.updateAway() }) @@ -524,6 +572,10 @@ func (u *user) run() { }) for _, n := range u.networks { n.stop() + + n.delivered.ForEachClient(func(clientName string) { + n.storeClientDeliveryReceipts(clientName) + }) } return default: @@ -665,3 +717,11 @@ func (u *user) stop() { u.events <- eventStop{} <-u.done } + +func (u *user) hasPersistentMsgStore() bool { + if u.msgStore == nil { + return false + } + _, isMem := u.msgStore.(*memoryMessageStore) + return !isMem +}