Add context support to user and network mutations

References: https://todo.sr.ht/~emersion/soju/141
This commit is contained in:
Simon Ser 2021-11-08 19:36:10 +01:00
parent 8b3e5e7465
commit c21202160c
3 changed files with 22 additions and 22 deletions

View File

@ -1134,7 +1134,7 @@ func (dc *downstreamConn) loadNetwork() error {
dc.logger.Printf("auto-saving network %q", dc.networkName)
var err error
network, err = dc.user.createNetwork(&Network{
network, err = dc.user.createNetwork(context.TODO(), &Network{
Addr: dc.networkName,
Nick: nick,
Enabled: true,
@ -1536,7 +1536,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
// Walk the network list as a second step, because updateNetwork
// mutates the original list
for _, record := range needUpdate {
if _, err := dc.user.updateNetwork(&record); err != nil {
if _, err := dc.user.updateNetwork(ctx, &record); err != nil {
dc.logger.Printf("failed to update network realname: %v", err)
storeErr = err
}
@ -1655,7 +1655,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: params,
})
if err := uc.network.deleteChannel(upstreamName); err != nil {
if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
}
}
@ -2441,7 +2441,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
record.Realname = ""
}
network, err := dc.user.createNetwork(record)
network, err := dc.user.createNetwork(ctx, record)
if err != nil {
return ircError{&irc.Message{
Command: "FAIL",
@ -2485,7 +2485,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
record.Realname = ""
}
_, err = dc.user.updateNetwork(&record)
_, err = dc.user.updateNetwork(ctx, &record)
if err != nil {
return ircError{&irc.Message{
Command: "FAIL",
@ -2516,7 +2516,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}}
}
if err := dc.user.deleteNetwork(net.ID); err != nil {
if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
return err
}

View File

@ -490,7 +490,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error {
return err
}
network, err := dc.user.createNetwork(record)
network, err := dc.user.createNetwork(context.TODO(), record)
if err != nil {
return fmt.Errorf("could not create network: %v", err)
}
@ -565,7 +565,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
return err
}
network, err := dc.user.updateNetwork(&record)
network, err := dc.user.updateNetwork(context.TODO(), &record)
if err != nil {
return fmt.Errorf("could not update network: %v", err)
}
@ -584,7 +584,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
return fmt.Errorf("unknown network %q", params[0])
}
if err := dc.user.deleteNetwork(net.ID); err != nil {
if err := dc.user.deleteNetwork(context.TODO(), net.ID); err != nil {
return err
}
@ -837,7 +837,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error {
return fmt.Errorf("cannot update -admin of own user")
}
if err := dc.user.updateUser(&record); err != nil {
if err := dc.user.updateUser(context.TODO(), &record); err != nil {
return err
}

24
user.go
View File

@ -319,7 +319,7 @@ func (net *network) attach(ch *Channel) {
})
}
func (net *network) deleteChannel(name string) error {
func (net *network) deleteChannel(ctx context.Context, name string) error {
ch := net.channels.Value(name)
if ch == nil {
return fmt.Errorf("unknown channel %q", name)
@ -331,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
}
}
if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
return err
}
net.channels.Delete(name)
@ -660,7 +660,7 @@ func (u *user) run() {
record.Admin = *e.admin
}
e.done <- u.updateUser(&record)
e.done <- u.updateUser(context.TODO(), &record)
// If the password was updated, kill all downstream connections to
// force them to re-authenticate with the new credentials.
@ -766,7 +766,7 @@ func (u *user) checkNetwork(record *Network) error {
return nil
}
func (u *user) createNetwork(record *Network) (*network, error) {
func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
if record.ID != 0 {
panic("tried creating an already-existing network")
}
@ -780,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
}
network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
if err != nil {
return nil, err
}
@ -802,7 +802,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
return network, nil
}
func (u *user) updateNetwork(record *Network) (*network, error) {
func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
if record.ID == 0 {
panic("tried updating a new network")
}
@ -822,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
panic("tried updating a non-existing network")
}
if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
return nil, err
}
@ -883,13 +883,13 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
return updatedNetwork, nil
}
func (u *user) deleteNetwork(id int64) error {
func (u *user) deleteNetwork(ctx context.Context, id int64) error {
network := u.getNetworkByID(id)
if network == nil {
panic("tried deleting a non-existing network")
}
if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
return err
}
@ -909,13 +909,13 @@ func (u *user) deleteNetwork(id int64) error {
return nil
}
func (u *user) updateUser(record *User) error {
func (u *user) updateUser(ctx context.Context, record *User) error {
if u.ID != record.ID {
panic("ID mismatch when updating user")
}
realnameUpdated := u.Realname != record.Realname
if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
if err := u.srv.db.StoreUser(ctx, record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
}
u.User = *record
@ -931,7 +931,7 @@ func (u *user) updateUser(record *User) error {
var netErr error
for _, net := range needUpdate {
if _, err := u.updateNetwork(&net); err != nil {
if _, err := u.updateNetwork(ctx, &net); err != nil {
netErr = err
}
}