Introduce user.updateUser

Unify updatePassword and updateRealname into a single function. This
allows "user update" to be atomic.
This commit is contained in:
Simon Ser 2021-06-28 18:05:03 +02:00
parent 00538e7028
commit acde97ca37
2 changed files with 33 additions and 25 deletions

View File

@ -775,19 +775,22 @@ func handleUserUpdate(dc *downstreamConn, params []string) error {
return err return err
} }
// copy the user record because we'll mutate it
record := dc.user.User
if password != nil { if password != nil {
hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return fmt.Errorf("failed to hash password: %v", err) return fmt.Errorf("failed to hash password: %v", err)
} }
if err := dc.user.updatePassword(string(hashed)); err != nil { record.Password = string(hashed)
return err
}
} }
if realname != nil { if realname != nil {
if err := dc.user.updateRealname(*realname); err != nil { record.Realname = *realname
return err }
}
if err := dc.user.updateUser(&record); err != nil {
return err
} }
sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username)) sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username))

43
user.go
View File

@ -856,33 +856,38 @@ func (u *user) deleteNetwork(id int64) error {
return nil return nil
} }
func (u *user) updatePassword(hashed string) error { func (u *user) updateUser(record *User) error {
u.User.Password = hashed if u.ID != record.ID {
return u.srv.db.StoreUser(&u.User) panic("ID mismatch when updating user")
} }
func (u *user) updateRealname(realname string) error { realnameUpdated := u.Realname != record.Realname
u.User.Realname = realname if err := u.srv.db.StoreUser(record); err != nil {
if err := u.srv.db.StoreUser(&u.User); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err) return fmt.Errorf("failed to update user %q: %v", u.Username, err)
} }
u.User = *record
// Re-connect to networks which use the default realname if realnameUpdated {
var needUpdate []Network // Re-connect to networks which use the default realname
u.forEachNetwork(func(net *network) { var needUpdate []Network
if net.Realname == "" { u.forEachNetwork(func(net *network) {
needUpdate = append(needUpdate, net.Network) if net.Realname == "" {
needUpdate = append(needUpdate, net.Network)
}
})
var netErr error
for _, net := range needUpdate {
if _, err := u.updateNetwork(&net); err != nil {
netErr = err
}
} }
}) if netErr != nil {
return netErr
var netErr error
for _, net := range needUpdate {
if _, err := u.updateNetwork(&net); err != nil {
netErr = err
} }
} }
return netErr return nil
} }
func (u *user) stop() { func (u *user) stop() {