diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 7bb7fbb780..831156fd40 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -129,6 +129,10 @@ func (p *password) UnmarshalJSON(b []byte) error { // OAuth2 describes enabled OAuth2 extensions. type OAuth2 struct { + // list of allowed grant types, + // defaults to all supported types + GrantTypes []string `json:"grantTypes"` + ResponseTypes []string `json:"responseTypes"` // If specified, do not prompt the user to approve client authorization. The // act of logging in implies authorization. diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 8ee02d5aa2..2103708478 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -87,6 +87,9 @@ staticClients: oauth2: alwaysShowLoginScreen: true + grantTypes: + - refresh_token + - "urn:ietf:params:oauth:grant-type:token-exchange" connectors: - type: mockCallback @@ -161,6 +164,10 @@ logger: }, OAuth2: OAuth2{ AlwaysShowLoginScreen: true, + GrantTypes: []string{ + "refresh_token", + "urn:ietf:params:oauth:grant-type:token-exchange", + }, }, StaticConnectors: []Connector{ { diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 70906a866d..47b090aeab 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -259,6 +259,7 @@ func runServe(options serveOptions) error { healthChecker := gosundheit.New() serverConfig := server.Config{ + AllowedGrantTypes: c.OAuth2.GrantTypes, SupportedResponseTypes: c.OAuth2.ResponseTypes, SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen, @@ -550,6 +551,17 @@ func applyConfigOverrides(options serveOptions, config *Config) { if config.Frontend.Dir == "" { config.Frontend.Dir = os.Getenv("DEX_FRONTEND_DIR") } + + if len(config.OAuth2.GrantTypes) == 0 { + config.OAuth2.GrantTypes = []string{ + "authorization_code", + "implicit", + "password", + "refresh_token", + "urn:ietf:params:oauth:grant-type:device_code", + "urn:ietf:params:oauth:grant-type:token-exchange", + } + } } func pprofHandler(router *http.ServeMux) { diff --git a/connector/connector.go b/connector/connector.go index e4cf58c0ae..d812390f0c 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -99,3 +99,7 @@ type RefreshConnector interface { // changes since the token was last refreshed. Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error) } + +type TokenIdentityConnector interface { + TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error) +} diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index e7ee438625..e97f986574 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -66,6 +66,10 @@ func (m *Callback) Refresh(ctx context.Context, s connector.Scopes, identity con return m.Identity, nil } +func (m *Callback) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) { + return m.Identity, nil +} + // CallbackConfig holds the configuration parameters for a connector which requires no interaction. type CallbackConfig struct{} diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index b38915e303..14329c0040 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -258,6 +258,7 @@ type caller uint const ( createCaller caller = iota refreshCaller + exchangeCaller ) func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { @@ -296,16 +297,32 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit return c.createIdentity(ctx, identity, token, refreshCaller) } +func (c *oidcConnector) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) { + var identity connector.Identity + token := &oauth2.Token{ + AccessToken: subjectToken, + } + return c.createIdentity(ctx, identity, token, exchangeCaller) +} + func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) { var claims map[string]interface{} - rawIDToken, ok := token.Extra("id_token").(string) - if ok { + if rawIDToken, ok := token.Extra("id_token").(string); ok { idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } + if err := idToken.Claims(&claims); err != nil { + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) + } + } else if caller == exchangeCaller { + // AccessToken here could be either an id token or an access token + idToken, err := c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}).Verify(ctx, token.AccessToken) + if err != nil { + return identity, fmt.Errorf("oidc: failed to verify token: %v", err) + } if err := idToken.Claims(&claims); err != nil { return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index d94af79de8..5c5208a60e 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -2,6 +2,7 @@ package oidc import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "encoding/base64" @@ -428,6 +429,81 @@ func TestRefresh(t *testing.T) { } } +func TestTokenIdentity(t *testing.T) { + tokenTypeAccess := "urn:ietf:params:oauth:token-type:access_token" + tokenTypeID := "urn:ietf:params:oauth:token-type:id_token" + long2short := map[string]string{ + tokenTypeAccess: "access_token", + tokenTypeID: "id_token", + } + + tests := []struct { + name string + subjectType string + userInfo bool + }{ + { + name: "id_token", + subjectType: tokenTypeID, + }, { + name: "access_token", + subjectType: tokenTypeAccess, + }, { + name: "id_token with user info", + subjectType: tokenTypeID, + userInfo: true, + }, { + name: "access_token with user info", + subjectType: tokenTypeAccess, + userInfo: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + testServer, err := setupServer(map[string]any{ + "sub": "subvalue", + "name": "namevalue", + }, true) + if err != nil { + t.Fatal("failed to setup test server", err) + } + conn, err := newConnector(Config{ + Issuer: testServer.URL, + Scopes: []string{"openid", "groups"}, + GetUserInfo: tc.userInfo, + }) + if err != nil { + t.Fatal("failed to create new connector", err) + } + + res, err := http.Get(testServer.URL + "/token") + if err != nil { + t.Fatal("failed to get initial token", err) + } + defer res.Body.Close() + var tokenResponse map[string]any + err = json.NewDecoder(res.Body).Decode(&tokenResponse) + if err != nil { + t.Fatal("failed to decode initial token", err) + } + + origToken := tokenResponse[long2short[tc.subjectType]].(string) + identity, err := conn.TokenIdentity(ctx, tc.subjectType, origToken) + if err != nil { + t.Fatal("failed to get token identity", err) + } + + // assert identity + expectEquals(t, identity.UserID, "subvalue") + expectEquals(t, identity.Username, "namevalue") + }) + } +} + func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) { key, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { diff --git a/server/handlers.go b/server/handlers.go old mode 100755 new mode 100644 index 08004c6d0e..9438d8072b --- a/server/handlers.go +++ b/server/handlers.go @@ -710,7 +710,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe implicitOrHybrid = true var err error - accessToken, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) + accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -830,6 +830,11 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { } grantType := r.PostFormValue("grant_type") + if !contains(s.supportedGrantTypes, grantType) { + s.logger.Errorf("unsupported grant type: %v", grantType) + s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) + return + } switch grantType { case grantTypeDeviceCode: s.handleDeviceToken(w, r) @@ -839,6 +844,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { s.withClientFromStorage(w, r, s.handleRefreshToken) case grantTypePassword: s.withClientFromStorage(w, r, s.handlePasswordGrant) + case grantTypeTokenExchange: + s.withClientFromStorage(w, r, s.handleTokenExchange) default: s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) } @@ -917,7 +924,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s } func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { - accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) + accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1180,7 +1187,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli Groups: identity.Groups, } - accessToken, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID) + accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID) if err != nil { s.logger.Errorf("password grant failed to create new access token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1319,21 +1326,109 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli s.writeAccessToken(w, resp) } +func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, client storage.Client) { + ctx := r.Context() + + if err := r.ParseForm(); err != nil { + s.logger.Errorf("could not parse request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + q := r.Form + + scopes := strings.Fields(q.Get("scope")) // OPTIONAL, map to issued token scope + requestedTokenType := q.Get("requested_token_type") // OPTIONAL, default to access token + if requestedTokenType == "" { + requestedTokenType = tokenTypeAccess + } + subjectToken := q.Get("subject_token") // REQUIRED + subjectTokenType := q.Get("subject_token_type") // REQUIRED + connID := q.Get("connector_id") // REQUIRED, not in RFC + + switch subjectTokenType { + case tokenTypeID, tokenTypeAccess: // ok, continue + default: + s.tokenErrHelper(w, errRequestNotSupported, "Invalid subject_token_type.", http.StatusBadRequest) + return + } + + if subjectToken == "" { + s.tokenErrHelper(w, errInvalidRequest, "Missing subject_token", http.StatusBadRequest) + return + } + + conn, err := s.getConnector(connID) + if err != nil { + s.logger.Errorf("failed to get connector: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) + return + } + teConn, ok := conn.Connector.(connector.TokenIdentityConnector) + if !ok { + s.logger.Errorf("connector doesn't implement token exchange: %v", connID) + s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) + return + } + identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken) + if err != nil { + s.logger.Errorf("failed to verify subject token: %v", err) + s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized) + return + } + + claims := storage.Claims{ + UserID: identity.UserID, + Username: identity.Username, + PreferredUsername: identity.PreferredUsername, + Email: identity.Email, + EmailVerified: identity.EmailVerified, + Groups: identity.Groups, + } + resp := accessTokenResponse{ + IssuedTokenType: requestedTokenType, + TokenType: "bearer", + } + var expiry time.Time + switch requestedTokenType { + case tokenTypeID: + resp.AccessToken, expiry, err = s.newIDToken(client.ID, claims, scopes, "", "", "", connID) + case tokenTypeAccess: + resp.AccessToken, expiry, err = s.newAccessToken(client.ID, claims, scopes, "", connID) + default: + s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest) + return + } + if err != nil { + s.logger.Errorf("token exchange failed to create new %v token: %v", requestedTokenType, err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + resp.ExpiresIn = int(time.Until(expiry).Seconds()) + + // Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1 + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + type accessTokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token,omitempty"` - IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + Scope string `json:"scope,omitempty"` } func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenResponse { return &accessTokenResponse{ - accessToken, - "bearer", - int(expiry.Sub(s.now()).Seconds()), - refreshToken, - idToken, + AccessToken: accessToken, + TokenType: "bearer", + ExpiresIn: int(expiry.Sub(s.now()).Seconds()), + RefreshToken: refreshToken, + IDToken: idToken, } } @@ -1355,7 +1450,7 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenRespon func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) { if err := s.templates.err(r, w, status, description); err != nil { - s.logger.Errorf("Server template error: %v", err) + s.logger.Errorf("server template error: %v", err) } } diff --git a/server/handlers_test.go b/server/handlers_test.go index b9340c9bd5..4d32684b37 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -10,6 +10,7 @@ import ( "net/http/httptest" "net/url" "path" + "strings" "testing" "time" @@ -332,7 +333,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { connID := "mockPw" authReqID := "test" expiry := time.Now().Add(100 * time.Second) - resTypes := []string{"code"} + resTypes := []string{responseTypeCode} tests := []struct { name string @@ -441,7 +442,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { connID := "mock" authReqID := "test" expiry := time.Now().Add(100 * time.Second) - resTypes := []string{"code"} + resTypes := []string{responseTypeCode} tests := []struct { name string @@ -527,3 +528,114 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { require.Equal(t, tc.expectedRes, cb.Path) } } + +func TestHandleTokenExchange(t *testing.T) { + tests := []struct { + name string + scope string + requestedTokenType string + subjectTokenType string + subjectToken string + + expectedCode int + expectedTokenType string + }{ + { + "id-for-acccess", + "openid", + tokenTypeAccess, + tokenTypeID, + "foobar", + http.StatusOK, + tokenTypeAccess, + }, + { + "id-for-id", + "openid", + tokenTypeID, + tokenTypeID, + "foobar", + http.StatusOK, + tokenTypeID, + }, + { + "id-for-default", + "openid", + "", + tokenTypeID, + "foobar", + http.StatusOK, + tokenTypeAccess, + }, + { + "access-for-access", + "openid", + tokenTypeAccess, + tokenTypeAccess, + "foobar", + http.StatusOK, + tokenTypeAccess, + }, + { + "missing-subject_token_type", + "openid", + tokenTypeAccess, + "", + "foobar", + http.StatusBadRequest, + "", + }, + { + "missing-subject_token", + "openid", + tokenTypeAccess, + tokenTypeAccess, + "", + http.StatusBadRequest, + "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Storage.CreateClient(storage.Client{ + ID: "client_1", + Secret: "secret_1", + }) + }) + defer httpServer.Close() + vals := make(url.Values) + vals.Set("grant_type", grantTypeTokenExchange) + setNonEmpty(vals, "connector_id", "mock") + setNonEmpty(vals, "scope", tc.scope) + setNonEmpty(vals, "requested_token_type", tc.requestedTokenType) + setNonEmpty(vals, "subject_token_type", tc.subjectTokenType) + setNonEmpty(vals, "subject_token", tc.subjectToken) + setNonEmpty(vals, "client_id", "client_1") + setNonEmpty(vals, "client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + + s.handleToken(rr, req) + + require.Equal(t, tc.expectedCode, rr.Code, rr.Body.String()) + require.Equal(t, "application/json", rr.Result().Header.Get("content-type")) + if tc.expectedCode == http.StatusOK { + var res accessTokenResponse + err := json.NewDecoder(rr.Result().Body).Decode(&res) + require.NoError(t, err) + require.Equal(t, tc.expectedTokenType, res.IssuedTokenType) + } + }) + } +} + +func setNonEmpty(vals url.Values, key, value string) { + if value != "" { + vals.Set(key, value) + } +} diff --git a/server/oauth2.go b/server/oauth2.go index bb0058a74a..cfae540528 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -93,7 +93,6 @@ func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) er return nil } -//nolint const ( errInvalidRequest = "invalid_request" errUnauthorizedClient = "unauthorized_client" @@ -132,6 +131,17 @@ const ( grantTypeImplicit = "implicit" grantTypePassword = "password" grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" + grantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" +) + +const ( + // https://www.rfc-editor.org/rfc/rfc8693.html#section-3 + tokenTypeAccess = "urn:ietf:params:oauth:token-type:access_token" + tokenTypeRefresh = "urn:ietf:params:oauth:token-type:refresh_token" + tokenTypeID = "urn:ietf:params:oauth:token-type:id_token" + tokenTypeSAML1 = "urn:ietf:params:oauth:token-type:saml1" + tokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2" + tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" ) const ( @@ -288,9 +298,8 @@ type federatedIDClaims struct { UserID string `json:"user_id,omitempty"` } -func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) { - idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID) - return idToken, err +func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) { + return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID) } func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 710382aa23..1acff6518a 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -290,7 +290,7 @@ func TestParseAuthorizationRequest(t *testing.T) { } for _, tc := range tests { - func() { + t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -343,7 +343,7 @@ func TestParseAuthorizationRequest(t *testing.T) { t.Fatalf("%s: unsupported error type", tc.name) } } - }() + }) } } diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 11eaf2e702..b3918ab475 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -361,7 +361,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - accessToken, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) + accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) s.refreshTokenErrHelper(w, newInternalServerError()) diff --git a/server/server.go b/server/server.go old mode 100755 new mode 100644 index f23eb54b7c..444fb7e15a --- a/server/server.go +++ b/server/server.go @@ -66,6 +66,8 @@ type Config struct { // The backing persistence layer. Storage storage.Storage + AllowedGrantTypes []string + // Valid values are "code" to enable the code flow and "token" to enable the implicit // flow. If no response types are supplied this value defaults to "code". SupportedResponseTypes []string @@ -213,7 +215,12 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) c.SupportedResponseTypes = []string{responseTypeCode} } - supportedGrant := []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode} // default + allSupportedGrants := map[string]bool{ + grantTypeAuthorizationCode: true, + grantTypeRefreshToken: true, + grantTypeDeviceCode: true, + grantTypeTokenExchange: true, + } supportedRes := make(map[string]bool) for _, respType := range c.SupportedResponseTypes { @@ -223,7 +230,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) case responseTypeToken: // response_type=token is an implicit flow, let's add it to the discovery info // https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.1 - supportedGrant = append(supportedGrant, grantTypeImplicit) + allSupportedGrants[grantTypeImplicit] = true default: return nil, fmt.Errorf("unsupported response_type %q", respType) } @@ -231,10 +238,22 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } if c.PasswordConnector != "" { - supportedGrant = append(supportedGrant, grantTypePassword) + allSupportedGrants[grantTypePassword] = true } - sort.Strings(supportedGrant) + var supportedGrants []string + if len(c.AllowedGrantTypes) > 0 { + for _, grant := range c.AllowedGrantTypes { + if allSupportedGrants[grant] { + supportedGrants = append(supportedGrants, grant) + } + } + } else { + for grant := range allSupportedGrants { + supportedGrants = append(supportedGrants, grant) + } + } + sort.Strings(supportedGrants) webFS := web.FS() if c.Web.Dir != "" { @@ -267,7 +286,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) connectors: make(map[string]Connector), storage: newKeyCacher(c.Storage, now), supportedResponseTypes: supportedRes, - supportedGrantTypes: supportedGrant, + supportedGrantTypes: supportedGrants, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), diff --git a/server/server_test.go b/server/server_test.go index bedc336be3..dd21d737e0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -99,6 +99,14 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi PrometheusRegistry: prometheus.NewRegistry(), HealthChecker: gosundheit.New(), SkipApprovalScreen: true, // Don't prompt for approval, just immediately redirect with code. + AllowedGrantTypes: []string{ // all implemented types + grantTypeDeviceCode, + grantTypeAuthorizationCode, + grantTypeRefreshToken, + grantTypeTokenExchange, + grantTypeImplicit, + grantTypePassword, + }, } if updateConfig != nil { updateConfig(&config) @@ -1756,17 +1764,22 @@ func TestServerSupportedGrants(t *testing.T) { { name: "Simple", config: func(c *Config) {}, - resGrants: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode}, + resGrants: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange}, + }, + { + name: "Minimal", + config: func(c *Config) { c.AllowedGrantTypes = []string{grantTypeTokenExchange} }, + resGrants: []string{grantTypeTokenExchange}, }, { name: "With password connector", config: func(c *Config) { c.PasswordConnector = "local" }, - resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode}, + resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange}, }, { name: "With token response", config: func(c *Config) { c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken) }, - resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode}, + resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange}, }, { name: "All", @@ -1774,14 +1787,14 @@ func TestServerSupportedGrants(t *testing.T) { c.PasswordConnector = "local" c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken) }, - resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode}, + resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, srv := newTestServer(context.TODO(), t, tc.config) - require.Equal(t, srv.supportedGrantTypes, tc.resGrants) + require.Equal(t, tc.resGrants, srv.supportedGrantTypes) }) } }