diff --git a/server/handlers.go b/server/handlers.go index cbeb0376a7..7aa11ae598 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -815,6 +815,54 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h handler(w, r, client) } +func (s *Server) withClientAndConnIDFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client, string)) { + clientID, clientSecret, ok := r.BasicAuth() + if ok { + var err error + if clientID, err = url.QueryUnescape(clientID); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest) + return + } + if clientSecret, err = url.QueryUnescape(clientSecret); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest) + return + } + } else { + clientID = r.PostFormValue("client_id") + clientSecret = r.PostFormValue("client_secret") + } + + client, err := s.storage.GetClient(clientID) + if err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get client: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + } else { + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + } + return + } + + if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 { + if clientSecret == "" { + s.logger.Infof("missing client_secret on token request for client: %s", client.ID) + } else { + s.logger.Infof("invalid client_secret on token request for client: %s", client.ID) + } + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + return + } + + connID, err := url.PathUnescape(mux.Vars(r)["connector"]) + if err != nil { + s.logger.Errorf("Failed to parse connector: %v", err) + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") + return + } + + handler(w, r, client, connID) +} + func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method != http.MethodPost { @@ -838,7 +886,36 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { case grantTypeRefreshToken: s.withClientFromStorage(w, r, s.handleRefreshToken) case grantTypePassword: - s.withClientFromStorage(w, r, s.handlePasswordGrant) + s.withClientAndConnIDFromStorage(w, r, s.handlePasswordGrant) + default: + s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) + } +} + +func (s *Server) handleTokenConnectorLogin(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method != http.MethodPost { + s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest) + return + } + + err := r.ParseForm() + if err != nil { + s.logger.Errorf("Could not parse request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + + grantType := r.PostFormValue("grant_type") + switch grantType { + case grantTypeDeviceCode: + s.handleDeviceToken(w, r) + case grantTypeAuthorizationCode: + s.withClientFromStorage(w, r, s.handleAuthCode) + case grantTypeRefreshToken: + s.withClientFromStorage(w, r, s.handleRefreshToken) + case grantTypePassword: + s.withClientAndConnIDFromStorage(w, r, s.handlePasswordGrant) default: s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) } @@ -1089,7 +1166,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { w.Write(claims) } -func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { +func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client, connID string) { // Parse the fields if err := r.ParseForm(); err != nil { s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest) @@ -1142,60 +1219,31 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli return } - // Try out every connector in the passwordConnectors list - var connID string - var identity connector.Identity - var conn Connector - var err error - var httpErrMsg string - var statusCode int - userLoggedIn := false - username := q.Get("username") - password := q.Get("password") - for _, id := range s.passwordConnectors { - connID = id - s.logger.Infof("Trying to login user with password connector: %v", id) - - // Get the connector - conn, err = s.getConnector(id) - if err != nil { - statusCode = http.StatusBadRequest - s.logger.Errorf("Requested connector does not exist.") - continue - } - - passwordConnector, ok := conn.Connector.(connector.PasswordConnector) - if !ok { - statusCode = http.StatusBadRequest - s.logger.Errorf("Requested password connector does not correct type.") - continue - } + // Get the connector + conn, err := s.getConnector(connID) + if err != nil { + s.logger.Errorf("Failed to find connector: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) + return + } - // Login - identity, ok, err = passwordConnector.Login(r.Context(), parseScopes(scopes), username, password) - if err != nil { - statusCode = http.StatusBadRequest - s.logger.Errorf("Could not login user") - continue - } - if !ok { - statusCode = http.StatusUnauthorized - httpErrMsg = "Invalid username or password" - s.logger.Errorf(httpErrMsg) - break - } else { - s.logger.Infof("User '%v' logged in successfully with the password connector: %v", username, connID) - userLoggedIn = true - break - } + passwordConnector, ok := conn.Connector.(connector.PasswordConnector) + if !ok { + s.tokenErrHelper(w, errInvalidRequest, "Requested password connector does not correct type.", http.StatusBadRequest) + return } - if !userLoggedIn { - s.logger.Errorf("Failed to login user to valid password connectors.") - if len(httpErrMsg) == 0 { - httpErrMsg = "Unable to login user with any password connector." - } - s.tokenErrHelper(w, errInvalidRequest, httpErrMsg, statusCode) + // Login + username := q.Get("username") + password := q.Get("password") + identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password) + if err != nil { + s.logger.Errorf("Failed to login user: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest) + return + } + if !ok { + s.tokenErrHelper(w, errAccessDenied, "Invalid username or password", http.StatusUnauthorized) return } diff --git a/server/handlers_test.go b/server/handlers_test.go index c303183633..54d0bcd3e0 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -281,11 +281,11 @@ func TestHandlePassword(t *testing.T) { mockConnectorDataTestStorage(t, s.storage) - makeReq := func(username, password string) *httptest.ResponseRecorder { + makeReq := func(username, password string, connID string) *httptest.ResponseRecorder { u, err := url.Parse(s.issuerURL.String()) require.NoError(t, err) - u.Path = path.Join(u.Path, "/token") + u.Path = path.Join(u.Path, "/token/", connID) v := url.Values{} v.Add("scope", "openid offline_access email") v.Add("grant_type", "password") @@ -302,15 +302,27 @@ func TestHandlePassword(t *testing.T) { return rr } + // Check bad request error (bad connector & invalid credentials) + { + rr := makeReq("test", "invalid", "foobar") + require.Equal(t, 400, rr.Code) + } + + // Check bad request error (bad connector & correct credentials) + { + rr := makeReq("test", "test", "foobar") + require.Equal(t, 400, rr.Code) + } + // Check unauthorized error { - rr := makeReq("test", "invalid") + rr := makeReq("test", "invalid", "test") require.Equal(t, 401, rr.Code) } // Check that we received expected refresh token { - rr := makeReq("test", "test") + rr := makeReq("test", "test", "test") require.Equal(t, 200, rr.Code) var ref struct { diff --git a/server/server.go b/server/server.go index 67e37ff9da..c01bbcd2aa 100755 --- a/server/server.go +++ b/server/server.go @@ -355,6 +355,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // TODO(ericchiang): rate limit certain paths based on IP. handleWithCORS("/token", s.handleToken) + handleWithCORS("/token/{connector}", s.handleTokenConnectorLogin) handleWithCORS("/keys", s.handlePublicKeys) handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization)