Skip to content

Commit

Permalink
Add convenience method for invalidating an OAuth2Token
Browse files Browse the repository at this point in the history
Closes gh-1717
  • Loading branch information
jgrandja committed Sep 12, 2024
1 parent 5b7e815 commit 8edbc26
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 101 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -479,7 +479,6 @@ public <T extends OAuth2Token> Builder token(T token) {
* @return the {@link Builder}
*/
public <T extends OAuth2Token> Builder token(T token, Consumer<Map<String, Object>> metadataConsumer) {

Assert.notNull(token, "token cannot be null");
Map<String, Object> metadata = Token.defaultMetadata();
Token<?> existingToken = this.tokens.get(token.getClass());
Expand All @@ -492,6 +491,33 @@ public <T extends OAuth2Token> Builder token(T token, Consumer<Map<String, Objec
return this;
}

/**
* Invalidates the {@link OAuth2Token token}.
* @param token the token
* @param <T> the type of the token
* @return the {@link Builder}
* @since 1.4
*/
public <T extends OAuth2Token> Builder invalidate(T token) {
Assert.notNull(token, "token cannot be null");
if (this.tokens.get(token.getClass()) == null) {
return this;
}
token(token, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
Token<?> accessToken = this.tokens.get(OAuth2AccessToken.class);
token(accessToken.getToken(),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

Token<?> authorizationCode = this.tokens.get(OAuth2AuthorizationCode.class);
if (authorizationCode != null && !authorizationCode.isInvalidated()) {
token(authorizationCode.getToken(),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
}
}
return this;
}

protected final Builder tokens(Map<Class<? extends OAuth2Token>, Token<?>> tokens) {
this.tokens = new HashMap<>(tokens);
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,10 +21,8 @@
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;

Expand All @@ -50,34 +48,6 @@ static OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidCli
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
}

static <T extends OAuth2Token> OAuth2Authorization invalidate(OAuth2Authorization authorization, T token) {

// @formatter:off
OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
.token(token,
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
authorizationBuilder.token(
authorization.getAccessToken().getToken(),
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
authorization.getToken(OAuth2AuthorizationCode.class);
if (authorizationCode != null && !authorizationCode.isInvalidated()) {
authorizationBuilder.token(
authorizationCode.getToken(),
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
}
}
// @formatter:on

return authorizationBuilder.build();
}

static <T extends OAuth2Token> OAuth2AccessToken accessToken(OAuth2Authorization.Builder builder, T token,
OAuth2TokenContext accessTokenContext) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ public Authentication authenticate(Authentication authentication) throws Authent
if (!authorizationCode.isInvalidated()) {
// Invalidate the authorization code given that a different client is
// attempting to use it
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization,
authorizationCode.getToken());
authorization = OAuth2Authorization.from(authorization)
.invalidate(authorizationCode.getToken())
.build();
this.authorizationService.save(authorization);
if (this.logger.isWarnEnabled()) {
this.logger.warn(LogMessage.format("Invalidated authorization code used by registered client '%s'",
Expand All @@ -172,7 +173,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
if (token != null) {
// Invalidate the access (and refresh) token as the client is
// attempting to use the authorization code more than once
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build();
this.authorizationService.save(authorization);
if (this.logger.isWarnEnabled()) {
this.logger.warn(LogMessage.format(
Expand Down Expand Up @@ -284,10 +285,10 @@ public Authentication authenticate(Authentication authentication) throws Authent
idToken = null;
}

authorization = authorizationBuilder.build();

// Invalidate the authorization code as it can only be used once
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());
authorizationBuilder.invalidate(authorizationCode.getToken());

authorization = authorizationBuilder.build();

this.authorizationService.save(authorization);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -187,10 +187,8 @@ public Authentication authenticate(Authentication authentication) throws Authent
}
}
authorization = OAuth2Authorization.from(authorization)
.token((deviceCodeToken.getToken()),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
.token((userCodeToken.getToken()),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
.invalidate(deviceCodeToken.getToken())
.invalidate(userCodeToken.getToken())
.attributes((attrs) -> attrs.remove(OAuth2ParameterNames.STATE))
.build();
this.authorizationService.save(authorization);
Expand All @@ -210,8 +208,7 @@ public Authentication authenticate(Authentication authentication) throws Authent

authorization = OAuth2Authorization.from(authorization)
.authorizedScopes(authorizedScopes)
.token((userCodeToken.getToken()),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
.invalidate(userCodeToken.getToken())
.attributes((attrs) -> attrs.remove(OAuth2ParameterNames.STATE))
.attributes((attrs) -> attrs.remove(OAuth2ParameterNames.SCOPE))
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -124,7 +124,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
if (!deviceCode.isInvalidated()) {
// Invalidate the device code given that a different client is attempting
// to use it
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, deviceCode.getToken());
authorization = OAuth2Authorization.from(authorization).invalidate(deviceCode.getToken()).build();
this.authorizationService.save(authorization);
if (this.logger.isWarnEnabled()) {
this.logger.warn(LogMessage.format("Invalidated device code used by registered client '%s'",
Expand Down Expand Up @@ -172,7 +172,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
// restarting to avoid unnecessary polling.
if (deviceCode.isExpired()) {
// Invalidate the device code
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, deviceCode.getToken());
authorization = OAuth2Authorization.from(authorization).invalidate(deviceCode.getToken()).build();
this.authorizationService.save(authorization);
if (this.logger.isWarnEnabled()) {
this.logger.warn(LogMessage.format("Invalidated device code used by registered client '%s'",
Expand Down Expand Up @@ -200,8 +200,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
// @formatter:off
OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
// Invalidate the device code as it can only be used (successfully) once
.token(deviceCode.getToken(), (metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
.invalidate(deviceCode.getToken());
// @formatter:on

// ----- Access token -----
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -166,8 +166,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
authorization = OAuth2Authorization.from(authorization)
.principalName(principal.getName())
.authorizedScopes(requestedScopes)
.token(userCode.getToken(), (metadata) -> metadata
.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
.invalidate(userCode.getToken())
.attribute(Principal.class.getName(), principal)
.attributes((attributes) -> attributes.remove(OAuth2ParameterNames.SCOPE))
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -79,7 +79,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
}

OAuth2Authorization.Token<OAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken());
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build();
this.authorizationService.save(authorization);

if (this.logger.isTraceEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,10 +18,8 @@
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;

Expand All @@ -36,34 +34,6 @@ final class OidcAuthenticationProviderUtils {
private OidcAuthenticationProviderUtils() {
}

static <T extends OAuth2Token> OAuth2Authorization invalidate(OAuth2Authorization authorization, T token) {

// @formatter:off
OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
.token(token,
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
authorizationBuilder.token(
authorization.getAccessToken().getToken(),
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
authorization.getToken(OAuth2AuthorizationCode.class);
if (authorizationCode != null && !authorizationCode.isInvalidated()) {
authorizationBuilder.token(
authorizationCode.getToken(),
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
}
}
// @formatter:on

return authorizationBuilder.build();
}

static <T extends OAuth2Token> OAuth2AccessToken accessToken(OAuth2Authorization.Builder builder, T token,
OAuth2TokenContext accessTokenContext) {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -260,12 +260,12 @@ private OidcClientRegistrationAuthenticationToken registerClient(
OAuth2Authorization registeredClientAuthorization = registerAccessToken(registeredClient);

// Invalidate the "initial" access token as it can only be used once
authorization = OidcAuthenticationProviderUtils.invalidate(authorization,
authorization.getAccessToken().getToken());
OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization)
.invalidate(authorization.getAccessToken().getToken());
if (authorization.getRefreshToken() != null) {
authorization = OidcAuthenticationProviderUtils.invalidate(authorization,
authorization.getRefreshToken().getToken());
builder.invalidate(authorization.getRefreshToken().getToken());
}
authorization = builder.build();
this.authorizationService.save(authorization);

if (this.logger.isTraceEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -147,7 +147,7 @@ public void authenticateWhenTokenInvalidatedThenNotActive() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, accessToken);
authorization = OAuth2Authorization.from(authorization).invalidate(accessToken).build();
given(this.authorizationService.findByToken(eq(accessToken.getTokenValue()), isNull()))
.willReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -176,8 +176,8 @@ public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationExc
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, jwtAccessToken, jwt.getClaims())
.invalidate(jwtAccessToken)
.build();
authorization = OidcAuthenticationProviderUtils.invalidate(authorization, jwtAccessToken);
given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()),
eq(OAuth2TokenType.ACCESS_TOKEN)))
.willReturn(authorization);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -250,8 +250,8 @@ public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationExc
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, jwtAccessToken, jwt.getClaims())
.invalidate(jwtAccessToken)
.build();
authorization = OidcAuthenticationProviderUtils.invalidate(authorization, jwtAccessToken);
given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()),
eq(OAuth2TokenType.ACCESS_TOKEN)))
.willReturn(authorization);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -133,8 +133,9 @@ public void authenticateWhenAccessTokenNotFoundThenThrowOAuth2AuthenticationExce
public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationException() {
String tokenValue = "token";
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
authorization = OidcAuthenticationProviderUtils.invalidate(authorization,
authorization.getAccessToken().getToken());
authorization = OAuth2Authorization.from(authorization)
.invalidate(authorization.getAccessToken().getToken())
.build();
given(this.authorizationService.findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)))
.willReturn(authorization);

Expand Down

0 comments on commit 8edbc26

Please sign in to comment.