soju/database/postgres.go
delthas e510f7a461 Add support for explicit PostgreSQL schema prefixes for tests
PostgreSQL tests use pg_temp only. pg_temp is never searched for FTS
objects, so creating then altering an FTS configuration will not work
because PostgreSQL will not be able to find the FTS configuration it
just created.

Instead, we explicitly refer to the FTS objects with their full name
including their prefix, which makes PostgreSQL able to find the object.

This is only needed for tests.

See: https://stackoverflow.com/a/31095452/2347617
See: https://www.postgresql.org/message-id/15191.1208975632@sss.pgh.pa.us
2023-02-17 14:13:43 +01:00

1177 lines
33 KiB
Go

package database
import (
"context"
"database/sql"
"errors"
"fmt"
"math"
"strings"
"time"
"git.sr.ht/~emersion/soju/xirc"
_ "github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
"gopkg.in/irc.v4"
)
const postgresQueryTimeout = 5 * time.Second
const postgresConfigSchema = `
CREATE TABLE IF NOT EXISTS "Config" (
id SMALLINT PRIMARY KEY,
version INTEGER NOT NULL,
CHECK(id = 1)
);
`
const postgresSchema = `
CREATE TABLE "User" (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255),
admin BOOLEAN NOT NULL DEFAULT FALSE,
nick VARCHAR(255),
realname VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
enabled BOOLEAN NOT NULL DEFAULT TRUE,
downstream_interacted_at TIMESTAMP WITH TIME ZONE
);
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
CREATE TABLE "Network" (
id SERIAL PRIMARY KEY,
name VARCHAR(255),
"user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
addr VARCHAR(255) NOT NULL,
nick VARCHAR(255),
username VARCHAR(255),
realname VARCHAR(255),
certfp TEXT,
pass VARCHAR(255),
connect_commands VARCHAR(1023),
sasl_mechanism sasl_mechanism,
sasl_plain_username VARCHAR(255),
sasl_plain_password VARCHAR(255),
sasl_external_cert BYTEA,
sasl_external_key BYTEA,
auto_away BOOLEAN NOT NULL DEFAULT TRUE,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
UNIQUE("user", addr, nick),
UNIQUE("user", name)
);
CREATE TABLE "Channel" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
name VARCHAR(255) NOT NULL,
key VARCHAR(255),
detached BOOLEAN NOT NULL DEFAULT FALSE,
detached_internal_msgid VARCHAR(255),
relay_detached INTEGER NOT NULL DEFAULT 0,
reattach_on INTEGER NOT NULL DEFAULT 0,
detach_after INTEGER NOT NULL DEFAULT 0,
detach_on INTEGER NOT NULL DEFAULT 0,
UNIQUE(network, name)
);
CREATE TABLE "DeliveryReceipt" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target VARCHAR(255) NOT NULL,
client VARCHAR(255) NOT NULL DEFAULT '',
internal_msgid VARCHAR(255) NOT NULL,
UNIQUE(network, target, client)
);
CREATE TABLE "ReadReceipt" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target VARCHAR(255) NOT NULL,
timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
UNIQUE(network, target)
);
CREATE TABLE "WebPushConfig" (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
vapid_key_public TEXT NOT NULL,
vapid_key_private TEXT NOT NULL,
UNIQUE(vapid_key_public)
);
CREATE TABLE "WebPushSubscription" (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
"user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE,
endpoint TEXT NOT NULL,
key_vapid TEXT,
key_auth TEXT,
key_p256dh TEXT,
UNIQUE(network, endpoint)
);
CREATE TABLE "MessageTarget" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target TEXT NOT NULL,
UNIQUE(network, target)
);
CREATE TEXT SEARCH DICTIONARY search_simple_dictionary (
TEMPLATE = pg_catalog.simple
);
CREATE TEXT SEARCH CONFIGURATION @SCHEMA_PREFIX@search_simple ( COPY = pg_catalog.simple );
ALTER TEXT SEARCH CONFIGURATION @SCHEMA_PREFIX@search_simple ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH @SCHEMA_PREFIX@search_simple_dictionary;
CREATE TABLE "Message" (
id SERIAL PRIMARY KEY,
target INTEGER NOT NULL REFERENCES "MessageTarget"(id) ON DELETE CASCADE,
raw TEXT NOT NULL,
time TIMESTAMP WITH TIME ZONE NOT NULL,
sender TEXT NOT NULL,
text TEXT,
text_search tsvector GENERATED ALWAYS AS (to_tsvector('@SCHEMA_PREFIX@search_simple', text)) STORED
);
CREATE INDEX "MessageIndex" ON "Message" (target, time);
CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search);
`
var postgresMigrations = []string{
"", // migration #0 is reserved for schema initialization
`ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
`
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
ALTER TABLE "Network"
ALTER COLUMN sasl_mechanism
TYPE sasl_mechanism
USING sasl_mechanism::sasl_mechanism;
`,
`
CREATE TABLE "ReadReceipt" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target VARCHAR(255) NOT NULL,
timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
UNIQUE(network, target)
);
`,
`
CREATE TABLE "WebPushConfig" (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
vapid_key_public TEXT NOT NULL,
vapid_key_private TEXT NOT NULL,
UNIQUE(vapid_key_public)
);
CREATE TABLE "WebPushSubscription" (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE,
endpoint TEXT NOT NULL,
key_vapid TEXT,
key_auth TEXT,
key_p256dh TEXT,
UNIQUE(network, endpoint)
);
`,
`
ALTER TABLE "WebPushSubscription"
ADD COLUMN "user" INTEGER
REFERENCES "User"(id) ON DELETE CASCADE
`,
`ALTER TABLE "User" ADD COLUMN nick VARCHAR(255)`,
// Before this migration, a bug swapped user and network, so empty the
// web push subscriptions table
`
DELETE FROM "WebPushSubscription";
ALTER TABLE "WebPushSubscription"
ALTER COLUMN "user"
SET NOT NULL;
`,
`ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "Network" ADD COLUMN certfp TEXT`,
`ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`,
`ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`,
`
CREATE TABLE "MessageTarget" (
id SERIAL PRIMARY KEY,
network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
target TEXT NOT NULL,
UNIQUE(network, target)
);
CREATE TEXT SEARCH DICTIONARY search_simple_dictionary (
TEMPLATE = pg_catalog.simple
);
CREATE TEXT SEARCH CONFIGURATION @SCHEMA_PREFIX@search_simple ( COPY = pg_catalog.simple );
ALTER TEXT SEARCH CONFIGURATION @SCHEMA_PREFIX@search_simple ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH @SCHEMA_PREFIX@search_simple_dictionary;
CREATE TABLE "Message" (
id SERIAL PRIMARY KEY,
target INTEGER NOT NULL REFERENCES "MessageTarget"(id) ON DELETE CASCADE,
raw TEXT NOT NULL,
time TIMESTAMP WITH TIME ZONE NOT NULL,
sender TEXT NOT NULL,
text TEXT,
text_search tsvector GENERATED ALWAYS AS (to_tsvector('@SCHEMA_PREFIX@search_simple', text)) STORED
);
CREATE INDEX "MessageIndex" ON "Message" (target, time);
CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search);
`,
}
type PostgresDB struct {
db *sql.DB
temp bool
}
func OpenPostgresDB(source string) (Database, error) {
sqlPostgresDB, err := sql.Open("postgres", source)
if err != nil {
return nil, err
}
// By default sql.DB doesn't have a connection limit. This can cause errors
// because PostgreSQL has a default of 100 max connections.
sqlPostgresDB.SetMaxOpenConns(25)
db := &PostgresDB{db: sqlPostgresDB}
if err := db.upgrade(); err != nil {
sqlPostgresDB.Close()
return nil, err
}
return db, nil
}
func openTempPostgresDB(source string) (*sql.DB, error) {
db, err := sql.Open("postgres", source)
if err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err)
}
// Store all tables in a temporary schema which will be dropped when the
// connection to PostgreSQL is closed.
db.SetMaxOpenConns(1)
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
return nil, fmt.Errorf("failed to set PostgreSQL search_path: %v", err)
}
return db, nil
}
func OpenTempPostgresDB(source string) (Database, error) {
sqlPostgresDB, err := openTempPostgresDB(source)
if err != nil {
return nil, err
}
db := &PostgresDB{db: sqlPostgresDB, temp: true}
if err := db.upgrade(); err != nil {
sqlPostgresDB.Close()
return nil, err
}
return db, nil
}
func (db *PostgresDB) template(t string) string {
// Hack to convince postgres to lookup text search configurations in
// pg_temp
if db.temp {
return strings.ReplaceAll(t, "@SCHEMA_PREFIX@", "pg_temp.")
}
return strings.ReplaceAll(t, "@SCHEMA_PREFIX@", "")
}
func (db *PostgresDB) upgrade() error {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(postgresConfigSchema); err != nil {
return fmt.Errorf("failed to create Config table: %s", err)
}
var version int
err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("failed to query schema version: %s", err)
}
if version == len(postgresMigrations) {
return nil
}
if version > len(postgresMigrations) {
return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
}
if version == 0 {
if _, err := tx.Exec(db.template(postgresSchema)); err != nil {
return fmt.Errorf("failed to initialize schema: %s", err)
}
} else {
for i := version; i < len(postgresMigrations); i++ {
if _, err := tx.Exec(db.template(postgresMigrations[i])); err != nil {
return fmt.Errorf("failed to execute migration #%v: %v", i, err)
}
}
}
_, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
if err != nil {
return fmt.Errorf("failed to bump schema version: %v", err)
}
return tx.Commit()
}
func (db *PostgresDB) Close() error {
return db.db.Close()
}
func (db *PostgresDB) RegisterMetrics(r prometheus.Registerer) error {
if err := r.Register(&postgresMetricsCollector{db}); err != nil {
return err
}
return r.Register(promcollectors.NewDBStatsCollector(db.db, "main"))
}
func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var stats DatabaseStats
row := db.db.QueryRowContext(ctx, `SELECT
(SELECT COUNT(*) FROM "User") AS users,
(SELECT COUNT(*) FROM "Network") AS networks,
(SELECT COUNT(*) FROM "Channel") AS channels`)
if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
return nil, err
}
return &stats, nil
}
func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM "User"`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var user User
var password, nick, realname sql.NullString
var downstreamInteractedAt sql.NullTime
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
return nil, err
}
user.Password = password.String
user.Nick = nick.String
user.Realname = realname.String
user.DownstreamInteractedAt = downstreamInteractedAt.Time
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
user := &User{Username: username}
var password, nick, realname sql.NullString
var downstreamInteractedAt sql.NullTime
row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at
FROM "User"
WHERE username = $1`,
username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
return nil, err
}
user.Password = password.String
user.Nick = nick.String
user.Realname = realname.String
user.DownstreamInteractedAt = downstreamInteractedAt.Time
return user, nil
}
func (db *PostgresDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
`SELECT username FROM "User" WHERE COALESCE(downstream_interacted_at, created_at) < $1`,
limit)
if err != nil {
return nil, err
}
defer rows.Close()
var usernames []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
if err := rows.Err(); err != nil {
return nil, err
}
return usernames, nil
}
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
password := toNullString(user.Password)
nick := toNullString(user.Nick)
realname := toNullString(user.Realname)
downstreamInteractedAt := toNullTime(user.DownstreamInteractedAt)
var err error
if user.ID == 0 {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "User" (username, password, admin, nick, realname,
enabled, downstream_interacted_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`,
user.Username, password, user.Admin, nick, realname, user.Enabled,
downstreamInteractedAt).Scan(&user.ID)
} else {
_, err = db.db.ExecContext(ctx, `
UPDATE "User"
SET password = $1, admin = $2, nick = $3, realname = $4,
enabled = $5, downstream_interacted_at = $6
WHERE id = $7`,
password, user.Admin, nick, realname, user.Enabled,
downstreamInteractedAt, user.ID)
}
return err
}
func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
SELECT id, name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism,
sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled
FROM "Network"
WHERE "user" = $1`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var networks []Network
for rows.Next() {
var net Network
var name, nick, username, realname, certfp, pass, connectCommands sql.NullString
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, &certfp,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled)
if err != nil {
return nil, err
}
net.Name = name.String
net.Nick = nick.String
net.Username = username.String
net.Realname = realname.String
net.CertFP = certfp.String
net.Pass = pass.String
if connectCommands.Valid {
net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
}
net.SASL.Mechanism = saslMechanism.String
net.SASL.Plain.Username = saslPlainUsername.String
net.SASL.Plain.Password = saslPlainPassword.String
networks = append(networks, net)
}
if err := rows.Err(); err != nil {
return nil, err
}
return networks, nil
}
func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
netName := toNullString(network.Name)
nick := toNullString(network.Nick)
netUsername := toNullString(network.Username)
realname := toNullString(network.Realname)
certfp := toNullString(network.CertFP)
pass := toNullString(network.Pass)
connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
if network.SASL.Mechanism != "" {
saslMechanism = toNullString(network.SASL.Mechanism)
switch network.SASL.Mechanism {
case "PLAIN":
saslPlainUsername = toNullString(network.SASL.Plain.Username)
saslPlainPassword = toNullString(network.SASL.Plain.Password)
network.SASL.External.CertBlob = nil
network.SASL.External.PrivKeyBlob = nil
case "EXTERNAL":
// keep saslPlain* nil
default:
return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
}
}
var err error
if network.ID == 0 {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "Network" ("user", name, addr, nick, username, realname, certfp, pass, connect_commands,
sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
sasl_external_key, auto_away, enabled)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id`,
userID, netName, network.Addr, nick, netUsername, realname, certfp, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled).Scan(&network.ID)
} else {
_, err = db.db.ExecContext(ctx, `
UPDATE "Network"
SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, certfp = $7, pass = $8,
connect_commands = $9, sasl_mechanism = $10, sasl_plain_username = $11,
sasl_plain_password = $12, sasl_external_cert = $13, sasl_external_key = $14,
auto_away = $14, enabled = $15
WHERE id = $1`,
network.ID, netName, network.Addr, nick, netUsername, realname, certfp, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled)
}
return err
}
func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
detach_on
FROM "Channel"
WHERE network = $1`, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var channels []Channel
for rows.Next() {
var ch Channel
var key, detachedInternalMsgID sql.NullString
var detachAfter int64
if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
return nil, err
}
ch.Key = key.String
ch.DetachedInternalMsgID = detachedInternalMsgID.String
ch.DetachAfter = time.Duration(detachAfter) * time.Second
channels = append(channels, ch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return channels, nil
}
func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
key := toNullString(ch.Key)
detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
var err error
if ch.ID == 0 {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
detach_after, detach_on)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id`,
networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
} else {
_, err = db.db.ExecContext(ctx, `
UPDATE "Channel"
SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
WHERE id = $1`,
ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
}
return err
}
func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
SELECT id, target, client, internal_msgid
FROM "DeliveryReceipt"
WHERE network = $1`, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var receipts []DeliveryReceipt
for rows.Next() {
var rcpt DeliveryReceipt
if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
return nil, err
}
receipts = append(receipts, rcpt)
}
if err := rows.Err(); err != nil {
return nil, err
}
return receipts, nil
}
func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx,
`DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
networkID, client)
if err != nil {
return err
}
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
VALUES ($1, $2, $3, $4)
RETURNING id`)
if err != nil {
return err
}
defer stmt.Close()
for i := range receipts {
rcpt := &receipts[i]
err := stmt.
QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
Scan(&rcpt.ID)
if err != nil {
return err
}
}
return tx.Commit()
}
func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
receipt := &ReadReceipt{
Target: name,
}
row := db.db.QueryRowContext(ctx,
`SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`,
networkID, name)
if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return receipt, nil
}
func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var err error
if receipt.ID != 0 {
_, err = db.db.ExecContext(ctx, `
UPDATE "ReadReceipt"
SET timestamp = $1
WHERE id = $2`,
receipt.Timestamp, receipt.ID)
} else {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "ReadReceipt" (network, target, timestamp)
VALUES ($1, $2, $3)
RETURNING id`,
networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID)
}
return err
}
func (db *PostgresDB) listTopNetworkAddrs(ctx context.Context) (map[string]int, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
addrs := make(map[string]int)
rows, err := db.db.QueryContext(ctx, `
SELECT addr, COUNT(addr) AS n
FROM "Network"
GROUP BY addr
ORDER BY n DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var (
addr string
n int
)
if err := rows.Scan(&addr, &n); err != nil {
return nil, err
}
addrs[addr] = n
}
return addrs, rows.Err()
}
func (db *PostgresDB) ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
SELECT id, vapid_key_public, vapid_key_private
FROM "WebPushConfig"`)
if err != nil {
return nil, err
}
defer rows.Close()
var configs []WebPushConfig
for rows.Next() {
var config WebPushConfig
if err := rows.Scan(&config.ID, &config.VAPIDKeys.Public, &config.VAPIDKeys.Private); err != nil {
return nil, err
}
configs = append(configs, config)
}
return configs, rows.Err()
}
func (db *PostgresDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
if config.ID != 0 {
return fmt.Errorf("cannot update a WebPushConfig")
}
err := db.db.QueryRowContext(ctx, `
INSERT INTO "WebPushConfig" (created_at, vapid_key_public, vapid_key_private)
VALUES (NOW(), $1, $2)
RETURNING id`,
config.VAPIDKeys.Public, config.VAPIDKeys.Private).Scan(&config.ID)
return err
}
func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
nullNetworkID := sql.NullInt64{
Int64: networkID,
Valid: networkID != 0,
}
rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM "WebPushSubscription"
WHERE "user" = $1 AND network IS NOT DISTINCT FROM $2`, userID, nullNetworkID)
if err != nil {
return nil, err
}
defer rows.Close()
var subs []WebPushSubscription
for rows.Next() {
var sub WebPushSubscription
if err := rows.Scan(&sub.ID, &sub.Endpoint, &sub.Keys.Auth, &sub.Keys.P256DH, &sub.Keys.VAPID); err != nil {
return nil, err
}
subs = append(subs, sub)
}
return subs, rows.Err()
}
func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
nullNetworkID := sql.NullInt64{
Int64: networkID,
Valid: networkID != 0,
}
var err error
if sub.ID != 0 {
_, err = db.db.ExecContext(ctx, `
UPDATE "WebPushSubscription"
SET updated_at = NOW(), key_auth = $1, key_p256dh = $2,
key_vapid = $3
WHERE id = $4`,
sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID)
} else {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "WebPushSubscription" (created_at, updated_at, "user",
network, endpoint, key_auth, key_p256dh, key_vapid)
VALUES (NOW(), NOW(), $1, $2, $3, $4, $5, $6)
RETURNING id`,
userID, nullNetworkID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
sub.Keys.VAPID).Scan(&sub.ID)
}
return err
}
func (db *PostgresDB) DeleteWebPushSubscription(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "WebPushSubscription" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var msgID int64
row := db.db.QueryRowContext(ctx, `
SELECT m.id FROM "Message" AS m, "MessageTarget" as t
WHERE t.network = $1 AND t.target = $2 AND m.target = t.id
ORDER BY m.time DESC LIMIT 1`,
networkID,
name,
)
if err := row.Scan(&msgID); err != nil {
if err == sql.ErrNoRows {
return 0, nil
}
return 0, err
}
return msgID, nil
}
func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var t time.Time
if tag, ok := msg.Tags["time"]; ok {
var err error
t, err = time.Parse(xirc.ServerTimeLayout, tag)
if err != nil {
return 0, fmt.Errorf("failed to parse message time tag: %v", err)
}
} else {
t = time.Now()
}
var text sql.NullString
switch msg.Command {
case "PRIVMSG", "NOTICE":
if len(msg.Params) > 1 {
text.Valid = true
text.String = msg.Params[1]
}
}
_, err := db.db.ExecContext(ctx, `
INSERT INTO "MessageTarget" (network, target)
VALUES ($1, $2)
ON CONFLICT DO NOTHING`,
networkID,
name,
)
if err != nil {
return 0, err
}
var id int64
err = db.db.QueryRowContext(ctx, `
INSERT INTO "Message" (target, raw, time, sender, text)
SELECT id, $1, $2, $3, $4
FROM "MessageTarget" as t
WHERE network = $5 AND target = $6
RETURNING id`,
msg.String(),
t,
msg.Name,
text,
networkID,
name,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
parameters := []interface{}{
networkID,
}
query := `
SELECT t.target, MAX(m.time) AS latest
FROM "Message" m, "MessageTarget" t
WHERE m.target = t.id AND t.network = $1
`
if !options.Events {
query += `AND m.text IS NOT NULL `
}
query += `
GROUP BY t.target
HAVING true
`
if !options.AfterTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.AfterTime)
query += fmt.Sprintf(`AND MAX(m.time) > $%d `, len(parameters))
}
if !options.BeforeTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.BeforeTime)
query += fmt.Sprintf(`AND MAX(m.time) < $%d `, len(parameters))
}
if options.TakeLast {
query += `ORDER BY latest DESC `
} else {
query += `ORDER BY latest ASC `
}
parameters = append(parameters, options.Limit)
query += fmt.Sprintf(`LIMIT $%d`, len(parameters))
rows, err := db.db.QueryContext(ctx, query, parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
var l []MessageTarget
for rows.Next() {
var mt MessageTarget
if err := rows.Scan(&mt.Name, &mt.LatestMessage); err != nil {
return nil, err
}
l = append(l, mt)
}
if err := rows.Err(); err != nil {
return nil, err
}
if options.TakeLast {
// We ordered by DESC to limit to the last lines.
// Reverse the list to order by ASC these last lines.
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
}
return l, nil
}
func (db *PostgresDB) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
parameters := []interface{}{
networkID,
name,
}
query := `
SELECT m.raw
FROM "Message" AS m, "MessageTarget" AS t
WHERE m.target = t.id AND t.network = $1 AND t.target = $2 `
if options.AfterID > 0 {
parameters = append(parameters, options.AfterID)
query += fmt.Sprintf(`AND m.id > $%d `, len(parameters))
}
if !options.AfterTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.AfterTime)
query += fmt.Sprintf(`AND m.time > $%d `, len(parameters))
}
if !options.BeforeTime.IsZero() {
// compares time strings by lexicographical order
parameters = append(parameters, options.BeforeTime)
query += fmt.Sprintf(`AND m.time < $%d `, len(parameters))
}
if options.Sender != "" {
parameters = append(parameters, options.Sender)
query += fmt.Sprintf(`AND m.sender = $%d `, len(parameters))
}
if options.Text != "" {
parameters = append(parameters, options.Text)
query += fmt.Sprintf(`AND text_search @@ plainto_tsquery('search_simple', $%d) `, len(parameters))
}
if !options.Events {
query += `AND m.text IS NOT NULL `
}
if options.TakeLast {
query += `ORDER BY m.time DESC `
} else {
query += `ORDER BY m.time ASC `
}
parameters = append(parameters, options.Limit)
query += fmt.Sprintf(`LIMIT $%d`, len(parameters))
rows, err := db.db.QueryContext(ctx, query, parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
var l []*irc.Message
for rows.Next() {
var raw string
if err := rows.Scan(&raw); err != nil {
return nil, err
}
msg, err := irc.ParseMessage(raw)
if err != nil {
return nil, err
}
l = append(l, msg)
}
if err := rows.Err(); err != nil {
return nil, err
}
if options.TakeLast {
// We ordered by DESC to limit to the last lines.
// Reverse the list to order by ASC these last lines.
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
}
return l, nil
}
var postgresNetworksTotalDesc = prometheus.NewDesc("soju_networks_total", "Number of networks", []string{"hostname"}, nil)
type postgresMetricsCollector struct {
db *PostgresDB
}
var _ prometheus.Collector = (*postgresMetricsCollector)(nil)
func (c *postgresMetricsCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- postgresNetworksTotalDesc
}
func (c *postgresMetricsCollector) Collect(ch chan<- prometheus.Metric) {
addrs, err := c.db.listTopNetworkAddrs(context.TODO())
if err != nil {
ch <- prometheus.NewInvalidMetric(postgresNetworksTotalDesc, err)
return
}
// Group by hostname
hostnames := make(map[string]int)
for addr, n := range addrs {
hostname := addr
network := Network{Addr: addr}
if u, err := network.URL(); err == nil {
hostname = u.Hostname()
}
hostnames[hostname] += n
}
// Group networks with low counts for privacy
watermark := 10
grouped := 0
for hostname, n := range hostnames {
if n >= watermark && hostname != "" && hostname != "*" {
ch <- prometheus.MustNewConstMetric(postgresNetworksTotalDesc, prometheus.GaugeValue, float64(n), hostname)
} else {
grouped += n
}
}
if grouped > 0 {
ch <- prometheus.MustNewConstMetric(postgresNetworksTotalDesc, prometheus.GaugeValue, float64(grouped), "*")
}
}