diff --git a/cmd/sojuctl/main.go b/cmd/sojuctl/main.go index 714a33d..368d2d8 100644 --- a/cmd/sojuctl/main.go +++ b/cmd/sojuctl/main.go @@ -69,7 +69,7 @@ func main() { Username: username, Password: string(hashed), } - if err := db.CreateUser(&user); err != nil { + if err := db.StoreUser(&user); err != nil { log.Fatalf("failed to create user: %v", err) } case "change-password": @@ -90,13 +90,13 @@ func main() { } user := soju.User{ + Created: true, Username: username, Password: string(hashed), } - if err := db.UpdatePassword(&user); err != nil { + if err := db.StoreUser(&user); err != nil { log.Fatalf("failed to update password: %v", err) } - default: flag.Usage() if cmd != "help" { diff --git a/db.go b/db.go index 31bc146..416c699 100644 --- a/db.go +++ b/db.go @@ -10,6 +10,7 @@ import ( ) type User struct { + Created bool Username string Password string // hashed } @@ -199,6 +200,7 @@ func (db *DB) ListUsers() ([]User, error) { if err := rows.Scan(&user.Username, &password); err != nil { return nil, err } + user.Created = true user.Password = fromStringPtr(password) users = append(users, user) } @@ -213,7 +215,7 @@ func (db *DB) GetUser(username string) (*User, error) { db.lock.RLock() defer db.lock.RUnlock() - user := &User{Username: username} + user := &User{Created: true, Username: username} var password *string row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username) @@ -224,24 +226,24 @@ func (db *DB) GetUser(username string) (*User, error) { return user, nil } -func (db *DB) CreateUser(user *User) error { +func (db *DB) StoreUser(user *User) error { db.lock.Lock() defer db.lock.Unlock() password := toStringPtr(user.Password) - _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password) - return err -} -func (db *DB) UpdatePassword(user *User) error { - db.lock.Lock() - defer db.lock.Unlock() + var err error + if user.Created { + _, err = db.db.Exec("UPDATE User SET password = ? WHERE username = ?", + password, user.Username) + } else { + _, err = db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", + user.Username, password) + if err == nil { + user.Created = true + } + } - password := toStringPtr(user.Password) - _, err := db.db.Exec(`UPDATE User - SET password = ? - WHERE username = ?`, - password, user.Username) return err } diff --git a/user.go b/user.go index 6d08caa..ea55e8a 100644 --- a/user.go +++ b/user.go @@ -547,5 +547,5 @@ func (u *user) deleteNetwork(id int64) error { func (u *user) updatePassword(hashed string) error { u.User.Password = hashed - return u.srv.db.UpdatePassword(&u.User) + return u.srv.db.StoreUser(&u.User) }