Add support for SASL OAUTHBEARER
This commit is contained in:
parent
22a88079c2
commit
c79fc0c19e
@ -7,11 +7,17 @@ import (
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
type Authenticator interface{}
|
||||
|
||||
type PlainAuthenticator interface {
|
||||
AuthPlain(ctx context.Context, db database.Database, username, password string) error
|
||||
}
|
||||
|
||||
func New(driver, source string) (PlainAuthenticator, error) {
|
||||
type OAuthBearerAuthenticator interface {
|
||||
AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error)
|
||||
}
|
||||
|
||||
func New(driver, source string) (Authenticator, error) {
|
||||
switch driver {
|
||||
case "internal":
|
||||
return NewInternal(), nil
|
||||
|
@ -19,7 +19,12 @@ type oauth2 struct {
|
||||
clientSecret string
|
||||
}
|
||||
|
||||
func newOAuth2(authURL string) (PlainAuthenticator, error) {
|
||||
var (
|
||||
_ PlainAuthenticator = (*oauth2)(nil)
|
||||
_ OAuthBearerAuthenticator = (*oauth2)(nil)
|
||||
)
|
||||
|
||||
func newOAuth2(authURL string) (Authenticator, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@ -77,14 +82,27 @@ func newOAuth2(authURL string) (PlainAuthenticator, error) {
|
||||
}
|
||||
|
||||
func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error {
|
||||
effectiveUsername, err := auth.AuthOAuthBearer(ctx, db, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if username != effectiveUsername {
|
||||
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error) {
|
||||
reqValues := make(url.Values)
|
||||
reqValues.Set("token", password)
|
||||
reqValues.Set("token", token)
|
||||
|
||||
reqBody := strings.NewReader(reqValues.Encode())
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err)
|
||||
return "", fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
@ -95,32 +113,29 @@ func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, usernam
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
|
||||
return "", fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
|
||||
return "", fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
|
||||
}
|
||||
|
||||
var data oauth2Introspection
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
|
||||
return "", fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
|
||||
}
|
||||
|
||||
if !data.Active {
|
||||
return fmt.Errorf("invalid access token")
|
||||
return "", fmt.Errorf("invalid access token")
|
||||
}
|
||||
if data.Username == "" {
|
||||
// We really need the username here, otherwise an OAuth 2.0 user can
|
||||
// impersonate any other user.
|
||||
return fmt.Errorf("missing username in OAuth 2.0 introspection response")
|
||||
}
|
||||
if username != data.Username {
|
||||
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", data.Username)
|
||||
return "", fmt.Errorf("missing username in OAuth 2.0 introspection response")
|
||||
}
|
||||
|
||||
return nil
|
||||
return data.Username, nil
|
||||
}
|
||||
|
||||
type oauth2Introspection struct {
|
||||
|
123
downstream.go
123
downstream.go
@ -17,6 +17,7 @@ import (
|
||||
"github.com/emersion/go-sasl"
|
||||
"gopkg.in/irc.v4"
|
||||
|
||||
"git.sr.ht/~emersion/soju/auth"
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
"git.sr.ht/~emersion/soju/msgstore"
|
||||
"git.sr.ht/~emersion/soju/xirc"
|
||||
@ -310,10 +311,16 @@ var passthroughIsupport = map[string]bool{
|
||||
"WHOX": true,
|
||||
}
|
||||
|
||||
type saslPlain struct {
|
||||
Username, Password string
|
||||
}
|
||||
|
||||
type downstreamSASL struct {
|
||||
server sasl.Server
|
||||
plainUsername, plainPassword string
|
||||
pendingResp bytes.Buffer
|
||||
server sasl.Server
|
||||
mechanism string
|
||||
plain *saslPlain
|
||||
oauthBearer *sasl.OAuthBearerOptions
|
||||
pendingResp bytes.Buffer
|
||||
}
|
||||
|
||||
type downstreamRegistration struct {
|
||||
@ -327,6 +334,17 @@ type downstreamRegistration struct {
|
||||
negotiatingCaps bool
|
||||
}
|
||||
|
||||
func serverSASLMechanisms(srv *Server) []string {
|
||||
var l []string
|
||||
if _, ok := srv.Config().Auth.(auth.PlainAuthenticator); ok {
|
||||
l = append(l, "PLAIN")
|
||||
}
|
||||
if _, ok := srv.Config().Auth.(auth.OAuthBearerAuthenticator); ok {
|
||||
l = append(l, "OAUTHBEARER")
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
type downstreamConn struct {
|
||||
conn
|
||||
|
||||
@ -379,7 +397,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
|
||||
for k, v := range permanentDownstreamCaps {
|
||||
dc.caps.Available[k] = v
|
||||
}
|
||||
dc.caps.Available["sasl"] = "PLAIN"
|
||||
dc.caps.Available["sasl"] = strings.Join(serverSASLMechanisms(dc.srv), ",")
|
||||
// TODO: this is racy, we should only enable chathistory after
|
||||
// authentication and then check that user.msgStore implements
|
||||
// chatHistoryMessageStore
|
||||
@ -659,8 +677,52 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
||||
break
|
||||
}
|
||||
|
||||
if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
|
||||
dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err)
|
||||
var username, clientName, networkName string
|
||||
switch credentials.mechanism {
|
||||
case "PLAIN":
|
||||
username, clientName, networkName = unmarshalUsername(credentials.plain.Username)
|
||||
password := credentials.plain.Password
|
||||
|
||||
auth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator)
|
||||
if !ok {
|
||||
err = fmt.Errorf("SASL PLAIN not supported")
|
||||
break
|
||||
}
|
||||
|
||||
if authErr := auth.AuthPlain(ctx, dc.srv.db, username, password); authErr != nil {
|
||||
err = newInvalidUsernameOrPasswordError(authErr)
|
||||
break
|
||||
}
|
||||
case "OAUTHBEARER":
|
||||
auth, ok := dc.srv.Config().Auth.(auth.OAuthBearerAuthenticator)
|
||||
if !ok {
|
||||
err = fmt.Errorf("SASL OAUTHBEARER not supported")
|
||||
break
|
||||
}
|
||||
|
||||
var authErr error
|
||||
username, authErr = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token)
|
||||
if authErr != nil {
|
||||
err = newInvalidUsernameOrPasswordError(authErr)
|
||||
break
|
||||
}
|
||||
|
||||
if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username {
|
||||
err = newInvalidUsernameOrPasswordError(fmt.Errorf("username mismatch (server returned %q)", username))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if username == "" {
|
||||
panic(fmt.Errorf("username unset after SASL authentication"))
|
||||
}
|
||||
err = dc.setUser(username, clientName, networkName)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
|
||||
dc.endSASL(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
Command: irc.ERR_SASLFAIL,
|
||||
@ -878,8 +940,15 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
|
||||
switch mech {
|
||||
case "PLAIN":
|
||||
server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
|
||||
dc.sasl.plainUsername = username
|
||||
dc.sasl.plainPassword = password
|
||||
dc.sasl.plain = &saslPlain{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
case "OAUTHBEARER":
|
||||
server = sasl.NewOAuthBearerServer(sasl.OAuthBearerAuthenticator(func(options sasl.OAuthBearerOptions) *sasl.OAuthBearerError {
|
||||
dc.sasl.oauthBearer = &options
|
||||
return nil
|
||||
}))
|
||||
default:
|
||||
@ -890,7 +959,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
|
||||
}}
|
||||
}
|
||||
|
||||
dc.sasl = &downstreamSASL{server: server}
|
||||
dc.sasl = &downstreamSASL{server: server, mechanism: mech}
|
||||
} else {
|
||||
chunk := msg.Params[0]
|
||||
if chunk == "+" {
|
||||
@ -1189,13 +1258,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
||||
return username, client, network
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
|
||||
username, clientName, networkName := unmarshalUsername(username)
|
||||
|
||||
if err := dc.srv.Config().Auth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
|
||||
return newInvalidUsernameOrPasswordError(err)
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) setUser(username, clientName, networkName string) error {
|
||||
dc.user = dc.srv.getUser(username)
|
||||
if dc.user == nil {
|
||||
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
|
||||
@ -1205,6 +1268,21 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
|
||||
username, clientName, networkName := unmarshalUsername(username)
|
||||
|
||||
plainAuth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator)
|
||||
if !ok {
|
||||
return fmt.Errorf("PLAIN authentication unsupported")
|
||||
}
|
||||
|
||||
if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
|
||||
return newInvalidUsernameOrPasswordError(err)
|
||||
}
|
||||
|
||||
return dc.setUser(username, clientName, networkName)
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) register(ctx context.Context) error {
|
||||
if dc.registered {
|
||||
panic("tried to register twice")
|
||||
@ -2420,6 +2498,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
}
|
||||
|
||||
if credentials != nil {
|
||||
if credentials.mechanism != "PLAIN" {
|
||||
dc.endSASL(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
Command: irc.ERR_SASLFAIL,
|
||||
Params: []string{dc.nick, "Unsupported SASL authentication mechanism"},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if uc.saslClient != nil {
|
||||
dc.endSASL(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
@ -2429,8 +2516,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
return nil
|
||||
}
|
||||
|
||||
uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
|
||||
uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
|
||||
uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plain.Username)
|
||||
uc.saslClient = sasl.NewPlainClient("", credentials.plain.Username, credentials.plain.Password)
|
||||
uc.enqueueCommand(dc, &irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"PLAIN"},
|
||||
|
@ -144,7 +144,7 @@ type Config struct {
|
||||
UpstreamUserIPs []*net.IPNet
|
||||
DisableInactiveUsersDelay time.Duration
|
||||
EnableUsersOnAuth bool
|
||||
Auth auth.PlainAuthenticator
|
||||
Auth auth.Authenticator
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
|
@ -862,7 +862,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
|
||||
if msg.Command == irc.RPL_SASLSUCCESS {
|
||||
uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword)
|
||||
uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password)
|
||||
}
|
||||
|
||||
dc.endSASL(msg)
|
||||
|
Loading…
Reference in New Issue
Block a user