From bbd9476a44b4c4b2eb99d9a6744ad332d31f94d7 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Mon, 16 Sep 2024 16:01:16 +0200 Subject: [PATCH] Support custom validation in OidcLogoutAuthenticationProvider - Similar to custom validation in OAuth2AuthorizationCodeRequestAuthenticationProvider - Closes gh-1693 --- .../OidcLogoutAuthenticationContext.java | 142 ++++++++++++++++++ .../OidcLogoutAuthenticationProvider.java | 38 +++-- .../OidcLogoutAuthenticationValidator.java | 94 ++++++++++++ ...OidcLogoutAuthenticationProviderTests.java | 31 ++++ 4 files changed, 296 insertions(+), 9 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationContext.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationValidator.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationContext.java new file mode 100644 index 000000000..5cf6171a0 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationContext.java @@ -0,0 +1,142 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationContext; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + +/** + * An {@link OAuth2AuthenticationContext} that holds an + * {@link OidcLogoutAuthenticationToken} and additional information and is used when + * validating the OpenID Connect RP-Initiated Logout Request parameters. + * + * @author Daniel Garnier-Moiroux + * @since 1.4 + * @see OAuth2AuthenticationContext + * @see OidcLogoutAuthenticationToken + * @see OidcLogoutAuthenticationProvider#setAuthenticationValidator(Consumer) + */ +public final class OidcLogoutAuthenticationContext implements OAuth2AuthenticationContext { + + private final Map context; + + private OidcLogoutAuthenticationContext(Map context) { + this.context = context; + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + /** + * Returns the {@link RegisteredClient registered client}. + * @return the {@link RegisteredClient} + */ + public RegisteredClient getRegisteredClient() { + return get(RegisteredClient.class); + } + + /** + * Returns the {@link OAuth2Authorization authorization request}. + * @return the {@link OAuth2Authorization} + */ + public OAuth2Authorization getAuthorizationRequest() { + return get(OAuth2Authorization.class); + } + + /** + * Returns the {@link OidcIdToken id_token}. + * @return the {@link OidcIdToken} + */ + public OidcIdToken getIdToken() { + return get(OidcIdToken.class); + } + + /** + * Constructs a new {@link Builder} with the provided + * {@link OidcLogoutAuthenticationToken}. + * @param authentication the {@link OidcLogoutAuthenticationToken} + * @return the {@link Builder} + */ + public static Builder with(OidcLogoutAuthenticationToken authentication) { + return new Builder(authentication); + } + + /** + * A builder for {@link OidcLogoutAuthenticationContext}. + */ + public static final class Builder extends AbstractBuilder { + + private Builder(Authentication authentication) { + super(authentication); + } + + /** + * Sets the {@link RegisteredClient registered client}. + * @param registeredClient the {@link RegisteredClient} + * @return the {@link Builder} for further configuration + */ + public Builder registeredClient(RegisteredClient registeredClient) { + return put(RegisteredClient.class, registeredClient); + } + + /** + * Sets the {@link OAuth2Authorization registered client}. + * @param authorization the {@link OAuth2Authorization} + * @return the {@link Builder} for further configuration + */ + public Builder authorization(OAuth2Authorization authorization) { + return put(OAuth2Authorization.class, authorization); + } + + /** + * Sets the {@link OidcIdToken id_token}. + * @param idToken the {@link OidcIdToken} + * @return the {@link Builder} for further configuration + */ + public Builder idToken(OidcIdToken idToken) { + return put(OidcIdToken.class, idToken); + } + + /** + * Builds a new {@link OidcLogoutAuthenticationContext}. + * @return the {@link OidcLogoutAuthenticationContext} + */ + @Override + public OidcLogoutAuthenticationContext build() { + return new OidcLogoutAuthenticationContext(getContext()); + } + + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java index fcbb591c3..a50d96314 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java @@ -21,6 +21,7 @@ import java.security.Principal; import java.util.Base64; import java.util.List; +import java.util.function.Consumer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -70,6 +71,8 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro private final SessionRegistry sessionRegistry; + private Consumer authenticationValidator = new OidcLogoutAuthenticationValidator(); + /** * Constructs an {@code OidcLogoutAuthenticationProvider} using the provided * parameters. @@ -118,19 +121,16 @@ public Authentication authenticate(Authentication authentication) throws Authent OidcIdToken idToken = authorizedIdToken.getToken(); // Validate client identity - List audClaim = idToken.getAudience(); - if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) { - throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD); - } if (StringUtils.hasText(oidcLogoutAuthentication.getClientId()) && !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); } - if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri()) - && !registeredClient.getPostLogoutRedirectUris() - .contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri"); - } + OidcLogoutAuthenticationContext context = OidcLogoutAuthenticationContext.with(oidcLogoutAuthentication) + .registeredClient(registeredClient) + .authorization(authorization) + .idToken(idToken) + .build(); + this.authenticationValidator.accept(context); if (this.logger.isTraceEnabled()) { this.logger.trace("Validated logout request parameters"); @@ -182,6 +182,26 @@ public boolean supports(Class authentication) { return OidcLogoutAuthenticationToken.class.isAssignableFrom(authentication); } + /** + * Sets the {@code Consumer} providing access to the + * {@link OidcLogoutAuthenticationContext} and is responsible for validating specific + * Open ID Connect RP-Initiated Logout Request parameters associated in the + * {@link OidcLogoutAuthenticationToken}. The default authentication validator is + * {@link OidcLogoutAuthenticationValidator}. + * + *

+ * NOTE: The authentication validator MUST throw + * {@link OAuth2AuthenticationException} if validation fails. + * @param authenticationValidator the {@code Consumer} providing access to the + * {@link OidcLogoutAuthenticationContext} and is responsible for validating specific + * Open ID Connect RP-Initiated Logout Request parameters + * @since 1.4 + */ + public void setAuthenticationValidator(Consumer authenticationValidator) { + Assert.notNull(authenticationValidator, "authenticationValidator cannot be null"); + this.authenticationValidator = authenticationValidator; + } + private SessionInformation findSessionInformation(Authentication principal, String sessionId) { List sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), true); SessionInformation sessionInformation = null; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationValidator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationValidator.java new file mode 100644 index 000000000..4524b4c02 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationValidator.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.List; +import java.util.function.Consumer; + +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * A {@code Consumer} providing access to the {@link OidcLogoutAuthenticationContext} + * containing an {@link OidcLogoutAuthenticationToken} and is the default + * {@link OidcLogoutAuthenticationProvider#setAuthenticationValidator(Consumer) + * authentication validator} used for validating specific OpenID Connect RP-Initiated + * Logout parameters used in the Authorization Code Grant. + * + *

+ * The default implementation first validates {@link OidcIdToken#getAudience()}, and then + * {@link OidcLogoutAuthenticationToken#getPostLogoutRedirectUri()}. If validation fails, + * an {@link OAuth2AuthenticationException} is thrown. + * + * @author Daniel Garnier-Moiroux + * @since 1.4 + * @see OidcLogoutAuthenticationContext + * @see OidcLogoutAuthenticationToken + * @see OidcLogoutAuthenticationProvider#setAuthenticationValidator(Consumer) + */ +public final class OidcLogoutAuthenticationValidator implements Consumer { + + /** + * The default validator for {@link OidcIdToken#getAudience()}. + */ + public static final Consumer DEFAULT_AUDIENCE_VALIDATOR = OidcLogoutAuthenticationValidator::validateAudience; + + /** + * The default validator for + * {@link OidcLogoutAuthenticationToken#getPostLogoutRedirectUri()}. + */ + public static final Consumer DEFAULT_POST_LOGOUT_REDIRECT_URI_VALIDATOR = OidcLogoutAuthenticationValidator::validatePostLogoutRedirectUri; + + private final Consumer authenticationValidator = DEFAULT_AUDIENCE_VALIDATOR + .andThen(DEFAULT_POST_LOGOUT_REDIRECT_URI_VALIDATOR); + + @Override + public void accept(OidcLogoutAuthenticationContext authenticationContext) { + this.authenticationValidator.accept(authenticationContext); + } + + private static void validatePostLogoutRedirectUri(OidcLogoutAuthenticationContext authenticationContext) { + OidcLogoutAuthenticationToken oidcLogoutAuthentication = authenticationContext.getAuthentication(); + RegisteredClient registeredClient = authenticationContext.getRegisteredClient(); + if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri()) + && !registeredClient.getPostLogoutRedirectUris() + .contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri"); + } + } + + private static void validateAudience(OidcLogoutAuthenticationContext authenticationContext) { + OidcIdToken idToken = authenticationContext.getIdToken(); + List audClaim = idToken.getAudience(); + RegisteredClient registeredClient = authenticationContext.getRegisteredClient(); + if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) { + throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD); + } + } + + private static void throwError(String errorCode, String parameterName) { + OAuth2Error error = new OAuth2Error(errorCode, "OpenID Connect 1.0 Logout Request Parameter: " + parameterName, + "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling"); + throw new OAuth2AuthenticationException(error); + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java index b7f91b808..30a73db25 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.Date; import java.util.List; +import java.util.function.Consumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -53,6 +54,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -314,6 +316,35 @@ public void authenticateWhenInvalidPostLogoutRedirectUriThenThrowOAuth2Authentic verify(this.registeredClientRepository).findById(eq(authorization.getRegisteredClientId())); } + @Test + void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setAuthenticationValidator(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationValidator cannot be null"); + } + + @Test + public void authenticateWhenCustomAuthenticationValidatorThenUsed() throws NoSuchAlgorithmException { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + String sessionId = "session-1"; + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId())) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .claim("sid", createHash(sessionId)) + .build(); + + @SuppressWarnings("unchecked") + Consumer authenticationValidator = mock(Consumer.class); + this.authenticationProvider.setAuthenticationValidator(authenticationValidator); + + authenticateValidIdToken(principal, registeredClient, sessionId, idToken); + verify(authenticationValidator).accept(any()); + } + @Test public void authenticateWhenMissingSubThenThrowOAuth2AuthenticationException() { TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials");