From 9ec1f1a5b03e0b749d9b642c8e9067e779cc6200 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 18 Oct 2021 19:15:15 +0200 Subject: [PATCH] Add context args to Database interface This is a mecanical change, which just lifts up the context.TODO() calls from inside the DB implementations to the callers. Future work involves properly wiring up the contexts when it makes sense. --- cmd/sojuctl/main.go | 7 +++--- contrib/znc-import.go | 13 ++++++----- db.go | 27 +++++++++++----------- db_postgres.go | 52 +++++++++++++++++++++---------------------- db_sqlite.go | 52 +++++++++++++++++++++---------------------- downstream.go | 13 ++++++----- server.go | 5 +++-- server_test.go | 5 +++-- service.go | 13 ++++++----- upstream.go | 3 ++- user.go | 21 ++++++++--------- 11 files changed, 110 insertions(+), 101 deletions(-) diff --git a/cmd/sojuctl/main.go b/cmd/sojuctl/main.go index d19ccfb..8eb79b4 100644 --- a/cmd/sojuctl/main.go +++ b/cmd/sojuctl/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "flag" "fmt" "io" @@ -75,7 +76,7 @@ func main() { Password: string(hashed), Admin: *admin, } - if err := db.StoreUser(&user); err != nil { + if err := db.StoreUser(context.TODO(), &user); err != nil { log.Fatalf("failed to create user: %v", err) } case "change-password": @@ -85,7 +86,7 @@ func main() { os.Exit(1) } - user, err := db.GetUser(username) + user, err := db.GetUser(context.TODO(), username) if err != nil { log.Fatalf("failed to get user: %v", err) } @@ -101,7 +102,7 @@ func main() { } user.Password = string(hashed) - if err := db.StoreUser(user); err != nil { + if err := db.StoreUser(context.TODO(), user); err != nil { log.Fatalf("failed to update password: %v", err) } default: diff --git a/contrib/znc-import.go b/contrib/znc-import.go index 8dd02ed..a4b870b 100644 --- a/contrib/znc-import.go +++ b/contrib/znc-import.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "flag" "fmt" "io" @@ -79,7 +80,7 @@ func main() { log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) } - l, err := db.ListUsers() + l, err := db.ListUsers(context.TODO()) if err != nil { log.Fatalf("failed to list users in DB: %v", err) } @@ -111,12 +112,12 @@ func main() { u.Admin = section.Values.Get("Admin") == "true" - if err := db.StoreUser(u); err != nil { + if err := db.StoreUser(context.TODO(), u); err != nil { log.Fatalf("failed to store user %q: %v", username, err) } userID := u.ID - l, err := db.ListNetworks(userID) + l, err := db.ListNetworks(context.TODO(), userID) if err != nil { log.Fatalf("failed to list networks for user %q: %v", username, err) } @@ -183,11 +184,11 @@ func main() { n.Pass = pass n.Enabled = section.Values.Get("IRCConnectEnabled") != "false" - if err := db.StoreNetwork(userID, n); err != nil { + if err := db.StoreNetwork(context.TODO(), userID, n); err != nil { logger.Fatalf("failed to store network: %v", err) } - l, err := db.ListChannels(n.ID) + l, err := db.ListChannels(context.TODO(), n.ID) if err != nil { logger.Fatalf("failed to list channels: %v", err) } @@ -217,7 +218,7 @@ func main() { ch.Key = section.Values.Get("Key") ch.Detached = section.Values.Get("Detached") == "true" - if err := db.StoreChannel(n.ID, ch); err != nil { + if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil { logger.Printf("channel %q: failed to store channel: %v", chName, err) } }) diff --git a/db.go b/db.go index 703d993..d30352d 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package soju import ( + "context" "fmt" "net/url" "strings" @@ -9,22 +10,22 @@ import ( type Database interface { Close() error - Stats() (*DatabaseStats, error) + Stats(ctx context.Context) (*DatabaseStats, error) - ListUsers() ([]User, error) - GetUser(username string) (*User, error) - StoreUser(user *User) error - DeleteUser(id int64) error + ListUsers(ctx context.Context) ([]User, error) + GetUser(ctx context.Context, username string) (*User, error) + StoreUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id int64) error - ListNetworks(userID int64) ([]Network, error) - StoreNetwork(userID int64, network *Network) error - DeleteNetwork(id int64) error - ListChannels(networkID int64) ([]Channel, error) - StoreChannel(networKID int64, ch *Channel) error - DeleteChannel(id int64) error + ListNetworks(ctx context.Context, userID int64) ([]Network, error) + StoreNetwork(ctx context.Context, userID int64, network *Network) error + DeleteNetwork(ctx context.Context, id int64) error + ListChannels(ctx context.Context, networkID int64) ([]Channel, error) + StoreChannel(ctx context.Context, networKID int64, ch *Channel) error + DeleteChannel(ctx context.Context, id int64) error - ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) - StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error + ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) + StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error } func OpenDB(driver, source string) (Database, error) { diff --git a/db_postgres.go b/db_postgres.go index cad8b55..d4d4b11 100644 --- a/db_postgres.go +++ b/db_postgres.go @@ -147,8 +147,8 @@ func (db *PostgresDB) Close() error { return db.db.Close() } -func (db *PostgresDB) Stats() (*DatabaseStats, error) { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() var stats DatabaseStats @@ -163,8 +163,8 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) { return &stats, nil } -func (db *PostgresDB) ListUsers() ([]User, error) { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, @@ -192,8 +192,8 @@ func (db *PostgresDB) ListUsers() ([]User, error) { return users, nil } -func (db *PostgresDB) GetUser(username string) (*User, error) { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() user := &User{Username: username} @@ -210,8 +210,8 @@ func (db *PostgresDB) GetUser(username string) (*User, error) { return user, nil } -func (db *PostgresDB) StoreUser(user *User) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() password := toNullString(user.Password) @@ -234,16 +234,16 @@ func (db *PostgresDB) StoreUser(user *User) error { return err } -func (db *PostgresDB) DeleteUser(id int64) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) return err } -func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, ` @@ -286,8 +286,8 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) { return networks, nil } -func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() netName := toNullString(network.Name) @@ -338,16 +338,16 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { return err } -func (db *PostgresDB) DeleteNetwork(id int64) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) return err } -func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, ` @@ -380,8 +380,8 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) { return channels, nil } -func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() key := toNullString(ch.Key) @@ -408,16 +408,16 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { return err } -func (db *PostgresDB) DeleteChannel(id int64) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) return err } -func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, ` @@ -444,8 +444,8 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, return receipts, nil } -func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { - ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) +func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() tx, err := db.db.Begin() diff --git a/db_sqlite.go b/db_sqlite.go index f5d2f9a..b4c8f88 100644 --- a/db_sqlite.go +++ b/db_sqlite.go @@ -208,11 +208,11 @@ func (db *SqliteDB) upgrade() error { return tx.Commit() } -func (db *SqliteDB) Stats() (*DatabaseStats, error) { +func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) { db.lock.RLock() defer db.lock.RUnlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() var stats DatabaseStats @@ -234,11 +234,11 @@ func toNullString(s string) sql.NullString { } } -func (db *SqliteDB) ListUsers() ([]User, error) { +func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { db.lock.RLock() defer db.lock.RUnlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, @@ -266,11 +266,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) { return users, nil } -func (db *SqliteDB) GetUser(username string) (*User, error) { +func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) { db.lock.RLock() defer db.lock.RUnlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() user := &User{Username: username} @@ -287,11 +287,11 @@ func (db *SqliteDB) GetUser(username string) (*User, error) { return user, nil } -func (db *SqliteDB) StoreUser(user *User) error { +func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() args := []interface{}{ @@ -323,11 +323,11 @@ func (db *SqliteDB) StoreUser(user *User) error { return err } -func (db *SqliteDB) DeleteUser(id int64) error { +func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() tx, err := db.db.Begin() @@ -371,11 +371,11 @@ func (db *SqliteDB) DeleteUser(id int64) error { return tx.Commit() } -func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) { +func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { db.lock.RLock() defer db.lock.RUnlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, ` @@ -420,11 +420,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) { return networks, nil } -func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error { +func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString @@ -490,11 +490,11 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error { return err } -func (db *SqliteDB) DeleteNetwork(id int64) error { +func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() tx, err := db.db.Begin() @@ -521,11 +521,11 @@ func (db *SqliteDB) DeleteNetwork(id int64) error { return tx.Commit() } -func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) { +func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { db.lock.RLock() defer db.lock.RUnlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, `SELECT @@ -558,11 +558,11 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) { return channels, nil } -func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error { +func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() args := []interface{}{ @@ -598,22 +598,22 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error { return err } -func (db *SqliteDB) DeleteChannel(id int64) error { +func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) return err } -func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { +func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { db.lock.RLock() defer db.lock.RUnlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() rows, err := db.db.QueryContext(ctx, ` @@ -642,11 +642,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er return receipts, nil } -func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { +func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { db.lock.Lock() defer db.lock.Unlock() - ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() tx, err := db.db.Begin() diff --git a/downstream.go b/downstream.go index 31fb5ff..ba536ab 100644 --- a/downstream.go +++ b/downstream.go @@ -1,6 +1,7 @@ package soju import ( + "context" "crypto/tls" "encoding/base64" "fmt" @@ -976,7 +977,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) { func (dc *downstreamConn) authenticate(username, password string) error { username, clientName, networkName := unmarshalUsername(username) - u, err := dc.srv.db.GetUser(username) + u, err := dc.srv.db.GetUser(context.TODO(), username) if err != nil { dc.logger.Printf("failed authentication for %q: user not found: %v", username, err) return errAuthFailed @@ -1377,7 +1378,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { return } n.Nick = nick - err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network) + err = dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network) }) if err != nil { return err @@ -1427,7 +1428,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { }) n.Realname = storeRealname - if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil { + if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil { dc.logger.Printf("failed to store network realname: %v", err) storeErr = err } @@ -1516,7 +1517,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } uc.network.channels.SetValue(upstreamName, ch) } - if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { + if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) } } @@ -1548,7 +1549,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } uc.network.channels.SetValue(upstreamName, ch) } - if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { + if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) } } else { @@ -2445,7 +2446,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) { n.SASL.Mechanism = "PLAIN" n.SASL.Plain.Username = username n.SASL.Plain.Password = password - if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil { + if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil { dc.logger.Printf("failed to save NickServ credentials: %v", err) } } diff --git a/server.go b/server.go index bda5cae..80a7e53 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package soju import ( + "context" "fmt" "log" "mime" @@ -85,7 +86,7 @@ func (s *Server) prefix() *irc.Prefix { } func (s *Server) Start() error { - users, err := s.db.ListUsers() + users, err := s.db.ListUsers(context.TODO()) if err != nil { return err } @@ -126,7 +127,7 @@ func (s *Server) createUser(user *User) (*user, error) { return nil, fmt.Errorf("user %q already exists", user.Username) } - err := s.db.StoreUser(user) + err := s.db.StoreUser(context.TODO(), user) if err != nil { return nil, fmt.Errorf("could not create user in db: %v", err) } diff --git a/server_test.go b/server_test.go index 0f20a33..5ed700b 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package soju import ( + "context" "net" "testing" @@ -43,7 +44,7 @@ func createTestUser(t *testing.T, db Database) *User { } record := &User{Username: testUsername, Password: string(hashed)} - if err := db.StoreUser(record); err != nil { + if err := db.StoreUser(context.TODO(), record); err != nil { t.Fatalf("failed to store test user: %v", err) } @@ -68,7 +69,7 @@ func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Li Nick: user.Username, Enabled: true, } - if err := db.StoreNetwork(user.ID, network); err != nil { + if err := db.StoreNetwork(context.TODO(), user.ID, network); err != nil { t.Fatalf("failed to store test network: %v", err) } diff --git a/service.go b/service.go index b7da332..2748c81 100644 --- a/service.go +++ b/service.go @@ -1,6 +1,7 @@ package soju import ( + "context" "crypto/sha1" "crypto/sha256" "crypto/sha512" @@ -657,7 +658,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error { net.SASL.External.PrivKeyBlob = privKey net.SASL.Mechanism = "EXTERNAL" - if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { return err } @@ -698,7 +699,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error { net.SASL.Plain.Password = params[2] net.SASL.Mechanism = "PLAIN" - if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { return err } @@ -722,7 +723,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error { net.SASL.External.PrivKeyBlob = nil net.SASL.Mechanism = "" - if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { return err } @@ -860,7 +861,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error { u.stop() - if err := dc.srv.db.DeleteUser(u.ID); err != nil { + if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil { return fmt.Errorf("failed to delete user: %v", err) } @@ -1015,7 +1016,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { uc.updateChannelAutoDetach(upstreamName) - if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { + if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil { return fmt.Errorf("failed to update channel: %v", err) } @@ -1024,7 +1025,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { } func handleServiceServerStatus(dc *downstreamConn, params []string) error { - dbStats, err := dc.user.srv.db.Stats() + dbStats, err := dc.user.srv.db.Stats(context.TODO()) if err != nil { return err } diff --git a/upstream.go b/upstream.go index 903df8e..1c73059 100644 --- a/upstream.go +++ b/upstream.go @@ -1,6 +1,7 @@ package soju import ( + "context" "crypto" "crypto/sha256" "crypto/tls" @@ -1516,7 +1517,7 @@ func (uc *upstreamConn) handleDetachedMessage(ch *Channel, msg *irc.Message) { } if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { uc.network.attach(ch) - if err := uc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { + if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil { uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) } } diff --git a/user.go b/user.go index 08bc913..99ca941 100644 --- a/user.go +++ b/user.go @@ -1,6 +1,7 @@ package soju import ( + "context" "crypto/sha256" "encoding/binary" "encoding/hex" @@ -330,7 +331,7 @@ func (net *network) deleteChannel(name string) error { } } - if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil { + if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil { return err } net.channels.Delete(name) @@ -367,7 +368,7 @@ func (net *network) storeClientDeliveryReceipts(clientName string) { }) }) - if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil { + if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil { net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err) } } @@ -487,7 +488,7 @@ func (u *user) run() { close(u.done) }() - networks, err := u.srv.db.ListNetworks(u.ID) + networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID) if err != nil { u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) return @@ -495,7 +496,7 @@ func (u *user) run() { for _, record := range networks { record := record - channels, err := u.srv.db.ListChannels(record.ID) + channels, err := u.srv.db.ListChannels(context.TODO(), record.ID) if err != nil { u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) continue @@ -505,7 +506,7 @@ func (u *user) run() { u.networks = append(u.networks, network) if u.hasPersistentMsgStore() { - receipts, err := u.srv.db.ListDeliveryReceipts(record.ID) + receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID) if err != nil { u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) return @@ -590,7 +591,7 @@ func (u *user) run() { continue } uc.network.detach(c) - if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil { + if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil { u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err) } case eventDownstreamConnected: @@ -779,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) { } network := newNetwork(u, record, nil) - err := u.srv.db.StoreNetwork(u.ID, &network.Network) + err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network) if err != nil { return nil, err } @@ -821,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) { panic("tried updating a non-existing network") } - if err := u.srv.db.StoreNetwork(u.ID, record); err != nil { + if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil { return nil, err } @@ -888,7 +889,7 @@ func (u *user) deleteNetwork(id int64) error { panic("tried deleting a non-existing network") } - if err := u.srv.db.DeleteNetwork(network.ID); err != nil { + if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil { return err } @@ -914,7 +915,7 @@ func (u *user) updateUser(record *User) error { } realnameUpdated := u.Realname != record.Realname - if err := u.srv.db.StoreUser(record); err != nil { + if err := u.srv.db.StoreUser(context.TODO(), record); err != nil { return fmt.Errorf("failed to update user %q: %v", u.Username, err) } u.User = *record