Add support for SASL OAUTHBEARER

This commit is contained in:
Simon Ser 2022-10-14 10:44:32 +02:00
parent 22a88079c2
commit c79fc0c19e
5 changed files with 141 additions and 33 deletions

View File

@ -7,11 +7,17 @@ import (
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
) )
type Authenticator interface{}
type PlainAuthenticator interface { type PlainAuthenticator interface {
AuthPlain(ctx context.Context, db database.Database, username, password string) error AuthPlain(ctx context.Context, db database.Database, username, password string) error
} }
func New(driver, source string) (PlainAuthenticator, error) { type OAuthBearerAuthenticator interface {
AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error)
}
func New(driver, source string) (Authenticator, error) {
switch driver { switch driver {
case "internal": case "internal":
return NewInternal(), nil return NewInternal(), nil

View File

@ -19,7 +19,12 @@ type oauth2 struct {
clientSecret string clientSecret string
} }
func newOAuth2(authURL string) (PlainAuthenticator, error) { var (
_ PlainAuthenticator = (*oauth2)(nil)
_ OAuthBearerAuthenticator = (*oauth2)(nil)
)
func newOAuth2(authURL string) (Authenticator, error) {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel() defer cancel()
@ -77,14 +82,27 @@ func newOAuth2(authURL string) (PlainAuthenticator, error) {
} }
func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error { func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error {
effectiveUsername, err := auth.AuthOAuthBearer(ctx, db, password)
if err != nil {
return err
}
if username != effectiveUsername {
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername)
}
return nil
}
func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error) {
reqValues := make(url.Values) reqValues := make(url.Values)
reqValues.Set("token", password) reqValues.Set("token", token)
reqBody := strings.NewReader(reqValues.Encode()) reqBody := strings.NewReader(reqValues.Encode())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody) req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody)
if err != nil { if err != nil {
return fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err) return "", fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err)
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
@ -95,32 +113,29 @@ func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, usernam
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err) return "", fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status) return "", fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
} }
var data oauth2Introspection var data oauth2Introspection
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err) return "", fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
} }
if !data.Active { if !data.Active {
return fmt.Errorf("invalid access token") return "", fmt.Errorf("invalid access token")
} }
if data.Username == "" { if data.Username == "" {
// We really need the username here, otherwise an OAuth 2.0 user can // We really need the username here, otherwise an OAuth 2.0 user can
// impersonate any other user. // impersonate any other user.
return fmt.Errorf("missing username in OAuth 2.0 introspection response") return "", fmt.Errorf("missing username in OAuth 2.0 introspection response")
}
if username != data.Username {
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", data.Username)
} }
return nil return data.Username, nil
} }
type oauth2Introspection struct { type oauth2Introspection struct {

View File

@ -17,6 +17,7 @@ import (
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
"git.sr.ht/~emersion/soju/auth"
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore" "git.sr.ht/~emersion/soju/msgstore"
"git.sr.ht/~emersion/soju/xirc" "git.sr.ht/~emersion/soju/xirc"
@ -310,10 +311,16 @@ var passthroughIsupport = map[string]bool{
"WHOX": true, "WHOX": true,
} }
type saslPlain struct {
Username, Password string
}
type downstreamSASL struct { type downstreamSASL struct {
server sasl.Server server sasl.Server
plainUsername, plainPassword string mechanism string
pendingResp bytes.Buffer plain *saslPlain
oauthBearer *sasl.OAuthBearerOptions
pendingResp bytes.Buffer
} }
type downstreamRegistration struct { type downstreamRegistration struct {
@ -327,6 +334,17 @@ type downstreamRegistration struct {
negotiatingCaps bool negotiatingCaps bool
} }
func serverSASLMechanisms(srv *Server) []string {
var l []string
if _, ok := srv.Config().Auth.(auth.PlainAuthenticator); ok {
l = append(l, "PLAIN")
}
if _, ok := srv.Config().Auth.(auth.OAuthBearerAuthenticator); ok {
l = append(l, "OAUTHBEARER")
}
return l
}
type downstreamConn struct { type downstreamConn struct {
conn conn
@ -379,7 +397,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
for k, v := range permanentDownstreamCaps { for k, v := range permanentDownstreamCaps {
dc.caps.Available[k] = v dc.caps.Available[k] = v
} }
dc.caps.Available["sasl"] = "PLAIN" dc.caps.Available["sasl"] = strings.Join(serverSASLMechanisms(dc.srv), ",")
// TODO: this is racy, we should only enable chathistory after // TODO: this is racy, we should only enable chathistory after
// authentication and then check that user.msgStore implements // authentication and then check that user.msgStore implements
// chatHistoryMessageStore // chatHistoryMessageStore
@ -659,8 +677,52 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
break break
} }
if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { var username, clientName, networkName string
dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err) switch credentials.mechanism {
case "PLAIN":
username, clientName, networkName = unmarshalUsername(credentials.plain.Username)
password := credentials.plain.Password
auth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator)
if !ok {
err = fmt.Errorf("SASL PLAIN not supported")
break
}
if authErr := auth.AuthPlain(ctx, dc.srv.db, username, password); authErr != nil {
err = newInvalidUsernameOrPasswordError(authErr)
break
}
case "OAUTHBEARER":
auth, ok := dc.srv.Config().Auth.(auth.OAuthBearerAuthenticator)
if !ok {
err = fmt.Errorf("SASL OAUTHBEARER not supported")
break
}
var authErr error
username, authErr = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token)
if authErr != nil {
err = newInvalidUsernameOrPasswordError(authErr)
break
}
if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username {
err = newInvalidUsernameOrPasswordError(fmt.Errorf("username mismatch (server returned %q)", username))
}
default:
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
}
if err == nil {
if username == "" {
panic(fmt.Errorf("username unset after SASL authentication"))
}
err = dc.setUser(username, clientName, networkName)
}
if err != nil {
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
dc.endSASL(&irc.Message{ dc.endSASL(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_SASLFAIL, Command: irc.ERR_SASLFAIL,
@ -878,8 +940,15 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
switch mech { switch mech {
case "PLAIN": case "PLAIN":
server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
dc.sasl.plainUsername = username dc.sasl.plain = &saslPlain{
dc.sasl.plainPassword = password Username: username,
Password: password,
}
return nil
}))
case "OAUTHBEARER":
server = sasl.NewOAuthBearerServer(sasl.OAuthBearerAuthenticator(func(options sasl.OAuthBearerOptions) *sasl.OAuthBearerError {
dc.sasl.oauthBearer = &options
return nil return nil
})) }))
default: default:
@ -890,7 +959,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
}} }}
} }
dc.sasl = &downstreamSASL{server: server} dc.sasl = &downstreamSASL{server: server, mechanism: mech}
} else { } else {
chunk := msg.Params[0] chunk := msg.Params[0]
if chunk == "+" { if chunk == "+" {
@ -1189,13 +1258,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
return username, client, network return username, client, network
} }
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { func (dc *downstreamConn) setUser(username, clientName, networkName string) error {
username, clientName, networkName := unmarshalUsername(username)
if err := dc.srv.Config().Auth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
return newInvalidUsernameOrPasswordError(err)
}
dc.user = dc.srv.getUser(username) dc.user = dc.srv.getUser(username)
if dc.user == nil { if dc.user == nil {
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help") return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
@ -1205,6 +1268,21 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
return nil return nil
} }
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
username, clientName, networkName := unmarshalUsername(username)
plainAuth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator)
if !ok {
return fmt.Errorf("PLAIN authentication unsupported")
}
if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
return newInvalidUsernameOrPasswordError(err)
}
return dc.setUser(username, clientName, networkName)
}
func (dc *downstreamConn) register(ctx context.Context) error { func (dc *downstreamConn) register(ctx context.Context) error {
if dc.registered { if dc.registered {
panic("tried to register twice") panic("tried to register twice")
@ -2420,6 +2498,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
if credentials != nil { if credentials != nil {
if credentials.mechanism != "PLAIN" {
dc.endSASL(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_SASLFAIL,
Params: []string{dc.nick, "Unsupported SASL authentication mechanism"},
})
return nil
}
if uc.saslClient != nil { if uc.saslClient != nil {
dc.endSASL(&irc.Message{ dc.endSASL(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -2429,8 +2516,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil return nil
} }
uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plain.Username)
uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) uc.saslClient = sasl.NewPlainClient("", credentials.plain.Username, credentials.plain.Password)
uc.enqueueCommand(dc, &irc.Message{ uc.enqueueCommand(dc, &irc.Message{
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{"PLAIN"}, Params: []string{"PLAIN"},

View File

@ -144,7 +144,7 @@ type Config struct {
UpstreamUserIPs []*net.IPNet UpstreamUserIPs []*net.IPNet
DisableInactiveUsersDelay time.Duration DisableInactiveUsersDelay time.Duration
EnableUsersOnAuth bool EnableUsersOnAuth bool
Auth auth.PlainAuthenticator Auth auth.Authenticator
} }
type Server struct { type Server struct {

View File

@ -862,7 +862,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
if msg.Command == irc.RPL_SASLSUCCESS { if msg.Command == irc.RPL_SASLSUCCESS {
uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword) uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password)
} }
dc.endSASL(msg) dc.endSASL(msg)