Add context args to Database interface
This is a mecanical change, which just lifts up the context.TODO() calls from inside the DB implementations to the callers. Future work involves properly wiring up the contexts when it makes sense.
This commit is contained in:
parent
4be6c4b19c
commit
9ec1f1a5b0
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -75,7 +76,7 @@ func main() {
|
||||
Password: string(hashed),
|
||||
Admin: *admin,
|
||||
}
|
||||
if err := db.StoreUser(&user); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), &user); err != nil {
|
||||
log.Fatalf("failed to create user: %v", err)
|
||||
}
|
||||
case "change-password":
|
||||
@ -85,7 +86,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
user, err := db.GetUser(username)
|
||||
user, err := db.GetUser(context.TODO(), username)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get user: %v", err)
|
||||
}
|
||||
@ -101,7 +102,7 @@ func main() {
|
||||
}
|
||||
|
||||
user.Password = string(hashed)
|
||||
if err := db.StoreUser(user); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), user); err != nil {
|
||||
log.Fatalf("failed to update password: %v", err)
|
||||
}
|
||||
default:
|
||||
|
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -79,7 +80,7 @@ func main() {
|
||||
log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
|
||||
}
|
||||
|
||||
l, err := db.ListUsers()
|
||||
l, err := db.ListUsers(context.TODO())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list users in DB: %v", err)
|
||||
}
|
||||
@ -111,12 +112,12 @@ func main() {
|
||||
|
||||
u.Admin = section.Values.Get("Admin") == "true"
|
||||
|
||||
if err := db.StoreUser(u); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), u); err != nil {
|
||||
log.Fatalf("failed to store user %q: %v", username, err)
|
||||
}
|
||||
userID := u.ID
|
||||
|
||||
l, err := db.ListNetworks(userID)
|
||||
l, err := db.ListNetworks(context.TODO(), userID)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list networks for user %q: %v", username, err)
|
||||
}
|
||||
@ -183,11 +184,11 @@ func main() {
|
||||
n.Pass = pass
|
||||
n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
|
||||
|
||||
if err := db.StoreNetwork(userID, n); err != nil {
|
||||
if err := db.StoreNetwork(context.TODO(), userID, n); err != nil {
|
||||
logger.Fatalf("failed to store network: %v", err)
|
||||
}
|
||||
|
||||
l, err := db.ListChannels(n.ID)
|
||||
l, err := db.ListChannels(context.TODO(), n.ID)
|
||||
if err != nil {
|
||||
logger.Fatalf("failed to list channels: %v", err)
|
||||
}
|
||||
@ -217,7 +218,7 @@ func main() {
|
||||
ch.Key = section.Values.Get("Key")
|
||||
ch.Detached = section.Values.Get("Detached") == "true"
|
||||
|
||||
if err := db.StoreChannel(n.ID, ch); err != nil {
|
||||
if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil {
|
||||
logger.Printf("channel %q: failed to store channel: %v", chName, err)
|
||||
}
|
||||
})
|
||||
|
27
db.go
27
db.go
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -9,22 +10,22 @@ import (
|
||||
|
||||
type Database interface {
|
||||
Close() error
|
||||
Stats() (*DatabaseStats, error)
|
||||
Stats(ctx context.Context) (*DatabaseStats, error)
|
||||
|
||||
ListUsers() ([]User, error)
|
||||
GetUser(username string) (*User, error)
|
||||
StoreUser(user *User) error
|
||||
DeleteUser(id int64) error
|
||||
ListUsers(ctx context.Context) ([]User, error)
|
||||
GetUser(ctx context.Context, username string) (*User, error)
|
||||
StoreUser(ctx context.Context, user *User) error
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
|
||||
ListNetworks(userID int64) ([]Network, error)
|
||||
StoreNetwork(userID int64, network *Network) error
|
||||
DeleteNetwork(id int64) error
|
||||
ListChannels(networkID int64) ([]Channel, error)
|
||||
StoreChannel(networKID int64, ch *Channel) error
|
||||
DeleteChannel(id int64) error
|
||||
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
|
||||
StoreNetwork(ctx context.Context, userID int64, network *Network) error
|
||||
DeleteNetwork(ctx context.Context, id int64) error
|
||||
ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
|
||||
StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
|
||||
DeleteChannel(ctx context.Context, id int64) error
|
||||
|
||||
ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error)
|
||||
StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
|
||||
ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
|
||||
StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
|
||||
}
|
||||
|
||||
func OpenDB(driver, source string) (Database, error) {
|
||||
|
@ -147,8 +147,8 @@ func (db *PostgresDB) Close() error {
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *PostgresDB) Stats() (*DatabaseStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var stats DatabaseStats
|
||||
@ -163,8 +163,8 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) {
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListUsers() ([]User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
@ -192,8 +192,8 @@ func (db *PostgresDB) ListUsers() ([]User, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) GetUser(username string) (*User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
user := &User{Username: username}
|
||||
@ -210,8 +210,8 @@ func (db *PostgresDB) GetUser(username string) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreUser(user *User) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
password := toNullString(user.Password)
|
||||
@ -234,16 +234,16 @@ func (db *PostgresDB) StoreUser(user *User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) DeleteUser(id int64) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -286,8 +286,8 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
return networks, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
netName := toNullString(network.Name)
|
||||
@ -338,16 +338,16 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) DeleteNetwork(id int64) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -380,8 +380,8 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
key := toNullString(ch.Key)
|
||||
@ -408,16 +408,16 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) DeleteChannel(id int64) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -444,8 +444,8 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt,
|
||||
return receipts, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
|
52
db_sqlite.go
52
db_sqlite.go
@ -208,11 +208,11 @@ func (db *SqliteDB) upgrade() error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *SqliteDB) Stats() (*DatabaseStats, error) {
|
||||
func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var stats DatabaseStats
|
||||
@ -234,11 +234,11 @@ func toNullString(s string) sql.NullString {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListUsers() ([]User, error) {
|
||||
func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
@ -266,11 +266,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) GetUser(username string) (*User, error) {
|
||||
func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
user := &User{Username: username}
|
||||
@ -287,11 +287,11 @@ func (db *SqliteDB) GetUser(username string) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreUser(user *User) error {
|
||||
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []interface{}{
|
||||
@ -323,11 +323,11 @@ func (db *SqliteDB) StoreUser(user *User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) DeleteUser(id int64) error {
|
||||
func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
@ -371,11 +371,11 @@ func (db *SqliteDB) DeleteUser(id int64) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -420,11 +420,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
return networks, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
|
||||
func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
|
||||
@ -490,11 +490,11 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) DeleteNetwork(id int64) error {
|
||||
func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
@ -521,11 +521,11 @@ func (db *SqliteDB) DeleteNetwork(id int64) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `SELECT
|
||||
@ -558,11 +558,11 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []interface{}{
|
||||
@ -598,22 +598,22 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) DeleteChannel(id int64) error {
|
||||
func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
|
||||
func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -642,11 +642,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er
|
||||
return receipts, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@ -976,7 +977,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
||||
func (dc *downstreamConn) authenticate(username, password string) error {
|
||||
username, clientName, networkName := unmarshalUsername(username)
|
||||
|
||||
u, err := dc.srv.db.GetUser(username)
|
||||
u, err := dc.srv.db.GetUser(context.TODO(), username)
|
||||
if err != nil {
|
||||
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
|
||||
return errAuthFailed
|
||||
@ -1377,7 +1378,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
return
|
||||
}
|
||||
n.Nick = nick
|
||||
err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network)
|
||||
err = dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1427,7 +1428,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
})
|
||||
|
||||
n.Realname = storeRealname
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
|
||||
dc.logger.Printf("failed to store network realname: %v", err)
|
||||
storeErr = err
|
||||
}
|
||||
@ -1516,7 +1517,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
}
|
||||
uc.network.channels.SetValue(upstreamName, ch)
|
||||
}
|
||||
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
|
||||
}
|
||||
}
|
||||
@ -1548,7 +1549,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
}
|
||||
uc.network.channels.SetValue(upstreamName, ch)
|
||||
}
|
||||
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
|
||||
}
|
||||
} else {
|
||||
@ -2445,7 +2446,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
|
||||
n.SASL.Mechanism = "PLAIN"
|
||||
n.SASL.Plain.Username = username
|
||||
n.SASL.Plain.Password = password
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
|
||||
dc.logger.Printf("failed to save NickServ credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"mime"
|
||||
@ -85,7 +86,7 @@ func (s *Server) prefix() *irc.Prefix {
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
users, err := s.db.ListUsers()
|
||||
users, err := s.db.ListUsers(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -126,7 +127,7 @@ func (s *Server) createUser(user *User) (*user, error) {
|
||||
return nil, fmt.Errorf("user %q already exists", user.Username)
|
||||
}
|
||||
|
||||
err := s.db.StoreUser(user)
|
||||
err := s.db.StoreUser(context.TODO(), user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create user in db: %v", err)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@ -43,7 +44,7 @@ func createTestUser(t *testing.T, db Database) *User {
|
||||
}
|
||||
|
||||
record := &User{Username: testUsername, Password: string(hashed)}
|
||||
if err := db.StoreUser(record); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), record); err != nil {
|
||||
t.Fatalf("failed to store test user: %v", err)
|
||||
}
|
||||
|
||||
@ -68,7 +69,7 @@ func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Li
|
||||
Nick: user.Username,
|
||||
Enabled: true,
|
||||
}
|
||||
if err := db.StoreNetwork(user.ID, network); err != nil {
|
||||
if err := db.StoreNetwork(context.TODO(), user.ID, network); err != nil {
|
||||
t.Fatalf("failed to store test network: %v", err)
|
||||
}
|
||||
|
||||
|
13
service.go
13
service.go
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
@ -657,7 +658,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
|
||||
net.SASL.External.PrivKeyBlob = privKey
|
||||
net.SASL.Mechanism = "EXTERNAL"
|
||||
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -698,7 +699,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
|
||||
net.SASL.Plain.Password = params[2]
|
||||
net.SASL.Mechanism = "PLAIN"
|
||||
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -722,7 +723,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
|
||||
net.SASL.External.PrivKeyBlob = nil
|
||||
net.SASL.Mechanism = ""
|
||||
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -860,7 +861,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
|
||||
|
||||
u.stop()
|
||||
|
||||
if err := dc.srv.db.DeleteUser(u.ID); err != nil {
|
||||
if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
@ -1015,7 +1016,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
|
||||
|
||||
uc.updateChannelAutoDetach(upstreamName)
|
||||
|
||||
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
return fmt.Errorf("failed to update channel: %v", err)
|
||||
}
|
||||
|
||||
@ -1024,7 +1025,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
|
||||
}
|
||||
|
||||
func handleServiceServerStatus(dc *downstreamConn, params []string) error {
|
||||
dbStats, err := dc.user.srv.db.Stats()
|
||||
dbStats, err := dc.user.srv.db.Stats(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
@ -1516,7 +1517,7 @@ func (uc *upstreamConn) handleDetachedMessage(ch *Channel, msg *irc.Message) {
|
||||
}
|
||||
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
|
||||
uc.network.attach(ch)
|
||||
if err := uc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
|
||||
}
|
||||
}
|
||||
|
21
user.go
21
user.go
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
@ -330,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
|
||||
if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
net.channels.Delete(name)
|
||||
@ -367,7 +368,7 @@ func (net *network) storeClientDeliveryReceipts(clientName string) {
|
||||
})
|
||||
})
|
||||
|
||||
if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
|
||||
if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
|
||||
net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
|
||||
}
|
||||
}
|
||||
@ -487,7 +488,7 @@ func (u *user) run() {
|
||||
close(u.done)
|
||||
}()
|
||||
|
||||
networks, err := u.srv.db.ListNetworks(u.ID)
|
||||
networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
|
||||
return
|
||||
@ -495,7 +496,7 @@ func (u *user) run() {
|
||||
|
||||
for _, record := range networks {
|
||||
record := record
|
||||
channels, err := u.srv.db.ListChannels(record.ID)
|
||||
channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
|
||||
continue
|
||||
@ -505,7 +506,7 @@ func (u *user) run() {
|
||||
u.networks = append(u.networks, network)
|
||||
|
||||
if u.hasPersistentMsgStore() {
|
||||
receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
|
||||
receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
|
||||
return
|
||||
@ -590,7 +591,7 @@ func (u *user) run() {
|
||||
continue
|
||||
}
|
||||
uc.network.detach(c)
|
||||
if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil {
|
||||
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
|
||||
u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
|
||||
}
|
||||
case eventDownstreamConnected:
|
||||
@ -779,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
|
||||
}
|
||||
|
||||
network := newNetwork(u, record, nil)
|
||||
err := u.srv.db.StoreNetwork(u.ID, &network.Network)
|
||||
err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -821,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
|
||||
panic("tried updating a non-existing network")
|
||||
}
|
||||
|
||||
if err := u.srv.db.StoreNetwork(u.ID, record); err != nil {
|
||||
if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -888,7 +889,7 @@ func (u *user) deleteNetwork(id int64) error {
|
||||
panic("tried deleting a non-existing network")
|
||||
}
|
||||
|
||||
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
|
||||
if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -914,7 +915,7 @@ func (u *user) updateUser(record *User) error {
|
||||
}
|
||||
|
||||
realnameUpdated := u.Realname != record.Realname
|
||||
if err := u.srv.db.StoreUser(record); err != nil {
|
||||
if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
|
||||
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
|
||||
}
|
||||
u.User = *record
|
||||
|
Loading…
Reference in New Issue
Block a user