Skip to content

Commit

Permalink
Modify client credential grant (#6)
Browse files Browse the repository at this point in the history
* feat: dynamic oauth2 credentials client_credential flow

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* adding tests for client_credentials flow

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* better credentials handling + adjust tests

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* fix lint

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

---------

Signed-off-by: Houssem Ben Mabrouk <[email protected]>
  • Loading branch information
orange-hbenmabrouk authored Apr 22, 2024
1 parent 823f186 commit 818c8c9
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 15 deletions.
29 changes: 23 additions & 6 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
RedirectURL: c.RedirectURI,
},
verifier: provider.Verifier(
&oidc.Config{ClientID: clientID},
&oidc.Config{ClientID: clientID, SkipClientIDCheck: len(clientID) == 0},
),
pkceVerifier: pkceVerifier,
logger: logger,
Expand Down Expand Up @@ -369,14 +369,32 @@ const (
exchangeCaller
)

func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err error) {
func (c *oidcConnector) getTokenViaClientCredentials(r *http.Request) (token *oauth2.Token, err error) {
// Setup default clientID & clientSecret
clientID := c.oauth2Config.ClientID
clientSecret := c.oauth2Config.ClientSecret

// Override clientID & clientSecret if they exist!
q := r.Form
if q.Has("custom_client_id") && q.Has("custom_client_secret") {
clientID = q.Get("custom_client_id")
clientSecret = q.Get("custom_client_secret")
}

// Check if oauth2 credentials are not empty
if len(clientID) == 0 || len(clientSecret) == 0 {
return nil, fmt.Errorf("oidc: unable to get clientID or clientSecret")
}

// Construct data to be sent to the external IdP
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {c.oauth2Config.ClientID},
"client_secret": {c.oauth2Config.ClientSecret},
"client_id": {clientID},
"client_secret": {clientSecret},
"scope": {strings.Join(c.oauth2Config.Scopes, " ")},
}

// Request token from external IdP
resp, err := c.httpClient.PostForm(c.oauth2Config.Endpoint.TokenURL, data)
if err != nil {
return nil, fmt.Errorf("oidc: failed to get token: %v", err)
Expand All @@ -401,7 +419,6 @@ func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err
if err = json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("oidc: unable to parse response: %v", err)
}

token = &oauth2.Token{
AccessToken: response.AccessToken,
Expiry: time.Now().Add(time.Second * time.Duration(response.ExpiresIn)),
Expand Down Expand Up @@ -435,7 +452,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
}
} else {
// get token via client_credentials
token, err = c.getTokenViaClientCredentials()
token, err = c.getTokenViaClientCredentials(r)
if err != nil {
return identity, err
}
Expand Down
155 changes: 147 additions & 8 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -65,6 +66,12 @@ func TestHandleCallback(t *testing.T) {
token map[string]interface{}
pkce bool
newGroupFromClaims []NewGroupFromClaims
expectedHandlerError error
clientID string
clientSecret string
customClientID string
customClientSecret string
clientCredentials bool
}{
{
name: "simpleCase",
Expand All @@ -81,6 +88,8 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
expectedHandlerError: nil,
},
{
name: "customEmailClaim",
Expand All @@ -96,6 +105,7 @@ func TestHandleCallback(t *testing.T) {
"mail": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "overrideWithCustomEmailClaim",
Expand All @@ -113,6 +123,7 @@ func TestHandleCallback(t *testing.T) {
"custommail": "customemailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "email_verified not in claims, configured to be skipped",
Expand All @@ -125,6 +136,7 @@ func TestHandleCallback(t *testing.T) {
"name": "namevalue",
"email": "emailvalue",
},
clientCredentials: false,
},
{
name: "withUserIDKey",
Expand All @@ -138,6 +150,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "withUserNameKey",
Expand All @@ -151,6 +164,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "withPreferredUsernameKey",
Expand All @@ -166,6 +180,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "withoutPreferredUsernameKeyAndBackendReturns",
Expand All @@ -180,6 +195,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "withoutPreferredUsernameKeyAndBackendNotReturn",
Expand All @@ -193,6 +209,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "emptyEmailScope",
Expand All @@ -206,6 +223,7 @@ func TestHandleCallback(t *testing.T) {
"name": "namevalue",
"user_name": "username",
},
clientCredentials: false,
},
{
name: "emptyEmailScopeButEmailProvided",
Expand All @@ -220,6 +238,7 @@ func TestHandleCallback(t *testing.T) {
"user_name": "username",
"email": "emailvalue",
},
clientCredentials: false,
},
{
name: "customGroupsKey",
Expand All @@ -237,6 +256,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"cognito:groups": []string{"group3", "group4"},
},
clientCredentials: false,
},
{
name: "customGroupsKeyButGroupsProvided",
Expand All @@ -255,6 +275,7 @@ func TestHandleCallback(t *testing.T) {
"groups": []string{"group1", "group2"},
"cognito:groups": []string{"group3", "group4"},
},
clientCredentials: false,
},
{
name: "customGroupsKeyDespiteGroupsProvidedButOverride",
Expand All @@ -274,6 +295,7 @@ func TestHandleCallback(t *testing.T) {
"groups": []string{"group1", "group2"},
"cognito:groups": []string{"group3", "group4"},
},
clientCredentials: false,
},
{
name: "singularGroupResponseAsString",
Expand All @@ -290,6 +312,7 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
clientCredentials: false,
},
{
name: "newGroupFromClaims",
Expand Down Expand Up @@ -347,7 +370,6 @@ func TestHandleCallback(t *testing.T) {
Prefix: "bk",
},
},

token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
Expand All @@ -363,6 +385,7 @@ func TestHandleCallback(t *testing.T) {
},
"non-string-claim2": 666,
},
clientCredentials: false,
},
{
name: "withPKCE",
Expand All @@ -379,7 +402,8 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
pkce: true,
pkce: true,
clientCredentials: false,
},
{
name: "withoutPKCE",
Expand All @@ -396,7 +420,86 @@ func TestHandleCallback(t *testing.T) {
"email": "emailvalue",
"email_verified": true,
},
pkce: false,
pkce: false,
clientCredentials: false,
},
{
name: "withCustomCredentials",
userIDKey: "", // not configured
userNameKey: "", // not configured
clientID: "", // not configured
clientSecret: "", // not configured
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: nil,
expectedEmailField: "emailvalue",
customClientID: "clientidvalue",
customClientSecret: "clientsecretvalue",
scopes: []string{"openid"},
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"email": "emailvalue",
"email_verified": true,
},
expectedHandlerError: nil,
clientCredentials: true,
},
{
name: "withConfiguredAndCustomCredentials",
userIDKey: "", // not configured
userNameKey: "", // not configured
clientID: "defaultClientID",
clientSecret: "defaultClientSecret",
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: nil,
expectedEmailField: "emailvalue",
customClientID: "clientidvalue",
customClientSecret: "clientsecretvalue",
scopes: []string{"openid"},
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"email": "emailvalue",
"email_verified": true,
},
expectedHandlerError: fmt.Errorf("expected audience \"defaultClientID\""),
clientCredentials: true,
},
{
name: "withoutBothCredentials",
userIDKey: "", // not configured
userNameKey: "", // not configured
clientID: "", // not configured
clientSecret: "", // not configured
expectUserID: "",
expectUserName: "",
expectGroups: nil,
expectedEmailField: "",
customClientID: "", // not configured in the request
customClientSecret: "", // not configured in the request
scopes: []string{"openid"},
token: nil,
expectedHandlerError: fmt.Errorf("oidc: unable to get clientID or clientSecret"),
clientCredentials: true,
},
{
name: "missingConfiguredAndASingleCustomCredential",
userIDKey: "", // not configured
userNameKey: "", // not configured
clientID: "", // not configured
clientSecret: "", // not configured
expectUserID: "",
expectUserName: "",
expectGroups: nil,
expectedEmailField: "",
customClientID: "clientidvalue",
customClientSecret: "", // not configured in the request
scopes: []string{"openid"},
token: nil,
expectedHandlerError: fmt.Errorf("oidc: unable to get clientID or clientSecret"),
clientCredentials: true,
},
}

Expand All @@ -419,8 +522,8 @@ func TestHandleCallback(t *testing.T) {
basicAuth := true
config := Config{
Issuer: serverURL,
ClientID: "clientID",
ClientSecret: "clientSecret",
ClientID: tc.clientID,
ClientSecret: tc.clientSecret,
Scopes: scopes,
RedirectURI: fmt.Sprintf("%s/callback", serverURL),
UserIDKey: tc.userIDKey,
Expand All @@ -440,17 +543,26 @@ func TestHandleCallback(t *testing.T) {
if err != nil {
t.Fatal("failed to create new connector", err)
}
var req *http.Request
if tc.clientCredentials {
req, err = newRequestWithoutAuthCode(testServer.URL)
data := url.Values{}
data.Set("custom_client_id", tc.customClientID)
data.Set("custom_client_secret", tc.customClientSecret)
req.Form = data
} else {
req, err = newRequestWithAuthCode(testServer.URL, "someCode")
}

req, err := newRequestWithAuthCode(testServer.URL, "someCode")
if err != nil {
t.Fatal("failed to create request", err)
}

identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req)
compareErrors(t, err, tc.expectedHandlerError)
if err != nil {
t.Fatal("handle callback failed", err)
return
}

expectEquals(t, identity.UserID, tc.expectUserID)
expectEquals(t, identity.Username, tc.expectUserName)
expectEquals(t, identity.PreferredUsername, tc.expectPreferredUsername)
Expand Down Expand Up @@ -828,6 +940,15 @@ func newRequestWithAuthCode(serverURL string, code string) (*http.Request, error
return req, nil
}

func newRequestWithoutAuthCode(serverURL string) (*http.Request, error) {
req, err := http.NewRequest("GET", serverURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}

return req, nil
}

func n(pub *rsa.PublicKey) string {
return encode(pub.N.Bytes())
}
Expand All @@ -848,3 +969,21 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) {
t.Errorf("Expected %+v to equal %+v", a, b)
}
}

func compareErrors(t *testing.T, a error, b error) {
if a == nil && b == nil {
return
}
if a == nil && b != nil {
t.Errorf("Expected \"%+v\" to be nil", b)
return
}
if a != nil && b == nil {
t.Errorf("Expected \"%+v\" to be \"%+v\"", b, a)
return
}

if !strings.Contains(a.Error(), b.Error()) {
t.Errorf("Expected \"%+v\" to be a part of \"%+v\"", b, a)
}
}
Loading

0 comments on commit 818c8c9

Please sign in to comment.