soju/auth/oauth2.go

186 lines
5.1 KiB
Go

package auth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"path"
"strings"
"time"
"git.sr.ht/~emersion/soju/database"
)
type oauth2 struct {
introspectionURL *url.URL
clientID string
clientSecret string
}
var (
_ PlainAuthenticator = (*oauth2)(nil)
_ OAuthBearerAuthenticator = (*oauth2)(nil)
)
func newOAuth2(authURL string) (Authenticator, error) {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
u, err := url.Parse(authURL)
if err != nil {
return nil, fmt.Errorf("failed to parse OAuth 2.0 server URL: %v", err)
}
var clientID, clientSecret string
if u.User != nil {
clientID = u.User.Username()
clientSecret, _ = u.User.Password()
}
discoveryURL := *u
discoveryURL.User = nil
discoveryURL.Path = path.Join("/.well-known/oauth-authorization-server", u.Path)
server, err := discoverOAuth2(ctx, discoveryURL.String())
if err != nil {
return nil, fmt.Errorf("OAuth 2.0 discovery failed: %v", err)
}
if server.IntrospectionEndpoint == "" {
return nil, fmt.Errorf("OAuth 2.0 server doesn't support token introspection")
}
introspectionURL, err := url.Parse(server.IntrospectionEndpoint)
if err != nil {
return nil, fmt.Errorf("failed to parse OAuth 2.0 introspection URL")
}
if server.IntrospectionEndpointAuthMethodsSupported != nil {
var supportsNone, supportsBasic bool
for _, name := range server.IntrospectionEndpointAuthMethodsSupported {
switch name {
case "none":
supportsNone = true
case "client_secret_basic":
supportsBasic = true
}
}
if clientID == "" && !supportsNone {
return nil, fmt.Errorf("OAuth 2.0 server requires authentication for introspection")
}
if clientID != "" && !supportsBasic {
return nil, fmt.Errorf("OAuth 2.0 server doesn't support Basic HTTP authentication for introspection")
}
}
return &oauth2{
introspectionURL: introspectionURL,
clientID: clientID,
clientSecret: clientSecret,
}, nil
}
func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error {
effectiveUsername, err := auth.AuthOAuthBearer(ctx, db, password)
if err != nil {
return err
}
if username != effectiveUsername {
return newInvalidCredentialsError(fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername))
}
return nil
}
func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error) {
reqValues := make(url.Values)
reqValues.Set("token", token)
reqBody := strings.NewReader(reqValues.Encode())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody)
if err != nil {
return "", fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
if auth.clientID != "" {
req.SetBasicAuth(url.QueryEscape(auth.clientID), url.QueryEscape(auth.clientSecret))
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
}
var data oauth2Introspection
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return "", fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
}
if !data.Active {
return "", newInvalidCredentialsError(fmt.Errorf("invalid access token"))
}
if data.Username == "" {
// We really need the username here, otherwise an OAuth 2.0 user can
// impersonate any other user.
return "", fmt.Errorf("missing username in OAuth 2.0 introspection response")
}
return data.Username, nil
}
type oauth2Introspection struct {
Active bool `json:"active"`
Username string `json:"username"`
}
type oauth2Server struct {
Issuer string `json:"issuer"`
IntrospectionEndpoint string `json:"introspection_endpoint"`
IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"`
}
type oauth2HTTPError string
func (err oauth2HTTPError) Error() string {
return fmt.Sprintf("OAuth 2.0 HTTP error: %v", string(err))
}
func discoverOAuth2(ctx context.Context, discoveryURL string) (*oauth2Server, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, oauth2HTTPError(resp.Status)
}
var data oauth2Server
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, fmt.Errorf("failed to decode response: %v", err)
}
if data.Issuer == "" {
return nil, fmt.Errorf("missing issuer in response")
}
return &data, nil
}