Add a flag to disable users

Add a new flag to disable users. This can be useful to temporarily
deactivate an account without erasing data.

The user goroutine is kept alive for simplicity's sake. Most of the
infrastructure assumes that each user always has a running goroutine.
A disabled user's goroutine is responsible for sending back an error
to downstream connections, and listening for potential events to
re-enable the account.
This commit is contained in:
Simon Ser 2023-01-26 18:33:55 +01:00
parent bbf234d441
commit d7d9d45b45
9 changed files with 98 additions and 37 deletions

View File

@ -78,6 +78,7 @@ func main() {
Username: username, Username: username,
Password: string(hashed), Password: string(hashed),
Admin: *admin, Admin: *admin,
Enabled: true,
} }
if err := db.StoreUser(ctx, &user); err != nil { if err := db.StoreUser(ctx, &user); err != nil {
log.Fatalf("failed to create user: %v", err) log.Fatalf("failed to create user: %v", err)

View File

@ -107,7 +107,7 @@ func main() {
log.Printf("user %q: updating existing user", username) log.Printf("user %q: updating existing user", username)
} else { } else {
// "!!" is an invalid crypt format, thus disables password auth // "!!" is an invalid crypt format, thus disables password auth
u = &database.User{Username: username, Password: "!!"} u = &database.User{Username: username, Password: "!!", Enabled: true}
usersCreated++ usersCreated++
log.Printf("user %q: creating new user", username) log.Printf("user %q: creating new user", username)
} }

View File

@ -71,6 +71,7 @@ type User struct {
Nick string Nick string
Realname string Realname string
Admin bool Admin bool
Enabled bool
} }
func (u *User) CheckPassword(password string) (upgraded bool, err error) { func (u *User) CheckPassword(password string) (upgraded bool, err error) {

View File

@ -32,7 +32,8 @@ CREATE TABLE "User" (
admin BOOLEAN NOT NULL DEFAULT FALSE, admin BOOLEAN NOT NULL DEFAULT FALSE,
nick VARCHAR(255), nick VARCHAR(255),
realname VARCHAR(255), realname VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now() created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
enabled BOOLEAN NOT NULL DEFAULT TRUE
); );
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
@ -169,6 +170,7 @@ var postgresMigrations = []string{
`ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`, `ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "Network" ADD COLUMN certfp TEXT`, `ALTER TABLE "Network" ADD COLUMN certfp TEXT`,
`ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`, `ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`,
`ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`,
} }
type PostgresDB struct { type PostgresDB struct {
@ -302,7 +304,8 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname FROM "User"`) `SELECT id, username, password, admin, nick, realname, enabled
FROM "User"`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -312,7 +315,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
for rows.Next() { for rows.Next() {
var user User var user User
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil { if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
@ -335,9 +338,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
row := db.db.QueryRowContext(ctx, row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname FROM "User" WHERE username = $1`, `SELECT id, password, admin, nick, realname, enabled
FROM "User"
WHERE username = $1`,
username) username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil { if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
@ -357,16 +362,16 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
var err error var err error
if user.ID == 0 { if user.ID == 0 {
err = db.db.QueryRowContext(ctx, ` err = db.db.QueryRowContext(ctx, `
INSERT INTO "User" (username, password, admin, nick, realname) INSERT INTO "User" (username, password, admin, nick, realname, enabled)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id`, RETURNING id`,
user.Username, password, user.Admin, nick, realname).Scan(&user.ID) user.Username, password, user.Admin, nick, realname, user.Enabled).Scan(&user.ID)
} else { } else {
_, err = db.db.ExecContext(ctx, ` _, err = db.db.ExecContext(ctx, `
UPDATE "User" UPDATE "User"
SET password = $1, admin = $2, nick = $3, realname = $4 SET password = $1, admin = $2, nick = $3, realname = $4, enabled = $5
WHERE id = $5`, WHERE id = $6`,
password, user.Admin, nick, realname, user.ID) password, user.Admin, nick, realname, user.Enabled, user.ID)
} }
return err return err
} }

View File

@ -62,7 +62,8 @@ CREATE TABLE User (
admin INTEGER NOT NULL DEFAULT 0, admin INTEGER NOT NULL DEFAULT 0,
realname TEXT, realname TEXT,
nick TEXT, nick TEXT,
created_at TEXT NOT NULL created_at TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1
); );
CREATE TABLE Network ( CREATE TABLE Network (
@ -289,6 +290,7 @@ var sqliteMigrations = []string{
ALTER TABLE User ADD COLUMN created_at TEXT NOT NULL DEFAULT ''; ALTER TABLE User ADD COLUMN created_at TEXT NOT NULL DEFAULT '';
UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now'); UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now');
`, `,
"ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
} }
type SqliteDB struct { type SqliteDB struct {
@ -388,7 +390,8 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, rows, err := db.db.QueryContext(ctx,
"SELECT id, username, password, admin, nick, realname FROM User") `SELECT id, username, password, admin, nick, realname, enabled
FROM User`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -398,7 +401,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
for rows.Next() { for rows.Next() {
var user User var user User
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil { if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
@ -421,9 +424,11 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
row := db.db.QueryRowContext(ctx, row := db.db.QueryRowContext(ctx,
"SELECT id, password, admin, nick, realname FROM User WHERE username = ?", `SELECT id, password, admin, nick, realname, enabled
FROM User
WHERE username = ?`,
username) username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil { if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
@ -442,21 +447,26 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
sql.Named("admin", user.Admin), sql.Named("admin", user.Admin),
sql.Named("nick", toNullString(user.Nick)), sql.Named("nick", toNullString(user.Nick)),
sql.Named("realname", toNullString(user.Realname)), sql.Named("realname", toNullString(user.Realname)),
sql.Named("enabled", user.Enabled),
sql.Named("now", sqliteTime{time.Now()}), sql.Named("now", sqliteTime{time.Now()}),
} }
var err error var err error
if user.ID != 0 { if user.ID != 0 {
_, err = db.db.ExecContext(ctx, ` _, err = db.db.ExecContext(ctx, `
UPDATE User SET password = :password, admin = :admin, nick = :nick, UPDATE User
realname = :realname WHERE username = :username`, SET password = :password, admin = :admin, nick = :nick,
realname = :realname, enabled = :enabled
WHERE username = :username`,
args...) args...)
} else { } else {
var res sql.Result var res sql.Result
res, err = db.db.ExecContext(ctx, ` res, err = db.db.ExecContext(ctx, `
INSERT INTO INSERT INTO
User(username, password, admin, nick, realname, created_at) User(username, password, admin, nick, realname, created_at,
VALUES (:username, :password, :admin, :nick, :realname, :now)`, enabled)
VALUES (:username, :password, :admin, :nick, :realname, :now,
:enabled)`,
args...) args...)
if err != nil { if err != nil {
return err return err

View File

@ -434,6 +434,11 @@ character.
Set the user's realname. This is used as a fallback if there is no Set the user's realname. This is used as a fallback if there is no
realname set for a network. realname set for a network.
*-enabled* true|false
Enable or disable the user. If the user is disabled, the bouncer will
not connect to any of their networks, and downstream connections will
be immediately closed. By default, users are enabled.
*user update* [username] [options...] *user update* [username] [options...]
Update a user. The options are the same as the _user create_ command. Update a user. The options are the same as the _user create_ command.
@ -445,7 +450,8 @@ character.
- The _-username_ flag is never valid, usernames are immutable. - The _-username_ flag is never valid, usernames are immutable.
- The _-nick_ and _-realname_ flag are only valid when updating the current - The _-nick_ and _-realname_ flag are only valid when updating the current
user. user.
- The _-admin_ flag is only valid when updating another user. - The _-admin_ and _-enabled_ flags are only valid when updating another
user.
*user delete* <username> [confirmation token] *user delete* <username> [confirmation token]
Delete a soju user. Delete a soju user.

View File

@ -51,7 +51,11 @@ func createTestUser(t *testing.T, db database.Database) *database.User {
t.Fatalf("failed to generate bcrypt hash: %v", err) t.Fatalf("failed to generate bcrypt hash: %v", err)
} }
record := &database.User{Username: testUsername, Password: string(hashed)} record := &database.User{
Username: testUsername,
Password: string(hashed),
Enabled: true,
}
if err := db.StoreUser(context.Background(), record); err != nil { if err := db.StoreUser(context.Background(), record); err != nil {
t.Fatalf("failed to store test user: %v", err) t.Fatalf("failed to store test user: %v", err)
} }

View File

@ -920,6 +920,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
nick := fs.String("nick", "", "") nick := fs.String("nick", "", "")
realname := fs.String("realname", "", "") realname := fs.String("realname", "", "")
admin := fs.Bool("admin", false, "") admin := fs.Bool("admin", false, "")
enabled := fs.Bool("enabled", true, "")
if err := fs.Parse(params); err != nil { if err := fs.Parse(params); err != nil {
return err return err
@ -939,6 +940,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
Nick: *nick, Nick: *nick,
Realname: *realname, Realname: *realname,
Admin: *admin, Admin: *admin,
Enabled: *enabled,
} }
if err := user.SetPassword(*password); err != nil { if err := user.SetPassword(*password); err != nil {
return err return err
@ -960,12 +962,13 @@ func popArg(params []string) (string, []string) {
func handleUserUpdate(ctx *serviceContext, params []string) error { func handleUserUpdate(ctx *serviceContext, params []string) error {
var password, nick, realname *string var password, nick, realname *string
var admin *bool var admin, enabled *bool
fs := newFlagSet() fs := newFlagSet()
fs.Var(stringPtrFlag{&password}, "password", "") fs.Var(stringPtrFlag{&password}, "password", "")
fs.Var(stringPtrFlag{&nick}, "nick", "") fs.Var(stringPtrFlag{&nick}, "nick", "")
fs.Var(stringPtrFlag{&realname}, "realname", "") fs.Var(stringPtrFlag{&realname}, "realname", "")
fs.Var(boolPtrFlag{&admin}, "admin", "") fs.Var(boolPtrFlag{&admin}, "admin", "")
fs.Var(boolPtrFlag{&enabled}, "enabled", "")
username, params := popArg(params) username, params := popArg(params)
if err := fs.Parse(params); err != nil { if err := fs.Parse(params); err != nil {
@ -1005,6 +1008,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
event := eventUserUpdate{ event := eventUserUpdate{
password: hashed, password: hashed,
admin: admin, admin: admin,
enabled: enabled,
done: done, done: done,
} }
select { select {
@ -1036,6 +1040,9 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
if admin != nil { if admin != nil {
return fmt.Errorf("cannot update -admin of own user") return fmt.Errorf("cannot update -admin of own user")
} }
if enabled != nil {
return fmt.Errorf("cannot update -enabled of own user")
}
if err := ctx.user.updateUser(ctx, &record); err != nil { if err := ctx.user.updateUser(ctx, &record); err != nil {
return err return err

53
user.go
View File

@ -74,6 +74,7 @@ type eventStop struct{}
type eventUserUpdate struct { type eventUserUpdate struct {
password *string password *string
admin *bool admin *bool
enabled *bool
done chan error done chan error
} }
@ -246,7 +247,7 @@ func (net *network) runConn(ctx context.Context) error {
} }
func (net *network) run() { func (net *network) run() {
if !net.Enabled { if !net.user.Enabled || !net.Enabled {
return return
} }
@ -687,6 +688,15 @@ func (u *user) run() {
dc.monitored.SetCasemapping(dc.network.casemap) dc.monitored.SetCasemapping(dc.network.casemap)
} }
if !u.Enabled {
dc.SendMessage(&irc.Message{
Command: "ERROR",
Params: []string{"This bouncer account is disabled"},
})
// TODO: close dc after the error message is sent
break
}
if err := dc.welcome(context.TODO()); err != nil { if err := dc.welcome(context.TODO()); err != nil {
if ircErr, ok := err.(ircError); ok { if ircErr, ok := err.(ircError); ok {
msg := ircErr.Message.Copy() msg := ircErr.Message.Copy()
@ -762,6 +772,9 @@ func (u *user) run() {
if e.admin != nil { if e.admin != nil {
record.Admin = *e.admin record.Admin = *e.admin
} }
if e.enabled != nil {
record.Enabled = *e.enabled
}
e.done <- u.updateUser(context.TODO(), &record) e.done <- u.updateUser(context.TODO(), &record)
@ -1071,6 +1084,7 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
nickUpdated := u.Nick != record.Nick nickUpdated := u.Nick != record.Nick
realnameUpdated := u.Realname != record.Realname realnameUpdated := u.Realname != record.Realname
enabledUpdated := u.Enabled != record.Enabled
if err := u.srv.db.StoreUser(ctx, record); err != nil { if err := u.srv.db.StoreUser(ctx, record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err) return fmt.Errorf("failed to update user %q: %v", u.Username, err)
} }
@ -1091,22 +1105,28 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
} }
} }
if realnameUpdated { if realnameUpdated || enabledUpdated {
// Re-connect to networks which use the default realname // Re-connect to networks which use the default realname
var needUpdate []database.Network var needUpdate []database.Network
for _, net := range u.networks { for _, net := range u.networks {
if net.Realname != "" { // If only the realname was updated, maybe we can skip the
continue // re-connect
} if realnameUpdated && !enabledUpdated {
// If this network has a custom realname set, no need to
// re-connect: the user-wide realname remains unused
if net.Realname != "" {
continue
}
// We only need to call updateNetwork for upstreams that don't // We only need to call updateNetwork for upstreams that don't
// support setname // support setname
if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") { if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") {
uc.SendMessage(ctx, &irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "SETNAME", Command: "SETNAME",
Params: []string{database.GetRealname(&u.User, &net.Network)}, Params: []string{database.GetRealname(&u.User, &net.Network)},
}) })
continue continue
}
} }
needUpdate = append(needUpdate, net.Network) needUpdate = append(needUpdate, net.Network)
@ -1123,6 +1143,13 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
} }
} }
if !u.Enabled {
// TODO: send an error message before disconnecting
for _, dc := range u.downstreamConns {
dc.Close()
}
}
return nil return nil
} }