Add support for SASL OAUTHBEARER
This commit is contained in:
parent
22a88079c2
commit
c79fc0c19e
@ -7,11 +7,17 @@ import (
|
|||||||
"git.sr.ht/~emersion/soju/database"
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Authenticator interface{}
|
||||||
|
|
||||||
type PlainAuthenticator interface {
|
type PlainAuthenticator interface {
|
||||||
AuthPlain(ctx context.Context, db database.Database, username, password string) error
|
AuthPlain(ctx context.Context, db database.Database, username, password string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(driver, source string) (PlainAuthenticator, error) {
|
type OAuthBearerAuthenticator interface {
|
||||||
|
AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(driver, source string) (Authenticator, error) {
|
||||||
switch driver {
|
switch driver {
|
||||||
case "internal":
|
case "internal":
|
||||||
return NewInternal(), nil
|
return NewInternal(), nil
|
||||||
|
@ -19,7 +19,12 @@ type oauth2 struct {
|
|||||||
clientSecret string
|
clientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOAuth2(authURL string) (PlainAuthenticator, error) {
|
var (
|
||||||
|
_ PlainAuthenticator = (*oauth2)(nil)
|
||||||
|
_ OAuthBearerAuthenticator = (*oauth2)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func newOAuth2(authURL string) (Authenticator, error) {
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -77,14 +82,27 @@ func newOAuth2(authURL string) (PlainAuthenticator, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error {
|
func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error {
|
||||||
|
effectiveUsername, err := auth.AuthOAuthBearer(ctx, db, password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if username != effectiveUsername {
|
||||||
|
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error) {
|
||||||
reqValues := make(url.Values)
|
reqValues := make(url.Values)
|
||||||
reqValues.Set("token", password)
|
reqValues.Set("token", token)
|
||||||
|
|
||||||
reqBody := strings.NewReader(reqValues.Encode())
|
reqBody := strings.NewReader(reqValues.Encode())
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err)
|
return "", fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
@ -95,32 +113,29 @@ func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, usernam
|
|||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
|
return "", fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
|
return "", fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
var data oauth2Introspection
|
var data oauth2Introspection
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||||
return fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
|
return "", fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !data.Active {
|
if !data.Active {
|
||||||
return fmt.Errorf("invalid access token")
|
return "", fmt.Errorf("invalid access token")
|
||||||
}
|
}
|
||||||
if data.Username == "" {
|
if data.Username == "" {
|
||||||
// We really need the username here, otherwise an OAuth 2.0 user can
|
// We really need the username here, otherwise an OAuth 2.0 user can
|
||||||
// impersonate any other user.
|
// impersonate any other user.
|
||||||
return fmt.Errorf("missing username in OAuth 2.0 introspection response")
|
return "", fmt.Errorf("missing username in OAuth 2.0 introspection response")
|
||||||
}
|
|
||||||
if username != data.Username {
|
|
||||||
return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", data.Username)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return data.Username, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauth2Introspection struct {
|
type oauth2Introspection struct {
|
||||||
|
123
downstream.go
123
downstream.go
@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/emersion/go-sasl"
|
"github.com/emersion/go-sasl"
|
||||||
"gopkg.in/irc.v4"
|
"gopkg.in/irc.v4"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/auth"
|
||||||
"git.sr.ht/~emersion/soju/database"
|
"git.sr.ht/~emersion/soju/database"
|
||||||
"git.sr.ht/~emersion/soju/msgstore"
|
"git.sr.ht/~emersion/soju/msgstore"
|
||||||
"git.sr.ht/~emersion/soju/xirc"
|
"git.sr.ht/~emersion/soju/xirc"
|
||||||
@ -310,10 +311,16 @@ var passthroughIsupport = map[string]bool{
|
|||||||
"WHOX": true,
|
"WHOX": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type saslPlain struct {
|
||||||
|
Username, Password string
|
||||||
|
}
|
||||||
|
|
||||||
type downstreamSASL struct {
|
type downstreamSASL struct {
|
||||||
server sasl.Server
|
server sasl.Server
|
||||||
plainUsername, plainPassword string
|
mechanism string
|
||||||
pendingResp bytes.Buffer
|
plain *saslPlain
|
||||||
|
oauthBearer *sasl.OAuthBearerOptions
|
||||||
|
pendingResp bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
type downstreamRegistration struct {
|
type downstreamRegistration struct {
|
||||||
@ -327,6 +334,17 @@ type downstreamRegistration struct {
|
|||||||
negotiatingCaps bool
|
negotiatingCaps bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func serverSASLMechanisms(srv *Server) []string {
|
||||||
|
var l []string
|
||||||
|
if _, ok := srv.Config().Auth.(auth.PlainAuthenticator); ok {
|
||||||
|
l = append(l, "PLAIN")
|
||||||
|
}
|
||||||
|
if _, ok := srv.Config().Auth.(auth.OAuthBearerAuthenticator); ok {
|
||||||
|
l = append(l, "OAUTHBEARER")
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
type downstreamConn struct {
|
type downstreamConn struct {
|
||||||
conn
|
conn
|
||||||
|
|
||||||
@ -379,7 +397,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
|
|||||||
for k, v := range permanentDownstreamCaps {
|
for k, v := range permanentDownstreamCaps {
|
||||||
dc.caps.Available[k] = v
|
dc.caps.Available[k] = v
|
||||||
}
|
}
|
||||||
dc.caps.Available["sasl"] = "PLAIN"
|
dc.caps.Available["sasl"] = strings.Join(serverSASLMechanisms(dc.srv), ",")
|
||||||
// TODO: this is racy, we should only enable chathistory after
|
// TODO: this is racy, we should only enable chathistory after
|
||||||
// authentication and then check that user.msgStore implements
|
// authentication and then check that user.msgStore implements
|
||||||
// chatHistoryMessageStore
|
// chatHistoryMessageStore
|
||||||
@ -659,8 +677,52 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
|
var username, clientName, networkName string
|
||||||
dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err)
|
switch credentials.mechanism {
|
||||||
|
case "PLAIN":
|
||||||
|
username, clientName, networkName = unmarshalUsername(credentials.plain.Username)
|
||||||
|
password := credentials.plain.Password
|
||||||
|
|
||||||
|
auth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("SASL PLAIN not supported")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if authErr := auth.AuthPlain(ctx, dc.srv.db, username, password); authErr != nil {
|
||||||
|
err = newInvalidUsernameOrPasswordError(authErr)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case "OAUTHBEARER":
|
||||||
|
auth, ok := dc.srv.Config().Auth.(auth.OAuthBearerAuthenticator)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("SASL OAUTHBEARER not supported")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var authErr error
|
||||||
|
username, authErr = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token)
|
||||||
|
if authErr != nil {
|
||||||
|
err = newInvalidUsernameOrPasswordError(authErr)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username {
|
||||||
|
err = newInvalidUsernameOrPasswordError(fmt.Errorf("username mismatch (server returned %q)", username))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if username == "" {
|
||||||
|
panic(fmt.Errorf("username unset after SASL authentication"))
|
||||||
|
}
|
||||||
|
err = dc.setUser(username, clientName, networkName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
|
||||||
dc.endSASL(&irc.Message{
|
dc.endSASL(&irc.Message{
|
||||||
Prefix: dc.srv.prefix(),
|
Prefix: dc.srv.prefix(),
|
||||||
Command: irc.ERR_SASLFAIL,
|
Command: irc.ERR_SASLFAIL,
|
||||||
@ -878,8 +940,15 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
|
|||||||
switch mech {
|
switch mech {
|
||||||
case "PLAIN":
|
case "PLAIN":
|
||||||
server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
|
server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
|
||||||
dc.sasl.plainUsername = username
|
dc.sasl.plain = &saslPlain{
|
||||||
dc.sasl.plainPassword = password
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
case "OAUTHBEARER":
|
||||||
|
server = sasl.NewOAuthBearerServer(sasl.OAuthBearerAuthenticator(func(options sasl.OAuthBearerOptions) *sasl.OAuthBearerError {
|
||||||
|
dc.sasl.oauthBearer = &options
|
||||||
return nil
|
return nil
|
||||||
}))
|
}))
|
||||||
default:
|
default:
|
||||||
@ -890,7 +959,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
|
|||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
dc.sasl = &downstreamSASL{server: server}
|
dc.sasl = &downstreamSASL{server: server, mechanism: mech}
|
||||||
} else {
|
} else {
|
||||||
chunk := msg.Params[0]
|
chunk := msg.Params[0]
|
||||||
if chunk == "+" {
|
if chunk == "+" {
|
||||||
@ -1189,13 +1258,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
|||||||
return username, client, network
|
return username, client, network
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
|
func (dc *downstreamConn) setUser(username, clientName, networkName string) error {
|
||||||
username, clientName, networkName := unmarshalUsername(username)
|
|
||||||
|
|
||||||
if err := dc.srv.Config().Auth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
|
|
||||||
return newInvalidUsernameOrPasswordError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dc.user = dc.srv.getUser(username)
|
dc.user = dc.srv.getUser(username)
|
||||||
if dc.user == nil {
|
if dc.user == nil {
|
||||||
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
|
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
|
||||||
@ -1205,6 +1268,21 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
|
||||||
|
username, clientName, networkName := unmarshalUsername(username)
|
||||||
|
|
||||||
|
plainAuth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("PLAIN authentication unsupported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil {
|
||||||
|
return newInvalidUsernameOrPasswordError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dc.setUser(username, clientName, networkName)
|
||||||
|
}
|
||||||
|
|
||||||
func (dc *downstreamConn) register(ctx context.Context) error {
|
func (dc *downstreamConn) register(ctx context.Context) error {
|
||||||
if dc.registered {
|
if dc.registered {
|
||||||
panic("tried to register twice")
|
panic("tried to register twice")
|
||||||
@ -2420,6 +2498,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if credentials != nil {
|
if credentials != nil {
|
||||||
|
if credentials.mechanism != "PLAIN" {
|
||||||
|
dc.endSASL(&irc.Message{
|
||||||
|
Prefix: dc.srv.prefix(),
|
||||||
|
Command: irc.ERR_SASLFAIL,
|
||||||
|
Params: []string{dc.nick, "Unsupported SASL authentication mechanism"},
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if uc.saslClient != nil {
|
if uc.saslClient != nil {
|
||||||
dc.endSASL(&irc.Message{
|
dc.endSASL(&irc.Message{
|
||||||
Prefix: dc.srv.prefix(),
|
Prefix: dc.srv.prefix(),
|
||||||
@ -2429,8 +2516,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
|
uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plain.Username)
|
||||||
uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
|
uc.saslClient = sasl.NewPlainClient("", credentials.plain.Username, credentials.plain.Password)
|
||||||
uc.enqueueCommand(dc, &irc.Message{
|
uc.enqueueCommand(dc, &irc.Message{
|
||||||
Command: "AUTHENTICATE",
|
Command: "AUTHENTICATE",
|
||||||
Params: []string{"PLAIN"},
|
Params: []string{"PLAIN"},
|
||||||
|
@ -144,7 +144,7 @@ type Config struct {
|
|||||||
UpstreamUserIPs []*net.IPNet
|
UpstreamUserIPs []*net.IPNet
|
||||||
DisableInactiveUsersDelay time.Duration
|
DisableInactiveUsersDelay time.Duration
|
||||||
EnableUsersOnAuth bool
|
EnableUsersOnAuth bool
|
||||||
Auth auth.PlainAuthenticator
|
Auth auth.Authenticator
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
@ -862,7 +862,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
|||||||
|
|
||||||
if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
|
if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
|
||||||
if msg.Command == irc.RPL_SASLSUCCESS {
|
if msg.Command == irc.RPL_SASLSUCCESS {
|
||||||
uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword)
|
uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password)
|
||||||
}
|
}
|
||||||
|
|
||||||
dc.endSASL(msg)
|
dc.endSASL(msg)
|
||||||
|
Loading…
Reference in New Issue
Block a user