Add context support to user and network mutations

References: https://todo.sr.ht/~emersion/soju/141
This commit is contained in:
Simon Ser 2021-11-08 19:36:10 +01:00
parent 8b3e5e7465
commit c21202160c
3 changed files with 22 additions and 22 deletions

View File

@ -1134,7 +1134,7 @@ func (dc *downstreamConn) loadNetwork() error {
dc.logger.Printf("auto-saving network %q", dc.networkName) dc.logger.Printf("auto-saving network %q", dc.networkName)
var err error var err error
network, err = dc.user.createNetwork(&Network{ network, err = dc.user.createNetwork(context.TODO(), &Network{
Addr: dc.networkName, Addr: dc.networkName,
Nick: nick, Nick: nick,
Enabled: true, Enabled: true,
@ -1536,7 +1536,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
// Walk the network list as a second step, because updateNetwork // Walk the network list as a second step, because updateNetwork
// mutates the original list // mutates the original list
for _, record := range needUpdate { for _, record := range needUpdate {
if _, err := dc.user.updateNetwork(&record); err != nil { if _, err := dc.user.updateNetwork(ctx, &record); err != nil {
dc.logger.Printf("failed to update network realname: %v", err) dc.logger.Printf("failed to update network realname: %v", err)
storeErr = err storeErr = err
} }
@ -1655,7 +1655,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: params, Params: params,
}) })
if err := uc.network.deleteChannel(upstreamName); err != nil { if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
} }
} }
@ -2441,7 +2441,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
record.Realname = "" record.Realname = ""
} }
network, err := dc.user.createNetwork(record) network, err := dc.user.createNetwork(ctx, record)
if err != nil { if err != nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
@ -2485,7 +2485,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
record.Realname = "" record.Realname = ""
} }
_, err = dc.user.updateNetwork(&record) _, err = dc.user.updateNetwork(ctx, &record)
if err != nil { if err != nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
@ -2516,7 +2516,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}} }}
} }
if err := dc.user.deleteNetwork(net.ID); err != nil { if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
return err return err
} }

View File

@ -490,7 +490,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error {
return err return err
} }
network, err := dc.user.createNetwork(record) network, err := dc.user.createNetwork(context.TODO(), record)
if err != nil { if err != nil {
return fmt.Errorf("could not create network: %v", err) return fmt.Errorf("could not create network: %v", err)
} }
@ -565,7 +565,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
return err return err
} }
network, err := dc.user.updateNetwork(&record) network, err := dc.user.updateNetwork(context.TODO(), &record)
if err != nil { if err != nil {
return fmt.Errorf("could not update network: %v", err) return fmt.Errorf("could not update network: %v", err)
} }
@ -584,7 +584,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
return fmt.Errorf("unknown network %q", params[0]) return fmt.Errorf("unknown network %q", params[0])
} }
if err := dc.user.deleteNetwork(net.ID); err != nil { if err := dc.user.deleteNetwork(context.TODO(), net.ID); err != nil {
return err return err
} }
@ -837,7 +837,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error {
return fmt.Errorf("cannot update -admin of own user") return fmt.Errorf("cannot update -admin of own user")
} }
if err := dc.user.updateUser(&record); err != nil { if err := dc.user.updateUser(context.TODO(), &record); err != nil {
return err return err
} }

24
user.go
View File

@ -319,7 +319,7 @@ func (net *network) attach(ch *Channel) {
}) })
} }
func (net *network) deleteChannel(name string) error { func (net *network) deleteChannel(ctx context.Context, name string) error {
ch := net.channels.Value(name) ch := net.channels.Value(name)
if ch == nil { if ch == nil {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
@ -331,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
} }
} }
if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil { if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
return err return err
} }
net.channels.Delete(name) net.channels.Delete(name)
@ -660,7 +660,7 @@ func (u *user) run() {
record.Admin = *e.admin record.Admin = *e.admin
} }
e.done <- u.updateUser(&record) e.done <- u.updateUser(context.TODO(), &record)
// If the password was updated, kill all downstream connections to // If the password was updated, kill all downstream connections to
// force them to re-authenticate with the new credentials. // force them to re-authenticate with the new credentials.
@ -766,7 +766,7 @@ func (u *user) checkNetwork(record *Network) error {
return nil return nil
} }
func (u *user) createNetwork(record *Network) (*network, error) { func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
if record.ID != 0 { if record.ID != 0 {
panic("tried creating an already-existing network") panic("tried creating an already-existing network")
} }
@ -780,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
} }
network := newNetwork(u, record, nil) network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network) err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -802,7 +802,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
return network, nil return network, nil
} }
func (u *user) updateNetwork(record *Network) (*network, error) { func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
if record.ID == 0 { if record.ID == 0 {
panic("tried updating a new network") panic("tried updating a new network")
} }
@ -822,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
panic("tried updating a non-existing network") panic("tried updating a non-existing network")
} }
if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil { if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
return nil, err return nil, err
} }
@ -883,13 +883,13 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
return updatedNetwork, nil return updatedNetwork, nil
} }
func (u *user) deleteNetwork(id int64) error { func (u *user) deleteNetwork(ctx context.Context, id int64) error {
network := u.getNetworkByID(id) network := u.getNetworkByID(id)
if network == nil { if network == nil {
panic("tried deleting a non-existing network") panic("tried deleting a non-existing network")
} }
if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil { if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
return err return err
} }
@ -909,13 +909,13 @@ func (u *user) deleteNetwork(id int64) error {
return nil return nil
} }
func (u *user) updateUser(record *User) error { func (u *user) updateUser(ctx context.Context, record *User) error {
if u.ID != record.ID { if u.ID != record.ID {
panic("ID mismatch when updating user") panic("ID mismatch when updating user")
} }
realnameUpdated := u.Realname != record.Realname realnameUpdated := u.Realname != record.Realname
if err := u.srv.db.StoreUser(context.TODO(), record); err != nil { if err := u.srv.db.StoreUser(ctx, record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err) return fmt.Errorf("failed to update user %q: %v", u.Username, err)
} }
u.User = *record u.User = *record
@ -931,7 +931,7 @@ func (u *user) updateUser(record *User) error {
var netErr error var netErr error
for _, net := range needUpdate { for _, net := range needUpdate {
if _, err := u.updateNetwork(&net); err != nil { if _, err := u.updateNetwork(ctx, &net); err != nil {
netErr = err netErr = err
} }
} }