diff --git a/auth/auth.go b/auth/auth.go index e0f4608..188e149 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -29,3 +29,29 @@ func New(driver, source string) (Authenticator, error) { return nil, fmt.Errorf("unknown auth driver %q", driver) } } + +// Error is an authentication error. +type Error struct { + // Internal error cause. This will not be revealed to the user. + InternalErr error + // Message which can safely be sent to the user without compromising + // security. + ExternalMsg string +} + +func (err *Error) Error() string { + return err.InternalErr.Error() +} + +func (err *Error) Unwrap() error { + return err.InternalErr +} + +// newInvalidCredentialsError wraps the provided error into an Error and +// indicates to the user that the provided credentials were invalid. +func newInvalidCredentialsError(err error) *Error { + return &Error{ + InternalErr: err, + ExternalMsg: "Invalid credentials", + } +} diff --git a/auth/internal.go b/auth/internal.go index 509803c..70b4e80 100644 --- a/auth/internal.go +++ b/auth/internal.go @@ -16,12 +16,12 @@ func NewInternal() PlainAuthenticator { func (internal) AuthPlain(ctx context.Context, db database.Database, username, password string) error { u, err := db.GetUser(ctx, username) if err != nil { - return fmt.Errorf("user not found: %w", err) + return newInvalidCredentialsError(fmt.Errorf("user not found: %w", err)) } upgraded, err := u.CheckPassword(password) if err != nil { - return err + return newInvalidCredentialsError(err) } if upgraded { diff --git a/auth/oauth2.go b/auth/oauth2.go index 4485ebe..2ae519a 100644 --- a/auth/oauth2.go +++ b/auth/oauth2.go @@ -88,7 +88,7 @@ func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, usernam } if username != effectiveUsername { - return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername) + return newInvalidCredentialsError(fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername)) } return nil @@ -127,7 +127,7 @@ func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, t } if !data.Active { - return "", fmt.Errorf("invalid access token") + return "", newInvalidCredentialsError(fmt.Errorf("invalid access token")) } if data.Username == "" { // We really need the username here, otherwise an OAuth 2.0 user can diff --git a/auth/pam.go b/auth/pam.go index 21c14de..f0e5aa4 100644 --- a/auth/pam.go +++ b/auth/pam.go @@ -37,7 +37,7 @@ func (pamAuth) AuthPlain(ctx context.Context, db database.Database, username, pa } if err := t.Authenticate(0); err != nil { - return fmt.Errorf("PAM auth error: %v", err) + return newInvalidCredentialsError(fmt.Errorf("PAM auth error: %v", err)) } if err := t.AcctMgmt(0); err != nil { diff --git a/downstream.go b/downstream.go index d7833bf..467f5be 100644 --- a/downstream.go +++ b/downstream.go @@ -67,40 +67,16 @@ func newChatHistoryError(subcommand string, target string) ircError { }} } -// authError is an authentication error. -type authError struct { - // Internal error cause. This will not be revealed to the user. - err error - // Error cause which can safely be sent to the user without compromising - // security. - reason string -} - -func (err *authError) Error() string { - return err.err.Error() -} - -func (err *authError) Unwrap() error { - return err.err -} - // authErrorReason returns the user-friendly reason of an authentication // failure. func authErrorReason(err error) string { - if authErr, ok := err.(*authError); ok { - return authErr.reason + if authErr, ok := err.(*auth.Error); ok { + return authErr.ExternalMsg } else { return "Authentication failed" } } -func newInvalidUsernameOrPasswordError(err error) error { - return &authError{ - err: err, - reason: "Invalid username or password", - } -} - func parseBouncerNetID(subcommand, s string) (int64, error) { id, err := strconv.ParseInt(s, 10, 64) if err != nil || id <= 0 { @@ -690,8 +666,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir break } - if authErr := auth.AuthPlain(ctx, dc.srv.db, username, password); authErr != nil { - err = newInvalidUsernameOrPasswordError(authErr) + if err = auth.AuthPlain(ctx, dc.srv.db, username, password); err != nil { break } case "OAUTHBEARER": @@ -701,15 +676,14 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir break } - var authErr error - username, authErr = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token) - if authErr != nil { - err = newInvalidUsernameOrPasswordError(authErr) + username, err = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token) + if err != nil { break } if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username { - err = newInvalidUsernameOrPasswordError(fmt.Errorf("username mismatch (server returned %q)", username)) + err = fmt.Errorf("username mismatch (server returned %q)", username) + break } default: panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism)) @@ -1292,7 +1266,7 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s } if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil { - return newInvalidUsernameOrPasswordError(err) + return err } return dc.setUser(username, clientName, networkName)