Skip to content

Commit

Permalink
wip wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Kehrlann committed Sep 17, 2024
1 parent 8edbc26 commit d4653af
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.oidc.authentication;

import java.util.Map;

import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationContext;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;

/**
* TODO
*
* @author Daniel Garnier-Moiroux
* @since 1.4
*/
public final class OidcLogoutAuthenticationContext implements OAuth2AuthenticationContext {

private final Map<Object, Object> context;

private OidcLogoutAuthenticationContext(Map<Object, Object> context) {
this.context = context;
}

@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}

@Override
public boolean hasKey(Object key) {
Assert.notNull(key, "key cannot be null");
return this.context.containsKey(key);
}

/**
* Returns the {@link RegisteredClient registered client}.
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
}

/**
* Returns the {@link OAuth2Authorization authorization request}.
* @return the {@link OAuth2Authorization}
*/
public OAuth2Authorization getAuthorizationRequest() {
return get(OAuth2Authorization.class);
}

/**
* Returns the {@link OidcIdToken id_token}.
* @return the {@link OidcIdToken}
*/
public OidcIdToken getIdToken() {
return get(OidcIdToken.class);
}

/**
* Constructs a new {@link Builder} with the provided
* {@link OidcLogoutAuthenticationToken}.
* @param authentication the {@link OidcLogoutAuthenticationToken}
* @return the {@link Builder}
*/
public static Builder with(OidcLogoutAuthenticationToken authentication) {
return new Builder(authentication);
}

/**
* A builder for {@link OidcLogoutAuthenticationContext}.
*/
public static final class Builder extends AbstractBuilder<OidcLogoutAuthenticationContext, Builder> {

private Builder(Authentication authentication) {
super(authentication);
}

/**
* Sets the {@link RegisteredClient registered client}.
* @param registeredClient the {@link RegisteredClient}
* @return the {@link Builder} for further configuration
*/
public Builder registeredClient(RegisteredClient registeredClient) {
return put(RegisteredClient.class, registeredClient);
}

/**
* Sets the {@link OAuth2Authorization registered client}.
* @param authorization the {@link OAuth2Authorization}
* @return the {@link Builder} for further configuration
*/
public Builder authorization(OAuth2Authorization authorization) {
return put(OAuth2Authorization.class, authorization);
}

/**
* Sets the {@link OidcIdToken id_token}.
* @param idToken the {@link OidcIdToken}
* @return the {@link Builder} for further configuration
*/
public Builder idToken(OidcIdToken idToken) {
return put(OidcIdToken.class, idToken);
}

/**
* Builds a new {@link OidcLogoutAuthenticationContext}.
* @return the {@link OidcLogoutAuthenticationContext}
*/
@Override
public OidcLogoutAuthenticationContext build() {
return new OidcLogoutAuthenticationContext(getContext());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.security.Principal;
import java.util.Base64;
import java.util.List;
import java.util.function.Consumer;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -70,6 +71,8 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro

private final SessionRegistry sessionRegistry;

private final Consumer<OidcLogoutAuthenticationContext> oidcLogoutAuthenticationValidator = new OidcLogoutAuthenticationValidator();

/**
* Constructs an {@code OidcLogoutAuthenticationProvider} using the provided
* parameters.
Expand Down Expand Up @@ -118,19 +121,16 @@ public Authentication authenticate(Authentication authentication) throws Authent
OidcIdToken idToken = authorizedIdToken.getToken();

// Validate client identity
List<String> audClaim = idToken.getAudience();
if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
}
if (StringUtils.hasText(oidcLogoutAuthentication.getClientId())
&& !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
}
if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())
&& !registeredClient.getPostLogoutRedirectUris()
.contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri");
}
OidcLogoutAuthenticationContext context = OidcLogoutAuthenticationContext.with(oidcLogoutAuthentication)
.registeredClient(registeredClient)
.authorization(authorization)
.idToken(idToken)
.build();
this.oidcLogoutAuthenticationValidator.accept(context);

if (this.logger.isTraceEnabled()) {
this.logger.trace("Validated logout request parameters");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.oidc.authentication;

import java.util.List;
import java.util.function.Consumer;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* TODO
*
* @author Daniel Garnier-Moiroux
* @since 1.4
*/
public final class OidcLogoutAuthenticationValidator implements Consumer<OidcLogoutAuthenticationContext> {

private static final Log LOGGER = LogFactory.getLog(OidcLogoutAuthenticationValidator.class);

/**
* The default validator for {@link OidcIdToken#getAudience()}.
*/
public static final Consumer<OidcLogoutAuthenticationContext> DEFAULT_AUDIENCE_VALIDATOR = OidcLogoutAuthenticationValidator::validateAudience;

/**
* The default validator for
* {@link OidcLogoutAuthenticationToken#getPostLogoutRedirectUri()}.
*/
public static final Consumer<OidcLogoutAuthenticationContext> DEFAULT_POST_LOGOUT_REDIRECT_URI_VALIDATOR = OidcLogoutAuthenticationValidator::validatePostLogoutRedirectUri;

private final Consumer<OidcLogoutAuthenticationContext> authenticationValidator = DEFAULT_AUDIENCE_VALIDATOR
.andThen(DEFAULT_POST_LOGOUT_REDIRECT_URI_VALIDATOR);

@Override
public void accept(OidcLogoutAuthenticationContext authenticationContext) {
this.authenticationValidator.accept(authenticationContext);
}

private static void validatePostLogoutRedirectUri(OidcLogoutAuthenticationContext authenticationContext) {
OidcLogoutAuthenticationToken oidcLogoutAuthentication = authenticationContext.getAuthentication();
RegisteredClient registeredClient = authenticationContext.getRegisteredClient();
if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())
&& !registeredClient.getPostLogoutRedirectUris()
.contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri");
}
}

private static void validateAudience(OidcLogoutAuthenticationContext authenticationContext) {
OidcIdToken idToken = authenticationContext.getIdToken();
List<String> audClaim = idToken.getAudience();
RegisteredClient registeredClient = authenticationContext.getRegisteredClient();
if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
}
}

private static void throwError(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OpenID Connect 1.0 Logout Request Parameter: " + parameterName,
"https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling");
throw new OAuth2AuthenticationException(error);
}

}

0 comments on commit d4653af

Please sign in to comment.