Introduce UserUpdateFunc

References: https://todo.sr.ht/~emersion/soju/206
This commit is contained in:
Simon Ser 2023-03-01 14:16:33 +01:00
parent 67335130b1
commit aecff32103
3 changed files with 56 additions and 49 deletions

View File

@ -1773,9 +1773,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
record.Nick = nick record.Nick = nick
err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record) err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record)
} else { } else {
record := dc.user.User err = dc.user.updateUser(ctx, func(record *database.User) error {
record.Nick = nick record.Nick = nick
err = dc.user.updateUser(ctx, &record) return nil
})
} }
if err != nil { if err != nil {
dc.logger.Printf("failed to update nick: %v", err) dc.logger.Printf("failed to update nick: %v", err)
@ -1840,9 +1841,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
_, err = dc.user.updateNetwork(ctx, &record) _, err = dc.user.updateNetwork(ctx, &record)
} }
} else { } else {
record := dc.user.User err = dc.user.updateUser(ctx, func(record *database.User) error {
record.Realname = realname record.Realname = realname
err = dc.user.updateUser(ctx, &record) return nil
})
} }
if err != nil { if err != nil {

View File

@ -1066,23 +1066,6 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
ctx.print(fmt.Sprintf("updated user %q", username)) ctx.print(fmt.Sprintf("updated user %q", username))
} else { } else {
// copy the user record because we'll mutate it
record := ctx.user.User
if password != nil {
if err := record.SetPassword(*password); err != nil {
return err
}
}
if disablePassword {
record.Password = ""
}
if nick != nil {
record.Nick = *nick
}
if realname != nil {
record.Realname = *realname
}
if admin != nil { if admin != nil {
return fmt.Errorf("cannot update -admin of own user") return fmt.Errorf("cannot update -admin of own user")
} }
@ -1090,7 +1073,24 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
return fmt.Errorf("cannot update -enabled of own user") return fmt.Errorf("cannot update -enabled of own user")
} }
if err := ctx.user.updateUser(ctx, &record); err != nil { err := ctx.user.updateUser(ctx, func(record *database.User) error {
if password != nil {
if err := record.SetPassword(*password); err != nil {
return err
}
}
if disablePassword {
record.Password = ""
}
if nick != nil {
record.Nick = *nick
}
if realname != nil {
record.Realname = *realname
}
return nil
})
if err != nil {
return err return err
} }

55
user.go
View File

@ -23,6 +23,8 @@ import (
"git.sr.ht/~emersion/soju/msgstore" "git.sr.ht/~emersion/soju/msgstore"
) )
type UserUpdateFunc func(record *database.User) error
type event interface{} type event interface{}
type eventUpstreamMessage struct { type eventUpstreamMessage struct {
@ -702,9 +704,11 @@ func (u *user) run() {
} }
if !u.Enabled && u.srv.Config().EnableUsersOnAuth { if !u.Enabled && u.srv.Config().EnableUsersOnAuth {
record := u.User err := u.updateUser(ctx, func(record *database.User) error {
record.Enabled = true record.Enabled = true
if err := u.updateUser(ctx, &record); err != nil { return nil
})
if err != nil {
dc.logger.Printf("failed to enable user after successful authentication: %v", err) dc.logger.Printf("failed to enable user after successful authentication: %v", err)
} }
} }
@ -791,20 +795,18 @@ func (u *user) run() {
dc.SendMessage(msg) dc.SendMessage(msg)
} }
case eventUserUpdate: case eventUserUpdate:
// copy the user record because we'll mutate it e.done <- u.updateUser(context.TODO(), func(record *database.User) error {
record := u.User if e.password != nil {
record.Password = *e.password
if e.password != nil { }
record.Password = *e.password if e.admin != nil {
} record.Admin = *e.admin
if e.admin != nil { }
record.Admin = *e.admin if e.enabled != nil {
} record.Enabled = *e.enabled
if e.enabled != nil { }
record.Enabled = *e.enabled return nil
} })
e.done <- u.updateUser(context.TODO(), &record)
// If the password was updated, kill all downstream connections to // If the password was updated, kill all downstream connections to
// force them to re-authenticate with the new credentials. // force them to re-authenticate with the new credentials.
@ -1110,18 +1112,19 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
return nil return nil
} }
func (u *user) updateUser(ctx context.Context, record *database.User) error { func (u *user) updateUser(ctx context.Context, update UserUpdateFunc) error {
if u.ID != record.ID { record := u.User // copy
panic("ID mismatch when updating user") if err := update(&record); err != nil {
return err
} }
nickUpdated := u.Nick != record.Nick nickUpdated := u.Nick != record.Nick
realnameUpdated := u.Realname != record.Realname realnameUpdated := u.Realname != record.Realname
enabledUpdated := u.Enabled != record.Enabled enabledUpdated := u.Enabled != record.Enabled
if err := u.srv.db.StoreUser(ctx, record); err != nil { if err := u.srv.db.StoreUser(ctx, &record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err) return fmt.Errorf("failed to update user %q: %v", u.Username, err)
} }
u.User = *record u.User = record
if nickUpdated { if nickUpdated {
for _, net := range u.networks { for _, net := range u.networks {
@ -1264,9 +1267,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd
} }
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) { func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
record := u.User err := u.updateUser(ctx, func(record *database.User) error {
record.DownstreamInteractedAt = time.Now() record.DownstreamInteractedAt = time.Now()
if err := u.updateUser(ctx, &record); err != nil { return nil
})
if err != nil {
u.logger.Printf("failed to bump downstream interaction time: %v", err) u.logger.Printf("failed to bump downstream interaction time: %v", err)
} }
} }