From a3466929e735eb1e60cd2567da03a99c56a5c474 Mon Sep 17 00:00:00 2001 From: Josh Long Date: Tue, 14 Nov 2023 16:01:53 -0800 Subject: [PATCH] This commit introduces AOT hints for types and resources used across the codebase and, more interestingly, specifically by the various JDBC-persisting repository implementations. This follows up on https://github.com/spring-projects/spring-authorization-server/issues/1380 --- ...horizationServerRuntimeHintsRegistrar.java | 150 ++++++------- ...JdbcOAuth2AuthorizationConsentService.java | 77 ++++--- .../JdbcOAuth2AuthorizationService.java | 190 ++++++++++------- .../JdbcRegisteredClientRepository.java | 197 +++++++++++------- .../resources/META-INF/spring/aot.factories | 2 + .../samples-demo-authorizationserver.gradle | 1 + .../DemoAuthorizationServerApplication.java | 4 + .../resources/META-INF/spring/aot.factories | 2 - 8 files changed, 358 insertions(+), 265 deletions(-) rename samples/demo-authorizationserver/src/main/java/sample/aot/hint/DemoAuthorizationServerRuntimeHints.java => oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AuthorizationServerRuntimeHintsRegistrar.java (57%) create mode 100644 oauth2-authorization-server/src/main/resources/META-INF/spring/aot.factories delete mode 100644 samples/demo-authorizationserver/src/main/resources/META-INF/spring/aot.factories diff --git a/samples/demo-authorizationserver/src/main/java/sample/aot/hint/DemoAuthorizationServerRuntimeHints.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AuthorizationServerRuntimeHintsRegistrar.java similarity index 57% rename from samples/demo-authorizationserver/src/main/java/sample/aot/hint/DemoAuthorizationServerRuntimeHints.java rename to oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AuthorizationServerRuntimeHintsRegistrar.java index 804679b0d..06c162f74 100644 --- a/samples/demo-authorizationserver/src/main/java/sample/aot/hint/DemoAuthorizationServerRuntimeHints.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AuthorizationServerRuntimeHintsRegistrar.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,27 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package sample.aot.hint; +package org.springframework.security.oauth2.server.authorization; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; - -import org.thymeleaf.expression.Lists; -import sample.web.AuthorizationConsentController; - -import org.springframework.aot.hint.BindingReflectionHintsRegistrar; -import org.springframework.aot.hint.MemberCategory; -import org.springframework.aot.hint.RuntimeHints; -import org.springframework.aot.hint.RuntimeHintsRegistrar; -import org.springframework.aot.hint.TypeReference; +import org.springframework.aot.hint.*; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.User; import org.springframework.security.jackson2.CoreJackson2Module; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -48,75 +35,78 @@ import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat; import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.security.web.jackson2.WebServletJackson2Module; +import org.springframework.util.ClassUtils; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; /** - * {@link RuntimeHintsRegistrar} that registers {@link RuntimeHints} required for the sample. - * Statically registered via META-INF/spring/aot.factories. + * {@link RuntimeHintsRegistrar} that registers {@link RuntimeHints} required for the + * sample. Statically registered via META-INF/spring/aot.factories. * * @author Joe Grandja + * @author Josh Long * @since 1.2 */ -class DemoAuthorizationServerRuntimeHints implements RuntimeHintsRegistrar { +class AuthorizationServerRuntimeHintsRegistrar implements RuntimeHintsRegistrar { + private final BindingReflectionHintsRegistrar reflectionHintsRegistrar = new BindingReflectionHintsRegistrar(); @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { - // Thymeleaf - hints.reflection().registerTypes( - Arrays.asList( - TypeReference.of(AuthorizationConsentController.ScopeWithDescription.class), - TypeReference.of(Lists.class) - ), builder -> - builder.withMembers(MemberCategory.DECLARED_FIELDS, - MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS) - ); - - // Collections -> UnmodifiableSet, UnmodifiableList, UnmodifiableMap, UnmodifiableRandomAccessList, etc. - hints.reflection().registerType( - Collections.class, MemberCategory.DECLARED_CLASSES); + // Collections -> UnmodifiableSet, UnmodifiableList, UnmodifiableMap, + // UnmodifiableRandomAccessList, etc. + hints.reflection().registerType(Collections.class, MemberCategory.DECLARED_CLASSES); // HashSet - hints.reflection().registerType( - HashSet.class, MemberCategory.DECLARED_FIELDS, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, - MemberCategory.INVOKE_DECLARED_METHODS); + hints.reflection().registerType(HashSet.class, MemberCategory.DECLARED_FIELDS, + MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS); // Spring Security and Spring Authorization Server - hints.reflection().registerTypes( - Arrays.asList( - TypeReference.of(AbstractAuthenticationToken.class), - TypeReference.of(WebAuthenticationDetails.class), - TypeReference.of(UsernamePasswordAuthenticationToken.class), - TypeReference.of(User.class), - TypeReference.of(OAuth2AuthenticationToken.class), - TypeReference.of(DefaultOidcUser.class), - TypeReference.of(DefaultOAuth2User.class), - TypeReference.of(OidcUserAuthority.class), - TypeReference.of(OAuth2UserAuthority.class), - TypeReference.of(SimpleGrantedAuthority.class), - TypeReference.of(OidcIdToken.class), - TypeReference.of(AbstractOAuth2Token.class), - TypeReference.of(OidcUserInfo.class), - TypeReference.of(OAuth2AuthorizationRequest.class), - TypeReference.of(AuthorizationGrantType.class), - TypeReference.of(OAuth2AuthorizationResponseType.class), - TypeReference.of(OAuth2TokenFormat.class) - ), builder -> - builder.withMembers(MemberCategory.DECLARED_FIELDS, - MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS) - ); + if (ClassUtils.isPresent("org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken", + ClassUtils.getDefaultClassLoader())) + hints.reflection().registerType( + TypeReference + .of("org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken"), + builder -> builder.withMembers(MemberCategory.DECLARED_FIELDS, + MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS)); + + hints.reflection().registerTypes(Arrays.asList(TypeReference.of(AbstractAuthenticationToken.class), + TypeReference.of(WebAuthenticationDetails.class), + TypeReference.of(UsernamePasswordAuthenticationToken.class), TypeReference.of(User.class), + TypeReference.of(DefaultOidcUser.class), TypeReference.of(DefaultOAuth2User.class), + TypeReference.of(OidcUserAuthority.class), TypeReference.of(OAuth2UserAuthority.class), + TypeReference.of(SimpleGrantedAuthority.class), TypeReference.of(OidcIdToken.class), + TypeReference.of(AbstractOAuth2Token.class), TypeReference.of(OidcUserInfo.class), + TypeReference.of(OAuth2AuthorizationRequest.class), TypeReference.of(AuthorizationGrantType.class), + TypeReference.of(OAuth2AuthorizationResponseType.class), TypeReference.of(OAuth2TokenFormat.class)), + builder -> builder.withMembers(MemberCategory.DECLARED_FIELDS, + MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS)); // Jackson Modules - Spring Security and Spring Authorization Server hints.reflection().registerTypes( - Arrays.asList( - TypeReference.of(CoreJackson2Module.class), - TypeReference.of(WebServletJackson2Module.class), - TypeReference.of(OAuth2ClientJackson2Module.class), - TypeReference.of(OAuth2AuthorizationServerJackson2Module.class) - ), builder -> - builder.withMembers(MemberCategory.DECLARED_FIELDS, - MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS) - ); + Set.of(CoreJackson2Module.class.getName(), WebServletJackson2Module.class.getName(), + OAuth2AuthorizationServerJackson2Module.class.getName()).stream().map(TypeReference::of) + .toList(), + builder -> builder.withMembers(MemberCategory.DECLARED_FIELDS, + MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS)); + + if (ClassUtils.isPresent("org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module", + ClassUtils.getDefaultClassLoader())) + hints.reflection().registerType( + TypeReference.of("org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module"), + b -> b.withMembers(MemberCategory.DECLARED_FIELDS, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, + MemberCategory.INVOKE_DECLARED_METHODS)); + + hints.reflection().registerTypes( + Set.of(CoreJackson2Module.class.getName(), WebServletJackson2Module.class.getName(), + OAuth2AuthorizationServerJackson2Module.class.getName()).stream().map(TypeReference::of) + .toList(), + builder -> builder.withMembers(MemberCategory.DECLARED_FIELDS, + MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS)); // Jackson Mixins - Spring Security and Spring Authorization Server try { @@ -126,8 +116,8 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { Class.forName("org.springframework.security.jackson2.UnmodifiableListMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName("org.springframework.security.jackson2.UnmodifiableMapMixin")); - this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), - Class.forName("org.springframework.security.oauth2.server.authorization.jackson2.UnmodifiableMapMixin")); + this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class + .forName("org.springframework.security.oauth2.server.authorization.jackson2.UnmodifiableMapMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName("org.springframework.security.oauth2.server.authorization.jackson2.HashSetMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), @@ -136,8 +126,8 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { Class.forName("org.springframework.security.jackson2.UsernamePasswordAuthenticationTokenMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName("org.springframework.security.jackson2.UserMixin")); - this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), - Class.forName("org.springframework.security.oauth2.client.jackson2.OAuth2AuthenticationTokenMixin")); + this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class + .forName("org.springframework.security.oauth2.client.jackson2.OAuth2AuthenticationTokenMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName("org.springframework.security.oauth2.client.jackson2.DefaultOidcUserMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), @@ -152,21 +142,15 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { Class.forName("org.springframework.security.oauth2.client.jackson2.OidcIdTokenMixin")); this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName("org.springframework.security.oauth2.client.jackson2.OidcUserInfoMixin")); - this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), - Class.forName("org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationRequestMixin")); - this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), - Class.forName("org.springframework.security.oauth2.server.authorization.jackson2.OAuth2TokenFormatMixin")); - } catch (ClassNotFoundException ex) { + this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName( + "org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationRequestMixin")); + this.reflectionHintsRegistrar.registerReflectionHints(hints.reflection(), Class.forName( + "org.springframework.security.oauth2.server.authorization.jackson2.OAuth2TokenFormatMixin")); + } + catch (ClassNotFoundException ex) { throw new RuntimeException(ex); } - // Sql Schema Resources - hints.resources().registerPattern( - "org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql"); - hints.resources().registerPattern( - "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql"); - hints.resources().registerPattern( - "org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql"); } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java index e489fc0a4..04e505b0c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java @@ -24,6 +24,10 @@ import java.util.Set; import java.util.function.Function; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.context.annotation.ImportRuntimeHints; +import org.springframework.core.io.ClassPathResource; import org.springframework.dao.DataRetrievalFailureException; import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; import org.springframework.jdbc.core.JdbcOperations; @@ -43,20 +47,32 @@ * {@link JdbcOperations} for {@link OAuth2AuthorizationConsent} persistence. * *

- * NOTE: This {@code OAuth2AuthorizationConsentService} depends on the table definition - * described in - * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql" and - * therefore MUST be defined in the database schema. + * NOTE: This {@code OAuth2AuthorizationConsentService} depends on the table + * definition described in + * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql" + * and therefore MUST be defined in the database schema. * * @author Ovidiu Popa + * @author Josh Long * @since 0.1.2 * @see OAuth2AuthorizationConsentService * @see OAuth2AuthorizationConsent * @see JdbcOperations * @see RowMapper */ +@ImportRuntimeHints(JdbcOAuth2AuthorizationConsentService.JdbcOAuth2AuthorizationConsentServiceRuntimeHintsRegistrar.class) public class JdbcOAuth2AuthorizationConsentService implements OAuth2AuthorizationConsentService { + static class JdbcOAuth2AuthorizationConsentServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.resources().registerResource(new ClassPathResource( + "org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql")); + } + + } + // @formatter:off private static final String COLUMN_NAMES = "registered_client_id, " + "principal_name, " @@ -87,13 +103,15 @@ public class JdbcOAuth2AuthorizationConsentService implements OAuth2Authorizatio private static final String REMOVE_AUTHORIZATION_CONSENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; private final JdbcOperations jdbcOperations; + private RowMapper authorizationConsentRowMapper; + private Function> authorizationConsentParametersMapper; /** - * Constructs a {@code JdbcOAuth2AuthorizationConsentService} using the provided parameters. - * - * @param jdbcOperations the JDBC operations + * Constructs a {@code JdbcOAuth2AuthorizationConsentService} using the provided + * parameters. + * @param jdbcOperations the JDBC operations * @param registeredClientRepository the registered client repository */ public JdbcOAuth2AuthorizationConsentService(JdbcOperations jdbcOperations, @@ -108,11 +126,12 @@ public JdbcOAuth2AuthorizationConsentService(JdbcOperations jdbcOperations, @Override public void save(OAuth2AuthorizationConsent authorizationConsent) { Assert.notNull(authorizationConsent, "authorizationConsent cannot be null"); - OAuth2AuthorizationConsent existingAuthorizationConsent = findById( - authorizationConsent.getRegisteredClientId(), authorizationConsent.getPrincipalName()); + OAuth2AuthorizationConsent existingAuthorizationConsent = findById(authorizationConsent.getRegisteredClientId(), + authorizationConsent.getPrincipalName()); if (existingAuthorizationConsent == null) { insertAuthorizationConsent(authorizationConsent); - } else { + } + else { updateAuthorizationConsent(authorizationConsent); } } @@ -138,8 +157,7 @@ public void remove(OAuth2AuthorizationConsent authorizationConsent) { Assert.notNull(authorizationConsent, "authorizationConsent cannot be null"); SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, authorizationConsent.getRegisteredClientId()), - new SqlParameterValue(Types.VARCHAR, authorizationConsent.getPrincipalName()) - }; + new SqlParameterValue(Types.VARCHAR, authorizationConsent.getPrincipalName()) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); this.jdbcOperations.update(REMOVE_AUTHORIZATION_CONSENT_SQL, pss); } @@ -151,7 +169,7 @@ public OAuth2AuthorizationConsent findById(String registeredClientId, String pri Assert.hasText(principalName, "principalName cannot be empty"); SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, registeredClientId), - new SqlParameterValue(Types.VARCHAR, principalName)}; + new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); List result = this.jdbcOperations.query(LOAD_AUTHORIZATION_CONSENT_SQL, pss, this.authorizationConsentRowMapper); @@ -162,22 +180,21 @@ public OAuth2AuthorizationConsent findById(String registeredClientId, String pri * Sets the {@link RowMapper} used for mapping the current row in * {@code java.sql.ResultSet} to {@link OAuth2AuthorizationConsent}. The default is * {@link OAuth2AuthorizationConsentRowMapper}. - * - * @param authorizationConsentRowMapper the {@link RowMapper} used for mapping the current - * row in {@code ResultSet} to {@link OAuth2AuthorizationConsent} + * @param authorizationConsentRowMapper the {@link RowMapper} used for mapping the + * current row in {@code ResultSet} to {@link OAuth2AuthorizationConsent} */ - public final void setAuthorizationConsentRowMapper(RowMapper authorizationConsentRowMapper) { + public final void setAuthorizationConsentRowMapper( + RowMapper authorizationConsentRowMapper) { Assert.notNull(authorizationConsentRowMapper, "authorizationConsentRowMapper cannot be null"); this.authorizationConsentRowMapper = authorizationConsentRowMapper; } /** - * Sets the {@code Function} used for mapping {@link OAuth2AuthorizationConsent} to - * a {@code List} of {@link SqlParameterValue}. The default is + * Sets the {@code Function} used for mapping {@link OAuth2AuthorizationConsent} to a + * {@code List} of {@link SqlParameterValue}. The default is * {@link OAuth2AuthorizationConsentParametersMapper}. - * * @param authorizationConsentParametersMapper the {@code Function} used for mapping - * {@link OAuth2AuthorizationConsent} to a {@code List} of {@link SqlParameterValue} + * {@link OAuth2AuthorizationConsent} to a {@code List} of {@link SqlParameterValue} */ public final void setAuthorizationConsentParametersMapper( Function> authorizationConsentParametersMapper) { @@ -198,10 +215,11 @@ protected final Function> ge } /** - * The default {@link RowMapper} that maps the current row in - * {@code ResultSet} to {@link OAuth2AuthorizationConsent}. + * The default {@link RowMapper} that maps the current row in {@code ResultSet} to + * {@link OAuth2AuthorizationConsent}. */ public static class OAuth2AuthorizationConsentRowMapper implements RowMapper { + private final RegisteredClientRepository registeredClientRepository; public OAuth2AuthorizationConsentRowMapper(RegisteredClientRepository registeredClientRepository) { @@ -214,13 +232,14 @@ public OAuth2AuthorizationConsent mapRow(ResultSet rs, int rowNum) throws SQLExc String registeredClientId = rs.getString("registered_client_id"); RegisteredClient registeredClient = this.registeredClientRepository.findById(registeredClientId); if (registeredClient == null) { - throw new DataRetrievalFailureException( - "The RegisteredClient with id '" + registeredClientId + "' was not found in the RegisteredClientRepository."); + throw new DataRetrievalFailureException("The RegisteredClient with id '" + registeredClientId + + "' was not found in the RegisteredClientRepository."); } String principalName = rs.getString("principal_name"); - OAuth2AuthorizationConsent.Builder builder = OAuth2AuthorizationConsent.withId(registeredClientId, principalName); + OAuth2AuthorizationConsent.Builder builder = OAuth2AuthorizationConsent.withId(registeredClientId, + principalName); String authorizationConsentAuthorities = rs.getString("authorities"); if (authorizationConsentAuthorities != null) { for (String authority : StringUtils.commaDelimitedListToSet(authorizationConsentAuthorities)) { @@ -240,7 +259,8 @@ protected final RegisteredClientRepository getRegisteredClientRepository() { * The default {@code Function} that maps {@link OAuth2AuthorizationConsent} to a * {@code List} of {@link SqlParameterValue}. */ - public static class OAuth2AuthorizationConsentParametersMapper implements Function> { + public static class OAuth2AuthorizationConsentParametersMapper + implements Function> { @Override public List apply(OAuth2AuthorizationConsent authorizationConsent) { @@ -252,7 +272,8 @@ public List apply(OAuth2AuthorizationConsent authorizationCon for (GrantedAuthority authority : authorizationConsent.getAuthorities()) { authorities.add(authority.getAuthority()); } - parameters.add(new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToDelimitedString(authorities, ","))); + parameters.add( + new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToDelimitedString(authorities, ","))); return parameters; } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java index e11a3271e..0d6f8185c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java @@ -35,6 +35,10 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.context.annotation.ImportRuntimeHints; +import org.springframework.core.io.ClassPathResource; import org.springframework.dao.DataRetrievalFailureException; import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; import org.springframework.jdbc.core.ConnectionCallback; @@ -70,19 +74,31 @@ *

* NOTE: This {@code OAuth2AuthorizationService} depends on the table definition * described in - * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql" and - * therefore MUST be defined in the database schema. + * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql" + * and therefore MUST be defined in the database schema. * * @author Ovidiu Popa * @author Joe Grandja + * @author Josh Long * @since 0.1.2 * @see OAuth2AuthorizationService * @see OAuth2Authorization * @see JdbcOperations * @see RowMapper */ +@ImportRuntimeHints(JdbcOAuth2AuthorizationService.JdbcOAuth2AuthorizationServiceServiceRuntimeHintsRegistrar.class) public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationService { + static class JdbcOAuth2AuthorizationServiceServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.resources().registerResource(new ClassPathResource( + "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql")); + } + + } + // @formatter:off private static final String COLUMN_NAMES = "id, " + "registered_client_id, " @@ -122,16 +138,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic private static final String TABLE_NAME = "oauth2_authorization"; private static final String PK_FILTER = "id = ?"; + private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorization_code_value = ? OR " + "access_token_value = ? OR oidc_id_token_value = ? OR refresh_token_value = ? OR user_code_value = ? OR " + "device_code_value = ?"; private static final String STATE_FILTER = "state = ?"; + private static final String AUTHORIZATION_CODE_FILTER = "authorization_code_value = ?"; + private static final String ACCESS_TOKEN_FILTER = "access_token_value = ?"; + private static final String ID_TOKEN_FILTER = "oidc_id_token_value = ?"; + private static final String REFRESH_TOKEN_FILTER = "refresh_token_value = ?"; + private static final String USER_CODE_FILTER = "user_code_value = ?"; + private static final String DEVICE_CODE_FILTER = "device_code_value = ?"; // @formatter:off @@ -162,14 +185,16 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic private static Map columnMetadataMap; private final JdbcOperations jdbcOperations; + private final LobHandler lobHandler; + private RowMapper authorizationRowMapper; + private Function> authorizationParametersMapper; /** * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters. - * - * @param jdbcOperations the JDBC operations + * @param jdbcOperations the JDBC operations * @param registeredClientRepository the registered client repository */ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, @@ -179,10 +204,9 @@ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, /** * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters. - * - * @param jdbcOperations the JDBC operations + * @param jdbcOperations the JDBC operations * @param registeredClientRepository the registered client repository - * @param lobHandler the handler for large binary fields and large text fields + * @param lobHandler the handler for large binary fields and large text fields */ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, RegisteredClientRepository registeredClientRepository, LobHandler lobHandler) { @@ -191,7 +215,8 @@ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, Assert.notNull(lobHandler, "lobHandler cannot be null"); this.jdbcOperations = jdbcOperations; this.lobHandler = lobHandler; - OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper(registeredClientRepository); + OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper( + registeredClientRepository); authorizationRowMapper.setLobHandler(lobHandler); this.authorizationRowMapper = authorizationRowMapper; this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper(); @@ -204,7 +229,8 @@ public void save(OAuth2Authorization authorization) { OAuth2Authorization existingAuthorization = findById(authorization.getId()); if (existingAuthorization == null) { insertAuthorization(authorization); - } else { + } + else { updateAuthorization(authorization); } } @@ -233,8 +259,7 @@ private void insertAuthorization(OAuth2Authorization authorization) { public void remove(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); SqlParameterValue[] parameters = new SqlParameterValue[] { - new SqlParameterValue(Types.VARCHAR, authorization.getId()) - }; + new SqlParameterValue(Types.VARCHAR, authorization.getId()) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); this.jdbcOperations.update(REMOVE_AUTHORIZATION_SQL, pss); } @@ -262,25 +287,32 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t parameters.add(mapToSqlParameter("user_code_value", token)); parameters.add(mapToSqlParameter("device_code_value", token)); return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters); - } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { + } + else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { parameters.add(new SqlParameterValue(Types.VARCHAR, token)); return findBy(STATE_FILTER, parameters); - } else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) { + } + else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) { parameters.add(mapToSqlParameter("authorization_code_value", token)); return findBy(AUTHORIZATION_CODE_FILTER, parameters); - } else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) { + } + else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) { parameters.add(mapToSqlParameter("access_token_value", token)); return findBy(ACCESS_TOKEN_FILTER, parameters); - } else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) { + } + else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) { parameters.add(mapToSqlParameter("oidc_id_token_value", token)); return findBy(ID_TOKEN_FILTER, parameters); - } else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) { + } + else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) { parameters.add(mapToSqlParameter("refresh_token_value", token)); return findBy(REFRESH_TOKEN_FILTER, parameters); - } else if (OAuth2ParameterNames.USER_CODE.equals(tokenType.getValue())) { + } + else if (OAuth2ParameterNames.USER_CODE.equals(tokenType.getValue())) { parameters.add(mapToSqlParameter("user_code_value", token)); return findBy(USER_CODE_FILTER, parameters); - } else if (OAuth2ParameterNames.DEVICE_CODE.equals(tokenType.getValue())) { + } + else if (OAuth2ParameterNames.DEVICE_CODE.equals(tokenType.getValue())) { parameters.add(mapToSqlParameter("device_code_value", token)); return findBy(DEVICE_CODE_FILTER, parameters); } @@ -291,7 +323,8 @@ private OAuth2Authorization findBy(String filter, List parame try (LobCreator lobCreator = getLobHandler().getLobCreator()) { PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, parameters.toArray()); - List result = getJdbcOperations().query(LOAD_AUTHORIZATION_SQL + filter, pss, getAuthorizationRowMapper()); + List result = getJdbcOperations().query(LOAD_AUTHORIZATION_SQL + filter, pss, + getAuthorizationRowMapper()); return !result.isEmpty() ? result.get(0) : null; } } @@ -300,9 +333,8 @@ private OAuth2Authorization findBy(String filter, List parame * Sets the {@link RowMapper} used for mapping the current row in * {@code java.sql.ResultSet} to {@link OAuth2Authorization}. The default is * {@link OAuth2AuthorizationRowMapper}. - * * @param authorizationRowMapper the {@link RowMapper} used for mapping the current - * row in {@code ResultSet} to {@link OAuth2Authorization} + * row in {@code ResultSet} to {@link OAuth2Authorization} */ public final void setAuthorizationRowMapper(RowMapper authorizationRowMapper) { Assert.notNull(authorizationRowMapper, "authorizationRowMapper cannot be null"); @@ -310,12 +342,11 @@ public final void setAuthorizationRowMapper(RowMapper autho } /** - * Sets the {@code Function} used for mapping {@link OAuth2Authorization} to - * a {@code List} of {@link SqlParameterValue}. The default is + * Sets the {@code Function} used for mapping {@link OAuth2Authorization} to a + * {@code List} of {@link SqlParameterValue}. The default is * {@link OAuth2AuthorizationParametersMapper}. - * * @param authorizationParametersMapper the {@code Function} used for mapping - * {@link OAuth2Authorization} to a {@code List} of {@link SqlParameterValue} + * {@link OAuth2Authorization} to a {@code List} of {@link SqlParameterValue} */ public final void setAuthorizationParametersMapper( Function> authorizationParametersMapper) { @@ -344,8 +375,11 @@ protected final Function> getAuthor * {@code java.sql.ResultSet} to {@link OAuth2Authorization}. */ public static class OAuth2AuthorizationRowMapper implements RowMapper { + private final RegisteredClientRepository registeredClientRepository; + private LobHandler lobHandler = new DefaultLobHandler(); + private ObjectMapper objectMapper = new ObjectMapper(); public OAuth2AuthorizationRowMapper(RegisteredClientRepository registeredClientRepository) { @@ -364,8 +398,8 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException String registeredClientId = rs.getString("registered_client_id"); RegisteredClient registeredClient = this.registeredClientRepository.findById(registeredClientId); if (registeredClient == null) { - throw new DataRetrievalFailureException( - "The RegisteredClient with id '" + registeredClientId + "' was not found in the RegisteredClientRepository."); + throw new DataRetrievalFailureException("The RegisteredClient with id '" + registeredClientId + + "' was not found in the RegisteredClientRepository."); } OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient); @@ -379,11 +413,9 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException } Map attributes = parseMap(getLobValue(rs, "attributes")); - builder.id(id) - .principalName(principalName) + builder.id(id).principalName(principalName) .authorizationGrantType(new AuthorizationGrantType(authorizationGrantType)) - .authorizedScopes(authorizedScopes) - .attributes((attrs) -> attrs.putAll(attributes)); + .authorizedScopes(authorizedScopes).attributes((attrs) -> attrs.putAll(attributes)); String state = rs.getString("state"); if (StringUtils.hasText(state)) { @@ -397,10 +429,11 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException if (StringUtils.hasText(authorizationCodeValue)) { tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant(); tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant(); - Map authorizationCodeMetadata = parseMap(getLobValue(rs, "authorization_code_metadata")); + Map authorizationCodeMetadata = parseMap( + getLobValue(rs, "authorization_code_metadata")); - OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( - authorizationCodeValue, tokenIssuedAt, tokenExpiresAt); + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(authorizationCodeValue, + tokenIssuedAt, tokenExpiresAt); builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata)); } @@ -419,7 +452,8 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException if (accessTokenScopes != null) { scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); } - OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, accessTokenValue, tokenIssuedAt, tokenExpiresAt, scopes); + OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, accessTokenValue, tokenIssuedAt, + tokenExpiresAt, scopes); builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata)); } @@ -429,8 +463,8 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant(); Map oidcTokenMetadata = parseMap(getLobValue(rs, "oidc_id_token_metadata")); - OidcIdToken oidcToken = new OidcIdToken( - oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, (Map) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME)); + OidcIdToken oidcToken = new OidcIdToken(oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, + (Map) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME)); builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata)); } @@ -444,8 +478,8 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException } Map refreshTokenMetadata = parseMap(getLobValue(rs, "refresh_token_metadata")); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken( - refreshTokenValue, tokenIssuedAt, tokenExpiresAt); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(refreshTokenValue, tokenIssuedAt, + tokenExpiresAt); builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata)); } @@ -480,9 +514,11 @@ private String getLobValue(ResultSet rs, String columnName) throws SQLException if (columnValueBytes != null) { columnValue = new String(columnValueBytes, StandardCharsets.UTF_8); } - } else if (Types.CLOB == columnMetadata.getDataType()) { + } + else if (Types.CLOB == columnMetadata.getDataType()) { columnValue = this.lobHandler.getClobAsString(rs, columnName); - } else { + } + else { columnValue = rs.getString(columnName); } return columnValue; @@ -512,8 +548,10 @@ protected final ObjectMapper getObjectMapper() { private Map parseMap(String data) { try { - return this.objectMapper.readValue(data, new TypeReference>() {}); - } catch (Exception ex) { + return this.objectMapper.readValue(data, new TypeReference>() { + }); + } + catch (Exception ex) { throw new IllegalArgumentException(ex.getMessage(), ex); } } @@ -524,7 +562,9 @@ private Map parseMap(String data) { * The default {@code Function} that maps {@link OAuth2Authorization} to a * {@code List} of {@link SqlParameterValue}. */ - public static class OAuth2AuthorizationParametersMapper implements Function> { + public static class OAuth2AuthorizationParametersMapper + implements Function> { + private ObjectMapper objectMapper = new ObjectMapper(); public OAuth2AuthorizationParametersMapper() { @@ -558,46 +598,46 @@ public List apply(OAuth2Authorization authorization) { } parameters.add(new SqlParameterValue(Types.VARCHAR, state)); - OAuth2Authorization.Token authorizationCode = - authorization.getToken(OAuth2AuthorizationCode.class); - List authorizationCodeSqlParameters = toSqlParameterList( - "authorization_code_value", "authorization_code_metadata", authorizationCode); + OAuth2Authorization.Token authorizationCode = authorization + .getToken(OAuth2AuthorizationCode.class); + List authorizationCodeSqlParameters = toSqlParameterList("authorization_code_value", + "authorization_code_metadata", authorizationCode); parameters.addAll(authorizationCodeSqlParameters); - OAuth2Authorization.Token accessToken = - authorization.getToken(OAuth2AccessToken.class); - List accessTokenSqlParameters = toSqlParameterList( - "access_token_value", "access_token_metadata", accessToken); + OAuth2Authorization.Token accessToken = authorization.getToken(OAuth2AccessToken.class); + List accessTokenSqlParameters = toSqlParameterList("access_token_value", + "access_token_metadata", accessToken); parameters.addAll(accessTokenSqlParameters); String accessTokenType = null; String accessTokenScopes = null; if (accessToken != null) { accessTokenType = accessToken.getToken().getTokenType().getValue(); if (!CollectionUtils.isEmpty(accessToken.getToken().getScopes())) { - accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","); + accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), + ","); } } parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenType)); parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes)); OAuth2Authorization.Token oidcIdToken = authorization.getToken(OidcIdToken.class); - List oidcIdTokenSqlParameters = toSqlParameterList( - "oidc_id_token_value", "oidc_id_token_metadata", oidcIdToken); + List oidcIdTokenSqlParameters = toSqlParameterList("oidc_id_token_value", + "oidc_id_token_metadata", oidcIdToken); parameters.addAll(oidcIdTokenSqlParameters); OAuth2Authorization.Token refreshToken = authorization.getRefreshToken(); - List refreshTokenSqlParameters = toSqlParameterList( - "refresh_token_value", "refresh_token_metadata", refreshToken); + List refreshTokenSqlParameters = toSqlParameterList("refresh_token_value", + "refresh_token_metadata", refreshToken); parameters.addAll(refreshTokenSqlParameters); OAuth2Authorization.Token userCode = authorization.getToken(OAuth2UserCode.class); - List userCodeSqlParameters = toSqlParameterList( - "user_code_value", "user_code_metadata", userCode); + List userCodeSqlParameters = toSqlParameterList("user_code_value", "user_code_metadata", + userCode); parameters.addAll(userCodeSqlParameters); OAuth2Authorization.Token deviceCode = authorization.getToken(OAuth2DeviceCode.class); - List deviceCodeSqlParameters = toSqlParameterList( - "device_code_value", "device_code_metadata", deviceCode); + List deviceCodeSqlParameters = toSqlParameterList("device_code_value", + "device_code_metadata", deviceCode); parameters.addAll(deviceCodeSqlParameters); return parameters; @@ -612,8 +652,8 @@ protected final ObjectMapper getObjectMapper() { return this.objectMapper; } - private List toSqlParameterList( - String tokenColumnName, String tokenMetadataColumnName, OAuth2Authorization.Token token) { + private List toSqlParameterList(String tokenColumnName, + String tokenMetadataColumnName, OAuth2Authorization.Token token) { List parameters = new ArrayList<>(); String tokenValue = null; @@ -641,7 +681,8 @@ private List toSqlParameterList( private String writeMap(Map data) { try { return this.objectMapper.writeValueAsString(data); - } catch (Exception ex) { + } + catch (Exception ex) { throw new IllegalArgumentException(ex.getMessage(), ex); } } @@ -649,6 +690,7 @@ private String writeMap(Map data) { } private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter { + private final LobCreator lobCreator; private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) { @@ -685,7 +727,9 @@ protected void doSetValue(PreparedStatement ps, int parameterPosition, Object ar } private static final class ColumnMetadata { + private final String columnName; + private final int dataType; private ColumnMetadata(String columnName, int dataType) { @@ -735,7 +779,8 @@ private static void initColumnMetadata(JdbcOperations jdbcOperations) { columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata); } - private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName, int defaultDataType) { + private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName, + int defaultDataType) { Integer dataType = jdbcOperations.execute((ConnectionCallback) conn -> { DatabaseMetaData databaseMetaData = conn.getMetaData(); ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName); @@ -743,10 +788,13 @@ private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, S return rs.getInt("DATA_TYPE"); } // NOTE: (Applies to HSQL) - // When a database object is created with one of the CREATE statements or renamed with the ALTER statement, - // if the name is enclosed in double quotes, the exact name is used as the case-normal form. + // When a database object is created with one of the CREATE statements or + // renamed with the ALTER statement, + // if the name is enclosed in double quotes, the exact name is used as the + // case-normal form. // But if it is not enclosed in double quotes, - // the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form. + // the name is converted to uppercase and this uppercase version is stored in + // the database as the case-normal form. rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(), columnName.toUpperCase()); if (rs.next()) { return rs.getInt("DATA_TYPE"); @@ -758,9 +806,9 @@ private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, S private static SqlParameterValue mapToSqlParameter(String columnName, String value) { ColumnMetadata columnMetadata = columnMetadataMap.get(columnName); - return Types.BLOB == columnMetadata.getDataType() && StringUtils.hasText(value) ? - new SqlParameterValue(Types.BLOB, value.getBytes(StandardCharsets.UTF_8)) : - new SqlParameterValue(columnMetadata.getDataType(), value); + return Types.BLOB == columnMetadata.getDataType() && StringUtils.hasText(value) + ? new SqlParameterValue(Types.BLOB, value.getBytes(StandardCharsets.UTF_8)) + : new SqlParameterValue(columnMetadata.getDataType(), value); } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java index 37a5cb69e..13eee21f1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java @@ -15,27 +15,14 @@ */ package org.springframework.security.oauth2.server.authorization.client; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; -import java.sql.Types; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Function; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.ObjectMapper; - -import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; -import org.springframework.jdbc.core.JdbcOperations; -import org.springframework.jdbc.core.PreparedStatementSetter; -import org.springframework.jdbc.core.RowMapper; -import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.context.annotation.ImportRuntimeHints; +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.core.*; import org.springframework.security.jackson2.SecurityJackson2Modules; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -47,26 +34,47 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.util.*; +import java.util.function.Function; + /** * A JDBC implementation of a {@link RegisteredClientRepository} that uses a * {@link JdbcOperations} for {@link RegisteredClient} persistence. * *

- * NOTE: This {@code RegisteredClientRepository} depends on the table definition described in - * "classpath:org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql" and - * therefore MUST be defined in the database schema. + * NOTE: This {@code RegisteredClientRepository} depends on the table definition + * described in + * "classpath:org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql" + * and therefore MUST be defined in the database schema. * * @author Rafal Lewczuk * @author Joe Grandja * @author Ovidiu Popa - * @since 0.1.2 + * @author Josh Long * @see RegisteredClientRepository * @see RegisteredClient * @see JdbcOperations * @see RowMapper + * @since 0.1.2 */ +@ImportRuntimeHints(JdbcRegisteredClientRepository.JdbcRegisteredClientRepositoryRuntimeHintsRegistrar.class) public class JdbcRegisteredClientRepository implements RegisteredClientRepository { + static class JdbcRegisteredClientRepositoryRuntimeHintsRegistrar implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.resources().registerResource(new ClassPathResource( + "org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql")); + } + + } + // @formatter:off private static final String COLUMN_NAMES = "id, " + "client_id, " @@ -87,7 +95,8 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor private static final String PK_FILTER = "id = ?"; - private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE "; + private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + + " WHERE "; // @formatter:off private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME @@ -105,12 +114,13 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor private static final String COUNT_REGISTERED_CLIENT_SQL = "SELECT COUNT(*) FROM " + TABLE_NAME + " WHERE "; private final JdbcOperations jdbcOperations; + private RowMapper registeredClientRowMapper; + private Function> registeredClientParametersMapper; /** * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters. - * * @param jdbcOperations the JDBC operations */ public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations) { @@ -123,17 +133,18 @@ public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations) { @Override public void save(RegisteredClient registeredClient) { Assert.notNull(registeredClient, "registeredClient cannot be null"); - RegisteredClient existingRegisteredClient = findBy(PK_FILTER, - registeredClient.getId()); + RegisteredClient existingRegisteredClient = findBy(PK_FILTER, registeredClient.getId()); if (existingRegisteredClient != null) { updateRegisteredClient(registeredClient); - } else { + } + else { insertRegisteredClient(registeredClient); } } private void updateRegisteredClient(RegisteredClient registeredClient) { - List parameters = new ArrayList<>(this.registeredClientParametersMapper.apply(registeredClient)); + List parameters = new ArrayList<>( + this.registeredClientParametersMapper.apply(registeredClient)); SqlParameterValue id = parameters.remove(0); parameters.remove(0); // remove client_id parameters.remove(0); // remove client_id_issued_at @@ -150,21 +161,17 @@ private void insertRegisteredClient(RegisteredClient registeredClient) { } private void assertUniqueIdentifiers(RegisteredClient registeredClient) { - Integer count = this.jdbcOperations.queryForObject( - COUNT_REGISTERED_CLIENT_SQL + "client_id = ?", - Integer.class, + Integer count = this.jdbcOperations.queryForObject(COUNT_REGISTERED_CLIENT_SQL + "client_id = ?", Integer.class, registeredClient.getClientId()); if (count != null && count > 0) { - throw new IllegalArgumentException("Registered client must be unique. " + - "Found duplicate client identifier: " + registeredClient.getClientId()); + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate client identifier: " + registeredClient.getClientId()); } - count = this.jdbcOperations.queryForObject( - COUNT_REGISTERED_CLIENT_SQL + "client_secret = ?", - Integer.class, + count = this.jdbcOperations.queryForObject(COUNT_REGISTERED_CLIENT_SQL + "client_secret = ?", Integer.class, registeredClient.getClientSecret()); if (count != null && count > 0) { - throw new IllegalArgumentException("Registered client must be unique. " + - "Found duplicate client secret for identifier: " + registeredClient.getId()); + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate client secret for identifier: " + registeredClient.getId()); } } @@ -181,16 +188,17 @@ public RegisteredClient findByClientId(String clientId) { } private RegisteredClient findBy(String filter, Object... args) { - List result = this.jdbcOperations.query( - LOAD_REGISTERED_CLIENT_SQL + filter, this.registeredClientRowMapper, args); + List result = this.jdbcOperations.query(LOAD_REGISTERED_CLIENT_SQL + filter, + this.registeredClientRowMapper, args); return !result.isEmpty() ? result.get(0) : null; } /** - * Sets the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link RegisteredClient}. - * The default is {@link RegisteredClientRowMapper}. - * - * @param registeredClientRowMapper the {@link RowMapper} used for mapping the current row in {@code ResultSet} to {@link RegisteredClient} + * Sets the {@link RowMapper} used for mapping the current row in + * {@code java.sql.ResultSet} to {@link RegisteredClient}. The default is + * {@link RegisteredClientRowMapper}. + * @param registeredClientRowMapper the {@link RowMapper} used for mapping the current + * row in {@code ResultSet} to {@link RegisteredClient} */ public final void setRegisteredClientRowMapper(RowMapper registeredClientRowMapper) { Assert.notNull(registeredClientRowMapper, "registeredClientRowMapper cannot be null"); @@ -198,12 +206,14 @@ public final void setRegisteredClientRowMapper(RowMapper regis } /** - * Sets the {@code Function} used for mapping {@link RegisteredClient} to a {@code List} of {@link SqlParameterValue}. - * The default is {@link RegisteredClientParametersMapper}. - * - * @param registeredClientParametersMapper the {@code Function} used for mapping {@link RegisteredClient} to a {@code List} of {@link SqlParameterValue} + * Sets the {@code Function} used for mapping {@link RegisteredClient} to a + * {@code List} of {@link SqlParameterValue}. The default is + * {@link RegisteredClientParametersMapper}. + * @param registeredClientParametersMapper the {@code Function} used for mapping + * {@link RegisteredClient} to a {@code List} of {@link SqlParameterValue} */ - public final void setRegisteredClientParametersMapper(Function> registeredClientParametersMapper) { + public final void setRegisteredClientParametersMapper( + Function> registeredClientParametersMapper) { Assert.notNull(registeredClientParametersMapper, "registeredClientParametersMapper cannot be null"); this.registeredClientParametersMapper = registeredClientParametersMapper; } @@ -225,6 +235,7 @@ protected final Function> getRegistere * {@code java.sql.ResultSet} to {@link RegisteredClient}. */ public static class RegisteredClientRowMapper implements RowMapper { + private ObjectMapper objectMapper = new ObjectMapper(); public RegisteredClientRowMapper() { @@ -238,10 +249,13 @@ public RegisteredClientRowMapper() { public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { Timestamp clientIdIssuedAt = rs.getTimestamp("client_id_issued_at"); Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at"); - Set clientAuthenticationMethods = StringUtils.commaDelimitedListToSet(rs.getString("client_authentication_methods")); - Set authorizationGrantTypes = StringUtils.commaDelimitedListToSet(rs.getString("authorization_grant_types")); + Set clientAuthenticationMethods = StringUtils + .commaDelimitedListToSet(rs.getString("client_authentication_methods")); + Set authorizationGrantTypes = StringUtils + .commaDelimitedListToSet(rs.getString("authorization_grant_types")); Set redirectUris = StringUtils.commaDelimitedListToSet(rs.getString("redirect_uris")); - Set postLogoutRedirectUris = StringUtils.commaDelimitedListToSet(rs.getString("post_logout_redirect_uris")); + Set postLogoutRedirectUris = StringUtils + .commaDelimitedListToSet(rs.getString("post_logout_redirect_uris")); Set clientScopes = StringUtils.commaDelimitedListToSet(rs.getString("scopes")); // @formatter:off @@ -286,8 +300,10 @@ protected final ObjectMapper getObjectMapper() { private Map parseMap(String data) { try { - return this.objectMapper.readValue(data, new TypeReference>() {}); - } catch (Exception ex) { + return this.objectMapper.readValue(data, new TypeReference>() { + }); + } + catch (Exception ex) { throw new IllegalArgumentException(ex.getMessage(), ex); } } @@ -295,32 +311,43 @@ private Map parseMap(String data) { private static AuthorizationGrantType resolveAuthorizationGrantType(String authorizationGrantType) { if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(authorizationGrantType)) { return AuthorizationGrantType.AUTHORIZATION_CODE; - } else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(authorizationGrantType)) { + } + else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(authorizationGrantType)) { return AuthorizationGrantType.CLIENT_CREDENTIALS; - } else if (AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(authorizationGrantType)) { + } + else if (AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(authorizationGrantType)) { return AuthorizationGrantType.REFRESH_TOKEN; } - return new AuthorizationGrantType(authorizationGrantType); // Custom authorization grant type + return new AuthorizationGrantType(authorizationGrantType); // Custom + // authorization + // grant type } private static ClientAuthenticationMethod resolveClientAuthenticationMethod(String clientAuthenticationMethod) { if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue().equals(clientAuthenticationMethod)) { return ClientAuthenticationMethod.CLIENT_SECRET_BASIC; - } else if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientAuthenticationMethod)) { + } + else if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientAuthenticationMethod)) { return ClientAuthenticationMethod.CLIENT_SECRET_POST; - } else if (ClientAuthenticationMethod.NONE.getValue().equals(clientAuthenticationMethod)) { + } + else if (ClientAuthenticationMethod.NONE.getValue().equals(clientAuthenticationMethod)) { return ClientAuthenticationMethod.NONE; } - return new ClientAuthenticationMethod(clientAuthenticationMethod); // Custom client authentication method + return new ClientAuthenticationMethod(clientAuthenticationMethod); // Custom + // client + // authentication + // method } } /** - * The default {@code Function} that maps {@link RegisteredClient} to a - * {@code List} of {@link SqlParameterValue}. + * The default {@code Function} that maps {@link RegisteredClient} to a {@code List} + * of {@link SqlParameterValue}. */ - public static class RegisteredClientParametersMapper implements Function> { + public static class RegisteredClientParametersMapper + implements Function> { + private ObjectMapper objectMapper = new ObjectMapper(); public RegisteredClientParametersMapper() { @@ -332,32 +359,39 @@ public RegisteredClientParametersMapper() { @Override public List apply(RegisteredClient registeredClient) { - Timestamp clientIdIssuedAt = registeredClient.getClientIdIssuedAt() != null ? - Timestamp.from(registeredClient.getClientIdIssuedAt()) : Timestamp.from(Instant.now()); + Timestamp clientIdIssuedAt = registeredClient.getClientIdIssuedAt() != null + ? Timestamp.from(registeredClient.getClientIdIssuedAt()) : Timestamp.from(Instant.now()); - Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ? - Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null; + Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null + ? Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null; - List clientAuthenticationMethods = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size()); - registeredClient.getClientAuthenticationMethods().forEach(clientAuthenticationMethod -> - clientAuthenticationMethods.add(clientAuthenticationMethod.getValue())); + List clientAuthenticationMethods = new ArrayList<>( + registeredClient.getClientAuthenticationMethods().size()); + registeredClient.getClientAuthenticationMethods() + .forEach(clientAuthenticationMethod -> clientAuthenticationMethods + .add(clientAuthenticationMethod.getValue())); - List authorizationGrantTypes = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size()); - registeredClient.getAuthorizationGrantTypes().forEach(authorizationGrantType -> - authorizationGrantTypes.add(authorizationGrantType.getValue())); + List authorizationGrantTypes = new ArrayList<>( + registeredClient.getAuthorizationGrantTypes().size()); + registeredClient.getAuthorizationGrantTypes() + .forEach(authorizationGrantType -> authorizationGrantTypes.add(authorizationGrantType.getValue())); - return Arrays.asList( - new SqlParameterValue(Types.VARCHAR, registeredClient.getId()), + return Arrays.asList(new SqlParameterValue(Types.VARCHAR, registeredClient.getId()), new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()), new SqlParameterValue(Types.TIMESTAMP, clientIdIssuedAt), new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()), new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt), new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()), - new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethods)), - new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypes)), - new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())), - new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getPostLogoutRedirectUris())), - new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())), + new SqlParameterValue(Types.VARCHAR, + StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethods)), + new SqlParameterValue(Types.VARCHAR, + StringUtils.collectionToCommaDelimitedString(authorizationGrantTypes)), + new SqlParameterValue(Types.VARCHAR, + StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())), + new SqlParameterValue(Types.VARCHAR, + StringUtils.collectionToCommaDelimitedString(registeredClient.getPostLogoutRedirectUris())), + new SqlParameterValue(Types.VARCHAR, + StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())), new SqlParameterValue(Types.VARCHAR, writeMap(registeredClient.getClientSettings().getSettings())), new SqlParameterValue(Types.VARCHAR, writeMap(registeredClient.getTokenSettings().getSettings()))); } @@ -374,7 +408,8 @@ protected final ObjectMapper getObjectMapper() { private String writeMap(Map data) { try { return this.objectMapper.writeValueAsString(data); - } catch (Exception ex) { + } + catch (Exception ex) { throw new IllegalArgumentException(ex.getMessage(), ex); } } diff --git a/oauth2-authorization-server/src/main/resources/META-INF/spring/aot.factories b/oauth2-authorization-server/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 000000000..73ad27d60 --- /dev/null +++ b/oauth2-authorization-server/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.security.oauth2.server.authorization.AuthorizationServerRuntimeHintsRegistrar diff --git a/samples/demo-authorizationserver/samples-demo-authorizationserver.gradle b/samples/demo-authorizationserver/samples-demo-authorizationserver.gradle index d6070d1ce..3bcd3d564 100644 --- a/samples/demo-authorizationserver/samples-demo-authorizationserver.gradle +++ b/samples/demo-authorizationserver/samples-demo-authorizationserver.gradle @@ -2,6 +2,7 @@ plugins { id "org.springframework.boot" version "3.1.0" id "io.spring.dependency-management" version "1.1.0" id "java" + id 'org.graalvm.buildtools.native' version '0.9.27' } group = project.rootProject.group diff --git a/samples/demo-authorizationserver/src/main/java/sample/DemoAuthorizationServerApplication.java b/samples/demo-authorizationserver/src/main/java/sample/DemoAuthorizationServerApplication.java index 88d788b24..c2a8203b3 100644 --- a/samples/demo-authorizationserver/src/main/java/sample/DemoAuthorizationServerApplication.java +++ b/samples/demo-authorizationserver/src/main/java/sample/DemoAuthorizationServerApplication.java @@ -15,13 +15,17 @@ */ package sample; +import org.springframework.aot.hint.annotation.RegisterReflectionForBinding; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import sample.web.AuthorizationConsentController; /** * @author Joe Grandja + * @author Josh Long * @since 1.1 */ +@RegisterReflectionForBinding(AuthorizationConsentController.ScopeWithDescription.class) @SpringBootApplication public class DemoAuthorizationServerApplication { diff --git a/samples/demo-authorizationserver/src/main/resources/META-INF/spring/aot.factories b/samples/demo-authorizationserver/src/main/resources/META-INF/spring/aot.factories deleted file mode 100644 index 9bce6054f..000000000 --- a/samples/demo-authorizationserver/src/main/resources/META-INF/spring/aot.factories +++ /dev/null @@ -1,2 +0,0 @@ -org.springframework.aot.hint.RuntimeHintsRegistrar=\ -sample.aot.hint.DemoAuthorizationServerRuntimeHints