database: add User.{Check,Set}Password
This commit is contained in:
parent
09f2cf8489
commit
fe40c51ff0
@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
@ -63,6 +64,29 @@ type User struct {
|
|||||||
Admin bool
|
Admin bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *User) CheckPassword(password string) error {
|
||||||
|
// Password auth disabled
|
||||||
|
if u.Password == "" {
|
||||||
|
return fmt.Errorf("password auth disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("wrong password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetPassword(password string) error {
|
||||||
|
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash password: %v", err)
|
||||||
|
}
|
||||||
|
u.Password = string(hashed)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type SASL struct {
|
type SASL struct {
|
||||||
Mechanism string
|
Mechanism string
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/emersion/go-sasl"
|
"github.com/emersion/go-sasl"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/soju/database"
|
"git.sr.ht/~emersion/soju/database"
|
||||||
@ -1304,14 +1303,8 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
|
|||||||
return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
|
return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Password auth disabled
|
if err := u.CheckPassword(password); err != nil {
|
||||||
if u.Password == "" {
|
return newInvalidUsernameOrPasswordError(err)
|
||||||
return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled"))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
|
|
||||||
if err != nil {
|
|
||||||
return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dc.user = dc.srv.getUser(username)
|
dc.user = dc.srv.getUser(username)
|
||||||
|
35
service.go
35
service.go
@ -830,17 +830,14 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string)
|
|||||||
return fmt.Errorf("flag -password is required")
|
return fmt.Errorf("flag -password is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to hash password: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user := &database.User{
|
user := &database.User{
|
||||||
Username: *username,
|
Username: *username,
|
||||||
Password: string(hashed),
|
|
||||||
Realname: *realname,
|
Realname: *realname,
|
||||||
Admin: *admin,
|
Admin: *admin,
|
||||||
}
|
}
|
||||||
|
if err := user.SetPassword(*password); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if _, err := dc.srv.createUser(ctx, user); err != nil {
|
if _, err := dc.srv.createUser(ctx, user); err != nil {
|
||||||
return fmt.Errorf("could not create user: %v", err)
|
return fmt.Errorf("could not create user: %v", err)
|
||||||
}
|
}
|
||||||
@ -872,16 +869,6 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string)
|
|||||||
return fmt.Errorf("unexpected argument")
|
return fmt.Errorf("unexpected argument")
|
||||||
}
|
}
|
||||||
|
|
||||||
var hashed *string
|
|
||||||
if password != nil {
|
|
||||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to hash password: %v", err)
|
|
||||||
}
|
|
||||||
hashedStr := string(hashedBytes)
|
|
||||||
hashed = &hashedStr
|
|
||||||
}
|
|
||||||
|
|
||||||
if username != "" && username != dc.user.Username {
|
if username != "" && username != dc.user.Username {
|
||||||
if !dc.user.Admin {
|
if !dc.user.Admin {
|
||||||
return fmt.Errorf("you must be an admin to update other users")
|
return fmt.Errorf("you must be an admin to update other users")
|
||||||
@ -890,6 +877,16 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string)
|
|||||||
return fmt.Errorf("cannot update -realname of other user")
|
return fmt.Errorf("cannot update -realname of other user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var hashed *string
|
||||||
|
if password != nil {
|
||||||
|
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash password: %v", err)
|
||||||
|
}
|
||||||
|
hashedStr := string(hashedBytes)
|
||||||
|
hashed = &hashedStr
|
||||||
|
}
|
||||||
|
|
||||||
u := dc.srv.getUser(username)
|
u := dc.srv.getUser(username)
|
||||||
if u == nil {
|
if u == nil {
|
||||||
return fmt.Errorf("unknown username %q", username)
|
return fmt.Errorf("unknown username %q", username)
|
||||||
@ -916,8 +913,10 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string)
|
|||||||
// copy the user record because we'll mutate it
|
// copy the user record because we'll mutate it
|
||||||
record := dc.user.User
|
record := dc.user.User
|
||||||
|
|
||||||
if hashed != nil {
|
if password != nil {
|
||||||
record.Password = *hashed
|
if err := record.SetPassword(*password); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if realname != nil {
|
if realname != nil {
|
||||||
record.Realname = *realname
|
record.Realname = *realname
|
||||||
|
Loading…
Reference in New Issue
Block a user