db_sqlite: drop mutex

See [1] for details.

[1]: https://github.com/mattn/go-sqlite3/issues/209
This commit is contained in:
Simon Ser 2022-05-03 23:17:56 +02:00
parent d37f946e83
commit 09d581dba4

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"math" "math"
"strings" "strings"
"sync"
"time" "time"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -192,15 +191,17 @@ var sqliteMigrations = []string{
} }
type SqliteDB struct { type SqliteDB struct {
lock sync.RWMutex
db *sql.DB db *sql.DB
} }
func OpenSqliteDB(source string) (Database, error) { func OpenSqliteDB(source string) (Database, error) {
sqlSqliteDB, err := sql.Open("sqlite3", source) // Open the DB with cache=shared and SetMaxOpenConns(1) to allow usage from
// multiple goroutines
sqlSqliteDB, err := sql.Open("sqlite3", source+"?cache=shared")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlSqliteDB.SetMaxOpenConns(1)
db := &SqliteDB{db: sqlSqliteDB} db := &SqliteDB{db: sqlSqliteDB}
if err := db.upgrade(); err != nil { if err := db.upgrade(); err != nil {
@ -212,15 +213,10 @@ func OpenSqliteDB(source string) (Database, error) {
} }
func (db *SqliteDB) Close() error { func (db *SqliteDB) Close() error {
db.lock.Lock()
defer db.lock.Unlock()
return db.db.Close() return db.db.Close()
} }
func (db *SqliteDB) upgrade() error { func (db *SqliteDB) upgrade() error {
db.lock.Lock()
defer db.lock.Unlock()
var version int var version int
if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
return fmt.Errorf("failed to query schema version: %v", err) return fmt.Errorf("failed to query schema version: %v", err)
@ -264,9 +260,6 @@ func (db *SqliteDB) RegisterMetrics(r prometheus.Registerer) error {
} }
func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) { func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -290,9 +283,6 @@ func toNullString(s string) sql.NullString {
} }
func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -322,9 +312,6 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
} }
func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) { func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -343,9 +330,6 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
} }
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -379,9 +363,6 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
} }
func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -438,9 +419,6 @@ func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
} }
func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -488,9 +466,6 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
} }
func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -558,9 +533,6 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
} }
func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -594,9 +566,6 @@ func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
} }
func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -631,9 +600,6 @@ func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channe
} }
func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -671,9 +637,6 @@ func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Chann
} }
func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error { func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -682,9 +645,6 @@ func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
} }
func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -715,9 +675,6 @@ func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) (
} }
func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -756,9 +713,6 @@ func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID i
} }
func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
@ -787,9 +741,6 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
} }
func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()