Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache improvements #708

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ void acquireTokenSilent_LabAuthority_TokenNotRefreshed(String environment) throw
// Check that access and id tokens are coming from cache
assertEquals(result.accessToken(), acquireSilentResult.accessToken());
assertEquals(result.idToken(), acquireSilentResult.idToken());
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
assertEquals(CacheRefreshReason.NOT_APPLICABLE, result.metadata().cacheRefreshReason());
assertEquals(TokenSource.CACHE, acquireSilentResult.metadata().tokenSource());
assertEquals(CacheRefreshReason.NOT_APPLICABLE, acquireSilentResult.metadata().cacheRefreshReason());
Copy link
Member

@bgavrilMS bgavrilMS Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more tests. Consider adding these assertions to all integrationt tests, not just public client ones. Client_credentials is of highest priority.

In particular:

  • client_credentials + all combos of CacheRefreshReason (especially PROACTIVE_REFRESH)

  • client_credentials + claims

  • obo and auth code

}

@ParameterizedTest
Expand All @@ -92,6 +96,10 @@ void acquireTokenSilent_ForceRefresh(String environment) throws Exception {

// Check that new refresh and id tokens are being returned
assertTokensAreNotEqual(result, resultAfterRefresh);
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
assertEquals(CacheRefreshReason.NOT_APPLICABLE, result.metadata().cacheRefreshReason());
assertEquals(TokenSource.IDENTITY_PROVIDER, resultAfterRefresh.metadata().tokenSource());
assertEquals(CacheRefreshReason.FORCE_REFRESH, resultAfterRefresh.metadata().cacheRefreshReason());
}

@ParameterizedTest
Expand Down Expand Up @@ -253,6 +261,11 @@ void acquireTokenSilent_WithRefreshOn(String environment) throws Exception {
//Current time is after refreshOn, so token should be refreshed
assertNotNull(resultSilentWithRefreshOn);
assertTokensAreNotEqual(resultSilent, resultSilentWithRefreshOn);

assertEquals(TokenSource.CACHE, resultSilent.metadata().tokenSource());
assertEquals(CacheRefreshReason.NOT_APPLICABLE, resultSilent.metadata().cacheRefreshReason());
assertEquals(TokenSource.IDENTITY_PROVIDER, resultSilentWithRefreshOn.metadata().tokenSource());
assertEquals(CacheRefreshReason.PROACTIVE_REFRESH, resultSilentWithRefreshOn.metadata().cacheRefreshReason());
}

@ParameterizedTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestInstance;

import static com.microsoft.aad.msal4j.TestConstants.KEYVAULT_DEFAULT_SCOPE;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.util.Collections;
Expand Down Expand Up @@ -195,6 +197,68 @@ void retrieveAccounts_ADFSOnPrem() throws Exception {
assertEquals(pca.getAccounts().join().size(), 1);
}

@Test
void testStaticCache() throws Exception {
AppCredentialProvider appProvider = new AppCredentialProvider(AzureEnvironment.AZURE);
final String clientId = appProvider.getLabVaultAppId();
final String password = appProvider.getLabVaultPassword();
IClientCredential credential = ClientCredentialFactory.createFromSecret(password);

//Create three client applications: one that uses its own instance of a TokenCache,
// and two that use the useSharedCache option to use the same static TokenCache
ConfidentialClientApplication cca_notStatic = ConfidentialClientApplication.builder(
clientId, credential).
authority(TestConstants.MICROSOFT_AUTHORITY).
build();

ConfidentialClientApplication cca_sharedCache1 = ConfidentialClientApplication.builder(
clientId, credential).
authority(TestConstants.MICROSOFT_AUTHORITY).
useSharedCache(true).
bgavrilMS marked this conversation as resolved.
Show resolved Hide resolved
build();

ConfidentialClientApplication cca_sharedCache2 = ConfidentialClientApplication.builder(
clientId, credential).
authority(TestConstants.MICROSOFT_AUTHORITY).
useSharedCache(true).
build();

ClientCredentialParameters parameters = ClientCredentialParameters
.builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE))
.build();

//Make a number of token calls using the different ConfidentialClientApplications
// 1. Retrieve and cache new tokens using the ConfidentialClientApplication that does not use the shared cache
IAuthenticationResult result_notStatic1 = cca_notStatic.acquireToken(parameters).get();
// 2. The client credential flow does a cache lookup by default, so making the same acquireToken call should retrieve the tokens cached during call 1
IAuthenticationResult result_notStatic2 = cca_notStatic.acquireToken(parameters).get();
// 3. Retrieve and cache new tokens using the ConfidentialClientApplication that uses the static cache
IAuthenticationResult result_sharedCache1 = cca_sharedCache1.acquireToken(parameters).get();
// 4. Due to using the static cache this should behave like token call 2 and retrieve the tokens cached in call 3
IAuthenticationResult result_sharedCache2 = cca_sharedCache2.acquireToken(parameters).get();

assertNotNull(result_notStatic1);
assertNotNull(result_notStatic1.accessToken());
assertNotNull(result_notStatic2);
assertNotNull(result_notStatic2.accessToken());
assertNotNull(result_sharedCache1);
assertNotNull(result_sharedCache1.accessToken());
assertNotNull(result_sharedCache2);
assertNotNull(result_sharedCache2.accessToken());

//None of the tokens retrieved using cca_notStatic should be the same as those retrieved using cca_sharedCache1 or cca_sharedCache2
assertNotEquals(result_notStatic1.accessToken(), result_sharedCache1.accessToken());
assertNotEquals(result_notStatic1.accessToken(), result_sharedCache2.accessToken());
assertNotEquals(result_notStatic2.accessToken(), result_sharedCache1.accessToken());
assertNotEquals(result_notStatic2.accessToken(), result_sharedCache2.accessToken());

//Because the confidential client flow has an internal silent call:
// -result_notStatic1 and result_notStatic2 should be the same, because they both used the non-static cache from one ConfidentialClientApplication instance
// -result_sharedCache1 and result_sharedCache2 should be the same, because they both used the static cache shared between two ConfidentialClientApplication instances
assertEquals(result_notStatic1.accessToken(), result_notStatic2.accessToken());
assertEquals(result_sharedCache1.accessToken(), result_sharedCache2.accessToken());
}


private static class TokenPersistence implements ITokenCacheAccessAspect {
String data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ public abstract class AbstractClientApplicationBase implements IClientApplicatio
@Getter
protected TokenCache tokenCache;

@Accessors(fluent = true)
protected static TokenCache sharedTokenCache = new TokenCache();

@Accessors(fluent = true)
@Getter
private String applicationName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class AcquireTokenSilentSupplier extends AuthenticationResultSupplier {
AuthenticationResult execute() throws Exception {
Authority requestAuthority = silentRequest.requestAuthority();
if (requestAuthority.authorityType != AuthorityType.B2C) {
requestAuthority =
getAuthorityWithPrefNetworkHost(silentRequest.requestAuthority().authority());
requestAuthority = getAuthorityWithPrefNetworkHost(silentRequest.requestAuthority().authority());
}

AuthenticationResult res;
Expand All @@ -46,6 +45,9 @@ AuthenticationResult execute() throws Exception {
throw new MsalClientException(AuthenticationErrorMessage.NO_TOKEN_IN_CACHE, AuthenticationErrorCode.CACHE_MISS);
}

//Some cached tokens were found, but this metadata will be overwritten if token needs to be refreshed
res.metadata().tokenSource(TokenSource.CACHE);

if (!StringHelper.isBlank(res.accessToken())) {
clientApplication.getServiceBundle().getServerSideTelemetry().incrementSilentSuccessfulCount();
}
Expand All @@ -60,15 +62,19 @@ AuthenticationResult execute() throws Exception {
//As of version 3 of the telemetry schema, there is a field for collecting data about why a token was refreshed,
// so here we set the telemetry value based on the cause of the refresh
if (silentRequest.parameters().forceRefresh()) {
this.silentRequest.requestContext().refreshReason(CacheRefreshReason.FORCE_REFRESH);
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_FORCE_REFRESH.telemetryValue);
} else if (afterRefreshOn) {
this.silentRequest.requestContext().refreshReason(CacheRefreshReason.PROACTIVE_REFRESH);
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_REFRESH_IN.telemetryValue);
} else if (res.expiresOn() < currTimeStampSec) {
this.silentRequest.requestContext().refreshReason(CacheRefreshReason.EXPIRED);
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_ACCESS_TOKEN_EXPIRED.telemetryValue);
} else if (StringHelper.isBlank(res.accessToken())) {
this.silentRequest.requestContext().refreshReason(CacheRefreshReason.NO_CACHED_ACCESS_TOKEN);
bgavrilMS marked this conversation as resolved.
Show resolved Hide resolved
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_NO_ACCESS_TOKEN.telemetryValue);
}
Expand All @@ -93,14 +99,20 @@ AuthenticationResult execute() throws Exception {

try {
res = acquireTokenByAuthorisationGrantSupplier.execute();
res.metadata().cacheRefreshReason(this.silentRequest.requestContext().refreshReason());
res.metadata().tokenSource(TokenSource.IDENTITY_PROVIDER);

log.info("Access token refreshed successfully.");
} catch (MsalServiceException ex) {
//If the token refresh attempt threw a MsalServiceException but the refresh attempt was done
// only because of refreshOn, then simply return the existing cached token
if (afterRefreshOn && !(silentRequest.parameters().forceRefresh() || StringHelper.isBlank(res.accessToken()))) {

return res;
} else throw ex;
}
} else {
log.warn("Refresh token not found in cache, cannot return valid tokens.");
res = null;
}
}
Expand All @@ -109,8 +121,6 @@ AuthenticationResult execute() throws Exception {
throw new MsalClientException(AuthenticationErrorMessage.NO_TOKEN_IN_CACHE, AuthenticationErrorCode.CACHE_MISS);
}

log.info("Returning token from cache");

return res;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,7 @@ private ITenantProfile getTenantProfile() {
private final Date expiresOnDate = new Date(expiresOn * 1000);

private final String scopes;

@Builder.Default
private final AuthenticationResultMetadata metadata = new AuthenticationResultMetadata();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.Accessors;

import java.io.Serializable;

/**
* Contains metadata and additional context for the contents of an AuthenticationResult
*/
@Accessors(fluent = true)
@Getter
@Setter(AccessLevel.PACKAGE)
public class AuthenticationResultMetadata implements Serializable {

private TokenSource tokenSource;
private CacheRefreshReason cacheRefreshReason;

/**
* Sets default metadata values. Used when creating an {@link IAuthenticationResult} before the values are known.
*/
AuthenticationResultMetadata() {
this.tokenSource = TokenSource.UNKNOWN;
this.cacheRefreshReason = CacheRefreshReason.NOT_APPLICABLE;
}

AuthenticationResultMetadata(TokenSource tokenSource, CacheRefreshReason refreshReason) {
this.tokenSource = tokenSource;
this.cacheRefreshReason = refreshReason;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

/**
* List of possible reasons the tokens in an {@link IAuthenticationResult} were refreshed.
*/
public enum CacheRefreshReason {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are missing a case here when "Claims" are specified. When this happens, MSAL has to bypass the token cache and to ESTS.

But not for client capabilities. Client capabilities are internally still claims. But we should not bypass the cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still open comment :)


/**
* Token did not need to be refreshed, or was retrieved in a non-silent call
*/
NOT_APPLICABLE,
/**
* Silent call was made with the force refresh option
*/
FORCE_REFRESH,
/**
* Access token was missing from the cache, but a valid refresh token was used to retrieve a new access token
*/
NO_CACHED_ACCESS_TOKEN,
/**
* Cached access token was expired and successfully refreshed
*/
EXPIRED,
/**
* Cached access token was not expired but was after the 'refresh_in' value, and was proactively refreshed before the expiration date
*/
PROACTIVE_REFRESH
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ public class ConfidentialClientApplication extends AbstractClientApplicationBase
@Getter
private boolean sendX5c;

@Accessors(fluent = true)
@Getter
private boolean useSharedCache;

@Override
public CompletableFuture<IAuthenticationResult> acquireToken(ClientCredentialParameters parameters) {
validateNotNull("parameters", parameters);
Expand Down Expand Up @@ -83,6 +87,8 @@ private ConfidentialClientApplication(Builder builder) {
super(builder);
sendX5c = builder.sendX5c;
appTokenProvider = builder.appTokenProvider;
useSharedCache = builder.useSharedCache;
if (useSharedCache) tokenCache = sharedTokenCache;

log = LoggerFactory.getLogger(ConfidentialClientApplication.class);

Expand Down Expand Up @@ -169,6 +175,7 @@ public static class Builder extends AbstractClientApplicationBase.Builder<Builde
private IClientCredential clientCredential;

private boolean sendX5c = true;
private boolean useSharedCache = false;

private Function<AppTokenProviderParameters, CompletableFuture<TokenProviderResult>> appTokenProvider;

Expand Down Expand Up @@ -208,6 +215,12 @@ public ConfidentialClientApplication.Builder appTokenProvider(Function<AppTokenP
throw new NullPointerException("appTokenProvider is null") ;
}

public ConfidentialClientApplication.Builder useSharedCache(boolean val) {
useSharedCache = val;

return self();
}

@Override
public ConfidentialClientApplication build() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ public interface IAuthenticationResult extends Serializable {
* @return access token expiration date
*/
java.util.Date expiresOnDate();

/**
* @return various metadata relating to this authentication result
*/
AuthenticationResultMetadata metadata();
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class RequestContext {
private IAcquireTokenParameters apiParameters;
private IClientApplicationBase clientApplication;
private UserIdentifier userIdentifier;
@Setter(AccessLevel.PACKAGE)
private CacheRefreshReason refreshReason;

public RequestContext(AbstractClientApplicationBase clientApplication,
PublicApi publicApi,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ private AuthenticationResult createAuthenticationResultFromOauthHttpResponse(
refreshOn(response.getRefreshIn() > 0 ? currTimestampSec + response.getRefreshIn() : 0).
accountCacheEntity(accountCacheEntity).
scopes(response.getScope()).
metadata(new AuthenticationResultMetadata(TokenSource.IDENTITY_PROVIDER, CacheRefreshReason.NOT_APPLICABLE)).
build();

} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

/**
* A list of possible sources for the tokens found in an {@link IAuthenticationResult}
*/
public enum TokenSource {

/**
* A default value, likely indicates tokens could not be retrieved
*/
UNKNOWN,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MSAL throws exception when token cannot be retried. I don't think you can get this value. Also, there is no unit test covering this path.


/**
* Indicates tokens came from an identity provider, such as Azure AD
*/
IDENTITY_PROVIDER,

/**
* Indicates tokens came from MSAL's cache
*/
CACHE
}