database: add missing user column to WebPushSubscription table

Some WebPushSubscription entries aren't tried to a network, in
which case the "network" column is NULL. But then all users share
the same row. Oops.

Fortunately network-less subscriptions aren't used for anything
yet, they're just stored. So the impact should be minimal.
This commit is contained in:
Simon Ser 2022-06-16 19:33:39 +02:00
parent de0992d41e
commit f0db261fc0
5 changed files with 30 additions and 18 deletions

View File

@ -36,8 +36,8 @@ type Database interface {
ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error)
StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error
ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error)
StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error
ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error)
StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error
DeleteWebPushSubscription(ctx context.Context, id int64) error
}

View File

@ -98,6 +98,7 @@ CREATE TABLE "WebPushSubscription" (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
"user" INTEGER REFERENCES "User"(id) ON DELETE CASCADE,
network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE,
endpoint TEXT NOT NULL,
key_vapid TEXT,
@ -147,6 +148,11 @@ var postgresMigrations = []string{
UNIQUE(network, endpoint)
);
`,
`
ALTER TABLE "WebPushSubscription"
ADD COLUMN "user" INTEGER
REFERENCES "User"(id) ON DELETE CASCADE
`,
}
type PostgresDB struct {
@ -704,7 +710,7 @@ func (db *PostgresDB) StoreWebPushConfig(ctx context.Context, config *WebPushCon
return err
}
func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) {
func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
@ -716,7 +722,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in
rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM "WebPushSubscription"
WHERE network IS NOT DISTINCT FROM $1`, nullNetworkID)
WHERE "user" = $1 AND network IS NOT DISTINCT FROM $2`, userID, nullNetworkID)
if err != nil {
return nil, err
}
@ -734,7 +740,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in
return subs, rows.Err()
}
func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error {
func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
@ -753,11 +759,11 @@ func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID in
sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID)
} else {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "WebPushSubscription" (created_at, updated_at, network,
endpoint, key_auth, key_p256dh, key_vapid)
VALUES (NOW(), NOW(), $1, $2, $3, $4, $5)
INSERT INTO "WebPushSubscription" (created_at, updated_at, "user",
network, endpoint, key_auth, key_p256dh, key_vapid)
VALUES (NOW(), NOW(), $1, $2, $3, $4, $5, $6)
RETURNING id`,
nullNetworkID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
nullNetworkID, userID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
sub.Keys.VAPID).Scan(&sub.ID)
}

View File

@ -97,11 +97,13 @@ CREATE TABLE WebPushSubscription (
id INTEGER PRIMARY KEY,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
user INTEGER NOT NULL,
network INTEGER,
endpoint TEXT NOT NULL,
key_vapid TEXT,
key_auth TEXT,
key_p256dh TEXT,
FOREIGN KEY(user) REFERENCES User(id),
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, endpoint)
);
@ -237,6 +239,9 @@ var sqliteMigrations = []string{
UNIQUE(network, endpoint)
);
`,
`
ALTER TABLE WebPushSubscription ADD COLUMN user INTEGER REFERENCES User(id);
`,
}
type SqliteDB struct {
@ -878,7 +883,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi
return err
}
func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) {
func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
@ -890,7 +895,7 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6
rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM WebPushSubscription
WHERE network IS ?`, nullNetworkID)
WHERE user = ? AND network IS ?`, userID, nullNetworkID)
if err != nil {
return nil, err
}
@ -908,12 +913,13 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6
return subs, rows.Err()
}
func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error {
func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
args := []interface{}{
sql.Named("id", sub.ID),
sql.Named("user", userID),
sql.Named("network", sql.NullInt64{
Int64: networkID,
Valid: networkID != 0,
@ -937,10 +943,10 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int6
var res sql.Result
res, err = db.db.ExecContext(ctx, `
INSERT INTO
WebPushSubscription(created_at, updated_at, network, endpoint,
WebPushSubscription(created_at, updated_at, user, network, endpoint,
key_auth, key_p256dh, key_vapid)
VALUES (:now, :now, :network, :endpoint, :key_auth, :key_p256dh,
:key_vapid)`,
VALUES (:now, :now, :user, :network, :endpoint, :key_auth,
:key_p256dh, :key_vapid)`,
args...)
if err != nil {
return err

View File

@ -3278,7 +3278,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// TODO: limit max number of subscriptions, prune old ones
if err := dc.user.srv.db.StoreWebPushSubscription(ctx, networkID, &newSub); err != nil {
if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.user.ID, networkID, &newSub); err != nil {
dc.logger.Printf("failed to store Web push subscription: %v", err)
return ircError{&irc.Message{
Command: "FAIL",
@ -3382,7 +3382,7 @@ func (dc *downstreamConn) findWebPushSubscription(ctx context.Context, endpoint
networkID = dc.network.ID
}
subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, networkID)
subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, dc.user.ID, networkID)
if err != nil {
return nil, err
}

View File

@ -445,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
}
func (net *network) broadcastWebPush(ctx context.Context, msg *irc.Message) {
subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.ID)
subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.user.ID, net.ID)
if err != nil {
net.logger.Printf("failed to list Web push subscriptions: %v", err)
return