database: batch msg inserts

This commit takes insert query compilation and transaction creation out
of the critical loop for migrating message logs. I have tested with
the sqlite backend, and a speedup of approximately 40x has been achieved
for log migration.
This commit is contained in:
Calvin Lee 2023-07-09 17:28:00 -07:00 committed by Simon Ser
parent 8a2a9706f7
commit d137c69131
5 changed files with 120 additions and 77 deletions

View File

@ -12,6 +12,8 @@ import (
"strings" "strings"
"time" "time"
"gopkg.in/irc.v4"
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore" "git.sr.ht/~emersion/soju/msgstore"
"git.sr.ht/~emersion/soju/msgstore/znclog" "git.sr.ht/~emersion/soju/msgstore/znclog"
@ -88,6 +90,7 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us
return fmt.Errorf("unable to open entry: %s", entryPath) return fmt.Errorf("unable to open entry: %s", entryPath)
} }
sc := bufio.NewScanner(entry) sc := bufio.NewScanner(entry)
var msgs []*irc.Message
for sc.Scan() { for sc.Scan() {
msg, _, err := znclog.UnmarshalLine(sc.Text(), user, network, target, ref, true) msg, _, err := znclog.UnmarshalLine(sc.Text(), user, network, target, ref, true)
if err != nil { if err != nil {
@ -95,14 +98,15 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us
} else if msg == nil { } else if msg == nil {
continue continue
} }
_, err = db.StoreMessage(ctx, network.ID, target, msg) msgs = append(msgs, msg)
if err != nil {
return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err)
}
} }
if sc.Err() != nil { if sc.Err() != nil {
return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err()) return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err())
} }
_, err = db.StoreMessages(ctx, network.ID, target, msgs)
if err != nil {
return fmt.Errorf("unable to store messages: %s: %s: %v", entryPath, sc.Text(), err)
}
entry.Close() entry.Close()
} }
} }

View File

@ -60,7 +60,7 @@ type Database interface {
DeleteWebPushSubscription(ctx context.Context, id int64) error DeleteWebPushSubscription(ctx context.Context, id int64) error
GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error)
StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error)
ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error)
ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error)
} }

View File

@ -928,16 +928,49 @@ func (db *PostgresDB) GetMessageLastID(ctx context.Context, networkID int64, nam
return msgID, nil return msgID, nil
} }
func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { func (db *PostgresDB) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) if len(msgs) == 0 {
return nil, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Duration(len(msgs))*sqliteQueryTimeout)
defer cancel() defer cancel()
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx, `
INSERT INTO "MessageTarget" (network, target)
VALUES ($1, $2)
ON CONFLICT DO NOTHING`,
networkID,
name,
)
if err != nil {
return nil, err
}
insertStmt, err := tx.PrepareContext(ctx, `
INSERT INTO "Message" (target, raw, time, sender, text)
SELECT id, $1, $2, $3, $4
FROM "MessageTarget" as t
WHERE network = $5 AND target = $6
RETURNING id`)
if err != nil {
return nil, err
}
ids := make([]int64, len(msgs))
for i, msg := range msgs {
var t time.Time var t time.Time
if tag, ok := msg.Tags["time"]; ok { if tag, ok := msg.Tags["time"]; ok {
var err error var err error
t, err = time.Parse(xirc.ServerTimeLayout, tag) t, err = time.Parse(xirc.ServerTimeLayout, tag)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to parse message time tag: %v", err) return nil, fmt.Errorf("failed to parse message time tag: %w", err)
} }
} else { } else {
t = time.Now() t = time.Now()
@ -952,35 +985,21 @@ func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name st
} }
} }
_, err := db.db.ExecContext(ctx, ` err = insertStmt.QueryRowContext(ctx,
INSERT INTO "MessageTarget" (network, target)
VALUES ($1, $2)
ON CONFLICT DO NOTHING`,
networkID,
name,
)
if err != nil {
return 0, err
}
var id int64
err = db.db.QueryRowContext(ctx, `
INSERT INTO "Message" (target, raw, time, sender, text)
SELECT id, $1, $2, $3, $4
FROM "MessageTarget" as t
WHERE network = $5 AND target = $6
RETURNING id`,
msg.String(), msg.String(),
t, t,
msg.Name, msg.Name,
text, text,
networkID, networkID,
name, name,
).Scan(&id) ).Scan(&ids[i])
if err != nil { if err != nil {
return 0, err return nil, err
} }
return id, nil }
err = tx.Commit()
return ids, err
} }
func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) {

View File

@ -1192,16 +1192,48 @@ func (db *SqliteDB) GetMessageLastID(ctx context.Context, networkID int64, name
return msgID, nil return msgID, nil
} }
func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { func (db *SqliteDB) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) if len(msgs) == 0 {
return nil, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Duration(len(msgs))*sqliteQueryTimeout)
defer cancel() defer cancel()
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
res, err := tx.ExecContext(ctx, `
INSERT INTO MessageTarget(network, target)
VALUES (:network, :target)
ON CONFLICT DO NOTHING`,
sql.Named("network", networkID),
sql.Named("target", name),
)
if err != nil {
return nil, err
}
insertStmt, err := tx.PrepareContext(ctx, `
INSERT INTO Message(target, raw, time, sender, text)
SELECT id, :raw, :time, :sender, :text
FROM MessageTarget as t
WHERE network = :network AND target = :target`)
if err != nil {
return nil, err
}
ids := make([]int64, len(msgs))
for i, msg := range msgs {
var t time.Time var t time.Time
if tag, ok := msg.Tags["time"]; ok { if tag, ok := msg.Tags["time"]; ok {
var err error var err error
t, err = time.Parse(xirc.ServerTimeLayout, tag) t, err = time.Parse(xirc.ServerTimeLayout, tag)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to parse message time tag: %v", err) return nil, fmt.Errorf("failed to parse message time tag: %w", err)
} }
} else { } else {
t = time.Now() t = time.Now()
@ -1216,22 +1248,7 @@ func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name stri
} }
} }
res, err := db.db.ExecContext(ctx, ` res, err = insertStmt.ExecContext(ctx,
INSERT INTO MessageTarget(network, target)
VALUES (:network, :target)
ON CONFLICT DO NOTHING`,
sql.Named("network", networkID),
sql.Named("target", name),
)
if err != nil {
return 0, err
}
res, err = db.db.ExecContext(ctx, `
INSERT INTO Message(target, raw, time, sender, text)
SELECT id, :raw, :time, :sender, :text
FROM MessageTarget as t
WHERE network = :network AND target = :target`,
sql.Named("network", networkID), sql.Named("network", networkID),
sql.Named("target", name), sql.Named("target", name),
sql.Named("raw", msg.String()), sql.Named("raw", msg.String()),
@ -1240,13 +1257,16 @@ func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name stri
sql.Named("text", text), sql.Named("text", text),
) )
if err != nil { if err != nil {
return 0, err return nil, err
} }
id, err := res.LastInsertId() ids[i], err = res.LastInsertId()
if err != nil { if err != nil {
return 0, err return nil, err
} }
return id, nil }
err = tx.Commit()
return ids, err
} }
func (db *SqliteDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { func (db *SqliteDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) {

View File

@ -81,11 +81,11 @@ func (ms *dbMessageStore) LoadLatestID(ctx context.Context, id string, options *
} }
func (ms *dbMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) { func (ms *dbMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
id, err := ms.db.StoreMessage(context.TODO(), network.ID, entity, msg) ids, err := ms.db.StoreMessages(context.TODO(), network.ID, entity, []*irc.Message{msg})
if err != nil { if err != nil {
return "", err return "", err
} }
return formatDBMsgID(network.ID, entity, id), nil return formatDBMsgID(network.ID, entity, ids[0]), nil
} }
func (ms *dbMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) { func (ms *dbMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) {