Skip to content

Commit

Permalink
Enhance MSAL Managed Identity to Bypass Cache When Claims Are Present (
Browse files Browse the repository at this point in the history
…#4875)

* initial

* Apply suggestions from code review

Co-authored-by: Bogdan Gavril <[email protected]>

* pr comments

---------

Co-authored-by: Gladwin Johnson <[email protected]>
Co-authored-by: Bogdan Gavril <[email protected]>
  • Loading branch information
3 people authored Aug 5, 2024
1 parent 0f9c36f commit 83725aa
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ public AcquireTokenForManagedIdentityParameterBuilder WithForceRefresh(bool forc
return this;
}

/// <summary>
/// Adds a claims challenge to the token request. The SDK will bypass the token cache when a claims challenge is specified. Retry the
/// token acquisition, and use this value in the <see cref="WithClaims(string)"/> method. A claims challenge typically arises when
/// calling the protected downstream API, for example when the tenant administrator revokes credentials. Apps are required
/// to look for a 401 Unauthorized response from the protected api and to parse the WWW-Authenticate response header in order to
/// extract the claims. See https://aka.ms/msal-net-claim-challenge for details.
/// </summary>
/// <param name="claims">A string with one or multiple claims.</param>
/// <returns>The builder to chain .With methods.</returns>
public AcquireTokenForManagedIdentityParameterBuilder WithClaims(string claims)
{
ValidateUseOfExperimentalFeature("WithClaims");
CommonParameters.Claims = claims;
return this;
}

/// <inheritdoc/>
internal override Task<AuthenticationResult> ExecuteInternalAsync(CancellationToken cancellationToken)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,10 @@ public class AssertionRequestOptions {
/// The intended token endpoint
/// </summary>
public string TokenEndpoint { get; set; }

/// <summary>
/// Claims to be included in the client assertion
/// </summary>
public string Claims { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ internal ManagedIdentityApplicationBuilder WithAppTokenCacheInternalForTest(ITok
return this;
}

/// <summary>
/// Microsoft Identity specific OIDC extension that allows resource challenges to be resolved without interaction.
/// Allows configuration of one or more client capabilities, e.g. "llt"
/// </summary>
/// <remarks>
/// MSAL will transform these into special claims request. See https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter for
/// details on claim requests. This is an experimental API. The method signature may change in the future
/// without involving a major version upgrade.
/// For more details see https://aka.ms/msal-net-claims-request
/// </remarks>
public ManagedIdentityApplicationBuilder WithClientCapabilities(IEnumerable<string> clientCapabilities)
{
ValidateUseOfExperimentalFeature();

if (clientCapabilities != null && clientCapabilities.Any())
{
Config.ClientCapabilities = clientCapabilities;
}

return this;
}

/// <summary>
/// Builds an instance of <see cref="IManagedIdentityApplication"/>
/// from the parameters set in the <see cref="ManagedIdentityApplicationBuilder"/>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
AuthenticationResult authResult = null;
ILoggerAdapter logger = AuthenticationRequestParameters.RequestContext.Logger;

// Skip checking cache when force refresh is specified
if (_managedIdentityParameters.ForceRefresh)
// Skip checking cache when force refresh or claims is specified
if (_managedIdentityParameters.ForceRefresh || !string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
{
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;
logger.Info("[ManagedIdentityRequest] Skipped looking for a cached access token because ForceRefresh was set.");

logger.Info("[ManagedIdentityRequest] Skipped looking for a cached access token because ForceRefresh or Claims were set. " +
"This means either a force refresh was requested or claims were present.");

authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
return authResult;
}
Expand Down Expand Up @@ -111,7 +114,8 @@ private async Task<AuthenticationResult> GetAccessTokenAsync(
// 1. Force refresh is requested, or
// 2. If the access token needs to be refreshed proactively.
if (_managedIdentityParameters.ForceRefresh ||
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo == CacheRefreshReason.ProactivelyRefreshed)
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo == CacheRefreshReason.ProactivelyRefreshed ||
!string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
{
authResult = await SendTokenRequestForManagedIdentityAsync(logger, cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ private IManagedIdentityApplication CreateMIAWithProxy(string url, string userAs
builder.Config.AccessorOptions = null;

IManagedIdentityApplication mia = builder
.WithExperimentalFeatures(true)
.WithHttpManager(proxyHttpManager).Build();

return mia;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,129 @@ public async Task ManagedIdentityForceRefreshTestAsync(
}
}

[DataTestMethod]
[DataRow(AppServiceEndpoint, Resource, ManagedIdentitySource.AppService)]
[DataRow(ImdsEndpoint, Resource, ManagedIdentitySource.Imds)]
[DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)]
[DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)]
[DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)]
public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync(
string endpoint,
string scope,
ManagedIdentitySource managedIdentitySource)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager(isManagedIdentity: true))
{
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
.WithExperimentalFeatures(true)
.WithClientCapabilities(TestConstants.ClientCapabilities)
.WithHttpManager(httpManager);

// Disabling shared cache options to avoid cross test pollution.
miBuilder.Config.AccessorOptions = null;

var mi = miBuilder.Build();

httpManager.AddManagedIdentityMockHandler(
endpoint,
Resource,
MockHelpers.GetMsiSuccessfulResponse(),
managedIdentitySource);

var result = await mi.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);

// Acquire token from cache
result = await mi.AcquireTokenForManagedIdentity(scope)
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);

httpManager.AddManagedIdentityMockHandler(
endpoint,
scope,
MockHelpers.GetMsiSuccessfulResponse(),
managedIdentitySource);

// Acquire token with force refresh
result = await mi.AcquireTokenForManagedIdentity(scope).WithClaims(TestConstants.Claims)
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
}
}

[DataTestMethod]
[DataRow(AppServiceEndpoint, Resource, ManagedIdentitySource.AppService)]
[DataRow(ImdsEndpoint, Resource, ManagedIdentitySource.Imds)]
[DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)]
[DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)]
[DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)]
public async Task ManagedIdentityWithClaimsTestAsync(
string endpoint,
string scope,
ManagedIdentitySource managedIdentitySource)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager(isManagedIdentity: true))
{
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
.WithExperimentalFeatures(true)
.WithHttpManager(httpManager);

// Disabling shared cache options to avoid cross test pollution.
miBuilder.Config.AccessorOptions = null;

var mi = miBuilder.Build();

httpManager.AddManagedIdentityMockHandler(
endpoint,
Resource,
MockHelpers.GetMsiSuccessfulResponse(),
managedIdentitySource);

var result = await mi.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);

// Acquire token from cache
result = await mi.AcquireTokenForManagedIdentity(scope)
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);

httpManager.AddManagedIdentityMockHandler(
endpoint,
scope,
MockHelpers.GetMsiSuccessfulResponse(),
managedIdentitySource);

// Acquire token with force refresh
result = await mi.AcquireTokenForManagedIdentity(scope).WithClaims(TestConstants.Claims)
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
}
}

[DataTestMethod]
[DataRow("user.read", ManagedIdentitySource.AppService, AppServiceEndpoint)]
[DataRow("https://management.core.windows.net//user_impersonation", ManagedIdentitySource.AppService, AppServiceEndpoint)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ public void AssertionInputIsMutable()
options.ClientID = "clientid";
options.TokenEndpoint = "https://login.microsoft.com/v2.0/token";
options.CancellationToken = CancellationToken.None;
options.Claims = TestConstants.Claims;
}
}
}

0 comments on commit 83725aa

Please sign in to comment.