From f0db261fc00c4c499aacbf978897c62c7696a515 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 16 Jun 2022 19:33:39 +0200 Subject: [PATCH] database: add missing user column to WebPushSubscription table Some WebPushSubscription entries aren't tried to a network, in which case the "network" column is NULL. But then all users share the same row. Oops. Fortunately network-less subscriptions aren't used for anything yet, they're just stored. So the impact should be minimal. --- database/database.go | 4 ++-- database/postgres.go | 20 +++++++++++++------- database/sqlite.go | 18 ++++++++++++------ downstream.go | 4 ++-- user.go | 2 +- 5 files changed, 30 insertions(+), 18 deletions(-) diff --git a/database/database.go b/database/database.go index 89bd0ac..977d1b7 100644 --- a/database/database.go +++ b/database/database.go @@ -36,8 +36,8 @@ type Database interface { ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error) StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error - ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) - StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error + ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) + StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error DeleteWebPushSubscription(ctx context.Context, id int64) error } diff --git a/database/postgres.go b/database/postgres.go index 51649f4..79a09ad 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -98,6 +98,7 @@ CREATE TABLE "WebPushSubscription" ( id SERIAL PRIMARY KEY, created_at TIMESTAMP WITH TIME ZONE NOT NULL, updated_at TIMESTAMP WITH TIME ZONE NOT NULL, + "user" INTEGER REFERENCES "User"(id) ON DELETE CASCADE, network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE, endpoint TEXT NOT NULL, key_vapid TEXT, @@ -147,6 +148,11 @@ var postgresMigrations = []string{ UNIQUE(network, endpoint) ); `, + ` + ALTER TABLE "WebPushSubscription" + ADD COLUMN "user" INTEGER + REFERENCES "User"(id) ON DELETE CASCADE + `, } type PostgresDB struct { @@ -704,7 +710,7 @@ func (db *PostgresDB) StoreWebPushConfig(ctx context.Context, config *WebPushCon return err } -func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) { +func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) { ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() @@ -716,7 +722,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in rows, err := db.db.QueryContext(ctx, ` SELECT id, endpoint, key_auth, key_p256dh, key_vapid FROM "WebPushSubscription" - WHERE network IS NOT DISTINCT FROM $1`, nullNetworkID) + WHERE "user" = $1 AND network IS NOT DISTINCT FROM $2`, userID, nullNetworkID) if err != nil { return nil, err } @@ -734,7 +740,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in return subs, rows.Err() } -func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error { +func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error { ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() @@ -753,11 +759,11 @@ func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID in sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID) } else { err = db.db.QueryRowContext(ctx, ` - INSERT INTO "WebPushSubscription" (created_at, updated_at, network, - endpoint, key_auth, key_p256dh, key_vapid) - VALUES (NOW(), NOW(), $1, $2, $3, $4, $5) + INSERT INTO "WebPushSubscription" (created_at, updated_at, "user", + network, endpoint, key_auth, key_p256dh, key_vapid) + VALUES (NOW(), NOW(), $1, $2, $3, $4, $5, $6) RETURNING id`, - nullNetworkID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH, + nullNetworkID, userID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID).Scan(&sub.ID) } diff --git a/database/sqlite.go b/database/sqlite.go index e8aff59..21ae8f1 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -97,11 +97,13 @@ CREATE TABLE WebPushSubscription ( id INTEGER PRIMARY KEY, created_at TEXT NOT NULL, updated_at TEXT NOT NULL, + user INTEGER NOT NULL, network INTEGER, endpoint TEXT NOT NULL, key_vapid TEXT, key_auth TEXT, key_p256dh TEXT, + FOREIGN KEY(user) REFERENCES User(id), FOREIGN KEY(network) REFERENCES Network(id), UNIQUE(network, endpoint) ); @@ -237,6 +239,9 @@ var sqliteMigrations = []string{ UNIQUE(network, endpoint) ); `, + ` + ALTER TABLE WebPushSubscription ADD COLUMN user INTEGER REFERENCES User(id); + `, } type SqliteDB struct { @@ -878,7 +883,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi return err } -func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) { +func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) { ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() @@ -890,7 +895,7 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6 rows, err := db.db.QueryContext(ctx, ` SELECT id, endpoint, key_auth, key_p256dh, key_vapid FROM WebPushSubscription - WHERE network IS ?`, nullNetworkID) + WHERE user = ? AND network IS ?`, userID, nullNetworkID) if err != nil { return nil, err } @@ -908,12 +913,13 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6 return subs, rows.Err() } -func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error { +func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error { ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() args := []interface{}{ sql.Named("id", sub.ID), + sql.Named("user", userID), sql.Named("network", sql.NullInt64{ Int64: networkID, Valid: networkID != 0, @@ -937,10 +943,10 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int6 var res sql.Result res, err = db.db.ExecContext(ctx, ` INSERT INTO - WebPushSubscription(created_at, updated_at, network, endpoint, + WebPushSubscription(created_at, updated_at, user, network, endpoint, key_auth, key_p256dh, key_vapid) - VALUES (:now, :now, :network, :endpoint, :key_auth, :key_p256dh, - :key_vapid)`, + VALUES (:now, :now, :user, :network, :endpoint, :key_auth, + :key_p256dh, :key_vapid)`, args...) if err != nil { return err diff --git a/downstream.go b/downstream.go index 57f9732..417b513 100644 --- a/downstream.go +++ b/downstream.go @@ -3278,7 +3278,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // TODO: limit max number of subscriptions, prune old ones - if err := dc.user.srv.db.StoreWebPushSubscription(ctx, networkID, &newSub); err != nil { + if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.user.ID, networkID, &newSub); err != nil { dc.logger.Printf("failed to store Web push subscription: %v", err) return ircError{&irc.Message{ Command: "FAIL", @@ -3382,7 +3382,7 @@ func (dc *downstreamConn) findWebPushSubscription(ctx context.Context, endpoint networkID = dc.network.ID } - subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, networkID) + subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, dc.user.ID, networkID) if err != nil { return nil, err } diff --git a/user.go b/user.go index 05bdea8..b50628d 100644 --- a/user.go +++ b/user.go @@ -445,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st } func (net *network) broadcastWebPush(ctx context.Context, msg *irc.Message) { - subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.ID) + subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.user.ID, net.ID) if err != nil { net.logger.Printf("failed to list Web push subscriptions: %v", err) return