From 350057e60bd7a59e31fbeb18ba5631a2cf790dc4 Mon Sep 17 00:00:00 2001 From: Hubert Hirtz Date: Thu, 14 Oct 2021 16:13:24 +0200 Subject: [PATCH] Set hard timeouts on DB transactions --- db_postgres.go | 82 ++++++++++++++++++++++++++++++++--------- db_sqlite.go | 99 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 141 insertions(+), 40 deletions(-) diff --git a/db_postgres.go b/db_postgres.go index 548f02c..cad8b55 100644 --- a/db_postgres.go +++ b/db_postgres.go @@ -1,6 +1,7 @@ package soju import ( + "context" "database/sql" "errors" "fmt" @@ -11,6 +12,8 @@ import ( _ "github.com/lib/pq" ) +const postgresQueryTimeout = 5 * time.Second + const postgresConfigSchema = ` CREATE TABLE IF NOT EXISTS "Config" ( id SMALLINT PRIMARY KEY, @@ -145,8 +148,11 @@ func (db *PostgresDB) Close() error { } func (db *PostgresDB) Stats() (*DatabaseStats, error) { + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + var stats DatabaseStats - row := db.db.QueryRow(`SELECT + row := db.db.QueryRowContext(ctx, `SELECT (SELECT COUNT(*) FROM "User") AS users, (SELECT COUNT(*) FROM "Network") AS networks, (SELECT COUNT(*) FROM "Channel") AS channels`) @@ -158,7 +164,11 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) { } func (db *PostgresDB) ListUsers() ([]User, error) { - rows, err := db.db.Query(`SELECT id, username, password, admin, realname FROM "User"`) + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, username, password, admin, realname FROM "User"`) if err != nil { return nil, err } @@ -183,10 +193,13 @@ func (db *PostgresDB) ListUsers() ([]User, error) { } func (db *PostgresDB) GetUser(username string) (*User, error) { + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + user := &User{Username: username} var password, realname sql.NullString - row := db.db.QueryRow( + row := db.db.QueryRowContext(ctx, `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, username) if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { @@ -198,18 +211,21 @@ func (db *PostgresDB) GetUser(username string) (*User, error) { } func (db *PostgresDB) StoreUser(user *User) error { + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + password := toNullString(user.Password) realname := toNullString(user.Realname) var err error if user.ID == 0 { - err = db.db.QueryRow(` + err = db.db.QueryRowContext(ctx, ` INSERT INTO "User" (username, password, admin, realname) VALUES ($1, $2, $3, $4) RETURNING id`, user.Username, password, user.Admin, realname).Scan(&user.ID) } else { - _, err = db.db.Exec(` + _, err = db.db.ExecContext(ctx, ` UPDATE "User" SET password = $1, admin = $2, realname = $3 WHERE id = $4`, @@ -219,12 +235,18 @@ func (db *PostgresDB) StoreUser(user *User) error { } func (db *PostgresDB) DeleteUser(id int64) error { - _, err := db.db.Exec(`DELETE FROM "User" WHERE id = $1`, id) + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) return err } func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) { - rows, err := db.db.Query(` + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled FROM "Network" @@ -265,6 +287,9 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) { } func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + netName := toNullString(network.Name) netUsername := toNullString(network.Username) realname := toNullString(network.Realname) @@ -289,7 +314,7 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { var err error if network.ID == 0 { - err = db.db.QueryRow(` + err = db.db.QueryRowContext(ctx, ` INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled) @@ -299,7 +324,7 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID) } else { - _, err = db.db.Exec(` + _, err = db.db.ExecContext(ctx, ` UPDATE "Network" SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, @@ -314,12 +339,18 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { } func (db *PostgresDB) DeleteNetwork(id int64) error { - _, err := db.db.Exec(`DELETE FROM "Network" WHERE id = $1`, id) + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) return err } func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) { - rows, err := db.db.Query(` + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on FROM "Channel" @@ -350,12 +381,15 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) { } func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + key := toNullString(ch.Key) detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds())) var err error if ch.ID == 0 { - err = db.db.QueryRow(` + err = db.db.QueryRowContext(ctx, ` INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) @@ -363,7 +397,7 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID) } else { - _, err = db.db.Exec(` + _, err = db.db.ExecContext(ctx, ` UPDATE "Channel" SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5, relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9 @@ -375,12 +409,18 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { } func (db *PostgresDB) DeleteChannel(id int64) error { - _, err := db.db.Exec(`DELETE FROM "Channel" WHERE id = $1`, id) + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) return err } func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { - rows, err := db.db.Query(` + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` SELECT id, target, client, internal_msgid FROM "DeliveryReceipt" WHERE network = $1`, networkID) @@ -405,19 +445,23 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, } func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { + ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) + defer cancel() + tx, err := db.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec(`DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`, + _, err = tx.ExecContext(ctx, + `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`, networkID, client) if err != nil { return err } - stmt, err := tx.Prepare(` + stmt, err := tx.PrepareContext(ctx, ` INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) VALUES ($1, $2, $3, $4) RETURNING id`) @@ -428,7 +472,9 @@ func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string for i := range receipts { rcpt := &receipts[i] - err := stmt.QueryRow(networkID, rcpt.Target, client, rcpt.InternalMsgID).Scan(&rcpt.ID) + err := stmt. + QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID). + Scan(&rcpt.ID) if err != nil { return err } diff --git a/db_sqlite.go b/db_sqlite.go index a6dc8eb..f5d2f9a 100644 --- a/db_sqlite.go +++ b/db_sqlite.go @@ -1,6 +1,7 @@ package soju import ( + "context" "database/sql" "fmt" "math" @@ -11,6 +12,8 @@ import ( _ "github.com/mattn/go-sqlite3" ) +const sqliteQueryTimeout = 5 * time.Second + const sqliteSchema = ` CREATE TABLE User ( id INTEGER PRIMARY KEY, @@ -209,8 +212,11 @@ func (db *SqliteDB) Stats() (*DatabaseStats, error) { db.lock.RLock() defer db.lock.RUnlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + var stats DatabaseStats - row := db.db.QueryRow(`SELECT + row := db.db.QueryRowContext(ctx, `SELECT (SELECT COUNT(*) FROM User) AS users, (SELECT COUNT(*) FROM Network) AS networks, (SELECT COUNT(*) FROM Channel) AS channels`) @@ -232,7 +238,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) { db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query("SELECT id, username, password, admin, realname FROM User") + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + "SELECT id, username, password, admin, realname FROM User") if err != nil { return nil, err } @@ -260,10 +270,15 @@ func (db *SqliteDB) GetUser(username string) (*User, error) { db.lock.RLock() defer db.lock.RUnlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + user := &User{Username: username} var password, realname sql.NullString - row := db.db.QueryRow("SELECT id, password, admin, realname FROM User WHERE username = ?", username) + row := db.db.QueryRowContext(ctx, + "SELECT id, password, admin, realname FROM User WHERE username = ?", + username) if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { return nil, err } @@ -276,6 +291,9 @@ func (db *SqliteDB) StoreUser(user *User) error { db.lock.Lock() defer db.lock.Unlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + args := []interface{}{ sql.Named("username", user.Username), sql.Named("password", toNullString(user.Password)), @@ -285,10 +303,17 @@ func (db *SqliteDB) StoreUser(user *User) error { var err error if user.ID != 0 { - _, err = db.db.Exec("UPDATE User SET password = :password, admin = :admin, realname = :realname WHERE username = :username", args...) + _, err = db.db.ExecContext(ctx, ` + UPDATE User SET password = :password, admin = :admin, + realname = :realname WHERE username = :username`, + args...) } else { var res sql.Result - res, err = db.db.Exec("INSERT INTO User(username, password, admin, realname) VALUES (:username, :password, :admin, :realname)", args...) + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + User(username, password, admin, realname) + VALUES (:username, :password, :admin, :realname)`, + args...) if err != nil { return err } @@ -302,13 +327,16 @@ func (db *SqliteDB) DeleteUser(id int64) error { db.lock.Lock() defer db.lock.Unlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + tx, err := db.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec(`DELETE FROM DeliveryReceipt + _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt WHERE id IN ( SELECT DeliveryReceipt.id FROM DeliveryReceipt @@ -319,7 +347,7 @@ func (db *SqliteDB) DeleteUser(id int64) error { return err } - _, err = tx.Exec(`DELETE FROM Channel + _, err = tx.ExecContext(ctx, `DELETE FROM Channel WHERE id IN ( SELECT Channel.id FROM Channel @@ -330,12 +358,12 @@ func (db *SqliteDB) DeleteUser(id int64) error { return err } - _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id) + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id) if err != nil { return err } - _, err = tx.Exec("DELETE FROM User WHERE id = ?", id) + _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id) if err != nil { return err } @@ -347,7 +375,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) { db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass, + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled FROM Network @@ -392,6 +424,9 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error { db.lock.Lock() defer db.lock.Unlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString if network.SASL.Mechanism != "" { saslMechanism = toNullString(network.SASL.Mechanism) @@ -429,7 +464,7 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error { var err error if network.ID != 0 { - _, err = db.db.Exec(` + _, err = db.db.ExecContext(ctx, ` UPDATE Network SET name = :name, addr = :addr, nick = :nick, username = :username, realname = :realname, pass = :pass, connect_commands = :connect_commands, @@ -439,7 +474,7 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error { WHERE id = :id`, args...) } else { var res sql.Result - res, err = db.db.Exec(` + res, err = db.db.ExecContext(ctx, ` INSERT INTO Network(user, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled) @@ -459,23 +494,26 @@ func (db *SqliteDB) DeleteNetwork(id int64) error { db.lock.Lock() defer db.lock.Unlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + tx, err := db.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ?", id) + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id) if err != nil { return err } - _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id) + _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id) if err != nil { return err } - _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id) + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id) if err != nil { return err } @@ -487,7 +525,10 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) { db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query(`SELECT + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, `SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on FROM Channel @@ -521,6 +562,9 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error { db.lock.Lock() defer db.lock.Unlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + args := []interface{}{ sql.Named("network", networkID), sql.Named("name", ch.Name), @@ -537,14 +581,14 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error { var err error if ch.ID != 0 { - _, err = db.db.Exec(`UPDATE Channel + _, err = db.db.ExecContext(ctx, `UPDATE Channel SET network = :network, name = :name, key = :key, detached = :detached, detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached, reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on WHERE id = :id`, args...) } else { var res sql.Result - res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) + res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...) if err != nil { return err @@ -558,7 +602,10 @@ func (db *SqliteDB) DeleteChannel(id int64) error { db.lock.Lock() defer db.lock.Unlock() - _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id) + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) return err } @@ -566,7 +613,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query(`SELECT id, target, client, internal_msgid + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid FROM DeliveryReceipt WHERE network = ?`, networkID) if err != nil { @@ -595,13 +646,16 @@ func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, db.lock.Lock() defer db.lock.Unlock() + ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + defer cancel() + tx, err := db.db.Begin() if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?", + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?", networkID, toNullString(client)) if err != nil { return err @@ -610,7 +664,8 @@ func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, for i := range receipts { rcpt := &receipts[i] - res, err := tx.Exec(`INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) + res, err := tx.ExecContext(ctx, ` + INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (:network, :target, :client, :internal_msgid)`, sql.Named("network", networkID), sql.Named("target", rcpt.Target),