database: add missing user column to WebPushSubscription table

Some WebPushSubscription entries aren't tried to a network, in
which case the "network" column is NULL. But then all users share
the same row. Oops.

Fortunately network-less subscriptions aren't used for anything
yet, they're just stored. So the impact should be minimal.
This commit is contained in:
Simon Ser 2022-06-16 19:33:39 +02:00
parent de0992d41e
commit f0db261fc0
5 changed files with 30 additions and 18 deletions

View File

@ -36,8 +36,8 @@ type Database interface {
ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error) ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error)
StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error
ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error)
StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error
DeleteWebPushSubscription(ctx context.Context, id int64) error DeleteWebPushSubscription(ctx context.Context, id int64) error
} }

View File

@ -98,6 +98,7 @@ CREATE TABLE "WebPushSubscription" (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL, created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL, updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
"user" INTEGER REFERENCES "User"(id) ON DELETE CASCADE,
network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE, network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE,
endpoint TEXT NOT NULL, endpoint TEXT NOT NULL,
key_vapid TEXT, key_vapid TEXT,
@ -147,6 +148,11 @@ var postgresMigrations = []string{
UNIQUE(network, endpoint) UNIQUE(network, endpoint)
); );
`, `,
`
ALTER TABLE "WebPushSubscription"
ADD COLUMN "user" INTEGER
REFERENCES "User"(id) ON DELETE CASCADE
`,
} }
type PostgresDB struct { type PostgresDB struct {
@ -704,7 +710,7 @@ func (db *PostgresDB) StoreWebPushConfig(ctx context.Context, config *WebPushCon
return err return err
} }
func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) { func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
@ -716,7 +722,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM "WebPushSubscription" FROM "WebPushSubscription"
WHERE network IS NOT DISTINCT FROM $1`, nullNetworkID) WHERE "user" = $1 AND network IS NOT DISTINCT FROM $2`, userID, nullNetworkID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -734,7 +740,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in
return subs, rows.Err() return subs, rows.Err()
} }
func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error { func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
@ -753,11 +759,11 @@ func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID in
sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID) sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID)
} else { } else {
err = db.db.QueryRowContext(ctx, ` err = db.db.QueryRowContext(ctx, `
INSERT INTO "WebPushSubscription" (created_at, updated_at, network, INSERT INTO "WebPushSubscription" (created_at, updated_at, "user",
endpoint, key_auth, key_p256dh, key_vapid) network, endpoint, key_auth, key_p256dh, key_vapid)
VALUES (NOW(), NOW(), $1, $2, $3, $4, $5) VALUES (NOW(), NOW(), $1, $2, $3, $4, $5, $6)
RETURNING id`, RETURNING id`,
nullNetworkID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH, nullNetworkID, userID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
sub.Keys.VAPID).Scan(&sub.ID) sub.Keys.VAPID).Scan(&sub.ID)
} }

View File

@ -97,11 +97,13 @@ CREATE TABLE WebPushSubscription (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
updated_at TEXT NOT NULL, updated_at TEXT NOT NULL,
user INTEGER NOT NULL,
network INTEGER, network INTEGER,
endpoint TEXT NOT NULL, endpoint TEXT NOT NULL,
key_vapid TEXT, key_vapid TEXT,
key_auth TEXT, key_auth TEXT,
key_p256dh TEXT, key_p256dh TEXT,
FOREIGN KEY(user) REFERENCES User(id),
FOREIGN KEY(network) REFERENCES Network(id), FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, endpoint) UNIQUE(network, endpoint)
); );
@ -237,6 +239,9 @@ var sqliteMigrations = []string{
UNIQUE(network, endpoint) UNIQUE(network, endpoint)
); );
`, `,
`
ALTER TABLE WebPushSubscription ADD COLUMN user INTEGER REFERENCES User(id);
`,
} }
type SqliteDB struct { type SqliteDB struct {
@ -878,7 +883,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi
return err return err
} }
func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) { func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -890,7 +895,7 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM WebPushSubscription FROM WebPushSubscription
WHERE network IS ?`, nullNetworkID) WHERE user = ? AND network IS ?`, userID, nullNetworkID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -908,12 +913,13 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6
return subs, rows.Err() return subs, rows.Err()
} }
func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error { func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
args := []interface{}{ args := []interface{}{
sql.Named("id", sub.ID), sql.Named("id", sub.ID),
sql.Named("user", userID),
sql.Named("network", sql.NullInt64{ sql.Named("network", sql.NullInt64{
Int64: networkID, Int64: networkID,
Valid: networkID != 0, Valid: networkID != 0,
@ -937,10 +943,10 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int6
var res sql.Result var res sql.Result
res, err = db.db.ExecContext(ctx, ` res, err = db.db.ExecContext(ctx, `
INSERT INTO INSERT INTO
WebPushSubscription(created_at, updated_at, network, endpoint, WebPushSubscription(created_at, updated_at, user, network, endpoint,
key_auth, key_p256dh, key_vapid) key_auth, key_p256dh, key_vapid)
VALUES (:now, :now, :network, :endpoint, :key_auth, :key_p256dh, VALUES (:now, :now, :user, :network, :endpoint, :key_auth,
:key_vapid)`, :key_p256dh, :key_vapid)`,
args...) args...)
if err != nil { if err != nil {
return err return err

View File

@ -3278,7 +3278,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// TODO: limit max number of subscriptions, prune old ones // TODO: limit max number of subscriptions, prune old ones
if err := dc.user.srv.db.StoreWebPushSubscription(ctx, networkID, &newSub); err != nil { if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.user.ID, networkID, &newSub); err != nil {
dc.logger.Printf("failed to store Web push subscription: %v", err) dc.logger.Printf("failed to store Web push subscription: %v", err)
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
@ -3382,7 +3382,7 @@ func (dc *downstreamConn) findWebPushSubscription(ctx context.Context, endpoint
networkID = dc.network.ID networkID = dc.network.ID
} }
subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, networkID) subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, dc.user.ID, networkID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -445,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
} }
func (net *network) broadcastWebPush(ctx context.Context, msg *irc.Message) { func (net *network) broadcastWebPush(ctx context.Context, msg *irc.Message) {
subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.ID) subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.user.ID, net.ID)
if err != nil { if err != nil {
net.logger.Printf("failed to list Web push subscriptions: %v", err) net.logger.Printf("failed to list Web push subscriptions: %v", err)
return return