Move authError to auth package

This allows auth backends to customize the error message displayed
to the user.
This commit is contained in:
Simon Ser 2023-02-23 22:32:24 +01:00
parent 05a0775658
commit 36d6cb19a4
5 changed files with 39 additions and 39 deletions

View File

@ -29,3 +29,29 @@ func New(driver, source string) (Authenticator, error) {
return nil, fmt.Errorf("unknown auth driver %q", driver) return nil, fmt.Errorf("unknown auth driver %q", driver)
} }
} }
// Error is an authentication error.
type Error struct {
// Internal error cause. This will not be revealed to the user.
InternalErr error
// Message which can safely be sent to the user without compromising
// security.
ExternalMsg string
}
func (err *Error) Error() string {
return err.InternalErr.Error()
}
func (err *Error) Unwrap() error {
return err.InternalErr
}
// newInvalidCredentialsError wraps the provided error into an Error and
// indicates to the user that the provided credentials were invalid.
func newInvalidCredentialsError(err error) *Error {
return &Error{
InternalErr: err,
ExternalMsg: "Invalid credentials",
}
}

View File

@ -16,12 +16,12 @@ func NewInternal() PlainAuthenticator {
func (internal) AuthPlain(ctx context.Context, db database.Database, username, password string) error { func (internal) AuthPlain(ctx context.Context, db database.Database, username, password string) error {
u, err := db.GetUser(ctx, username) u, err := db.GetUser(ctx, username)
if err != nil { if err != nil {
return fmt.Errorf("user not found: %w", err) return newInvalidCredentialsError(fmt.Errorf("user not found: %w", err))
} }
upgraded, err := u.CheckPassword(password) upgraded, err := u.CheckPassword(password)
if err != nil { if err != nil {
return err return newInvalidCredentialsError(err)
} }
if upgraded { if upgraded {

View File

@ -88,7 +88,7 @@ func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, usernam
} }
if username != effectiveUsername { if username != effectiveUsername {
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername) return newInvalidCredentialsError(fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername))
} }
return nil return nil
@ -127,7 +127,7 @@ func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, t
} }
if !data.Active { if !data.Active {
return "", fmt.Errorf("invalid access token") return "", newInvalidCredentialsError(fmt.Errorf("invalid access token"))
} }
if data.Username == "" { if data.Username == "" {
// We really need the username here, otherwise an OAuth 2.0 user can // We really need the username here, otherwise an OAuth 2.0 user can

View File

@ -37,7 +37,7 @@ func (pamAuth) AuthPlain(ctx context.Context, db database.Database, username, pa
} }
if err := t.Authenticate(0); err != nil { if err := t.Authenticate(0); err != nil {
return fmt.Errorf("PAM auth error: %v", err) return newInvalidCredentialsError(fmt.Errorf("PAM auth error: %v", err))
} }
if err := t.AcctMgmt(0); err != nil { if err := t.AcctMgmt(0); err != nil {

View File

@ -67,40 +67,16 @@ func newChatHistoryError(subcommand string, target string) ircError {
}} }}
} }
// authError is an authentication error.
type authError struct {
// Internal error cause. This will not be revealed to the user.
err error
// Error cause which can safely be sent to the user without compromising
// security.
reason string
}
func (err *authError) Error() string {
return err.err.Error()
}
func (err *authError) Unwrap() error {
return err.err
}
// authErrorReason returns the user-friendly reason of an authentication // authErrorReason returns the user-friendly reason of an authentication
// failure. // failure.
func authErrorReason(err error) string { func authErrorReason(err error) string {
if authErr, ok := err.(*authError); ok { if authErr, ok := err.(*auth.Error); ok {
return authErr.reason return authErr.ExternalMsg
} else { } else {
return "Authentication failed" return "Authentication failed"
} }
} }
func newInvalidUsernameOrPasswordError(err error) error {
return &authError{
err: err,
reason: "Invalid username or password",
}
}
func parseBouncerNetID(subcommand, s string) (int64, error) { func parseBouncerNetID(subcommand, s string) (int64, error) {
id, err := strconv.ParseInt(s, 10, 64) id, err := strconv.ParseInt(s, 10, 64)
if err != nil || id <= 0 { if err != nil || id <= 0 {
@ -690,8 +666,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
break break
} }
if authErr := auth.AuthPlain(ctx, dc.srv.db, username, password); authErr != nil { if err = auth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
err = newInvalidUsernameOrPasswordError(authErr)
break break
} }
case "OAUTHBEARER": case "OAUTHBEARER":
@ -701,15 +676,14 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
break break
} }
var authErr error username, err = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token)
username, authErr = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token) if err != nil {
if authErr != nil {
err = newInvalidUsernameOrPasswordError(authErr)
break break
} }
if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username { if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username {
err = newInvalidUsernameOrPasswordError(fmt.Errorf("username mismatch (server returned %q)", username)) err = fmt.Errorf("username mismatch (server returned %q)", username)
break
} }
default: default:
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism)) panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
@ -1292,7 +1266,7 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
} }
if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil { if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
return newInvalidUsernameOrPasswordError(err) return err
} }
return dc.setUser(username, clientName, networkName) return dc.setUser(username, clientName, networkName)