diff --git a/database/database.go b/database/database.go index 977d1b7..fda61c8 100644 --- a/database/database.go +++ b/database/database.go @@ -67,6 +67,7 @@ type User struct { ID int64 Username string Password string // hashed + Nick string Realname string Admin bool } @@ -150,6 +151,9 @@ func GetNick(user *User, net *Network) string { if net != nil && net.Nick != "" { return net.Nick } + if user.Nick != "" { + return user.Nick + } return user.Username } diff --git a/database/postgres.go b/database/postgres.go index 79a09ad..1015b52 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -30,6 +30,7 @@ CREATE TABLE "User" ( username VARCHAR(255) NOT NULL UNIQUE, password VARCHAR(255), admin BOOLEAN NOT NULL DEFAULT FALSE, + nick VARCHAR(255), realname VARCHAR(255) ); @@ -153,6 +154,7 @@ var postgresMigrations = []string{ ADD COLUMN "user" INTEGER REFERENCES "User"(id) ON DELETE CASCADE `, + `ALTER TABLE "User" ADD COLUMN nick VARCHAR(255)`, } type PostgresDB struct { @@ -282,7 +284,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { defer cancel() rows, err := db.db.QueryContext(ctx, - `SELECT id, username, password, admin, realname FROM "User"`) + `SELECT id, username, password, admin, nick, realname FROM "User"`) if err != nil { return nil, err } @@ -291,11 +293,12 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { var users []User for rows.Next() { var user User - var password, realname sql.NullString - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + var password, nick, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil { return nil, err } user.Password = password.String + user.Nick = nick.String user.Realname = realname.String users = append(users, user) } @@ -312,14 +315,15 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro user := &User{Username: username} - var password, realname sql.NullString + var password, nick, realname sql.NullString row := db.db.QueryRowContext(ctx, - `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, + `SELECT id, password, admin, nick, realname FROM "User" WHERE username = $1`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil { return nil, err } user.Password = password.String + user.Nick = nick.String user.Realname = realname.String return user, nil } @@ -329,21 +333,22 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { defer cancel() password := toNullString(user.Password) + nick := toNullString(user.Nick) realname := toNullString(user.Realname) var err error if user.ID == 0 { err = db.db.QueryRowContext(ctx, ` - INSERT INTO "User" (username, password, admin, realname) - VALUES ($1, $2, $3, $4) + INSERT INTO "User" (username, password, admin, nick, realname) + VALUES ($1, $2, $3, $4, $5) RETURNING id`, - user.Username, password, user.Admin, realname).Scan(&user.ID) + user.Username, password, user.Admin, nick, realname).Scan(&user.ID) } else { _, err = db.db.ExecContext(ctx, ` UPDATE "User" - SET password = $1, admin = $2, realname = $3 - WHERE id = $4`, - password, user.Admin, realname, user.ID) + SET password = $1, admin = $2, nick = $3, realname = $4 + WHERE id = $5`, + password, user.Admin, nick, realname, user.ID) } return err } diff --git a/database/sqlite.go b/database/sqlite.go index dbf45d2..fc70f11 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -27,7 +27,8 @@ CREATE TABLE User ( username TEXT NOT NULL UNIQUE, password TEXT, admin INTEGER NOT NULL DEFAULT 0, - realname TEXT + realname TEXT, + nick TEXT ); CREATE TABLE Network ( @@ -243,6 +244,7 @@ var sqliteMigrations = []string{ ALTER TABLE WebPushSubscription ADD COLUMN user INTEGER REFERENCES User(id); UPDATE WebPushSubscription AS wps SET user = (SELECT n.user FROM Network AS n WHERE n.id = wps.network); `, + "ALTER TABLE User ADD COLUMN nick TEXT;", } type SqliteDB struct { @@ -349,7 +351,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { defer cancel() rows, err := db.db.QueryContext(ctx, - "SELECT id, username, password, admin, realname FROM User") + "SELECT id, username, password, admin, nick, realname FROM User") if err != nil { return nil, err } @@ -358,11 +360,12 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { var users []User for rows.Next() { var user User - var password, realname sql.NullString - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + var password, nick, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil { return nil, err } user.Password = password.String + user.Nick = nick.String user.Realname = realname.String users = append(users, user) } @@ -379,14 +382,15 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) user := &User{Username: username} - var password, realname sql.NullString + var password, nick, realname sql.NullString row := db.db.QueryRowContext(ctx, - "SELECT id, password, admin, realname FROM User WHERE username = ?", + "SELECT id, password, admin, nick, realname FROM User WHERE username = ?", username) - if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil { return nil, err } user.Password = password.String + user.Nick = nick.String user.Realname = realname.String return user, nil } @@ -399,21 +403,22 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { sql.Named("username", user.Username), sql.Named("password", toNullString(user.Password)), sql.Named("admin", user.Admin), + sql.Named("nick", toNullString(user.Nick)), sql.Named("realname", toNullString(user.Realname)), } var err error if user.ID != 0 { _, err = db.db.ExecContext(ctx, ` - UPDATE User SET password = :password, admin = :admin, + UPDATE User SET password = :password, admin = :admin, nick = :nick, realname = :realname WHERE username = :username`, args...) } else { var res sql.Result res, err = db.db.ExecContext(ctx, ` INSERT INTO - User(username, password, admin, realname) - VALUES (:username, :password, :admin, :realname)`, + User(username, password, admin, nick, realname) + VALUES (:username, :password, :admin, :nick, :realname)`, args...) if err != nil { return err diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 4b7a7de..adf6f62 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -416,6 +416,10 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just *-admin* true|false Make the new user an administrator. + *-nick* + Set the user's nickname. This is used as a fallback if there is no + nickname set for a network. + *-realname* Set the user's realname. This is used as a fallback if there is no realname set for a network. @@ -429,7 +433,8 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just Not all flags are valid in all contexts: - The _-username_ flag is never valid, usernames are immutable. - - The _-realname_ flag is only valid when updating the current user. + - The _-nick_ and _-realname_ flag are only valid when updating the current + user. - The _-admin_ flag is only valid when updating another user. *user delete* diff --git a/downstream.go b/downstream.go index dea8ca5..db29912 100644 --- a/downstream.go +++ b/downstream.go @@ -1798,12 +1798,6 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } - if dc.network == nil { - return ircError{&irc.Message{ - Command: xirc.ERR_UNKNOWNERROR, - Params: []string{dc.nick, "NICK", "Cannot change nickname on the bouncer connection"}, - }} - } if nick == "" || strings.ContainsAny(nick, illegalNickChars) { return ircError{&irc.Message{ Command: irc.ERR_ERRONEUSNICKNAME, @@ -1817,25 +1811,45 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - record := dc.network.Network - record.Nick = nick - if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record); err != nil { - return err + var err error + if dc.network != nil { + record := dc.network.Network + record.Nick = nick + err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record) + } else { + record := dc.user.User + record.Nick = nick + err = dc.user.updateUser(ctx, &record) + } + if err != nil { + dc.logger.Printf("failed to update nick: %v", err) + return ircError{&irc.Message{ + Command: xirc.ERR_UNKNOWNERROR, + Params: []string{dc.nick, "NICK", "Failed to update nick"}, + }} } - if uc := dc.upstream(); uc != nil { - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Command: "NICK", - Params: []string{nick}, - }) + if dc.network != nil { + if uc := dc.upstream(); uc != nil { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NICK", + Params: []string{nick}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{nick}, + }) + dc.nick = nick + dc.nickCM = casemapASCII(dc.nick) + } } else { - dc.SendMessage(&irc.Message{ - Prefix: dc.prefix(), - Command: "NICK", - Params: []string{nick}, - }) - dc.nick = nick - dc.nickCM = casemapASCII(dc.nick) + for _, c := range dc.user.downstreamConns { + if c.network == nil { + c.updateNick() + } + } } case "SETNAME": var realname string diff --git a/service.go b/service.go index 689555f..29424c0 100644 --- a/service.go +++ b/service.go @@ -817,6 +817,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) fs := newFlagSet() username := fs.String("username", "", "") password := fs.String("password", "", "") + nick := fs.String("nick", "", "") realname := fs.String("realname", "", "") admin := fs.Bool("admin", false, "") @@ -832,6 +833,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) user := &database.User{ Username: *username, + Nick: *nick, Realname: *realname, Admin: *admin, } @@ -854,10 +856,11 @@ func popArg(params []string) (string, []string) { } func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error { - var password, realname *string + var password, nick, realname *string var admin *bool fs := newFlagSet() fs.Var(stringPtrFlag{&password}, "password", "") + fs.Var(stringPtrFlag{&nick}, "nick", "") fs.Var(stringPtrFlag{&realname}, "realname", "") fs.Var(boolPtrFlag{&admin}, "admin", "") @@ -873,6 +876,9 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) if !dc.user.Admin { return fmt.Errorf("you must be an admin to update other users") } + if nick != nil { + return fmt.Errorf("cannot update -nick of other user") + } if realname != nil { return fmt.Errorf("cannot update -realname of other user") } @@ -918,6 +924,9 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) return err } } + if nick != nil { + record.Nick = *nick + } if realname != nil { record.Realname = *realname } diff --git a/user.go b/user.go index 2ab3bbf..7e22c45 100644 --- a/user.go +++ b/user.go @@ -933,8 +933,11 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne panic("tried updating a new network") } - // If the realname is reset to the default, just wipe the per-network - // setting + // If the nickname/realname is reset to the default, just wipe the + // per-network setting + if record.Nick == u.Nick { + record.Nick = "" + } if record.Realname == u.Realname { record.Realname = "" } @@ -1030,12 +1033,28 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error { panic("ID mismatch when updating user") } + nickUpdated := u.Nick != record.Nick realnameUpdated := u.Realname != record.Realname if err := u.srv.db.StoreUser(ctx, record); err != nil { return fmt.Errorf("failed to update user %q: %v", u.Username, err) } u.User = *record + if nickUpdated { + for _, net := range u.networks { + if net.Nick != "" { + continue + } + + if uc := net.conn; uc != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{database.GetNick(&u.User, &net.Network)}, + }) + } + } + } + if realnameUpdated { // Re-connect to networks which use the default realname var needUpdate []database.Network