From f885df4343b4961732157238082ed5407bc5632c Mon Sep 17 00:00:00 2001 From: Joe Grandja <10884212+jgrandja@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:48:37 -0400 Subject: [PATCH] Allow customizing LogoutHandler in OidcLogoutEndpointFilter Closes gh-1244 --- .../ROOT/pages/protocol-endpoints.adoc | 2 +- .../oidc/web/OidcLogoutEndpointFilter.java | 60 +-------- ...idcLogoutAuthenticationSuccessHandler.java | 124 ++++++++++++++++++ ...goutAuthenticationSuccessHandlerTests.java | 94 +++++++++++++ 4 files changed, 223 insertions(+), 57 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandler.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandlerTests.java diff --git a/docs/modules/ROOT/pages/protocol-endpoints.adoc b/docs/modules/ROOT/pages/protocol-endpoints.adoc index 5cdd65c52..011dbbffb 100644 --- a/docs/modules/ROOT/pages/protocol-endpoints.adoc +++ b/docs/modules/ROOT/pages/protocol-endpoints.adoc @@ -545,7 +545,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h * `*AuthenticationConverter*` -- An `OidcLogoutAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcLogoutAuthenticationProvider`. -* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcLogoutAuthenticationToken` and performs the logout. +* `*AuthenticationSuccessHandler*` -- An `OidcLogoutAuthenticationSuccessHandler`. * `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. [NOTE] diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java index aa3b9dd30..4921bc837 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.security.oauth2.server.authorization.oidc.web; import java.io.IOException; -import java.nio.charset.StandardCharsets; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -32,27 +31,18 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcLogoutAuthenticationConverter; -import org.springframework.security.web.DefaultRedirectStrategy; -import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcLogoutAuthenticationSuccessHandler; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; -import org.springframework.security.web.authentication.logout.LogoutHandler; -import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; -import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler; -import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; /** * A {@code Filter} that processes OpenID Connect 1.0 RP-Initiated Logout Requests. @@ -60,6 +50,7 @@ * @author Joe Grandja * @since 1.1 * @see OidcLogoutAuthenticationConverter + * @see OidcLogoutAuthenticationSuccessHandler * @see OidcLogoutAuthenticationProvider * @see 2. * RP-Initiated Logout @@ -76,15 +67,9 @@ public final class OidcLogoutEndpointFilter extends OncePerRequestFilter { private final RequestMatcher logoutEndpointMatcher; - private final LogoutHandler logoutHandler; - - private final LogoutSuccessHandler logoutSuccessHandler; - - private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); - private AuthenticationConverter authenticationConverter; - private AuthenticationSuccessHandler authenticationSuccessHandler = this::performLogout; + private AuthenticationSuccessHandler authenticationSuccessHandler = new OidcLogoutAuthenticationSuccessHandler(); private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; @@ -109,10 +94,6 @@ public OidcLogoutEndpointFilter(AuthenticationManager authenticationManager, Str this.logoutEndpointMatcher = new OrRequestMatcher( new AntPathRequestMatcher(logoutEndpointUri, HttpMethod.GET.name()), new AntPathRequestMatcher(logoutEndpointUri, HttpMethod.POST.name())); - this.logoutHandler = new SecurityContextLogoutHandler(); - SimpleUrlLogoutSuccessHandler urlLogoutSuccessHandler = new SimpleUrlLogoutSuccessHandler(); - urlLogoutSuccessHandler.setDefaultTargetUrl("/"); - this.logoutSuccessHandler = urlLogoutSuccessHandler; this.authenticationConverter = new OidcLogoutAuthenticationConverter(); } @@ -187,39 +168,6 @@ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authent this.authenticationFailureHandler = authenticationFailureHandler; } - private void performLogout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) - throws IOException, ServletException { - - OidcLogoutAuthenticationToken oidcLogoutAuthentication = (OidcLogoutAuthenticationToken) authentication; - - // Check for active user session - if (oidcLogoutAuthentication.isPrincipalAuthenticated() - && StringUtils.hasText(oidcLogoutAuthentication.getSessionId())) { - // Perform logout - this.logoutHandler.logout(request, response, (Authentication) oidcLogoutAuthentication.getPrincipal()); - } - - if (oidcLogoutAuthentication.isAuthenticated() - && StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { - // Perform post-logout redirect - UriComponentsBuilder uriBuilder = UriComponentsBuilder - .fromUriString(oidcLogoutAuthentication.getPostLogoutRedirectUri()); - String redirectUri; - if (StringUtils.hasText(oidcLogoutAuthentication.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, - UriUtils.encode(oidcLogoutAuthentication.getState(), StandardCharsets.UTF_8)); - } - // build(true) -> Components are explicitly encoded - redirectUri = uriBuilder.build(true).toUriString(); - this.redirectStrategy.sendRedirect(request, response, redirectUri); - } - else { - // Perform default redirect - this.logoutSuccessHandler.onLogoutSuccess(request, response, - (Authentication) oidcLogoutAuthentication.getPrincipal()); - } - } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandler.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandler.java new file mode 100644 index 000000000..150088234 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandler.java @@ -0,0 +1,124 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web.authentication; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.oidc.web.OidcLogoutEndpointFilter; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.logout.LogoutHandler; +import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * An implementation of an {@link AuthenticationSuccessHandler} used for handling an + * {@link OidcLogoutAuthenticationToken} and performing the OpenID Connect 1.0 + * RP-Initiated Logout. + * + * @author Joe Grandja + * @since 1.4 + * @see OidcLogoutEndpointFilter#setAuthenticationSuccessHandler(AuthenticationSuccessHandler) + * @see LogoutHandler + */ +public final class OidcLogoutAuthenticationSuccessHandler implements AuthenticationSuccessHandler { + + private final Log logger = LogFactory.getLog(getClass()); + + private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + + private final SecurityContextLogoutHandler securityContextLogoutHandler = new SecurityContextLogoutHandler(); + + private LogoutHandler logoutHandler = this::performLogout; + + @Override + public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException, ServletException { + + if (!(authentication instanceof OidcLogoutAuthenticationToken)) { + if (this.logger.isErrorEnabled()) { + this.logger.error(Authentication.class.getSimpleName() + " must be of type " + + OidcLogoutAuthenticationToken.class.getName() + " but was " + + authentication.getClass().getName()); + } + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "Unable to process the OpenID Connect 1.0 RP-Initiated Logout response.", null); + throw new OAuth2AuthenticationException(error); + } + + this.logoutHandler.logout(request, response, authentication); + + sendLogoutRedirect(request, response, authentication); + } + + /** + * Sets the {@link LogoutHandler} used for performing logout. + * @param logoutHandler the {@link LogoutHandler} used for performing logout + */ + public void setLogoutHandler(LogoutHandler logoutHandler) { + Assert.notNull(logoutHandler, "logoutHandler cannot be null"); + this.logoutHandler = logoutHandler; + } + + private void performLogout(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) { + OidcLogoutAuthenticationToken oidcLogoutAuthentication = (OidcLogoutAuthenticationToken) authentication; + + // Check for active user session + if (oidcLogoutAuthentication.isPrincipalAuthenticated()) { + this.securityContextLogoutHandler.logout(request, response, + (Authentication) oidcLogoutAuthentication.getPrincipal()); + } + } + + private void sendLogoutRedirect(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException { + OidcLogoutAuthenticationToken oidcLogoutAuthentication = (OidcLogoutAuthenticationToken) authentication; + + String redirectUri = "/"; + if (oidcLogoutAuthentication.isAuthenticated() + && StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { + // Use the `post_logout_redirect_uri` parameter + UriComponentsBuilder uriBuilder = UriComponentsBuilder + .fromUriString(oidcLogoutAuthentication.getPostLogoutRedirectUri()); + if (StringUtils.hasText(oidcLogoutAuthentication.getState())) { + uriBuilder.queryParam(OAuth2ParameterNames.STATE, + UriUtils.encode(oidcLogoutAuthentication.getState(), StandardCharsets.UTF_8)); + } + // build(true) -> Components are explicitly encoded + redirectUri = uriBuilder.build(true).toUriString(); + } + this.redirectStrategy.sendRedirect(request, response, redirectUri); + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandlerTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandlerTests.java new file mode 100644 index 000000000..b39b78cc5 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationSuccessHandlerTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web.authentication; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; +import org.springframework.security.web.authentication.logout.LogoutHandler; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OidcLogoutAuthenticationSuccessHandler}. + * + * @author Joe Grandja + */ +public class OidcLogoutAuthenticationSuccessHandlerTests { + + private TestingAuthenticationToken principal; + + private final OidcLogoutAuthenticationSuccessHandler authenticationSuccessHandler = new OidcLogoutAuthenticationSuccessHandler(); + + @BeforeEach + public void setUp() { + this.principal = new TestingAuthenticationToken("principal", "credentials"); + this.principal.setAuthenticated(true); + } + + @Test + public void setLogoutHandlerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authenticationSuccessHandler.setLogoutHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("logoutHandler cannot be null"); + // @formatter:on + } + + @Test + public void onAuthenticationSuccessWhenInvalidAuthenticationTypeThenThrowOAuth2AuthenticationException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + assertThatThrownBy( + () -> this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, this.principal)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting((ex) -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + } + + @Test + public void onAuthenticationSuccessWhenLogoutHandlerSetThenUsed() throws Exception { + LogoutHandler logoutHandler = mock(LogoutHandler.class); + this.authenticationSuccessHandler.setLogoutHandler(logoutHandler); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpSession session = (MockHttpSession) request.getSession(true); + MockHttpServletResponse response = new MockHttpServletResponse(); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken("id-token", this.principal, + session.getId(), null, null, null); + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication); + + verify(logoutHandler).logout(any(HttpServletRequest.class), any(HttpServletResponse.class), + any(Authentication.class)); + } + +}