From d137c69131cf655224c2b2e005436e5f74c1ca55 Mon Sep 17 00:00:00 2001 From: Calvin Lee Date: Sun, 9 Jul 2023 17:28:00 -0700 Subject: [PATCH] database: batch msg inserts This commit takes insert query compilation and transaction creation out of the critical loop for migrating message logs. I have tested with the sqlite backend, and a speedup of approximately 40x has been achieved for log migration. --- contrib/migrate-logs/main.go | 12 +++-- database/database.go | 2 +- database/postgres.go | 87 +++++++++++++++++++++------------- database/sqlite.go | 92 ++++++++++++++++++++++-------------- msgstore/db.go | 4 +- 5 files changed, 120 insertions(+), 77 deletions(-) diff --git a/contrib/migrate-logs/main.go b/contrib/migrate-logs/main.go index 42ad156..59992d5 100644 --- a/contrib/migrate-logs/main.go +++ b/contrib/migrate-logs/main.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "gopkg.in/irc.v4" + "git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/msgstore" "git.sr.ht/~emersion/soju/msgstore/znclog" @@ -88,6 +90,7 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us return fmt.Errorf("unable to open entry: %s", entryPath) } sc := bufio.NewScanner(entry) + var msgs []*irc.Message for sc.Scan() { msg, _, err := znclog.UnmarshalLine(sc.Text(), user, network, target, ref, true) if err != nil { @@ -95,14 +98,15 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us } else if msg == nil { continue } - _, err = db.StoreMessage(ctx, network.ID, target, msg) - if err != nil { - return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err) - } + msgs = append(msgs, msg) } if sc.Err() != nil { return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err()) } + _, err = db.StoreMessages(ctx, network.ID, target, msgs) + if err != nil { + return fmt.Errorf("unable to store messages: %s: %s: %v", entryPath, sc.Text(), err) + } entry.Close() } } diff --git a/database/database.go b/database/database.go index 15c04b9..44fa60e 100644 --- a/database/database.go +++ b/database/database.go @@ -60,7 +60,7 @@ type Database interface { DeleteWebPushSubscription(ctx context.Context, id int64) error GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) - StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) + StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) } diff --git a/database/postgres.go b/database/postgres.go index e18cc05..68672a6 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -928,31 +928,21 @@ func (db *PostgresDB) GetMessageLastID(ctx context.Context, networkID int64, nam return msgID, nil } -func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { - ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) +func (db *PostgresDB) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) { + if len(msgs) == 0 { + return nil, nil + } + + ctx, cancel := context.WithTimeout(ctx, time.Duration(len(msgs))*sqliteQueryTimeout) defer cancel() - var t time.Time - if tag, ok := msg.Tags["time"]; ok { - var err error - t, err = time.Parse(xirc.ServerTimeLayout, tag) - if err != nil { - return 0, fmt.Errorf("failed to parse message time tag: %v", err) - } - } else { - t = time.Now() + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return nil, err } + defer tx.Rollback() - var text sql.NullString - switch msg.Command { - case "PRIVMSG", "NOTICE": - if len(msg.Params) > 1 { - text.Valid = true - text.String = msg.Params[1] - } - } - - _, err := db.db.ExecContext(ctx, ` + _, err = tx.ExecContext(ctx, ` INSERT INTO "MessageTarget" (network, target) VALUES ($1, $2) ON CONFLICT DO NOTHING`, @@ -960,27 +950,56 @@ func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name st name, ) if err != nil { - return 0, err + return nil, err } - var id int64 - err = db.db.QueryRowContext(ctx, ` + insertStmt, err := tx.PrepareContext(ctx, ` INSERT INTO "Message" (target, raw, time, sender, text) SELECT id, $1, $2, $3, $4 FROM "MessageTarget" as t WHERE network = $5 AND target = $6 - RETURNING id`, - msg.String(), - t, - msg.Name, - text, - networkID, - name, - ).Scan(&id) + RETURNING id`) if err != nil { - return 0, err + return nil, err } - return id, nil + + ids := make([]int64, len(msgs)) + for i, msg := range msgs { + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(xirc.ServerTimeLayout, tag) + if err != nil { + return nil, fmt.Errorf("failed to parse message time tag: %w", err) + } + } else { + t = time.Now() + } + + var text sql.NullString + switch msg.Command { + case "PRIVMSG", "NOTICE": + if len(msg.Params) > 1 { + text.Valid = true + text.String = msg.Params[1] + } + } + + err = insertStmt.QueryRowContext(ctx, + msg.String(), + t, + msg.Name, + text, + networkID, + name, + ).Scan(&ids[i]) + if err != nil { + return nil, err + } + } + + err = tx.Commit() + return ids, err } func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { diff --git a/database/sqlite.go b/database/sqlite.go index 99c2f05..97ce8fa 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -1192,31 +1192,21 @@ func (db *SqliteDB) GetMessageLastID(ctx context.Context, networkID int64, name return msgID, nil } -func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { - ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) +func (db *SqliteDB) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) { + if len(msgs) == 0 { + return nil, nil + } + + ctx, cancel := context.WithTimeout(ctx, time.Duration(len(msgs))*sqliteQueryTimeout) defer cancel() - var t time.Time - if tag, ok := msg.Tags["time"]; ok { - var err error - t, err = time.Parse(xirc.ServerTimeLayout, tag) - if err != nil { - return 0, fmt.Errorf("failed to parse message time tag: %v", err) - } - } else { - t = time.Now() + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return nil, err } + defer tx.Rollback() - var text sql.NullString - switch msg.Command { - case "PRIVMSG", "NOTICE": - if len(msg.Params) > 1 { - text.Valid = true - text.String = msg.Params[1] - } - } - - res, err := db.db.ExecContext(ctx, ` + res, err := tx.ExecContext(ctx, ` INSERT INTO MessageTarget(network, target) VALUES (:network, :target) ON CONFLICT DO NOTHING`, @@ -1224,29 +1214,59 @@ func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name stri sql.Named("target", name), ) if err != nil { - return 0, err + return nil, err } - res, err = db.db.ExecContext(ctx, ` + insertStmt, err := tx.PrepareContext(ctx, ` INSERT INTO Message(target, raw, time, sender, text) SELECT id, :raw, :time, :sender, :text FROM MessageTarget as t - WHERE network = :network AND target = :target`, - sql.Named("network", networkID), - sql.Named("target", name), - sql.Named("raw", msg.String()), - sql.Named("time", sqliteTime{t}), - sql.Named("sender", msg.Name), - sql.Named("text", text), - ) + WHERE network = :network AND target = :target`) if err != nil { - return 0, err + return nil, err } - id, err := res.LastInsertId() - if err != nil { - return 0, err + + ids := make([]int64, len(msgs)) + for i, msg := range msgs { + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(xirc.ServerTimeLayout, tag) + if err != nil { + return nil, fmt.Errorf("failed to parse message time tag: %w", err) + } + } else { + t = time.Now() + } + + var text sql.NullString + switch msg.Command { + case "PRIVMSG", "NOTICE": + if len(msg.Params) > 1 { + text.Valid = true + text.String = msg.Params[1] + } + } + + res, err = insertStmt.ExecContext(ctx, + sql.Named("network", networkID), + sql.Named("target", name), + sql.Named("raw", msg.String()), + sql.Named("time", sqliteTime{t}), + sql.Named("sender", msg.Name), + sql.Named("text", text), + ) + if err != nil { + return nil, err + } + ids[i], err = res.LastInsertId() + if err != nil { + return nil, err + } } - return id, nil + + err = tx.Commit() + return ids, err } func (db *SqliteDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { diff --git a/msgstore/db.go b/msgstore/db.go index 6a9efdc..fbdb167 100644 --- a/msgstore/db.go +++ b/msgstore/db.go @@ -81,11 +81,11 @@ func (ms *dbMessageStore) LoadLatestID(ctx context.Context, id string, options * } func (ms *dbMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) { - id, err := ms.db.StoreMessage(context.TODO(), network.ID, entity, msg) + ids, err := ms.db.StoreMessages(context.TODO(), network.ID, entity, []*irc.Message{msg}) if err != nil { return "", err } - return formatDBMsgID(network.ID, entity, id), nil + return formatDBMsgID(network.ID, entity, ids[0]), nil } func (ms *dbMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) {