database/sqlite: introduce sqliteTime type

This implements sql.Scanner and sql/driver.Valuer, so that we can
load/store time values into SQLite with the format we want, and
properly handle NULL (which the go-sqlite3 package doesn't do
correctly).
This commit is contained in:
Simon Ser 2023-01-26 14:11:07 +01:00
parent d74b66f240
commit 2abe231eef

View File

@ -5,6 +5,7 @@ package database
import ( import (
"context" "context"
"database/sql" "database/sql"
sqldriver "database/sql/driver"
"fmt" "fmt"
"math" "math"
"strings" "strings"
@ -20,8 +21,36 @@ const sqliteQueryTimeout = 5 * time.Second
const sqliteTimeLayout = "2006-01-02T15:04:05.000Z" const sqliteTimeLayout = "2006-01-02T15:04:05.000Z"
func formatSqliteTime(t time.Time) string { type sqliteTime struct {
return t.UTC().Format(sqliteTimeLayout) time.Time
}
var (
_ sql.Scanner = (*sqliteTime)(nil)
_ sqldriver.Valuer = sqliteTime{}
)
func (t *sqliteTime) Scan(value interface{}) error {
if value == nil {
t.Time = time.Time{}
return nil
}
if s, ok := value.(string); ok {
tt, err := time.Parse(sqliteTimeLayout, s)
if err != nil {
return err
}
t.Time = tt
return nil
}
return fmt.Errorf("cannot scan time from type %T", value)
}
func (t sqliteTime) Value() (sqldriver.Value, error) {
if t.Time.IsZero() {
return nil, nil
}
return t.UTC().Format(sqliteTimeLayout), nil
} }
const sqliteSchema = ` const sqliteSchema = `
@ -800,18 +829,14 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
sql.Named("network", networkID), sql.Named("network", networkID),
sql.Named("target", name), sql.Named("target", name),
) )
var timestamp string var timestamp sqliteTime
if err := row.Scan(&receipt.ID, &timestamp); err != nil { if err := row.Scan(&receipt.ID, &timestamp); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil { receipt.Timestamp = timestamp.Time
return nil, err
} else {
receipt.Timestamp = t
}
return receipt, nil return receipt, nil
} }
@ -821,7 +846,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei
args := []interface{}{ args := []interface{}{
sql.Named("id", receipt.ID), sql.Named("id", receipt.ID),
sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)), sql.Named("timestamp", sqliteTime{receipt.Timestamp}),
sql.Named("network", networkID), sql.Named("network", networkID),
sql.Named("target", receipt.Target), sql.Named("target", receipt.Target),
} }
@ -884,7 +909,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi
VALUES (:now, :vapid_key_public, :vapid_key_private)`, VALUES (:now, :vapid_key_public, :vapid_key_private)`,
sql.Named("vapid_key_public", config.VAPIDKeys.Public), sql.Named("vapid_key_public", config.VAPIDKeys.Public),
sql.Named("vapid_key_private", config.VAPIDKeys.Private), sql.Named("vapid_key_private", config.VAPIDKeys.Private),
sql.Named("now", formatSqliteTime(time.Now()))) sql.Named("now", sqliteTime{time.Now()}))
if err != nil { if err != nil {
return err return err
} }
@ -933,7 +958,7 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networ
Int64: networkID, Int64: networkID,
Valid: networkID != 0, Valid: networkID != 0,
}), }),
sql.Named("now", formatSqliteTime(time.Now())), sql.Named("now", sqliteTime{time.Now()}),
sql.Named("endpoint", sub.Endpoint), sql.Named("endpoint", sub.Endpoint),
sql.Named("key_auth", sub.Keys.Auth), sql.Named("key_auth", sub.Keys.Auth),
sql.Named("key_p256dh", sub.Keys.P256DH), sql.Named("key_p256dh", sub.Keys.P256DH),