diff --git a/db.go b/db.go index b18b73e..54d80a9 100644 --- a/db.go +++ b/db.go @@ -10,7 +10,7 @@ import ( ) type User struct { - Created bool + ID int64 Username string Password string // hashed Admin bool @@ -190,7 +190,7 @@ func (db *DB) ListUsers() ([]User, error) { db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query("SELECT username, password, admin FROM User") + rows, err := db.db.Query("SELECT rowid, username, password, admin FROM User") if err != nil { return nil, err } @@ -200,10 +200,9 @@ func (db *DB) ListUsers() ([]User, error) { for rows.Next() { var user User var password *string - if err := rows.Scan(&user.Username, &password, &user.Admin); err != nil { + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil { return nil, err } - user.Created = true user.Password = fromStringPtr(password) users = append(users, user) } @@ -218,11 +217,11 @@ func (db *DB) GetUser(username string) (*User, error) { db.lock.RLock() defer db.lock.RUnlock() - user := &User{Created: true, Username: username} + user := &User{Username: username} var password *string - row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username) - if err := row.Scan(&password, &user.Admin); err != nil { + row := db.db.QueryRow("SELECT rowid, password, admin FROM User WHERE username = ?", username) + if err := row.Scan(&user.ID, &password, &user.Admin); err != nil { return nil, err } user.Password = fromStringPtr(password) @@ -236,15 +235,17 @@ func (db *DB) StoreUser(user *User) error { password := toStringPtr(user.Password) var err error - if user.Created { + if user.ID != 0 { _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?", password, user.Admin, user.Username) } else { - _, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)", + var res sql.Result + res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?, ?)", user.Username, password, user.Admin) - if err == nil { - user.Created = true + if err != nil { + return err } + user.ID, err = res.LastInsertId() } return err