diff --git a/database/sqlite.go b/database/sqlite.go index 11e83ee..6974ba4 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -5,6 +5,7 @@ package database import ( "context" "database/sql" + sqldriver "database/sql/driver" "fmt" "math" "strings" @@ -20,8 +21,36 @@ const sqliteQueryTimeout = 5 * time.Second const sqliteTimeLayout = "2006-01-02T15:04:05.000Z" -func formatSqliteTime(t time.Time) string { - return t.UTC().Format(sqliteTimeLayout) +type sqliteTime struct { + time.Time +} + +var ( + _ sql.Scanner = (*sqliteTime)(nil) + _ sqldriver.Valuer = sqliteTime{} +) + +func (t *sqliteTime) Scan(value interface{}) error { + if value == nil { + t.Time = time.Time{} + return nil + } + if s, ok := value.(string); ok { + tt, err := time.Parse(sqliteTimeLayout, s) + if err != nil { + return err + } + t.Time = tt + return nil + } + return fmt.Errorf("cannot scan time from type %T", value) +} + +func (t sqliteTime) Value() (sqldriver.Value, error) { + if t.Time.IsZero() { + return nil, nil + } + return t.UTC().Format(sqliteTimeLayout), nil } const sqliteSchema = ` @@ -800,18 +829,14 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st sql.Named("network", networkID), sql.Named("target", name), ) - var timestamp string + var timestamp sqliteTime if err := row.Scan(&receipt.ID, ×tamp); err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil { - return nil, err - } else { - receipt.Timestamp = t - } + receipt.Timestamp = timestamp.Time return receipt, nil } @@ -821,7 +846,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei args := []interface{}{ sql.Named("id", receipt.ID), - sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)), + sql.Named("timestamp", sqliteTime{receipt.Timestamp}), sql.Named("network", networkID), sql.Named("target", receipt.Target), } @@ -884,7 +909,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi VALUES (:now, :vapid_key_public, :vapid_key_private)`, sql.Named("vapid_key_public", config.VAPIDKeys.Public), sql.Named("vapid_key_private", config.VAPIDKeys.Private), - sql.Named("now", formatSqliteTime(time.Now()))) + sql.Named("now", sqliteTime{time.Now()})) if err != nil { return err } @@ -933,7 +958,7 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networ Int64: networkID, Valid: networkID != 0, }), - sql.Named("now", formatSqliteTime(time.Now())), + sql.Named("now", sqliteTime{time.Now()}), sql.Named("endpoint", sub.Endpoint), sql.Named("key_auth", sub.Keys.Auth), sql.Named("key_p256dh", sub.Keys.P256DH),