Add a flag to disable users
Add a new flag to disable users. This can be useful to temporarily deactivate an account without erasing data. The user goroutine is kept alive for simplicity's sake. Most of the infrastructure assumes that each user always has a running goroutine. A disabled user's goroutine is responsible for sending back an error to downstream connections, and listening for potential events to re-enable the account.
This commit is contained in:
parent
bbf234d441
commit
d7d9d45b45
@ -78,6 +78,7 @@ func main() {
|
||||
Username: username,
|
||||
Password: string(hashed),
|
||||
Admin: *admin,
|
||||
Enabled: true,
|
||||
}
|
||||
if err := db.StoreUser(ctx, &user); err != nil {
|
||||
log.Fatalf("failed to create user: %v", err)
|
||||
|
@ -107,7 +107,7 @@ func main() {
|
||||
log.Printf("user %q: updating existing user", username)
|
||||
} else {
|
||||
// "!!" is an invalid crypt format, thus disables password auth
|
||||
u = &database.User{Username: username, Password: "!!"}
|
||||
u = &database.User{Username: username, Password: "!!", Enabled: true}
|
||||
usersCreated++
|
||||
log.Printf("user %q: creating new user", username)
|
||||
}
|
||||
|
@ -71,6 +71,7 @@ type User struct {
|
||||
Nick string
|
||||
Realname string
|
||||
Admin bool
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
func (u *User) CheckPassword(password string) (upgraded bool, err error) {
|
||||
|
@ -32,7 +32,8 @@ CREATE TABLE "User" (
|
||||
admin BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
nick VARCHAR(255),
|
||||
realname VARCHAR(255),
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE
|
||||
);
|
||||
|
||||
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
|
||||
@ -169,6 +170,7 @@ var postgresMigrations = []string{
|
||||
`ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`,
|
||||
`ALTER TABLE "Network" ADD COLUMN certfp TEXT`,
|
||||
`ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`,
|
||||
`ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`,
|
||||
}
|
||||
|
||||
type PostgresDB struct {
|
||||
@ -302,7 +304,8 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
`SELECT id, username, password, admin, nick, realname FROM "User"`)
|
||||
`SELECT id, username, password, admin, nick, realname, enabled
|
||||
FROM "User"`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -312,7 +315,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
for rows.Next() {
|
||||
var user User
|
||||
var password, nick, realname sql.NullString
|
||||
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil {
|
||||
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Password = password.String
|
||||
@ -335,9 +338,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
|
||||
|
||||
var password, nick, realname sql.NullString
|
||||
row := db.db.QueryRowContext(ctx,
|
||||
`SELECT id, password, admin, nick, realname FROM "User" WHERE username = $1`,
|
||||
`SELECT id, password, admin, nick, realname, enabled
|
||||
FROM "User"
|
||||
WHERE username = $1`,
|
||||
username)
|
||||
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil {
|
||||
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Password = password.String
|
||||
@ -357,16 +362,16 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
|
||||
var err error
|
||||
if user.ID == 0 {
|
||||
err = db.db.QueryRowContext(ctx, `
|
||||
INSERT INTO "User" (username, password, admin, nick, realname)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
INSERT INTO "User" (username, password, admin, nick, realname, enabled)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id`,
|
||||
user.Username, password, user.Admin, nick, realname).Scan(&user.ID)
|
||||
user.Username, password, user.Admin, nick, realname, user.Enabled).Scan(&user.ID)
|
||||
} else {
|
||||
_, err = db.db.ExecContext(ctx, `
|
||||
UPDATE "User"
|
||||
SET password = $1, admin = $2, nick = $3, realname = $4
|
||||
WHERE id = $5`,
|
||||
password, user.Admin, nick, realname, user.ID)
|
||||
SET password = $1, admin = $2, nick = $3, realname = $4, enabled = $5
|
||||
WHERE id = $6`,
|
||||
password, user.Admin, nick, realname, user.Enabled, user.ID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -62,7 +62,8 @@ CREATE TABLE User (
|
||||
admin INTEGER NOT NULL DEFAULT 0,
|
||||
realname TEXT,
|
||||
nick TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE Network (
|
||||
@ -289,6 +290,7 @@ var sqliteMigrations = []string{
|
||||
ALTER TABLE User ADD COLUMN created_at TEXT NOT NULL DEFAULT '';
|
||||
UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now');
|
||||
`,
|
||||
"ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
|
||||
}
|
||||
|
||||
type SqliteDB struct {
|
||||
@ -388,7 +390,8 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
"SELECT id, username, password, admin, nick, realname FROM User")
|
||||
`SELECT id, username, password, admin, nick, realname, enabled
|
||||
FROM User`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -398,7 +401,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
for rows.Next() {
|
||||
var user User
|
||||
var password, nick, realname sql.NullString
|
||||
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil {
|
||||
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Password = password.String
|
||||
@ -421,9 +424,11 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
|
||||
|
||||
var password, nick, realname sql.NullString
|
||||
row := db.db.QueryRowContext(ctx,
|
||||
"SELECT id, password, admin, nick, realname FROM User WHERE username = ?",
|
||||
`SELECT id, password, admin, nick, realname, enabled
|
||||
FROM User
|
||||
WHERE username = ?`,
|
||||
username)
|
||||
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil {
|
||||
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Password = password.String
|
||||
@ -442,21 +447,26 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
|
||||
sql.Named("admin", user.Admin),
|
||||
sql.Named("nick", toNullString(user.Nick)),
|
||||
sql.Named("realname", toNullString(user.Realname)),
|
||||
sql.Named("enabled", user.Enabled),
|
||||
sql.Named("now", sqliteTime{time.Now()}),
|
||||
}
|
||||
|
||||
var err error
|
||||
if user.ID != 0 {
|
||||
_, err = db.db.ExecContext(ctx, `
|
||||
UPDATE User SET password = :password, admin = :admin, nick = :nick,
|
||||
realname = :realname WHERE username = :username`,
|
||||
UPDATE User
|
||||
SET password = :password, admin = :admin, nick = :nick,
|
||||
realname = :realname, enabled = :enabled
|
||||
WHERE username = :username`,
|
||||
args...)
|
||||
} else {
|
||||
var res sql.Result
|
||||
res, err = db.db.ExecContext(ctx, `
|
||||
INSERT INTO
|
||||
User(username, password, admin, nick, realname, created_at)
|
||||
VALUES (:username, :password, :admin, :nick, :realname, :now)`,
|
||||
User(username, password, admin, nick, realname, created_at,
|
||||
enabled)
|
||||
VALUES (:username, :password, :admin, :nick, :realname, :now,
|
||||
:enabled)`,
|
||||
args...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -434,6 +434,11 @@ character.
|
||||
Set the user's realname. This is used as a fallback if there is no
|
||||
realname set for a network.
|
||||
|
||||
*-enabled* true|false
|
||||
Enable or disable the user. If the user is disabled, the bouncer will
|
||||
not connect to any of their networks, and downstream connections will
|
||||
be immediately closed. By default, users are enabled.
|
||||
|
||||
*user update* [username] [options...]
|
||||
Update a user. The options are the same as the _user create_ command.
|
||||
|
||||
@ -445,7 +450,8 @@ character.
|
||||
- The _-username_ flag is never valid, usernames are immutable.
|
||||
- The _-nick_ and _-realname_ flag are only valid when updating the current
|
||||
user.
|
||||
- The _-admin_ flag is only valid when updating another user.
|
||||
- The _-admin_ and _-enabled_ flags are only valid when updating another
|
||||
user.
|
||||
|
||||
*user delete* <username> [confirmation token]
|
||||
Delete a soju user.
|
||||
|
@ -51,7 +51,11 @@ func createTestUser(t *testing.T, db database.Database) *database.User {
|
||||
t.Fatalf("failed to generate bcrypt hash: %v", err)
|
||||
}
|
||||
|
||||
record := &database.User{Username: testUsername, Password: string(hashed)}
|
||||
record := &database.User{
|
||||
Username: testUsername,
|
||||
Password: string(hashed),
|
||||
Enabled: true,
|
||||
}
|
||||
if err := db.StoreUser(context.Background(), record); err != nil {
|
||||
t.Fatalf("failed to store test user: %v", err)
|
||||
}
|
||||
|
@ -920,6 +920,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
|
||||
nick := fs.String("nick", "", "")
|
||||
realname := fs.String("realname", "", "")
|
||||
admin := fs.Bool("admin", false, "")
|
||||
enabled := fs.Bool("enabled", true, "")
|
||||
|
||||
if err := fs.Parse(params); err != nil {
|
||||
return err
|
||||
@ -939,6 +940,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
|
||||
Nick: *nick,
|
||||
Realname: *realname,
|
||||
Admin: *admin,
|
||||
Enabled: *enabled,
|
||||
}
|
||||
if err := user.SetPassword(*password); err != nil {
|
||||
return err
|
||||
@ -960,12 +962,13 @@ func popArg(params []string) (string, []string) {
|
||||
|
||||
func handleUserUpdate(ctx *serviceContext, params []string) error {
|
||||
var password, nick, realname *string
|
||||
var admin *bool
|
||||
var admin, enabled *bool
|
||||
fs := newFlagSet()
|
||||
fs.Var(stringPtrFlag{&password}, "password", "")
|
||||
fs.Var(stringPtrFlag{&nick}, "nick", "")
|
||||
fs.Var(stringPtrFlag{&realname}, "realname", "")
|
||||
fs.Var(boolPtrFlag{&admin}, "admin", "")
|
||||
fs.Var(boolPtrFlag{&enabled}, "enabled", "")
|
||||
|
||||
username, params := popArg(params)
|
||||
if err := fs.Parse(params); err != nil {
|
||||
@ -1005,6 +1008,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
|
||||
event := eventUserUpdate{
|
||||
password: hashed,
|
||||
admin: admin,
|
||||
enabled: enabled,
|
||||
done: done,
|
||||
}
|
||||
select {
|
||||
@ -1036,6 +1040,9 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
|
||||
if admin != nil {
|
||||
return fmt.Errorf("cannot update -admin of own user")
|
||||
}
|
||||
if enabled != nil {
|
||||
return fmt.Errorf("cannot update -enabled of own user")
|
||||
}
|
||||
|
||||
if err := ctx.user.updateUser(ctx, &record); err != nil {
|
||||
return err
|
||||
|
31
user.go
31
user.go
@ -74,6 +74,7 @@ type eventStop struct{}
|
||||
type eventUserUpdate struct {
|
||||
password *string
|
||||
admin *bool
|
||||
enabled *bool
|
||||
done chan error
|
||||
}
|
||||
|
||||
@ -246,7 +247,7 @@ func (net *network) runConn(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (net *network) run() {
|
||||
if !net.Enabled {
|
||||
if !net.user.Enabled || !net.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
@ -687,6 +688,15 @@ func (u *user) run() {
|
||||
dc.monitored.SetCasemapping(dc.network.casemap)
|
||||
}
|
||||
|
||||
if !u.Enabled {
|
||||
dc.SendMessage(&irc.Message{
|
||||
Command: "ERROR",
|
||||
Params: []string{"This bouncer account is disabled"},
|
||||
})
|
||||
// TODO: close dc after the error message is sent
|
||||
break
|
||||
}
|
||||
|
||||
if err := dc.welcome(context.TODO()); err != nil {
|
||||
if ircErr, ok := err.(ircError); ok {
|
||||
msg := ircErr.Message.Copy()
|
||||
@ -762,6 +772,9 @@ func (u *user) run() {
|
||||
if e.admin != nil {
|
||||
record.Admin = *e.admin
|
||||
}
|
||||
if e.enabled != nil {
|
||||
record.Enabled = *e.enabled
|
||||
}
|
||||
|
||||
e.done <- u.updateUser(context.TODO(), &record)
|
||||
|
||||
@ -1071,6 +1084,7 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||
|
||||
nickUpdated := u.Nick != record.Nick
|
||||
realnameUpdated := u.Realname != record.Realname
|
||||
enabledUpdated := u.Enabled != record.Enabled
|
||||
if err := u.srv.db.StoreUser(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
|
||||
}
|
||||
@ -1091,10 +1105,15 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||
}
|
||||
}
|
||||
|
||||
if realnameUpdated {
|
||||
if realnameUpdated || enabledUpdated {
|
||||
// Re-connect to networks which use the default realname
|
||||
var needUpdate []database.Network
|
||||
for _, net := range u.networks {
|
||||
// If only the realname was updated, maybe we can skip the
|
||||
// re-connect
|
||||
if realnameUpdated && !enabledUpdated {
|
||||
// If this network has a custom realname set, no need to
|
||||
// re-connect: the user-wide realname remains unused
|
||||
if net.Realname != "" {
|
||||
continue
|
||||
}
|
||||
@ -1108,6 +1127,7 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
needUpdate = append(needUpdate, net.Network)
|
||||
}
|
||||
@ -1123,6 +1143,13 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||
}
|
||||
}
|
||||
|
||||
if !u.Enabled {
|
||||
// TODO: send an error message before disconnecting
|
||||
for _, dc := range u.downstreamConns {
|
||||
dc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user