Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement custom client_credentials grant & add support of pkce to device_grant #7

Merged
merged 5 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ func applyConfigOverrides(options serveOptions, config *Config) {
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
"client_credentials",
}
}
}
Expand Down
102 changes: 98 additions & 4 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -40,6 +41,11 @@ type Config struct {

Scopes []string `json:"scopes"` // defaults to "profile" and "email"

PKCE struct {
// Configurable key which controls if pkce challenge should be created or not
Enabled bool `json:"enabled"` // defaults to "false"
} `json:"pkce"`

// HostedDomains was an optional list of whitelisted domains when using the OIDC connector with Google.
// Only users from a whitelisted domain were allowed to log in.
// Support for this option was removed from the OIDC connector.
Expand Down Expand Up @@ -247,6 +253,12 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
promptType = *c.PromptType
}

// pkce
pkceVerifier := ""
if c.PKCE.Enabled {
pkceVerifier = oauth2.GenerateVerifier()
}

clientID := c.ClientID
return &oidcConnector{
provider: provider,
Expand All @@ -259,8 +271,9 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
RedirectURL: c.RedirectURI,
},
verifier: provider.Verifier(
&oidc.Config{ClientID: clientID},
&oidc.Config{ClientID: clientID, SkipClientIDCheck: len(clientID) == 0},
),
pkceVerifier: pkceVerifier,
logger: logger,
cancel: cancel,
httpClient: httpClient,
Expand Down Expand Up @@ -290,6 +303,7 @@ type oidcConnector struct {
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
pkceVerifier string
cancel context.CancelFunc
logger log.Logger
httpClient *http.Client
Expand Down Expand Up @@ -328,6 +342,10 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}

if c.pkceVerifier != "" {
opts = append(opts, oauth2.S256ChallengeOption(c.pkceVerifier))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
}

Expand All @@ -351,17 +369,93 @@ const (
exchangeCaller
)

func (c *oidcConnector) getTokenViaClientCredentials(r *http.Request) (token *oauth2.Token, err error) {
// Setup default clientID & clientSecret
clientID := c.oauth2Config.ClientID
clientSecret := c.oauth2Config.ClientSecret

// Override clientID & clientSecret if they exist!
q := r.Form
if q.Has("custom_client_id") && q.Has("custom_client_secret") {
clientID = q.Get("custom_client_id")
clientSecret = q.Get("custom_client_secret")
}

// Check if oauth2 credentials are not empty
if len(clientID) == 0 || len(clientSecret) == 0 {
return nil, fmt.Errorf("oidc: unable to get clientID or clientSecret")
}

// Construct data to be sent to the external IdP
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {clientID},
"client_secret": {clientSecret},
"scope": {strings.Join(c.oauth2Config.Scopes, " ")},
}

// Request token from external IdP
resp, err := c.httpClient.PostForm(c.oauth2Config.Endpoint.TokenURL, data)
if err != nil {
return nil, fmt.Errorf("oidc: failed to get token: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("oidc: issuer returned an error: %v", resp.Status)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("oidc: failed to get read token body: %v", err)
}

type AccessTokenType struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
response := AccessTokenType{}
if err = json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("oidc: unable to parse response: %v", err)
}
token = &oauth2.Token{
AccessToken: response.AccessToken,
Expiry: time.Now().Add(time.Second * time.Duration(response.ExpiresIn)),
}
raw := make(map[string]interface{})
json.Unmarshal(body, &raw) // no error checks for optional fields
token = token.WithExtra(raw)

return token, nil
}

func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
}

ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)
var token *oauth2.Token
if q.Has("code") {
// exchange code to token
var opts []oauth2.AuthCodeOption

token, err := c.oauth2Config.Exchange(ctx, q.Get("code"))
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
if c.pkceVerifier != "" {
opts = append(opts, oauth2.VerifierOption(c.pkceVerifier))
}

token, err = c.oauth2Config.Exchange(ctx, q.Get("code"), opts...)
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}
} else {
// get token via client_credentials
token, err = c.getTokenViaClientCredentials(r)
if err != nil {
return identity, err
}
}
return c.createIdentity(ctx, identity, token, createCaller)
}
Expand Down
Loading
Loading