Introduce UserUpdateFunc
References: https://todo.sr.ht/~emersion/soju/206
This commit is contained in:
parent
67335130b1
commit
aecff32103
@ -1773,9 +1773,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
record.Nick = nick
|
||||
err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record)
|
||||
} else {
|
||||
record := dc.user.User
|
||||
record.Nick = nick
|
||||
err = dc.user.updateUser(ctx, &record)
|
||||
err = dc.user.updateUser(ctx, func(record *database.User) error {
|
||||
record.Nick = nick
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
dc.logger.Printf("failed to update nick: %v", err)
|
||||
@ -1840,9 +1841,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
_, err = dc.user.updateNetwork(ctx, &record)
|
||||
}
|
||||
} else {
|
||||
record := dc.user.User
|
||||
record.Realname = realname
|
||||
err = dc.user.updateUser(ctx, &record)
|
||||
err = dc.user.updateUser(ctx, func(record *database.User) error {
|
||||
record.Realname = realname
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
36
service.go
36
service.go
@ -1066,23 +1066,6 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
|
||||
|
||||
ctx.print(fmt.Sprintf("updated user %q", username))
|
||||
} else {
|
||||
// copy the user record because we'll mutate it
|
||||
record := ctx.user.User
|
||||
|
||||
if password != nil {
|
||||
if err := record.SetPassword(*password); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if disablePassword {
|
||||
record.Password = ""
|
||||
}
|
||||
if nick != nil {
|
||||
record.Nick = *nick
|
||||
}
|
||||
if realname != nil {
|
||||
record.Realname = *realname
|
||||
}
|
||||
if admin != nil {
|
||||
return fmt.Errorf("cannot update -admin of own user")
|
||||
}
|
||||
@ -1090,7 +1073,24 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
|
||||
return fmt.Errorf("cannot update -enabled of own user")
|
||||
}
|
||||
|
||||
if err := ctx.user.updateUser(ctx, &record); err != nil {
|
||||
err := ctx.user.updateUser(ctx, func(record *database.User) error {
|
||||
if password != nil {
|
||||
if err := record.SetPassword(*password); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if disablePassword {
|
||||
record.Password = ""
|
||||
}
|
||||
if nick != nil {
|
||||
record.Nick = *nick
|
||||
}
|
||||
if realname != nil {
|
||||
record.Realname = *realname
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
55
user.go
55
user.go
@ -23,6 +23,8 @@ import (
|
||||
"git.sr.ht/~emersion/soju/msgstore"
|
||||
)
|
||||
|
||||
type UserUpdateFunc func(record *database.User) error
|
||||
|
||||
type event interface{}
|
||||
|
||||
type eventUpstreamMessage struct {
|
||||
@ -702,9 +704,11 @@ func (u *user) run() {
|
||||
}
|
||||
|
||||
if !u.Enabled && u.srv.Config().EnableUsersOnAuth {
|
||||
record := u.User
|
||||
record.Enabled = true
|
||||
if err := u.updateUser(ctx, &record); err != nil {
|
||||
err := u.updateUser(ctx, func(record *database.User) error {
|
||||
record.Enabled = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
dc.logger.Printf("failed to enable user after successful authentication: %v", err)
|
||||
}
|
||||
}
|
||||
@ -791,20 +795,18 @@ func (u *user) run() {
|
||||
dc.SendMessage(msg)
|
||||
}
|
||||
case eventUserUpdate:
|
||||
// copy the user record because we'll mutate it
|
||||
record := u.User
|
||||
|
||||
if e.password != nil {
|
||||
record.Password = *e.password
|
||||
}
|
||||
if e.admin != nil {
|
||||
record.Admin = *e.admin
|
||||
}
|
||||
if e.enabled != nil {
|
||||
record.Enabled = *e.enabled
|
||||
}
|
||||
|
||||
e.done <- u.updateUser(context.TODO(), &record)
|
||||
e.done <- u.updateUser(context.TODO(), func(record *database.User) error {
|
||||
if e.password != nil {
|
||||
record.Password = *e.password
|
||||
}
|
||||
if e.admin != nil {
|
||||
record.Admin = *e.admin
|
||||
}
|
||||
if e.enabled != nil {
|
||||
record.Enabled = *e.enabled
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// If the password was updated, kill all downstream connections to
|
||||
// force them to re-authenticate with the new credentials.
|
||||
@ -1110,18 +1112,19 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||
if u.ID != record.ID {
|
||||
panic("ID mismatch when updating user")
|
||||
func (u *user) updateUser(ctx context.Context, update UserUpdateFunc) error {
|
||||
record := u.User // copy
|
||||
if err := update(&record); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nickUpdated := u.Nick != record.Nick
|
||||
realnameUpdated := u.Realname != record.Realname
|
||||
enabledUpdated := u.Enabled != record.Enabled
|
||||
if err := u.srv.db.StoreUser(ctx, record); err != nil {
|
||||
if err := u.srv.db.StoreUser(ctx, &record); err != nil {
|
||||
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
|
||||
}
|
||||
u.User = *record
|
||||
u.User = record
|
||||
|
||||
if nickUpdated {
|
||||
for _, net := range u.networks {
|
||||
@ -1264,9 +1267,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd
|
||||
}
|
||||
|
||||
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
|
||||
record := u.User
|
||||
record.DownstreamInteractedAt = time.Now()
|
||||
if err := u.updateUser(ctx, &record); err != nil {
|
||||
err := u.updateUser(ctx, func(record *database.User) error {
|
||||
record.DownstreamInteractedAt = time.Now()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to bump downstream interaction time: %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user