Add a database store for messages

This adds a new config option, `logs db`, which enables storing chat
logs in the soju database.

Regular store options, CHATHISTORY options, and SEARCH operations are
supported, like the fs logs backend.

Messages are stored in a new table, Message. In order to track the list
of targets we have messages for in an optimized manner, another database
is used: MessageTarget.

All new requests are backend by indexes so should be fast even with
hundreds of thousands of messages.

A contrib script is provided for migrating existing logs fs chat logs to
the database. It can be run with eg:

  go run ./contrib/migrate-logs/ logs/ sqlite3:soju.db

Co-authored-by: Simon Ser <contact@emersion.fr>
This commit is contained in:
delthas 2022-12-11 00:01:16 +01:00 committed by Simon Ser
parent 47f0dd5b3f
commit 1ccc7ce6d2
16 changed files with 937 additions and 15 deletions

View File

@ -92,6 +92,7 @@ func loadConfig() (*config.Server, *soju.Config, error) {
cfg := &soju.Config{ cfg := &soju.Config{
Hostname: raw.Hostname, Hostname: raw.Hostname,
Title: raw.Title, Title: raw.Title,
LogDriver: raw.MsgStore.Driver,
LogPath: raw.MsgStore.Source, LogPath: raw.MsgStore.Source,
HTTPOrigins: raw.HTTPOrigins, HTTPOrigins: raw.HTTPOrigins,
AcceptProxyIPs: raw.AcceptProxyIPs, AcceptProxyIPs: raw.AcceptProxyIPs,

View File

@ -150,8 +150,7 @@ func parse(cfg scfg.Block) (*Server, error) {
return nil, err return nil, err
} }
switch srv.MsgStore.Driver { switch srv.MsgStore.Driver {
case "memory": case "memory", "db":
srv.MsgStore.Source = ""
case "fs": case "fs":
if err := d.ParseParams(nil, &srv.MsgStore.Source); err != nil { if err := d.ParseParams(nil, &srv.MsgStore.Source); err != nil {
return nil, err return nil, err

View File

@ -0,0 +1,148 @@
package main
import (
"bufio"
"context"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"
"time"
"git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore"
)
const usage = `usage: migrate-logs <source logs> <destination database>
Migrates existing Soju logs stored on disk to a Soju database. Database is specified
in the format of "driver:source" where driver is sqlite3 or postgres and source
is the string that would be in the Soju config file.
Options:
-help Show this help message
`
var logRoot string
func init() {
flag.Usage = func() {
fmt.Fprint(flag.CommandLine.Output(), usage)
}
}
func migrateNetwork(ctx context.Context, db database.Database, user *database.User, network *database.Network) error {
log.Printf("Migrating logs for network: %s\n", network.Name)
rootPath := filepath.Join(logRoot, msgstore.EscapeFilename(user.Username), msgstore.EscapeFilename(network.GetName()))
root, err := os.Open(rootPath)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return fmt.Errorf("unable to open network folder: %s", rootPath)
}
// The returned targets are escaped, and there is no way to un-escape
// TODO: switch to ReadDir (Go 1.16+)
targets, err := root.Readdirnames(0)
root.Close()
if err != nil {
return fmt.Errorf("unable to read network folder: %s", rootPath)
}
for _, target := range targets {
log.Printf("Migrating logs for target: %s\n", target)
// target is already escaped here
targetPath := filepath.Join(rootPath, target)
targetDir, err := os.Open(targetPath)
if err != nil {
return fmt.Errorf("unable to open target folder: %s", targetPath)
}
entryNames, err := targetDir.Readdirnames(0)
targetDir.Close()
if err != nil {
return fmt.Errorf("unable to read target folder: %s", targetPath)
}
sort.Strings(entryNames)
for _, entryName := range entryNames {
entryPath := filepath.Join(targetPath, entryName)
var year, month, day int
_, err := fmt.Sscanf(entryName, "%04d-%02d-%02d.log", &year, &month, &day)
if err != nil {
return fmt.Errorf("invalid entry name: %s", entryName)
}
ref := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
entry, err := os.Open(entryPath)
if err != nil {
return fmt.Errorf("unable to open entry: %s", entryPath)
}
sc := bufio.NewScanner(entry)
for sc.Scan() {
msg, _, err := msgstore.FSParseMessage(sc.Text(), user, network, target, ref, true)
if err != nil {
return fmt.Errorf("unable to parse entry: %s: %s", entryPath, sc.Text())
} else if msg == nil {
continue
}
_, err = db.StoreMessage(ctx, network.ID, target, msg)
if err != nil {
return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err)
}
}
if sc.Err() != nil {
return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err())
}
entry.Close()
}
}
return nil
}
func main() {
flag.Parse()
ctx := context.Background()
logRoot = flag.Arg(0)
dbParams := strings.Split(flag.Arg(1), ":")
if len(dbParams) != 2 {
log.Fatalf("database not properly specified: %s", flag.Arg(1))
}
db, err := database.Open(dbParams[0], dbParams[1])
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()
users, err := db.ListUsers(ctx)
if err != nil {
log.Fatalf("unable to get users: %v", err)
}
for _, user := range users {
log.Printf("Migrating logs for user: %s\n", user.Username)
networks, err := db.ListNetworks(ctx, user.ID)
if err != nil {
log.Fatalf("unable to get networks for user: #%d %s", user.ID, user.Username)
}
for _, network := range networks {
if err := migrateNetwork(ctx, db, &user, &network); err != nil {
log.Fatalf("migrating %v: %v", network.Name, err)
}
}
}
}

View File

@ -10,8 +10,25 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v4"
) )
type MessageTarget struct {
Name string
LatestMessage time.Time
}
type MessageOptions struct {
AfterID int64
AfterTime time.Time
BeforeTime time.Time
Limit int
Events bool
Sender string
Text string
TakeLast bool
}
type Database interface { type Database interface {
Close() error Close() error
Stats(ctx context.Context) (*DatabaseStats, error) Stats(ctx context.Context) (*DatabaseStats, error)
@ -41,6 +58,11 @@ type Database interface {
ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error)
StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error
DeleteWebPushSubscription(ctx context.Context, id int64) error DeleteWebPushSubscription(ctx context.Context, id int64) error
GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error)
StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error)
ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error)
ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error)
} }
type MetricsCollectorDatabase interface { type MetricsCollectorDatabase interface {

View File

@ -9,9 +9,11 @@ import (
"strings" "strings"
"time" "time"
"git.sr.ht/~emersion/soju/xirc"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
promcollectors "github.com/prometheus/client_golang/prometheus/collectors" promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
"gopkg.in/irc.v4"
) )
const postgresQueryTimeout = 5 * time.Second const postgresQueryTimeout = 5 * time.Second
@ -112,6 +114,30 @@ CREATE TABLE "WebPushSubscription" (
key_p256dh TEXT, key_p256dh TEXT,
UNIQUE(network, endpoint) UNIQUE(network, endpoint)
); );
CREATE TABLE "MessageTarget" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target TEXT NOT NULL,
UNIQUE(network, target)
);
CREATE TEXT SEARCH DICTIONARY "search_simple_dictionary" (
TEMPLATE = pg_catalog.simple
);
CREATE TEXT SEARCH CONFIGURATION "search_simple" ( COPY = pg_catalog.simple );
ALTER TEXT SEARCH CONFIGURATION "search_simple" ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH "search_simple_dictionary";
CREATE TABLE "Message" (
id SERIAL PRIMARY KEY,
target INTEGER NOT NULL REFERENCES "MessageTarget"(id) ON DELETE CASCADE,
raw TEXT NOT NULL,
time TIMESTAMP WITH TIME ZONE NOT NULL,
sender TEXT NOT NULL,
text TEXT,
text_search tsvector GENERATED ALWAYS AS (to_tsvector('search_simple', text)) STORED
);
CREATE INDEX "MessageIndex" ON "Message" (target, time);
CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search);
` `
var postgresMigrations = []string{ var postgresMigrations = []string{
@ -173,6 +199,30 @@ var postgresMigrations = []string{
`ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`, `ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`,
`ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`, `ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`, `ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`,
`
CREATE TABLE "MessageTarget" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target TEXT NOT NULL,
UNIQUE(network, target)
);
CREATE TEXT SEARCH DICTIONARY "search_simple_dictionary" (
TEMPLATE = pg_catalog.simple
);
CREATE TEXT SEARCH CONFIGURATION "search_simple" ( COPY = pg_catalog.simple );
ALTER TEXT SEARCH CONFIGURATION "search_simple" ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH "search_simple_dictionary";
CREATE TABLE "Message" (
id SERIAL PRIMARY KEY,
target INTEGER NOT NULL REFERENCES "MessageTarget"(id) ON DELETE CASCADE,
raw TEXT NOT NULL,
time TIMESTAMP WITH TIME ZONE NOT NULL,
sender TEXT NOT NULL,
text TEXT,
text_search tsvector GENERATED ALWAYS AS (to_tsvector('search_simple', text)) STORED
);
CREATE INDEX "MessageIndex" ON "Message" (target, time);
CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search);
`,
} }
type PostgresDB struct { type PostgresDB struct {
@ -847,6 +897,229 @@ func (db *PostgresDB) DeleteWebPushSubscription(ctx context.Context, id int64) e
return err return err
} }
func (db *PostgresDB) GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var msgID int64
row := db.db.QueryRowContext(ctx, `
SELECT m.id FROM "Message" AS m, "MessageTarget" as t
WHERE t.network = $1 AND t.target = $2 AND m.target = t.id
ORDER BY m.time DESC LIMIT 1`,
networkID,
name,
)
if err := row.Scan(&msgID); err != nil {
if err == sql.ErrNoRows {
return 0, nil
}
return 0, err
}
return msgID, nil
}
func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var t time.Time
if tag, ok := msg.Tags["time"]; ok {
var err error
t, err = time.Parse(xirc.ServerTimeLayout, tag)
if err != nil {
return 0, fmt.Errorf("failed to parse message time tag: %v", err)
}
} else {
t = time.Now()
}
var text sql.NullString
switch msg.Command {
case "PRIVMSG", "NOTICE":
if len(msg.Params) > 1 {
text.Valid = true
text.String = msg.Params[1]
}
}
_, err := db.db.ExecContext(ctx, `
INSERT INTO "MessageTarget" (network, target)
VALUES ($1, $2)
ON CONFLICT DO NOTHING`,
networkID,
name,
)
if err != nil {
return 0, err
}
var id int64
err = db.db.QueryRowContext(ctx, `
INSERT INTO "Message" (target, raw, time, sender, text)
SELECT id, $1, $2, $3, $4
FROM "MessageTarget" as t
WHERE network = $5 AND target = $6
RETURNING id`,
msg.String(),
t,
msg.Name,
text,
networkID,
name,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
parameters := []interface{}{
networkID,
}
query := `
SELECT t.target, MAX(m.time) AS latest
FROM "Message" m, "MessageTarget" t
WHERE m.target = t.id AND t.network = $1
`
if !options.Events {
query += `AND m.text IS NOT NULL `
}
query += `
GROUP BY t.target
HAVING true
`
if !options.AfterTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.AfterTime)
query += fmt.Sprintf(`AND MAX(m.time) > $%d `, len(parameters))
}
if !options.BeforeTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.BeforeTime)
query += fmt.Sprintf(`AND MAX(m.time) < $%d `, len(parameters))
}
if options.TakeLast {
query += `ORDER BY latest DESC `
} else {
query += `ORDER BY latest ASC `
}
parameters = append(parameters, options.Limit)
query += fmt.Sprintf(`LIMIT $%d`, len(parameters))
rows, err := db.db.QueryContext(ctx, query, parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
var l []MessageTarget
for rows.Next() {
var mt MessageTarget
if err := rows.Scan(&mt.Name, &mt.LatestMessage); err != nil {
return nil, err
}
l = append(l, mt)
}
if err := rows.Err(); err != nil {
return nil, err
}
if options.TakeLast {
// We ordered by DESC to limit to the last lines.
// Reverse the list to order by ASC these last lines.
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
}
return l, nil
}
func (db *PostgresDB) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
parameters := []interface{}{
networkID,
name,
}
query := `
SELECT m.raw
FROM "Message" AS m, "MessageTarget" AS t
WHERE m.target = t.id AND t.network = $1 AND t.target = $2 `
if options.AfterID > 0 {
parameters = append(parameters, options.AfterID)
query += fmt.Sprintf(`AND m.id > $%d `, len(parameters))
}
if !options.AfterTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.AfterTime)
query += fmt.Sprintf(`AND m.time > $%d `, len(parameters))
}
if !options.BeforeTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.BeforeTime)
query += fmt.Sprintf(`AND m.time < $%d `, len(parameters))
}
if options.Sender != "" {
parameters = append(parameters, options.Sender)
query += fmt.Sprintf(`AND m.sender = $%d `, len(parameters))
}
if options.Text != "" {
parameters = append(parameters, options.Text)
query += fmt.Sprintf(`AND text_search @@ plainto_tsquery('search_simple', $%d) `, len(parameters))
}
if !options.Events {
query += `AND m.text IS NOT NULL `
}
if options.TakeLast {
query += `ORDER BY m.time DESC `
} else {
query += `ORDER BY m.time ASC `
}
parameters = append(parameters, options.Limit)
query += fmt.Sprintf(`LIMIT $%d`, len(parameters))
rows, err := db.db.QueryContext(ctx, query, parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
var l []*irc.Message
for rows.Next() {
var raw string
if err := rows.Scan(&raw); err != nil {
return nil, err
}
msg, err := irc.ParseMessage(raw)
if err != nil {
return nil, err
}
l = append(l, msg)
}
if err := rows.Err(); err != nil {
return nil, err
}
if options.TakeLast {
// We ordered by DESC to limit to the last lines.
// Reverse the list to order by ASC these last lines.
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
}
return l, nil
}
var postgresNetworksTotalDesc = prometheus.NewDesc("soju_networks_total", "Number of networks", []string{"hostname"}, nil) var postgresNetworksTotalDesc = prometheus.NewDesc("soju_networks_total", "Number of networks", []string{"hostname"}, nil)
type postgresMetricsCollector struct { type postgresMetricsCollector struct {

View File

@ -1,4 +1,5 @@
//go:build !nosqlite //go:build !nosqlite
// +build !nosqlite
package database package database
@ -11,8 +12,10 @@ import (
"strings" "strings"
"time" "time"
"git.sr.ht/~emersion/soju/xirc"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
promcollectors "github.com/prometheus/client_golang/prometheus/collectors" promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
"gopkg.in/irc.v4"
) )
const SqliteEnabled = true const SqliteEnabled = true
@ -146,6 +149,41 @@ CREATE TABLE WebPushSubscription (
FOREIGN KEY(network) REFERENCES Network(id), FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, endpoint) UNIQUE(network, endpoint)
); );
CREATE TABLE Message (
id INTEGER PRIMARY KEY,
target INTEGER NOT NULL,
raw TEXT NOT NULL,
time TEXT NOT NULL,
sender TEXT NOT NULL,
text TEXT,
FOREIGN KEY(target) REFERENCES MessageTarget(id)
);
CREATE INDEX MessageIndex ON Message(target, time);
CREATE TABLE MessageTarget (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
target TEXT NOT NULL,
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, target)
);
CREATE VIRTUAL TABLE MessageFTS USING fts5 (
text,
content=Message,
content_rowid=id
);
CREATE TRIGGER MessageFTSInsert AFTER INSERT ON Message BEGIN
INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text);
END;
CREATE TRIGGER MessageFTSDelete AFTER DELETE ON Message BEGIN
INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text);
END;
CREATE TRIGGER MessageFTSUpdate AFTER UPDATE ON Message BEGIN
INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text);
INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text);
END;
` `
var sqliteMigrations = []string{ var sqliteMigrations = []string{
@ -293,6 +331,42 @@ var sqliteMigrations = []string{
`, `,
"ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", "ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
"ALTER TABLE User ADD COLUMN downstream_interacted_at TEXT;", "ALTER TABLE User ADD COLUMN downstream_interacted_at TEXT;",
`
CREATE TABLE Message (
id INTEGER PRIMARY KEY,
target INTEGER NOT NULL,
raw TEXT NOT NULL,
time TEXT NOT NULL,
sender TEXT NOT NULL,
text TEXT,
FOREIGN KEY(target) REFERENCES MessageTarget(id)
);
CREATE INDEX MessageIndex ON Message(target, time);
CREATE TABLE MessageTarget (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
target TEXT NOT NULL,
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, target)
);
CREATE VIRTUAL TABLE MessageFTS USING fts5 (
text,
content=Message,
content_rowid=id
);
CREATE TRIGGER MessageFTSInsert AFTER INSERT ON Message BEGIN
INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text);
END;
CREATE TRIGGER MessageFTSDelete AFTER DELETE ON Message BEGIN
INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text);
END;
CREATE TRIGGER MessageFTSUpdate AFTER UPDATE ON Message BEGIN
INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text);
INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text);
END;
`,
} }
type SqliteDB struct { type SqliteDB struct {
@ -697,6 +771,16 @@ func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
} }
defer tx.Rollback() defer tx.Rollback()
_, err = tx.ExecContext(ctx, "DELETE FROM Message WHERE target IN (SELECT id FROM MessageTarget WHERE network = ?)", id)
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "DELETE FROM MessageTarget WHERE network = ?", id)
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE network = ?", id) _, err = tx.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE network = ?", id)
if err != nil { if err != nil {
return err return err
@ -1054,3 +1138,232 @@ func (db *SqliteDB) DeleteWebPushSubscription(ctx context.Context, id int64) err
_, err := db.db.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE id = ?", id) _, err := db.db.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE id = ?", id)
return err return err
} }
func (db *SqliteDB) GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
var msgID int64
row := db.db.QueryRowContext(ctx, `
SELECT m.id FROM Message AS m, MessageTarget AS t
WHERE t.network = :network AND t.target = :target AND m.target = t.id
ORDER BY m.time DESC LIMIT 1`,
sql.Named("network", networkID),
sql.Named("target", name),
)
if err := row.Scan(&msgID); err != nil {
if err == sql.ErrNoRows {
return 0, nil
}
return 0, err
}
return msgID, nil
}
func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
var t time.Time
if tag, ok := msg.Tags["time"]; ok {
var err error
t, err = time.Parse(xirc.ServerTimeLayout, tag)
if err != nil {
return 0, fmt.Errorf("failed to parse message time tag: %v", err)
}
} else {
t = time.Now()
}
var text sql.NullString
switch msg.Command {
case "PRIVMSG", "NOTICE":
if len(msg.Params) > 1 {
text.Valid = true
text.String = msg.Params[1]
}
}
res, err := db.db.ExecContext(ctx, `
INSERT INTO MessageTarget(network, target)
VALUES (:network, :target)
ON CONFLICT DO NOTHING`,
sql.Named("network", networkID),
sql.Named("target", name),
)
if err != nil {
return 0, err
}
res, err = db.db.ExecContext(ctx, `
INSERT INTO Message(target, raw, time, sender, text)
SELECT id, :raw, :time, :sender, :text
FROM MessageTarget as t
WHERE network = :network AND target = :target`,
sql.Named("network", networkID),
sql.Named("target", name),
sql.Named("raw", msg.String()),
sql.Named("time", sqliteTime{t}),
sql.Named("sender", msg.Name),
sql.Named("text", text),
)
if err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}
func (db *SqliteDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
innerQuery := `
SELECT time
FROM Message
WHERE target = MessageTarget.id `
if !options.Events {
innerQuery += `AND text IS NOT NULL `
}
innerQuery += `
ORDER BY time DESC
LIMIT 1
`
query := `
SELECT target, (` + innerQuery + `) latest
FROM MessageTarget
WHERE network = :network `
if !options.AfterTime.IsZero() {
// compares time strings by lexicographical order
query += `AND latest > :after `
}
if !options.BeforeTime.IsZero() {
// compares time strings by lexicographical order
query += `AND latest < :before `
}
if options.TakeLast {
query += `ORDER BY latest DESC `
} else {
query += `ORDER BY latest ASC `
}
query += `LIMIT :limit`
rows, err := db.db.QueryContext(ctx, query,
sql.Named("network", networkID),
sql.Named("after", sqliteTime{options.AfterTime}),
sql.Named("before", sqliteTime{options.BeforeTime}),
sql.Named("limit", options.Limit),
)
if err != nil {
return nil, err
}
defer rows.Close()
var l []MessageTarget
for rows.Next() {
var mt MessageTarget
var ts sqliteTime
if err := rows.Scan(&mt.Name, &ts); err != nil {
return nil, err
}
mt.LatestMessage = ts.Time
l = append(l, mt)
}
if err := rows.Err(); err != nil {
return nil, err
}
if options.TakeLast {
// We ordered by DESC to limit to the last lines.
// Reverse the list to order by ASC these last lines.
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
}
return l, nil
}
func (db *SqliteDB) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
query := `
SELECT m.raw
FROM Message AS m, MessageTarget AS t
WHERE m.target = t.id AND t.network = :network AND t.target = :target `
if options.AfterID > 0 {
query += `AND m.id > :afterID `
}
if !options.AfterTime.IsZero() {
// compares time strings by lexicographical order
query += `AND m.time > :after `
}
if !options.BeforeTime.IsZero() {
// compares time strings by lexicographical order
query += `AND m.time < :before `
}
if options.Sender != "" {
query += `AND m.sender = :sender `
}
if options.Text != "" {
query += `AND m.id IN (SELECT ROWID FROM MessageFTS WHERE MessageFTS MATCH :text) `
}
if !options.Events {
query += `AND m.text IS NOT NULL `
}
if options.TakeLast {
query += `ORDER BY m.time DESC `
} else {
query += `ORDER BY m.time ASC `
}
query += `LIMIT :limit`
rows, err := db.db.QueryContext(ctx, query,
sql.Named("network", networkID),
sql.Named("target", name),
sql.Named("afterID", options.AfterID),
sql.Named("after", sqliteTime{options.AfterTime}),
sql.Named("before", sqliteTime{options.BeforeTime}),
sql.Named("sender", options.Sender),
sql.Named("text", options.Text),
sql.Named("limit", options.Limit),
)
if err != nil {
return nil, err
}
defer rows.Close()
var l []*irc.Message
for rows.Next() {
var raw string
if err := rows.Scan(&raw); err != nil {
return nil, err
}
msg, err := irc.ParseMessage(raw)
if err != nil {
return nil, err
}
l = append(l, msg)
}
if err := rows.Err(); err != nil {
return nil, err
}
if options.TakeLast {
// We ordered by DESC to limit to the last lines.
// Reverse the list to order by ASC these last lines.
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
}
return l, nil
}

View File

@ -3,6 +3,7 @@
package database package database
import ( import (
_ "git.sr.ht/~emersion/go-sqlite3-fts5"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )

View File

@ -137,6 +137,7 @@ The following directives are supported:
- _memory_ stores messages in memory. - _memory_ stores messages in memory.
- _fs_ stores messages on disk, in the same format as ZNC. _source_ is - _fs_ stores messages on disk, in the same format as ZNC. _source_ is
required and is the root directory path for the database. required and is the root directory path for the database.
- _db_ stores messages in the database.
(_log_ is a deprecated alias for this directive.) (_log_ is a deprecated alias for this directive.)

View File

@ -401,7 +401,8 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
// TODO: this is racy, we should only enable chathistory after // TODO: this is racy, we should only enable chathistory after
// authentication and then check that user.msgStore implements // authentication and then check that user.msgStore implements
// chatHistoryMessageStore // chatHistoryMessageStore
if srv.Config().LogPath != "" { switch srv.Config().LogDriver {
case "fs", "db":
dc.caps.Available["draft/chathistory"] = "" dc.caps.Available["draft/chathistory"] = ""
dc.caps.Available["soju.im/search"] = "" dc.caps.Available["soju.im/search"] = ""
} }

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.15
require ( require (
git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99
git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc
git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
github.com/SherClockHolmes/webpush-go v1.2.0 github.com/SherClockHolmes/webpush-go v1.2.0
github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead

2
go.sum
View File

@ -33,6 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao= git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao=
git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U= git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U=
git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc h1:+y3OijpLl4rgbFsqMBmYUTCsGCkxQUWpWaqfS8j9Ygc=
git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc/go.mod h1:PCl1xjl7iC6x35TKKubKRyo/3TT0dGI66jyNI6vmYnU=
git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw= git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw=
git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE= git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE=
git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA=

151
msgstore/db.go Normal file
View File

@ -0,0 +1,151 @@
package msgstore
import (
"context"
"time"
"git.sr.ht/~emersion/soju/database"
"git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v4"
)
type dbMsgID struct {
ID bare.Uint
}
func (dbMsgID) msgIDType() msgIDType {
return msgIDDB
}
func parseDBMsgID(s string) (msgID int64, err error) {
var id dbMsgID
_, _, err = ParseMsgID(s, &id)
if err != nil {
return 0, err
}
return int64(id.ID), nil
}
func formatDBMsgID(netID int64, target string, msgID int64) string {
id := dbMsgID{bare.Uint(msgID)}
return formatMsgID(netID, target, &id)
}
// dbMessageStore is a persistent store for IRC messages, that
// stores messages in the soju database.
type dbMessageStore struct {
db database.Database
}
var (
_ Store = (*dbMessageStore)(nil)
_ ChatHistoryStore = (*dbMessageStore)(nil)
_ SearchStore = (*dbMessageStore)(nil)
)
func NewDBStore(db database.Database) *dbMessageStore {
return &dbMessageStore{
db: db,
}
}
func (ms *dbMessageStore) Close() error {
return nil
}
func (ms *dbMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
// TODO: what should we do with t?
id, err := ms.db.GetMessageLastID(context.TODO(), network.ID, entity)
if err != nil {
return "", err
}
return formatDBMsgID(network.ID, entity, id), nil
}
func (ms *dbMessageStore) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error) {
msgID, err := parseDBMsgID(id)
if err != nil {
return nil, err
}
l, err := ms.db.ListMessages(ctx, options.Network.ID, options.Entity, &database.MessageOptions{
AfterID: msgID,
Limit: options.Limit,
TakeLast: true,
})
if err != nil {
return nil, err
}
return l, nil
}
func (ms *dbMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
id, err := ms.db.StoreMessage(context.TODO(), network.ID, entity, msg)
if err != nil {
return "", err
}
return formatDBMsgID(network.ID, entity, id), nil
}
func (ms *dbMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) {
l, err := ms.db.ListMessageLastPerTarget(ctx, network.ID, &database.MessageOptions{
AfterTime: start,
BeforeTime: end,
Limit: limit,
Events: events,
})
if err != nil {
return nil, err
}
targets := make([]ChatHistoryTarget, len(l))
for i, v := range l {
targets[i] = ChatHistoryTarget{
Name: v.Name,
LatestMessage: v.LatestMessage,
}
}
return targets, nil
}
func (ms *dbMessageStore) LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) {
l, err := ms.db.ListMessages(ctx, options.Network.ID, options.Entity, &database.MessageOptions{
AfterTime: end,
BeforeTime: start,
Limit: options.Limit,
Events: options.Events,
TakeLast: true,
})
if err != nil {
return nil, err
}
return l, nil
}
func (ms *dbMessageStore) LoadAfterTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) {
l, err := ms.db.ListMessages(ctx, options.Network.ID, options.Entity, &database.MessageOptions{
AfterTime: start,
BeforeTime: end,
Limit: options.Limit,
Events: options.Events,
})
if err != nil {
return nil, err
}
return l, nil
}
func (ms *dbMessageStore) Search(ctx context.Context, network *database.Network, options *SearchMessageOptions) ([]*irc.Message, error) {
l, err := ms.db.ListMessages(ctx, network.ID, options.In, &database.MessageOptions{
AfterTime: options.Start,
BeforeTime: options.End,
Limit: options.Limit,
Sender: options.From,
Text: options.Text,
TakeLast: true,
})
if err != nil {
return nil, err
}
return l, nil
}

View File

@ -23,7 +23,7 @@ const (
fsMessageStoreMaxTries = 100 fsMessageStoreMaxTries = 100
) )
func escapeFilename(unsafe string) (safe string) { func EscapeFilename(unsafe string) (safe string) {
if unsafe == "." { if unsafe == "." {
return "-" return "-"
} else if unsafe == ".." { } else if unsafe == ".." {
@ -103,7 +103,7 @@ func IsFSStore(store Store) bool {
func NewFSStore(root string, user *database.User) *fsMessageStore { func NewFSStore(root string, user *database.User) *fsMessageStore {
return &fsMessageStore{ return &fsMessageStore{
root: filepath.Join(root, escapeFilename(user.Username)), root: filepath.Join(root, EscapeFilename(user.Username)),
user: user, user: user,
files: make(map[string]*fsMessageStoreFile), files: make(map[string]*fsMessageStoreFile),
} }
@ -112,7 +112,7 @@ func NewFSStore(root string, user *database.User) *fsMessageStore {
func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string { func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string {
year, month, day := t.Date() year, month, day := t.Date()
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) return filepath.Join(ms.root, EscapeFilename(network.GetName()), EscapeFilename(entity), filename)
} }
// nextMsgID queries the message ID for the next message to be written to f. // nextMsgID queries the message ID for the next message to be written to f.
@ -265,6 +265,10 @@ func formatMessage(msg *irc.Message) string {
} }
func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
return FSParseMessage(line, ms.user, network, entity, ref, events)
}
func FSParseMessage(line string, user *database.User, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
var hour, minute, second int var hour, minute, second int
_, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
if err != nil { if err != nil {
@ -391,7 +395,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e
// our nickname in the logs, so grab it from the network settings. // our nickname in the logs, so grab it from the network settings.
// Not very accurate since this may not match our nick at the time // Not very accurate since this may not match our nick at the time
// the message was received, but we can't do a lot better. // the message was received, but we can't do a lot better.
entity = database.GetNick(ms.user, network) entity = database.GetNick(user, network)
} }
params = []string{entity, text} params = []string{entity, text}
} }
@ -634,7 +638,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) { func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) rootPath := filepath.Join(ms.root, EscapeFilename(network.GetName()))
root, err := os.Open(rootPath) root, err := os.Open(rootPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil, nil return nil, nil
@ -713,7 +717,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *SearchMessageOptions) ([]*irc.Message, error) { func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *SearchMessageOptions) ([]*irc.Message, error) {
text := strings.ToLower(opts.Text) text := strings.ToLower(opts.Text)
selector := func(m *irc.Message) bool { selector := func(m *irc.Message) bool {
if opts.From != "" && m.User != opts.From { if opts.From != "" && m.Name != opts.From {
return false return false
} }
if text != "" && !strings.Contains(strings.ToLower(m.Params[1]), text) { if text != "" && !strings.Contains(strings.ToLower(m.Params[1]), text) {
@ -734,8 +738,8 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network,
} }
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error { func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error {
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) oldDir := filepath.Join(ms.root, EscapeFilename(oldNet.GetName()))
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) newDir := filepath.Join(ms.root, EscapeFilename(newNet.GetName()))
// Avoid loosing data by overwriting an existing directory // Avoid loosing data by overwriting an existing directory
if _, err := os.Stat(newDir); err == nil { if _, err := os.Stat(newDir); err == nil {
return fmt.Errorf("destination %q already exists", newDir) return fmt.Errorf("destination %q already exists", newDir)

View File

@ -52,7 +52,7 @@ type ChatHistoryStore interface {
// end is before start. // end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The // LoadAfterTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is after start. // end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
@ -90,6 +90,7 @@ const (
msgIDNone msgIDType = iota msgIDNone msgIDType = iota
msgIDMemory msgIDMemory
msgIDFS msgIDFS
msgIDDB
) )
const msgIDVersion uint = 0 const msgIDVersion uint = 0

View File

@ -138,6 +138,7 @@ func (ln *retryListener) Accept() (net.Conn, error) {
type Config struct { type Config struct {
Hostname string Hostname string
Title string Title string
LogDriver string
LogPath string LogPath string
HTTPOrigins []string HTTPOrigins []string
AcceptProxyIPs config.IPSet AcceptProxyIPs config.IPSet

View File

@ -516,9 +516,12 @@ func newUser(srv *Server, record *database.User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore msgstore.Store var msgStore msgstore.Store
if logPath := srv.Config().LogPath; logPath != "" { switch srv.Config().LogDriver {
msgStore = msgstore.NewFSStore(logPath, record) case "fs":
} else { msgStore = msgstore.NewFSStore(srv.Config().LogPath, record)
case "db":
msgStore = msgstore.NewDBStore(srv.db)
case "memory":
msgStore = msgstore.NewMemoryStore() msgStore = msgstore.NewMemoryStore()
} }