diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 659f76611b..3e20262019 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -271,7 +271,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e RedirectURL: c.RedirectURI, }, verifier: provider.Verifier( - &oidc.Config{ClientID: clientID}, + &oidc.Config{ClientID: clientID, SkipClientIDCheck: len(clientID) == 0}, ), pkceVerifier: pkceVerifier, logger: logger, @@ -369,14 +369,32 @@ const ( exchangeCaller ) -func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err error) { +func (c *oidcConnector) getTokenViaClientCredentials(r *http.Request) (token *oauth2.Token, err error) { + // Setup default clientID & clientSecret + clientID := c.oauth2Config.ClientID + clientSecret := c.oauth2Config.ClientSecret + + // Override clientID & clientSecret if they exist! + q := r.Form + if q.Has("custom_client_id") && q.Has("custom_client_secret") { + clientID = q.Get("custom_client_id") + clientSecret = q.Get("custom_client_secret") + } + + // Check if oauth2 credentials are not empty + if len(clientID) == 0 || len(clientSecret) == 0 { + return nil, fmt.Errorf("oidc: unable to get clientID or clientSecret") + } + + // Construct data to be sent to the external IdP data := url.Values{ "grant_type": {"client_credentials"}, - "client_id": {c.oauth2Config.ClientID}, - "client_secret": {c.oauth2Config.ClientSecret}, + "client_id": {clientID}, + "client_secret": {clientSecret}, "scope": {strings.Join(c.oauth2Config.Scopes, " ")}, } + // Request token from external IdP resp, err := c.httpClient.PostForm(c.oauth2Config.Endpoint.TokenURL, data) if err != nil { return nil, fmt.Errorf("oidc: failed to get token: %v", err) @@ -401,7 +419,6 @@ func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err if err = json.Unmarshal(body, &response); err != nil { return nil, fmt.Errorf("oidc: unable to parse response: %v", err) } - token = &oauth2.Token{ AccessToken: response.AccessToken, Expiry: time.Now().Add(time.Second * time.Duration(response.ExpiresIn)), @@ -435,7 +452,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide } } else { // get token via client_credentials - token, err = c.getTokenViaClientCredentials() + token, err = c.getTokenViaClientCredentials(r) if err != nil { return identity, err } diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 10550dfe79..310ef386ee 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "reflect" "strings" "testing" @@ -65,6 +66,12 @@ func TestHandleCallback(t *testing.T) { token map[string]interface{} pkce bool newGroupFromClaims []NewGroupFromClaims + expectedHandlerError error + clientID string + clientSecret string + customClientID string + customClientSecret string + clientCredentials bool }{ { name: "simpleCase", @@ -81,6 +88,8 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, + expectedHandlerError: nil, }, { name: "customEmailClaim", @@ -96,6 +105,7 @@ func TestHandleCallback(t *testing.T) { "mail": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "overrideWithCustomEmailClaim", @@ -113,6 +123,7 @@ func TestHandleCallback(t *testing.T) { "custommail": "customemailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "email_verified not in claims, configured to be skipped", @@ -125,6 +136,7 @@ func TestHandleCallback(t *testing.T) { "name": "namevalue", "email": "emailvalue", }, + clientCredentials: false, }, { name: "withUserIDKey", @@ -138,6 +150,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "withUserNameKey", @@ -151,6 +164,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "withPreferredUsernameKey", @@ -166,6 +180,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "withoutPreferredUsernameKeyAndBackendReturns", @@ -180,6 +195,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "withoutPreferredUsernameKeyAndBackendNotReturn", @@ -193,6 +209,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "emptyEmailScope", @@ -206,6 +223,7 @@ func TestHandleCallback(t *testing.T) { "name": "namevalue", "user_name": "username", }, + clientCredentials: false, }, { name: "emptyEmailScopeButEmailProvided", @@ -220,6 +238,7 @@ func TestHandleCallback(t *testing.T) { "user_name": "username", "email": "emailvalue", }, + clientCredentials: false, }, { name: "customGroupsKey", @@ -237,6 +256,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "cognito:groups": []string{"group3", "group4"}, }, + clientCredentials: false, }, { name: "customGroupsKeyButGroupsProvided", @@ -255,6 +275,7 @@ func TestHandleCallback(t *testing.T) { "groups": []string{"group1", "group2"}, "cognito:groups": []string{"group3", "group4"}, }, + clientCredentials: false, }, { name: "customGroupsKeyDespiteGroupsProvidedButOverride", @@ -274,6 +295,7 @@ func TestHandleCallback(t *testing.T) { "groups": []string{"group1", "group2"}, "cognito:groups": []string{"group3", "group4"}, }, + clientCredentials: false, }, { name: "singularGroupResponseAsString", @@ -290,6 +312,7 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, + clientCredentials: false, }, { name: "newGroupFromClaims", @@ -347,7 +370,6 @@ func TestHandleCallback(t *testing.T) { Prefix: "bk", }, }, - token: map[string]interface{}{ "sub": "subvalue", "name": "namevalue", @@ -363,6 +385,7 @@ func TestHandleCallback(t *testing.T) { }, "non-string-claim2": 666, }, + clientCredentials: false, }, { name: "withPKCE", @@ -379,7 +402,8 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, - pkce: true, + pkce: true, + clientCredentials: false, }, { name: "withoutPKCE", @@ -396,7 +420,86 @@ func TestHandleCallback(t *testing.T) { "email": "emailvalue", "email_verified": true, }, - pkce: false, + pkce: false, + clientCredentials: false, + }, + { + name: "withCustomCredentials", + userIDKey: "", // not configured + userNameKey: "", // not configured + clientID: "", // not configured + clientSecret: "", // not configured + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: nil, + expectedEmailField: "emailvalue", + customClientID: "clientidvalue", + customClientSecret: "clientsecretvalue", + scopes: []string{"openid"}, + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "email": "emailvalue", + "email_verified": true, + }, + expectedHandlerError: nil, + clientCredentials: true, + }, + { + name: "withConfiguredAndCustomCredentials", + userIDKey: "", // not configured + userNameKey: "", // not configured + clientID: "defaultClientID", + clientSecret: "defaultClientSecret", + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: nil, + expectedEmailField: "emailvalue", + customClientID: "clientidvalue", + customClientSecret: "clientsecretvalue", + scopes: []string{"openid"}, + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "email": "emailvalue", + "email_verified": true, + }, + expectedHandlerError: fmt.Errorf("expected audience \"defaultClientID\""), + clientCredentials: true, + }, + { + name: "withoutBothCredentials", + userIDKey: "", // not configured + userNameKey: "", // not configured + clientID: "", // not configured + clientSecret: "", // not configured + expectUserID: "", + expectUserName: "", + expectGroups: nil, + expectedEmailField: "", + customClientID: "", // not configured in the request + customClientSecret: "", // not configured in the request + scopes: []string{"openid"}, + token: nil, + expectedHandlerError: fmt.Errorf("oidc: unable to get clientID or clientSecret"), + clientCredentials: true, + }, + { + name: "missingConfiguredAndASingleCustomCredential", + userIDKey: "", // not configured + userNameKey: "", // not configured + clientID: "", // not configured + clientSecret: "", // not configured + expectUserID: "", + expectUserName: "", + expectGroups: nil, + expectedEmailField: "", + customClientID: "clientidvalue", + customClientSecret: "", // not configured in the request + scopes: []string{"openid"}, + token: nil, + expectedHandlerError: fmt.Errorf("oidc: unable to get clientID or clientSecret"), + clientCredentials: true, }, } @@ -419,8 +522,8 @@ func TestHandleCallback(t *testing.T) { basicAuth := true config := Config{ Issuer: serverURL, - ClientID: "clientID", - ClientSecret: "clientSecret", + ClientID: tc.clientID, + ClientSecret: tc.clientSecret, Scopes: scopes, RedirectURI: fmt.Sprintf("%s/callback", serverURL), UserIDKey: tc.userIDKey, @@ -440,17 +543,26 @@ func TestHandleCallback(t *testing.T) { if err != nil { t.Fatal("failed to create new connector", err) } + var req *http.Request + if tc.clientCredentials { + req, err = newRequestWithoutAuthCode(testServer.URL) + data := url.Values{} + data.Set("custom_client_id", tc.customClientID) + data.Set("custom_client_secret", tc.customClientSecret) + req.Form = data + } else { + req, err = newRequestWithAuthCode(testServer.URL, "someCode") + } - req, err := newRequestWithAuthCode(testServer.URL, "someCode") if err != nil { t.Fatal("failed to create request", err) } identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + compareErrors(t, err, tc.expectedHandlerError) if err != nil { - t.Fatal("handle callback failed", err) + return } - expectEquals(t, identity.UserID, tc.expectUserID) expectEquals(t, identity.Username, tc.expectUserName) expectEquals(t, identity.PreferredUsername, tc.expectPreferredUsername) @@ -828,6 +940,15 @@ func newRequestWithAuthCode(serverURL string, code string) (*http.Request, error return req, nil } +func newRequestWithoutAuthCode(serverURL string) (*http.Request, error) { + req, err := http.NewRequest("GET", serverURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %v", err) + } + + return req, nil +} + func n(pub *rsa.PublicKey) string { return encode(pub.N.Bytes()) } @@ -848,3 +969,21 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) { t.Errorf("Expected %+v to equal %+v", a, b) } } + +func compareErrors(t *testing.T, a error, b error) { + if a == nil && b == nil { + return + } + if a == nil && b != nil { + t.Errorf("Expected \"%+v\" to be nil", b) + return + } + if a != nil && b == nil { + t.Errorf("Expected \"%+v\" to be \"%+v\"", b, a) + return + } + + if !strings.Contains(a.Error(), b.Error()) { + t.Errorf("Expected \"%+v\" to be a part of \"%+v\"", b, a) + } +} diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 69424a4a7b..c84609d7e3 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -411,7 +411,7 @@ func (s *Server) handleClientCredentials(w http.ResponseWriter, r *http.Request, return } - // Login + // Callback identity, err := callbackConnector.HandleCallback(parseScopes(scopes), r) if err != nil { s.logger.Errorf("Failed to login user: %v", err)