Introduce UserUpdateFunc
References: https://todo.sr.ht/~emersion/soju/206
This commit is contained in:
parent
67335130b1
commit
aecff32103
@ -1773,9 +1773,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
record.Nick = nick
|
record.Nick = nick
|
||||||
err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record)
|
err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record)
|
||||||
} else {
|
} else {
|
||||||
record := dc.user.User
|
err = dc.user.updateUser(ctx, func(record *database.User) error {
|
||||||
record.Nick = nick
|
record.Nick = nick
|
||||||
err = dc.user.updateUser(ctx, &record)
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dc.logger.Printf("failed to update nick: %v", err)
|
dc.logger.Printf("failed to update nick: %v", err)
|
||||||
@ -1840,9 +1841,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
_, err = dc.user.updateNetwork(ctx, &record)
|
_, err = dc.user.updateNetwork(ctx, &record)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
record := dc.user.User
|
err = dc.user.updateUser(ctx, func(record *database.User) error {
|
||||||
record.Realname = realname
|
record.Realname = realname
|
||||||
err = dc.user.updateUser(ctx, &record)
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
36
service.go
36
service.go
@ -1066,23 +1066,6 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
|
|||||||
|
|
||||||
ctx.print(fmt.Sprintf("updated user %q", username))
|
ctx.print(fmt.Sprintf("updated user %q", username))
|
||||||
} else {
|
} else {
|
||||||
// copy the user record because we'll mutate it
|
|
||||||
record := ctx.user.User
|
|
||||||
|
|
||||||
if password != nil {
|
|
||||||
if err := record.SetPassword(*password); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if disablePassword {
|
|
||||||
record.Password = ""
|
|
||||||
}
|
|
||||||
if nick != nil {
|
|
||||||
record.Nick = *nick
|
|
||||||
}
|
|
||||||
if realname != nil {
|
|
||||||
record.Realname = *realname
|
|
||||||
}
|
|
||||||
if admin != nil {
|
if admin != nil {
|
||||||
return fmt.Errorf("cannot update -admin of own user")
|
return fmt.Errorf("cannot update -admin of own user")
|
||||||
}
|
}
|
||||||
@ -1090,7 +1073,24 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
|
|||||||
return fmt.Errorf("cannot update -enabled of own user")
|
return fmt.Errorf("cannot update -enabled of own user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ctx.user.updateUser(ctx, &record); err != nil {
|
err := ctx.user.updateUser(ctx, func(record *database.User) error {
|
||||||
|
if password != nil {
|
||||||
|
if err := record.SetPassword(*password); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if disablePassword {
|
||||||
|
record.Password = ""
|
||||||
|
}
|
||||||
|
if nick != nil {
|
||||||
|
record.Nick = *nick
|
||||||
|
}
|
||||||
|
if realname != nil {
|
||||||
|
record.Realname = *realname
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
55
user.go
55
user.go
@ -23,6 +23,8 @@ import (
|
|||||||
"git.sr.ht/~emersion/soju/msgstore"
|
"git.sr.ht/~emersion/soju/msgstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type UserUpdateFunc func(record *database.User) error
|
||||||
|
|
||||||
type event interface{}
|
type event interface{}
|
||||||
|
|
||||||
type eventUpstreamMessage struct {
|
type eventUpstreamMessage struct {
|
||||||
@ -702,9 +704,11 @@ func (u *user) run() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !u.Enabled && u.srv.Config().EnableUsersOnAuth {
|
if !u.Enabled && u.srv.Config().EnableUsersOnAuth {
|
||||||
record := u.User
|
err := u.updateUser(ctx, func(record *database.User) error {
|
||||||
record.Enabled = true
|
record.Enabled = true
|
||||||
if err := u.updateUser(ctx, &record); err != nil {
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
dc.logger.Printf("failed to enable user after successful authentication: %v", err)
|
dc.logger.Printf("failed to enable user after successful authentication: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -791,20 +795,18 @@ func (u *user) run() {
|
|||||||
dc.SendMessage(msg)
|
dc.SendMessage(msg)
|
||||||
}
|
}
|
||||||
case eventUserUpdate:
|
case eventUserUpdate:
|
||||||
// copy the user record because we'll mutate it
|
e.done <- u.updateUser(context.TODO(), func(record *database.User) error {
|
||||||
record := u.User
|
if e.password != nil {
|
||||||
|
record.Password = *e.password
|
||||||
if e.password != nil {
|
}
|
||||||
record.Password = *e.password
|
if e.admin != nil {
|
||||||
}
|
record.Admin = *e.admin
|
||||||
if e.admin != nil {
|
}
|
||||||
record.Admin = *e.admin
|
if e.enabled != nil {
|
||||||
}
|
record.Enabled = *e.enabled
|
||||||
if e.enabled != nil {
|
}
|
||||||
record.Enabled = *e.enabled
|
return nil
|
||||||
}
|
})
|
||||||
|
|
||||||
e.done <- u.updateUser(context.TODO(), &record)
|
|
||||||
|
|
||||||
// If the password was updated, kill all downstream connections to
|
// If the password was updated, kill all downstream connections to
|
||||||
// force them to re-authenticate with the new credentials.
|
// force them to re-authenticate with the new credentials.
|
||||||
@ -1110,18 +1112,19 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
func (u *user) updateUser(ctx context.Context, update UserUpdateFunc) error {
|
||||||
if u.ID != record.ID {
|
record := u.User // copy
|
||||||
panic("ID mismatch when updating user")
|
if err := update(&record); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nickUpdated := u.Nick != record.Nick
|
nickUpdated := u.Nick != record.Nick
|
||||||
realnameUpdated := u.Realname != record.Realname
|
realnameUpdated := u.Realname != record.Realname
|
||||||
enabledUpdated := u.Enabled != record.Enabled
|
enabledUpdated := u.Enabled != record.Enabled
|
||||||
if err := u.srv.db.StoreUser(ctx, record); err != nil {
|
if err := u.srv.db.StoreUser(ctx, &record); err != nil {
|
||||||
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
|
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
|
||||||
}
|
}
|
||||||
u.User = *record
|
u.User = record
|
||||||
|
|
||||||
if nickUpdated {
|
if nickUpdated {
|
||||||
for _, net := range u.networks {
|
for _, net := range u.networks {
|
||||||
@ -1264,9 +1267,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
|
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
|
||||||
record := u.User
|
err := u.updateUser(ctx, func(record *database.User) error {
|
||||||
record.DownstreamInteractedAt = time.Now()
|
record.DownstreamInteractedAt = time.Now()
|
||||||
if err := u.updateUser(ctx, &record); err != nil {
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
u.logger.Printf("failed to bump downstream interaction time: %v", err)
|
u.logger.Printf("failed to bump downstream interaction time: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user