Introduce User.Created

For Network and Channel, the database only needed to define one Store
operation to create/update a record. However since User is missing an ID
we couldn't have a single StoreUser function like other types. We had
CreateUser and UpdatePassword. As new User fields get added (e.g. the
upcoming Admin flag) this isn't sustainable.

We could have CreateUser and UpdateUser, but this wouldn't be consistent
with other types. Instead, introduce User.Created which indicates
whether the record is already stored in the DB. This can be used in a
new StoreUser function to decide whether we need to UPDATE or INSERT
without relying on SQL constraints and INSERT OR UPDATE.

The ListUsers and GetUser functions set User.Created to true.
This commit is contained in:
Simon Ser 2020-06-08 11:59:03 +02:00
parent d0cf1d2882
commit 998546cdc3
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 19 additions and 17 deletions

View File

@ -69,7 +69,7 @@ func main() {
Username: username, Username: username,
Password: string(hashed), Password: string(hashed),
} }
if err := db.CreateUser(&user); err != nil { if err := db.StoreUser(&user); err != nil {
log.Fatalf("failed to create user: %v", err) log.Fatalf("failed to create user: %v", err)
} }
case "change-password": case "change-password":
@ -90,13 +90,13 @@ func main() {
} }
user := soju.User{ user := soju.User{
Created: true,
Username: username, Username: username,
Password: string(hashed), Password: string(hashed),
} }
if err := db.UpdatePassword(&user); err != nil { if err := db.StoreUser(&user); err != nil {
log.Fatalf("failed to update password: %v", err) log.Fatalf("failed to update password: %v", err)
} }
default: default:
flag.Usage() flag.Usage()
if cmd != "help" { if cmd != "help" {

28
db.go
View File

@ -10,6 +10,7 @@ import (
) )
type User struct { type User struct {
Created bool
Username string Username string
Password string // hashed Password string // hashed
} }
@ -199,6 +200,7 @@ func (db *DB) ListUsers() ([]User, error) {
if err := rows.Scan(&user.Username, &password); err != nil { if err := rows.Scan(&user.Username, &password); err != nil {
return nil, err return nil, err
} }
user.Created = true
user.Password = fromStringPtr(password) user.Password = fromStringPtr(password)
users = append(users, user) users = append(users, user)
} }
@ -213,7 +215,7 @@ func (db *DB) GetUser(username string) (*User, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
user := &User{Username: username} user := &User{Created: true, Username: username}
var password *string var password *string
row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username) row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
@ -224,24 +226,24 @@ func (db *DB) GetUser(username string) (*User, error) {
return user, nil return user, nil
} }
func (db *DB) CreateUser(user *User) error { func (db *DB) StoreUser(user *User) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
password := toStringPtr(user.Password) password := toStringPtr(user.Password)
_, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
return err
}
func (db *DB) UpdatePassword(user *User) error { var err error
db.lock.Lock() if user.Created {
defer db.lock.Unlock() _, err = db.db.Exec("UPDATE User SET password = ? WHERE username = ?",
password := toStringPtr(user.Password)
_, err := db.db.Exec(`UPDATE User
SET password = ?
WHERE username = ?`,
password, user.Username) password, user.Username)
} else {
_, err = db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)",
user.Username, password)
if err == nil {
user.Created = true
}
}
return err return err
} }

View File

@ -547,5 +547,5 @@ func (u *user) deleteNetwork(id int64) error {
func (u *user) updatePassword(hashed string) error { func (u *user) updatePassword(hashed string) error {
u.User.Password = hashed u.User.Password = hashed
return u.srv.db.UpdatePassword(&u.User) return u.srv.db.StoreUser(&u.User)
} }