From 2abe231eefcfe02e114ff2e549d2d2df5c4f2027 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 26 Jan 2023 14:11:07 +0100 Subject: [PATCH] database/sqlite: introduce sqliteTime type This implements sql.Scanner and sql/driver.Valuer, so that we can load/store time values into SQLite with the format we want, and properly handle NULL (which the go-sqlite3 package doesn't do correctly). --- database/sqlite.go | 47 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/database/sqlite.go b/database/sqlite.go index 11e83ee..6974ba4 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -5,6 +5,7 @@ package database import ( "context" "database/sql" + sqldriver "database/sql/driver" "fmt" "math" "strings" @@ -20,8 +21,36 @@ const sqliteQueryTimeout = 5 * time.Second const sqliteTimeLayout = "2006-01-02T15:04:05.000Z" -func formatSqliteTime(t time.Time) string { - return t.UTC().Format(sqliteTimeLayout) +type sqliteTime struct { + time.Time +} + +var ( + _ sql.Scanner = (*sqliteTime)(nil) + _ sqldriver.Valuer = sqliteTime{} +) + +func (t *sqliteTime) Scan(value interface{}) error { + if value == nil { + t.Time = time.Time{} + return nil + } + if s, ok := value.(string); ok { + tt, err := time.Parse(sqliteTimeLayout, s) + if err != nil { + return err + } + t.Time = tt + return nil + } + return fmt.Errorf("cannot scan time from type %T", value) +} + +func (t sqliteTime) Value() (sqldriver.Value, error) { + if t.Time.IsZero() { + return nil, nil + } + return t.UTC().Format(sqliteTimeLayout), nil } const sqliteSchema = ` @@ -800,18 +829,14 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st sql.Named("network", networkID), sql.Named("target", name), ) - var timestamp string + var timestamp sqliteTime if err := row.Scan(&receipt.ID, ×tamp); err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil { - return nil, err - } else { - receipt.Timestamp = t - } + receipt.Timestamp = timestamp.Time return receipt, nil } @@ -821,7 +846,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei args := []interface{}{ sql.Named("id", receipt.ID), - sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)), + sql.Named("timestamp", sqliteTime{receipt.Timestamp}), sql.Named("network", networkID), sql.Named("target", receipt.Target), } @@ -884,7 +909,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi VALUES (:now, :vapid_key_public, :vapid_key_private)`, sql.Named("vapid_key_public", config.VAPIDKeys.Public), sql.Named("vapid_key_private", config.VAPIDKeys.Private), - sql.Named("now", formatSqliteTime(time.Now()))) + sql.Named("now", sqliteTime{time.Now()})) if err != nil { return err } @@ -933,7 +958,7 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networ Int64: networkID, Valid: networkID != 0, }), - sql.Named("now", formatSqliteTime(time.Now())), + sql.Named("now", sqliteTime{time.Now()}), sql.Named("endpoint", sub.Endpoint), sql.Named("key_auth", sub.Keys.Auth), sql.Named("key_p256dh", sub.Keys.P256DH),