Skip to content

Commit

Permalink
include backward compatibility & improved logic
Browse files Browse the repository at this point in the history
Signed-off-by: Houssem Ben Mabrouk <[email protected]>
  • Loading branch information
orange-hbenmabrouk committed Jun 16, 2023
1 parent 0d13a63 commit ec74512
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 124 deletions.
6 changes: 4 additions & 2 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ type OAuth2 struct {
SkipApprovalScreen bool `json:"skipApprovalScreen"`
// If specified, show the connector selection screen even if there's only one
AlwaysShowLoginScreen bool `json:"alwaysShowLoginScreen"`
// This is a list of the connectors that can be used for password grant
PasswordConnectors []string `json:"passwordConnectors"`
// This is the connector that can be used for password grant
PasswordConnector string `json:"passwordConnector"`
// This is the default connector that can be used for password grant
DefaultPasswordConnector string `json:"defaultPasswordConnector"`
}

// Web is the config format for the HTTP server.
Expand Down
31 changes: 17 additions & 14 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,10 @@ func runServe(options serveOptions) error {
if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen")
}
if len(c.OAuth2.PasswordConnectors) > 0 {
logger.Infof("config using password grant connectors: %s", c.OAuth2.PasswordConnectors)
if c.OAuth2.DefaultPasswordConnector != "" {
logger.Infof("config using the default password grant connector: %s", c.OAuth2.DefaultPasswordConnector)
} else if c.OAuth2.PasswordConnector != "" {
logger.Infof("config using the password grant connector: %s", c.OAuth2.PasswordConnector)
}
if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins)
Expand All @@ -270,18 +272,19 @@ func runServe(options serveOptions) error {
healthChecker := gosundheit.New()

serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnectors: c.OAuth2.PasswordConnectors,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
HealthChecker: healthChecker,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
DefaultPasswordConnector: c.OAuth2.DefaultPasswordConnector,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
HealthChecker: healthChecker,
}
if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
Expand Down
7 changes: 5 additions & 2 deletions config.docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ oauth2:
responseTypes: {{ getenv "DEX_OAUTH2_RESPONSE_TYPES" "[code]" }}
skipApprovalScreen: {{ getenv "DEX_OAUTH2_SKIP_APPROVAL_SCREEN" "false" }}
alwaysShowLoginScreen: {{ getenv "DEX_OAUTH2_ALWAYS_SHOW_LOGIN_SCREEN" "false" }}
{{- if getenv "DEX_OAUTH2_PASSWORD_CONNECTORS" "[]" }}
passwordConnectors: {{ .Env.DEX_OAUTH2_PASSWORD_CONNECTORS }}
{{- if getenv "DEX_OAUTH2_DEFAULT_PASSWORD_CONNECTOR" "" }}
defaultPasswordConnector: {{ .Env.DEX_OAUTH2_DEFAULT_PASSWORD_CONNECTOR }}
{{- end }}
{{- if getenv "DEX_OAUTH2_PASSWORD_CONNECTOR" "" }}
passwordConnector: {{ .Env.DEX_OAUTH2_PASSWORD_CONNECTOR }}
{{- end }}

enablePasswordDB: {{ getenv "DEX_ENABLE_PASSWORD_DB" "true" }}
Expand Down
5 changes: 3 additions & 2 deletions config.yaml.dist
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ web:
# alwaysShowLoginScreen: false
#
# # Uncomment to use a specific connector for password grants
# passwordConnectors:
# - local
# defaultPasswordConnector: local
# # Deprecated option
# passwordConnector: local

# Static clients registered in Dex by default.
#
Expand Down
6 changes: 4 additions & 2 deletions examples/config-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ telemetry:
# go directly to it. For connected IdPs, this redirects the browser away
# from application to upstream provider such as the Google login page
# alwaysShowLoginScreen: false
# Uncomment the passwordConnectors list to attempt authentication using multiple password connectors for password grants
# passwordConnectors: ['local']
# Uncomment the defaultPasswordConnector to use a specific connector as default for password grants
# defaultPasswordConnector: local
# Deprecated option
# passwordConnector: local

# Instead of reading from an external storage, use this list of clients.
#
Expand Down
87 changes: 10 additions & 77 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -815,54 +815,6 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h
handler(w, r, client)
}

func (s *Server) withClientAndConnIDFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client, string)) {
clientID, clientSecret, ok := r.BasicAuth()
if ok {
var err error
if clientID, err = url.QueryUnescape(clientID); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest)
return
}
if clientSecret, err = url.QueryUnescape(clientSecret); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest)
return
}
} else {
clientID = r.PostFormValue("client_id")
clientSecret = r.PostFormValue("client_secret")
}

client, err := s.storage.GetClient(clientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get client: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
}
return
}

if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 {
if clientSecret == "" {
s.logger.Infof("missing client_secret on token request for client: %s", client.ID)
} else {
s.logger.Infof("invalid client_secret on token request for client: %s", client.ID)
}
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}

connID, err := url.PathUnescape(mux.Vars(r)["connector"])
if err != nil {
s.logger.Errorf("Failed to parse connector: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return
}

handler(w, r, client, connID)
}

func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
Expand All @@ -877,33 +829,9 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
return
}

grantType := r.PostFormValue("grant_type")
switch grantType {
case grantTypeDeviceCode:
s.handleDeviceToken(w, r)
case grantTypeAuthorizationCode:
s.withClientFromStorage(w, r, s.handleAuthCode)
case grantTypeRefreshToken:
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword:
s.withClientAndConnIDFromStorage(w, r, s.handlePasswordGrant)
default:
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
}

func (s *Server) handleTokenConnectorLogin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest)
return
}

err := r.ParseForm()
if err != nil {
s.logger.Errorf("Could not parse request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
// Set the connector_id to the default value if it wasn't passed to the query
if !r.Form.Has("connector_id") || r.Form.Get("connector_id") == "" {
r.Form.Set("connector_id", s.defaultPasswordConnector)
}

grantType := r.PostFormValue("grant_type")
Expand All @@ -915,7 +843,7 @@ func (s *Server) handleTokenConnectorLogin(w http.ResponseWriter, r *http.Reques
case grantTypeRefreshToken:
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword:
s.withClientAndConnIDFromStorage(w, r, s.handlePasswordGrant)
s.withClientFromStorage(w, r, s.handlePasswordGrant)
default:
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
Expand Down Expand Up @@ -1166,7 +1094,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
w.Write(claims)
}

func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client, connID string) {
func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {
// Parse the fields
if err := r.ParseForm(); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
Expand Down Expand Up @@ -1220,6 +1148,11 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
}

// Get the connector
connID := q.Get("connector_id")
if connID == "" {
// backward compatibility support
connID = s.passwordConnector
}
conn, err := s.getConnector(connID)
if err != nil {
s.logger.Errorf("Failed to find connector: %v", err)
Expand Down
20 changes: 17 additions & 3 deletions server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ func TestHandlePassword(t *testing.T) {

// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.PasswordConnectors = []string{"foobar", "test", "mock"}
c.DefaultPasswordConnector = "test"
c.PasswordConnector = "test"
c.Now = func() time.Time { return t0 }
})
defer httpServer.Close()
Expand All @@ -285,12 +286,13 @@ func TestHandlePassword(t *testing.T) {
u, err := url.Parse(s.issuerURL.String())
require.NoError(t, err)

u.Path = path.Join(u.Path, "/token/", connID)
u.Path = path.Join(u.Path, "/token")
v := url.Values{}
v.Add("scope", "openid offline_access email")
v.Add("grant_type", "password")
v.Add("username", username)
v.Add("password", password)
v.Add("connector_id", connID)

req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
Expand All @@ -314,12 +316,24 @@ func TestHandlePassword(t *testing.T) {
require.Equal(t, 400, rr.Code)
}

// Check unauthorized error
// Check unauthorized error (valid connector & invalid credentials)
{
rr := makeReq("test", "invalid", "test")
require.Equal(t, 401, rr.Code)
}

// Check unauthorized error (default connector & invalid credentials)
{
rr := makeReq("test", "invalid", "")
require.Equal(t, 401, rr.Code)
}

// default connector & valid credentials
{
rr := makeReq("test", "test", "")
require.Equal(t, 200, rr.Code)
}

// Check that we received expected refresh token
{
rr := makeReq("test", "test", "test")
Expand Down
46 changes: 26 additions & 20 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ type Config struct {
// Refresh token expiration settings
RefreshTokenPolicy *RefreshTokenPolicy

// If set, the server will attempt to use a valid connector from this list to handle password grants
PasswordConnectors []string
// This is the default connctor the server will use to handle password grants
DefaultPasswordConnector string

// If set, the server will use this connector to handle password grants
PasswordConnector string

GCFrequency time.Duration // Defaults to 5 minutes

Expand Down Expand Up @@ -166,8 +169,11 @@ type Server struct {
// If enabled, show the connector selection screen even if there's only one
alwaysShowLogin bool

// Default connector used for password grant
defaultPasswordConnector string

// Used for password grant
passwordConnectors []string
passwordConnector string

supportedResponseTypes map[string]bool

Expand Down Expand Up @@ -230,7 +236,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
supportedRes[respType] = true
}

if len(c.PasswordConnectors) > 0 {
if c.DefaultPasswordConnector != "" || c.PasswordConnector != "" {
supportedGrant = append(supportedGrant, grantTypePassword)
}

Expand Down Expand Up @@ -263,21 +269,22 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
}

s := &Server{
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supportedRes,
supportedGrantTypes: supportedGrant,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
refreshTokenPolicy: c.RefreshTokenPolicy,
skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now,
templates: tmpls,
passwordConnectors: c.PasswordConnectors,
logger: c.Logger,
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supportedRes,
supportedGrantTypes: supportedGrant,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
refreshTokenPolicy: c.RefreshTokenPolicy,
skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now,
templates: tmpls,
defaultPasswordConnector: c.DefaultPasswordConnector,
passwordConnector: c.PasswordConnector,
logger: c.Logger,
}

// Retrieves connector objects in backend storage. This list includes the static connectors
Expand Down Expand Up @@ -355,7 +362,6 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)

// TODO(ericchiang): rate limit certain paths based on IP.
handleWithCORS("/token", s.handleToken)
handleWithCORS("/token/{connector}", s.handleTokenConnectorLogin)
handleWithCORS("/keys", s.handlePublicKeys)
handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization)
Expand Down
4 changes: 2 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ func TestServerSupportedGrants(t *testing.T) {
},
{
name: "With password connector",
config: func(c *Config) { c.PasswordConnectors = []string{"local"} },
config: func(c *Config) { c.DefaultPasswordConnector = "local" },
resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode},
},
{
Expand All @@ -1771,7 +1771,7 @@ func TestServerSupportedGrants(t *testing.T) {
{
name: "All",
config: func(c *Config) {
c.PasswordConnectors = []string{"local"}
c.DefaultPasswordConnector = "local"
c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken)
},
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode},
Expand Down

0 comments on commit ec74512

Please sign in to comment.