database/sqlite: introduce sqliteTime type

This implements sql.Scanner and sql/driver.Valuer, so that we can
load/store time values into SQLite with the format we want, and
properly handle NULL (which the go-sqlite3 package doesn't do
correctly).
This commit is contained in:
Simon Ser 2023-01-26 14:11:07 +01:00
parent d74b66f240
commit 2abe231eef

View File

@ -5,6 +5,7 @@ package database
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"math"
"strings"
@ -20,8 +21,36 @@ const sqliteQueryTimeout = 5 * time.Second
const sqliteTimeLayout = "2006-01-02T15:04:05.000Z"
func formatSqliteTime(t time.Time) string {
return t.UTC().Format(sqliteTimeLayout)
type sqliteTime struct {
time.Time
}
var (
_ sql.Scanner = (*sqliteTime)(nil)
_ sqldriver.Valuer = sqliteTime{}
)
func (t *sqliteTime) Scan(value interface{}) error {
if value == nil {
t.Time = time.Time{}
return nil
}
if s, ok := value.(string); ok {
tt, err := time.Parse(sqliteTimeLayout, s)
if err != nil {
return err
}
t.Time = tt
return nil
}
return fmt.Errorf("cannot scan time from type %T", value)
}
func (t sqliteTime) Value() (sqldriver.Value, error) {
if t.Time.IsZero() {
return nil, nil
}
return t.UTC().Format(sqliteTimeLayout), nil
}
const sqliteSchema = `
@ -800,18 +829,14 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
sql.Named("network", networkID),
sql.Named("target", name),
)
var timestamp string
var timestamp sqliteTime
if err := row.Scan(&receipt.ID, &timestamp); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil {
return nil, err
} else {
receipt.Timestamp = t
}
receipt.Timestamp = timestamp.Time
return receipt, nil
}
@ -821,7 +846,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei
args := []interface{}{
sql.Named("id", receipt.ID),
sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)),
sql.Named("timestamp", sqliteTime{receipt.Timestamp}),
sql.Named("network", networkID),
sql.Named("target", receipt.Target),
}
@ -884,7 +909,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi
VALUES (:now, :vapid_key_public, :vapid_key_private)`,
sql.Named("vapid_key_public", config.VAPIDKeys.Public),
sql.Named("vapid_key_private", config.VAPIDKeys.Private),
sql.Named("now", formatSqliteTime(time.Now())))
sql.Named("now", sqliteTime{time.Now()}))
if err != nil {
return err
}
@ -933,7 +958,7 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networ
Int64: networkID,
Valid: networkID != 0,
}),
sql.Named("now", formatSqliteTime(time.Now())),
sql.Named("now", sqliteTime{time.Now()}),
sql.Named("endpoint", sub.Endpoint),
sql.Named("key_auth", sub.Keys.Auth),
sql.Named("key_p256dh", sub.Keys.P256DH),