Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.
This commit is contained in:
Simon Ser 2021-10-18 19:15:15 +02:00
parent 4be6c4b19c
commit 9ec1f1a5b0
11 changed files with 110 additions and 101 deletions

View File

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"flag"
"fmt"
"io"
@ -75,7 +76,7 @@ func main() {
Password: string(hashed),
Admin: *admin,
}
if err := db.StoreUser(&user); err != nil {
if err := db.StoreUser(context.TODO(), &user); err != nil {
log.Fatalf("failed to create user: %v", err)
}
case "change-password":
@ -85,7 +86,7 @@ func main() {
os.Exit(1)
}
user, err := db.GetUser(username)
user, err := db.GetUser(context.TODO(), username)
if err != nil {
log.Fatalf("failed to get user: %v", err)
}
@ -101,7 +102,7 @@ func main() {
}
user.Password = string(hashed)
if err := db.StoreUser(user); err != nil {
if err := db.StoreUser(context.TODO(), user); err != nil {
log.Fatalf("failed to update password: %v", err)
}
default:

View File

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"flag"
"fmt"
"io"
@ -79,7 +80,7 @@ func main() {
log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
}
l, err := db.ListUsers()
l, err := db.ListUsers(context.TODO())
if err != nil {
log.Fatalf("failed to list users in DB: %v", err)
}
@ -111,12 +112,12 @@ func main() {
u.Admin = section.Values.Get("Admin") == "true"
if err := db.StoreUser(u); err != nil {
if err := db.StoreUser(context.TODO(), u); err != nil {
log.Fatalf("failed to store user %q: %v", username, err)
}
userID := u.ID
l, err := db.ListNetworks(userID)
l, err := db.ListNetworks(context.TODO(), userID)
if err != nil {
log.Fatalf("failed to list networks for user %q: %v", username, err)
}
@ -183,11 +184,11 @@ func main() {
n.Pass = pass
n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
if err := db.StoreNetwork(userID, n); err != nil {
if err := db.StoreNetwork(context.TODO(), userID, n); err != nil {
logger.Fatalf("failed to store network: %v", err)
}
l, err := db.ListChannels(n.ID)
l, err := db.ListChannels(context.TODO(), n.ID)
if err != nil {
logger.Fatalf("failed to list channels: %v", err)
}
@ -217,7 +218,7 @@ func main() {
ch.Key = section.Values.Get("Key")
ch.Detached = section.Values.Get("Detached") == "true"
if err := db.StoreChannel(n.ID, ch); err != nil {
if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil {
logger.Printf("channel %q: failed to store channel: %v", chName, err)
}
})

27
db.go
View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"fmt"
"net/url"
"strings"
@ -9,22 +10,22 @@ import (
type Database interface {
Close() error
Stats() (*DatabaseStats, error)
Stats(ctx context.Context) (*DatabaseStats, error)
ListUsers() ([]User, error)
GetUser(username string) (*User, error)
StoreUser(user *User) error
DeleteUser(id int64) error
ListUsers(ctx context.Context) ([]User, error)
GetUser(ctx context.Context, username string) (*User, error)
StoreUser(ctx context.Context, user *User) error
DeleteUser(ctx context.Context, id int64) error
ListNetworks(userID int64) ([]Network, error)
StoreNetwork(userID int64, network *Network) error
DeleteNetwork(id int64) error
ListChannels(networkID int64) ([]Channel, error)
StoreChannel(networKID int64, ch *Channel) error
DeleteChannel(id int64) error
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(ctx context.Context, userID int64, network *Network) error
DeleteNetwork(ctx context.Context, id int64) error
ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
DeleteChannel(ctx context.Context, id int64) error
ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
}
func OpenDB(driver, source string) (Database, error) {

View File

@ -147,8 +147,8 @@ func (db *PostgresDB) Close() error {
return db.db.Close()
}
func (db *PostgresDB) Stats() (*DatabaseStats, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var stats DatabaseStats
@ -163,8 +163,8 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) {
return &stats, nil
}
func (db *PostgresDB) ListUsers() ([]User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
@ -192,8 +192,8 @@ func (db *PostgresDB) ListUsers() ([]User, error) {
return users, nil
}
func (db *PostgresDB) GetUser(username string) (*User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
user := &User{Username: username}
@ -210,8 +210,8 @@ func (db *PostgresDB) GetUser(username string) (*User, error) {
return user, nil
}
func (db *PostgresDB) StoreUser(user *User) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
password := toNullString(user.Password)
@ -234,16 +234,16 @@ func (db *PostgresDB) StoreUser(user *User) error {
return err
}
func (db *PostgresDB) DeleteUser(id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -286,8 +286,8 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
return networks, nil
}
func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
netName := toNullString(network.Name)
@ -338,16 +338,16 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
return err
}
func (db *PostgresDB) DeleteNetwork(id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -380,8 +380,8 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil
}
func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
key := toNullString(ch.Key)
@ -408,16 +408,16 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
return err
}
func (db *PostgresDB) DeleteChannel(id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -444,8 +444,8 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt,
return receipts, nil
}
func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
tx, err := db.db.Begin()

View File

@ -208,11 +208,11 @@ func (db *SqliteDB) upgrade() error {
return tx.Commit()
}
func (db *SqliteDB) Stats() (*DatabaseStats, error) {
func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
var stats DatabaseStats
@ -234,11 +234,11 @@ func toNullString(s string) sql.NullString {
}
}
func (db *SqliteDB) ListUsers() ([]User, error) {
func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
@ -266,11 +266,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) {
return users, nil
}
func (db *SqliteDB) GetUser(username string) (*User, error) {
func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
user := &User{Username: username}
@ -287,11 +287,11 @@ func (db *SqliteDB) GetUser(username string) (*User, error) {
return user, nil
}
func (db *SqliteDB) StoreUser(user *User) error {
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
args := []interface{}{
@ -323,11 +323,11 @@ func (db *SqliteDB) StoreUser(user *User) error {
return err
}
func (db *SqliteDB) DeleteUser(id int64) error {
func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
tx, err := db.db.Begin()
@ -371,11 +371,11 @@ func (db *SqliteDB) DeleteUser(id int64) error {
return tx.Commit()
}
func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -420,11 +420,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
return networks, nil
}
func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
@ -490,11 +490,11 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
return err
}
func (db *SqliteDB) DeleteNetwork(id int64) error {
func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
tx, err := db.db.Begin()
@ -521,11 +521,11 @@ func (db *SqliteDB) DeleteNetwork(id int64) error {
return tx.Commit()
}
func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `SELECT
@ -558,11 +558,11 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil
}
func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
args := []interface{}{
@ -598,22 +598,22 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
return err
}
func (db *SqliteDB) DeleteChannel(id int64) error {
func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
return err
}
func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -642,11 +642,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er
return receipts, nil
}
func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
tx, err := db.db.Begin()

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
@ -976,7 +977,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
func (dc *downstreamConn) authenticate(username, password string) error {
username, clientName, networkName := unmarshalUsername(username)
u, err := dc.srv.db.GetUser(username)
u, err := dc.srv.db.GetUser(context.TODO(), username)
if err != nil {
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
return errAuthFailed
@ -1377,7 +1378,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return
}
n.Nick = nick
err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network)
err = dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network)
})
if err != nil {
return err
@ -1427,7 +1428,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
})
n.Realname = storeRealname
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to store network realname: %v", err)
storeErr = err
}
@ -1516,7 +1517,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
uc.network.channels.SetValue(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
}
}
@ -1548,7 +1549,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
uc.network.channels.SetValue(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
}
} else {
@ -2445,7 +2446,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
n.SASL.Mechanism = "PLAIN"
n.SASL.Plain.Username = username
n.SASL.Plain.Password = password
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to save NickServ credentials: %v", err)
}
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"fmt"
"log"
"mime"
@ -85,7 +86,7 @@ func (s *Server) prefix() *irc.Prefix {
}
func (s *Server) Start() error {
users, err := s.db.ListUsers()
users, err := s.db.ListUsers(context.TODO())
if err != nil {
return err
}
@ -126,7 +127,7 @@ func (s *Server) createUser(user *User) (*user, error) {
return nil, fmt.Errorf("user %q already exists", user.Username)
}
err := s.db.StoreUser(user)
err := s.db.StoreUser(context.TODO(), user)
if err != nil {
return nil, fmt.Errorf("could not create user in db: %v", err)
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"net"
"testing"
@ -43,7 +44,7 @@ func createTestUser(t *testing.T, db Database) *User {
}
record := &User{Username: testUsername, Password: string(hashed)}
if err := db.StoreUser(record); err != nil {
if err := db.StoreUser(context.TODO(), record); err != nil {
t.Fatalf("failed to store test user: %v", err)
}
@ -68,7 +69,7 @@ func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Li
Nick: user.Username,
Enabled: true,
}
if err := db.StoreNetwork(user.ID, network); err != nil {
if err := db.StoreNetwork(context.TODO(), user.ID, network); err != nil {
t.Fatalf("failed to store test network: %v", err)
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
@ -657,7 +658,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = privKey
net.SASL.Mechanism = "EXTERNAL"
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err
}
@ -698,7 +699,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
net.SASL.Plain.Password = params[2]
net.SASL.Mechanism = "PLAIN"
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err
}
@ -722,7 +723,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = nil
net.SASL.Mechanism = ""
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err
}
@ -860,7 +861,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
u.stop()
if err := dc.srv.db.DeleteUser(u.ID); err != nil {
if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil {
return fmt.Errorf("failed to delete user: %v", err)
}
@ -1015,7 +1016,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
uc.updateChannelAutoDetach(upstreamName)
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
return fmt.Errorf("failed to update channel: %v", err)
}
@ -1024,7 +1025,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
}
func handleServiceServerStatus(dc *downstreamConn, params []string) error {
dbStats, err := dc.user.srv.db.Stats()
dbStats, err := dc.user.srv.db.Stats(context.TODO())
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto"
"crypto/sha256"
"crypto/tls"
@ -1516,7 +1517,7 @@ func (uc *upstreamConn) handleDetachedMessage(ch *Channel, msg *irc.Message) {
}
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
uc.network.attach(ch)
if err := uc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
}
}

21
user.go
View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
@ -330,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
}
}
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
return err
}
net.channels.Delete(name)
@ -367,7 +368,7 @@ func (net *network) storeClientDeliveryReceipts(clientName string) {
})
})
if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
}
}
@ -487,7 +488,7 @@ func (u *user) run() {
close(u.done)
}()
networks, err := u.srv.db.ListNetworks(u.ID)
networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
if err != nil {
u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
return
@ -495,7 +496,7 @@ func (u *user) run() {
for _, record := range networks {
record := record
channels, err := u.srv.db.ListChannels(record.ID)
channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
if err != nil {
u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
continue
@ -505,7 +506,7 @@ func (u *user) run() {
u.networks = append(u.networks, network)
if u.hasPersistentMsgStore() {
receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
if err != nil {
u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
return
@ -590,7 +591,7 @@ func (u *user) run() {
continue
}
uc.network.detach(c)
if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil {
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
}
case eventDownstreamConnected:
@ -779,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
}
network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(u.ID, &network.Network)
err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
if err != nil {
return nil, err
}
@ -821,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
panic("tried updating a non-existing network")
}
if err := u.srv.db.StoreNetwork(u.ID, record); err != nil {
if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
return nil, err
}
@ -888,7 +889,7 @@ func (u *user) deleteNetwork(id int64) error {
panic("tried deleting a non-existing network")
}
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
return err
}
@ -914,7 +915,7 @@ func (u *user) updateUser(record *User) error {
}
realnameUpdated := u.Realname != record.Realname
if err := u.srv.db.StoreUser(record); err != nil {
if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
}
u.User = *record