database: add User.{Check,Set}Password

This commit is contained in:
Simon Ser 2022-06-08 13:27:33 +02:00
parent 09f2cf8489
commit fe40c51ff0
3 changed files with 43 additions and 27 deletions

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt"
) )
type Database interface { type Database interface {
@ -63,6 +64,29 @@ type User struct {
Admin bool Admin bool
} }
func (u *User) CheckPassword(password string) error {
// Password auth disabled
if u.Password == "" {
return fmt.Errorf("password auth disabled")
}
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
if err != nil {
return fmt.Errorf("wrong password: %v", err)
}
return nil
}
func (u *User) SetPassword(password string) error {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
u.Password = string(hashed)
return nil
}
type SASL struct { type SASL struct {
Mechanism string Mechanism string

View File

@ -14,7 +14,6 @@ import (
"time" "time"
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
"golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
@ -1304,14 +1303,8 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err)) return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
} }
// Password auth disabled if err := u.CheckPassword(password); err != nil {
if u.Password == "" { return newInvalidUsernameOrPasswordError(err)
return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled"))
}
err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
if err != nil {
return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password"))
} }
dc.user = dc.srv.getUser(username) dc.user = dc.srv.getUser(username)

View File

@ -830,17 +830,14 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string)
return fmt.Errorf("flag -password is required") return fmt.Errorf("flag -password is required")
} }
hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
user := &database.User{ user := &database.User{
Username: *username, Username: *username,
Password: string(hashed),
Realname: *realname, Realname: *realname,
Admin: *admin, Admin: *admin,
} }
if err := user.SetPassword(*password); err != nil {
return err
}
if _, err := dc.srv.createUser(ctx, user); err != nil { if _, err := dc.srv.createUser(ctx, user); err != nil {
return fmt.Errorf("could not create user: %v", err) return fmt.Errorf("could not create user: %v", err)
} }
@ -872,16 +869,6 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string)
return fmt.Errorf("unexpected argument") return fmt.Errorf("unexpected argument")
} }
var hashed *string
if password != nil {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
hashedStr := string(hashedBytes)
hashed = &hashedStr
}
if username != "" && username != dc.user.Username { if username != "" && username != dc.user.Username {
if !dc.user.Admin { if !dc.user.Admin {
return fmt.Errorf("you must be an admin to update other users") return fmt.Errorf("you must be an admin to update other users")
@ -890,6 +877,16 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string)
return fmt.Errorf("cannot update -realname of other user") return fmt.Errorf("cannot update -realname of other user")
} }
var hashed *string
if password != nil {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
hashedStr := string(hashedBytes)
hashed = &hashedStr
}
u := dc.srv.getUser(username) u := dc.srv.getUser(username)
if u == nil { if u == nil {
return fmt.Errorf("unknown username %q", username) return fmt.Errorf("unknown username %q", username)
@ -916,8 +913,10 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string)
// copy the user record because we'll mutate it // copy the user record because we'll mutate it
record := dc.user.User record := dc.user.User
if hashed != nil { if password != nil {
record.Password = *hashed if err := record.SetPassword(*password); err != nil {
return err
}
} }
if realname != nil { if realname != nil {
record.Realname = *realname record.Realname = *realname