Skip to content

Commit

Permalink
Support custom validation in OidcLogoutAuthenticationProvider
Browse files Browse the repository at this point in the history
- Similar to custom validation in OAuth2AuthorizationCodeRequestAuthenticationProvider
- Closes gh-1693
  • Loading branch information
Kehrlann committed Sep 17, 2024
1 parent 052a0a6 commit bbd9476
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.oidc.authentication;

import java.util.Map;
import java.util.function.Consumer;

import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationContext;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;

/**
* An {@link OAuth2AuthenticationContext} that holds an
* {@link OidcLogoutAuthenticationToken} and additional information and is used when
* validating the OpenID Connect RP-Initiated Logout Request parameters.
*
* @author Daniel Garnier-Moiroux
* @since 1.4
* @see OAuth2AuthenticationContext
* @see OidcLogoutAuthenticationToken
* @see OidcLogoutAuthenticationProvider#setAuthenticationValidator(Consumer)
*/
public final class OidcLogoutAuthenticationContext implements OAuth2AuthenticationContext {

private final Map<Object, Object> context;

private OidcLogoutAuthenticationContext(Map<Object, Object> context) {
this.context = context;
}

@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}

@Override
public boolean hasKey(Object key) {
Assert.notNull(key, "key cannot be null");
return this.context.containsKey(key);
}

/**
* Returns the {@link RegisteredClient registered client}.
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
}

/**
* Returns the {@link OAuth2Authorization authorization request}.
* @return the {@link OAuth2Authorization}
*/
public OAuth2Authorization getAuthorizationRequest() {
return get(OAuth2Authorization.class);
}

/**
* Returns the {@link OidcIdToken id_token}.
* @return the {@link OidcIdToken}
*/
public OidcIdToken getIdToken() {
return get(OidcIdToken.class);
}

/**
* Constructs a new {@link Builder} with the provided
* {@link OidcLogoutAuthenticationToken}.
* @param authentication the {@link OidcLogoutAuthenticationToken}
* @return the {@link Builder}
*/
public static Builder with(OidcLogoutAuthenticationToken authentication) {
return new Builder(authentication);
}

/**
* A builder for {@link OidcLogoutAuthenticationContext}.
*/
public static final class Builder extends AbstractBuilder<OidcLogoutAuthenticationContext, Builder> {

private Builder(Authentication authentication) {
super(authentication);
}

/**
* Sets the {@link RegisteredClient registered client}.
* @param registeredClient the {@link RegisteredClient}
* @return the {@link Builder} for further configuration
*/
public Builder registeredClient(RegisteredClient registeredClient) {
return put(RegisteredClient.class, registeredClient);
}

/**
* Sets the {@link OAuth2Authorization registered client}.
* @param authorization the {@link OAuth2Authorization}
* @return the {@link Builder} for further configuration
*/
public Builder authorization(OAuth2Authorization authorization) {
return put(OAuth2Authorization.class, authorization);
}

/**
* Sets the {@link OidcIdToken id_token}.
* @param idToken the {@link OidcIdToken}
* @return the {@link Builder} for further configuration
*/
public Builder idToken(OidcIdToken idToken) {
return put(OidcIdToken.class, idToken);
}

/**
* Builds a new {@link OidcLogoutAuthenticationContext}.
* @return the {@link OidcLogoutAuthenticationContext}
*/
@Override
public OidcLogoutAuthenticationContext build() {
return new OidcLogoutAuthenticationContext(getContext());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.security.Principal;
import java.util.Base64;
import java.util.List;
import java.util.function.Consumer;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -70,6 +71,8 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro

private final SessionRegistry sessionRegistry;

private Consumer<OidcLogoutAuthenticationContext> authenticationValidator = new OidcLogoutAuthenticationValidator();

/**
* Constructs an {@code OidcLogoutAuthenticationProvider} using the provided
* parameters.
Expand Down Expand Up @@ -118,19 +121,16 @@ public Authentication authenticate(Authentication authentication) throws Authent
OidcIdToken idToken = authorizedIdToken.getToken();

// Validate client identity
List<String> audClaim = idToken.getAudience();
if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
}
if (StringUtils.hasText(oidcLogoutAuthentication.getClientId())
&& !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
}
if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())
&& !registeredClient.getPostLogoutRedirectUris()
.contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri");
}
OidcLogoutAuthenticationContext context = OidcLogoutAuthenticationContext.with(oidcLogoutAuthentication)
.registeredClient(registeredClient)
.authorization(authorization)
.idToken(idToken)
.build();
this.authenticationValidator.accept(context);

if (this.logger.isTraceEnabled()) {
this.logger.trace("Validated logout request parameters");
Expand Down Expand Up @@ -182,6 +182,26 @@ public boolean supports(Class<?> authentication) {
return OidcLogoutAuthenticationToken.class.isAssignableFrom(authentication);
}

/**
* Sets the {@code Consumer} providing access to the
* {@link OidcLogoutAuthenticationContext} and is responsible for validating specific
* Open ID Connect RP-Initiated Logout Request parameters associated in the
* {@link OidcLogoutAuthenticationToken}. The default authentication validator is
* {@link OidcLogoutAuthenticationValidator}.
*
* <p>
* <b>NOTE:</b> The authentication validator MUST throw
* {@link OAuth2AuthenticationException} if validation fails.
* @param authenticationValidator the {@code Consumer} providing access to the
* {@link OidcLogoutAuthenticationContext} and is responsible for validating specific
* Open ID Connect RP-Initiated Logout Request parameters
* @since 1.4
*/
public void setAuthenticationValidator(Consumer<OidcLogoutAuthenticationContext> authenticationValidator) {
Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
this.authenticationValidator = authenticationValidator;
}

private SessionInformation findSessionInformation(Authentication principal, String sessionId) {
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), true);
SessionInformation sessionInformation = null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.oidc.authentication;

import java.util.List;
import java.util.function.Consumer;

import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* A {@code Consumer} providing access to the {@link OidcLogoutAuthenticationContext}
* containing an {@link OidcLogoutAuthenticationToken} and is the default
* {@link OidcLogoutAuthenticationProvider#setAuthenticationValidator(Consumer)
* authentication validator} used for validating specific OpenID Connect RP-Initiated
* Logout parameters used in the Authorization Code Grant.
*
* <p>
* The default implementation first validates {@link OidcIdToken#getAudience()}, and then
* {@link OidcLogoutAuthenticationToken#getPostLogoutRedirectUri()}. If validation fails,
* an {@link OAuth2AuthenticationException} is thrown.
*
* @author Daniel Garnier-Moiroux
* @since 1.4
* @see OidcLogoutAuthenticationContext
* @see OidcLogoutAuthenticationToken
* @see OidcLogoutAuthenticationProvider#setAuthenticationValidator(Consumer)
*/
public final class OidcLogoutAuthenticationValidator implements Consumer<OidcLogoutAuthenticationContext> {

/**
* The default validator for {@link OidcIdToken#getAudience()}.
*/
public static final Consumer<OidcLogoutAuthenticationContext> DEFAULT_AUDIENCE_VALIDATOR = OidcLogoutAuthenticationValidator::validateAudience;

/**
* The default validator for
* {@link OidcLogoutAuthenticationToken#getPostLogoutRedirectUri()}.
*/
public static final Consumer<OidcLogoutAuthenticationContext> DEFAULT_POST_LOGOUT_REDIRECT_URI_VALIDATOR = OidcLogoutAuthenticationValidator::validatePostLogoutRedirectUri;

private final Consumer<OidcLogoutAuthenticationContext> authenticationValidator = DEFAULT_AUDIENCE_VALIDATOR
.andThen(DEFAULT_POST_LOGOUT_REDIRECT_URI_VALIDATOR);

@Override
public void accept(OidcLogoutAuthenticationContext authenticationContext) {
this.authenticationValidator.accept(authenticationContext);
}

private static void validatePostLogoutRedirectUri(OidcLogoutAuthenticationContext authenticationContext) {
OidcLogoutAuthenticationToken oidcLogoutAuthentication = authenticationContext.getAuthentication();
RegisteredClient registeredClient = authenticationContext.getRegisteredClient();
if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())
&& !registeredClient.getPostLogoutRedirectUris()
.contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri");
}
}

private static void validateAudience(OidcLogoutAuthenticationContext authenticationContext) {
OidcIdToken idToken = authenticationContext.getIdToken();
List<String> audClaim = idToken.getAudience();
RegisteredClient registeredClient = authenticationContext.getRegisteredClient();
if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
}
}

private static void throwError(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OpenID Connect 1.0 Logout Request Parameter: " + parameterName,
"https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling");
throw new OAuth2AuthenticationException(error);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.function.Consumer;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -53,6 +54,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -314,6 +316,35 @@ public void authenticateWhenInvalidPostLogoutRedirectUriThenThrowOAuth2Authentic
verify(this.registeredClientRepository).findById(eq(authorization.getRegisteredClientId()));
}

@Test
void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authenticationProvider.setAuthenticationValidator(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authenticationValidator cannot be null");
}

@Test
public void authenticateWhenCustomAuthenticationValidatorThenUsed() throws NoSuchAlgorithmException {
TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials");
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
String sessionId = "session-1";
OidcIdToken idToken = OidcIdToken.withTokenValue("id-token")
.issuer("https://provider.com")
.subject(principal.getName())
.audience(Collections.singleton(registeredClient.getClientId()))
.issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS))
.expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS))
.claim("sid", createHash(sessionId))
.build();

@SuppressWarnings("unchecked")
Consumer<OidcLogoutAuthenticationContext> authenticationValidator = mock(Consumer.class);
this.authenticationProvider.setAuthenticationValidator(authenticationValidator);

authenticateValidIdToken(principal, registeredClient, sessionId, idToken);
verify(authenticationValidator).accept(any());
}

@Test
public void authenticateWhenMissingSubThenThrowOAuth2AuthenticationException() {
TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials");
Expand Down

0 comments on commit bbd9476

Please sign in to comment.