From c21202160cd827b2711fcbab39ffce3d93f70861 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 8 Nov 2021 19:36:10 +0100 Subject: [PATCH] Add context support to user and network mutations References: https://todo.sr.ht/~emersion/soju/141 --- downstream.go | 12 ++++++------ service.go | 8 ++++---- user.go | 24 ++++++++++++------------ 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/downstream.go b/downstream.go index 016f84f..46e2015 100644 --- a/downstream.go +++ b/downstream.go @@ -1134,7 +1134,7 @@ func (dc *downstreamConn) loadNetwork() error { dc.logger.Printf("auto-saving network %q", dc.networkName) var err error - network, err = dc.user.createNetwork(&Network{ + network, err = dc.user.createNetwork(context.TODO(), &Network{ Addr: dc.networkName, Nick: nick, Enabled: true, @@ -1536,7 +1536,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { // Walk the network list as a second step, because updateNetwork // mutates the original list for _, record := range needUpdate { - if _, err := dc.user.updateNetwork(&record); err != nil { + if _, err := dc.user.updateNetwork(ctx, &record); err != nil { dc.logger.Printf("failed to update network realname: %v", err) storeErr = err } @@ -1655,7 +1655,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: params, }) - if err := uc.network.deleteChannel(upstreamName); err != nil { + if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) } } @@ -2441,7 +2441,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { record.Realname = "" } - network, err := dc.user.createNetwork(record) + network, err := dc.user.createNetwork(ctx, record) if err != nil { return ircError{&irc.Message{ Command: "FAIL", @@ -2485,7 +2485,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { record.Realname = "" } - _, err = dc.user.updateNetwork(&record) + _, err = dc.user.updateNetwork(ctx, &record) if err != nil { return ircError{&irc.Message{ Command: "FAIL", @@ -2516,7 +2516,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { }} } - if err := dc.user.deleteNetwork(net.ID); err != nil { + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { return err } diff --git a/service.go b/service.go index a596c14..93d2dba 100644 --- a/service.go +++ b/service.go @@ -490,7 +490,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error { return err } - network, err := dc.user.createNetwork(record) + network, err := dc.user.createNetwork(context.TODO(), record) if err != nil { return fmt.Errorf("could not create network: %v", err) } @@ -565,7 +565,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error { return err } - network, err := dc.user.updateNetwork(&record) + network, err := dc.user.updateNetwork(context.TODO(), &record) if err != nil { return fmt.Errorf("could not update network: %v", err) } @@ -584,7 +584,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { return fmt.Errorf("unknown network %q", params[0]) } - if err := dc.user.deleteNetwork(net.ID); err != nil { + if err := dc.user.deleteNetwork(context.TODO(), net.ID); err != nil { return err } @@ -837,7 +837,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error { return fmt.Errorf("cannot update -admin of own user") } - if err := dc.user.updateUser(&record); err != nil { + if err := dc.user.updateUser(context.TODO(), &record); err != nil { return err } diff --git a/user.go b/user.go index 8b142c6..f3db7eb 100644 --- a/user.go +++ b/user.go @@ -319,7 +319,7 @@ func (net *network) attach(ch *Channel) { }) } -func (net *network) deleteChannel(name string) error { +func (net *network) deleteChannel(ctx context.Context, name string) error { ch := net.channels.Value(name) if ch == nil { return fmt.Errorf("unknown channel %q", name) @@ -331,7 +331,7 @@ func (net *network) deleteChannel(name string) error { } } - if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil { + if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { return err } net.channels.Delete(name) @@ -660,7 +660,7 @@ func (u *user) run() { record.Admin = *e.admin } - e.done <- u.updateUser(&record) + e.done <- u.updateUser(context.TODO(), &record) // If the password was updated, kill all downstream connections to // force them to re-authenticate with the new credentials. @@ -766,7 +766,7 @@ func (u *user) checkNetwork(record *Network) error { return nil } -func (u *user) createNetwork(record *Network) (*network, error) { +func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { if record.ID != 0 { panic("tried creating an already-existing network") } @@ -780,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) { } network := newNetwork(u, record, nil) - err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network) + err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network) if err != nil { return nil, err } @@ -802,7 +802,7 @@ func (u *user) createNetwork(record *Network) (*network, error) { return network, nil } -func (u *user) updateNetwork(record *Network) (*network, error) { +func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { if record.ID == 0 { panic("tried updating a new network") } @@ -822,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) { panic("tried updating a non-existing network") } - if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil { + if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil { return nil, err } @@ -883,13 +883,13 @@ func (u *user) updateNetwork(record *Network) (*network, error) { return updatedNetwork, nil } -func (u *user) deleteNetwork(id int64) error { +func (u *user) deleteNetwork(ctx context.Context, id int64) error { network := u.getNetworkByID(id) if network == nil { panic("tried deleting a non-existing network") } - if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil { + if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil { return err } @@ -909,13 +909,13 @@ func (u *user) deleteNetwork(id int64) error { return nil } -func (u *user) updateUser(record *User) error { +func (u *user) updateUser(ctx context.Context, record *User) error { if u.ID != record.ID { panic("ID mismatch when updating user") } realnameUpdated := u.Realname != record.Realname - if err := u.srv.db.StoreUser(context.TODO(), record); err != nil { + if err := u.srv.db.StoreUser(ctx, record); err != nil { return fmt.Errorf("failed to update user %q: %v", u.Username, err) } u.User = *record @@ -931,7 +931,7 @@ func (u *user) updateUser(record *User) error { var netErr error for _, net := range needUpdate { - if _, err := u.updateNetwork(&net); err != nil { + if _, err := u.updateNetwork(ctx, &net); err != nil { netErr = err } }