Skip to content

Commit

Permalink
Merge pull request #569 from AzureAD/SJAIN/instance-discovery-endpoint
Browse files Browse the repository at this point in the history
expose instanceDiscovery flag
  • Loading branch information
siddhijain authored Jan 25, 2023
2 parents 1177c5b + dee401f commit 0b81ab6
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 9 deletions.
1 change: 1 addition & 0 deletions msal4j-sdk/changelog.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Version 1.13.4
=============
- regional endpoint updates
- fixed manifest
- Expose instance discovery flag to perform instance discovery.

Version 1.13.3
=============
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
package com.microsoft.aad.msal4j;

import org.easymock.Capture;
import org.easymock.EasyMock;
import org.powermock.api.easymock.PowerMock;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.testng.Assert;
import org.testng.IObjectFactory;
import org.testng.annotations.DataProvider;
import org.testng.annotations.ObjectFactory;
import org.testng.annotations.Test;

import java.net.URI;
import java.util.Collections;
import java.util.Date;
import java.util.concurrent.CompletableFuture;

@PrepareForTest({HttpHelper.class, PublicClientApplication.class})
public class InstanceDiscoveryTest {

private PublicClientApplication app;

@ObjectFactory
public IObjectFactory getObjectFactory() {
return new org.powermock.modules.testng.PowerMockObjectFactory();
}

@DataProvider(name = "aadClouds")
private static Object[][] getAadClouds(){
return new Object[][] {{"https://login.microsoftonline.com/common"} , // #Known to Microsoft
{"https://private.cloud/foo"}//Private Cloud
};
}

/**
* when instance_discovery flag is set to true (by default), an instance_discovery is performed for authorityType = AAD
*/
@Test( dataProvider = "aadClouds")
public void aadInstanceDiscoveryTrue(String authority) throws Exception{
app = PowerMock.createPartialMock(PublicClientApplication.class,
new String[]{"acquireTokenCommon"},
PublicClientApplication.builder(TestConfiguration.AAD_CLIENT_ID)
.authority(authority));

Capture<MsalRequest> capturedMsalRequest = Capture.newInstance();

PowerMock.expectPrivate(app, "acquireTokenCommon",
EasyMock.capture(capturedMsalRequest), EasyMock.isA(AADAuthority.class)).andReturn(
AuthenticationResult.builder().
accessToken("accessToken").
expiresOn(new Date().getTime() + 100).
refreshToken("refreshToken").
idToken("idToken").environment("environment").build());

PowerMock.mockStatic(HttpHelper.class);

HttpResponse instanceDiscoveryResponse = new HttpResponse();
instanceDiscoveryResponse.statusCode(200);
instanceDiscoveryResponse.body(TestConfiguration.INSTANCE_DISCOVERY_RESPONSE);

Capture<HttpRequest> capturedHttpRequest = Capture.newInstance();

EasyMock.expect(
HttpHelper.executeHttpRequest(
EasyMock.capture(capturedHttpRequest),
EasyMock.isA(RequestContext.class),
EasyMock.isA(ServiceBundle.class)))
.andReturn(instanceDiscoveryResponse);

PowerMock.replay(HttpHelper.class, HttpResponse.class);

CompletableFuture<IAuthenticationResult> completableFuture = app.acquireToken(
AuthorizationCodeParameters.builder
("auth_code",
new URI(TestConfiguration.AAD_DEFAULT_REDIRECT_URI))
.scopes(Collections.singleton("default-scope"))
.build());

completableFuture.get();
Assert.assertEquals(capturedHttpRequest.getValues().size(),1);

}

/**
* when instance_discovery flag is set to false, instance_discovery is not performed
*/
@Test (dataProvider = "aadClouds")
public void aadInstanceDiscoveryFalse(String authority) throws Exception {

app = PowerMock.createPartialMock(PublicClientApplication.class,
new String[]{"acquireTokenCommon"},
PublicClientApplication.builder(TestConfiguration.AAD_CLIENT_ID)
.authority(authority)
.instanceDiscovery(false));

Capture<MsalRequest> capturedMsalRequest = Capture.newInstance();

PowerMock.expectPrivate(app, "acquireTokenCommon",
EasyMock.capture(capturedMsalRequest), EasyMock.isA(AADAuthority.class)).andReturn(
AuthenticationResult.builder().
accessToken("accessToken").
expiresOn(new Date().getTime() + 100).
refreshToken("refreshToken").
idToken("idToken").environment("environment").build());

PowerMock.mockStatic(HttpHelper.class);

HttpResponse instanceDiscoveryResponse = new HttpResponse();
instanceDiscoveryResponse.statusCode(200);
instanceDiscoveryResponse.body(TestConfiguration.INSTANCE_DISCOVERY_RESPONSE);

Capture<HttpRequest> capturedHttpRequest = Capture.newInstance();

EasyMock.expect(
HttpHelper.executeHttpRequest(
EasyMock.capture(capturedHttpRequest),
EasyMock.isA(RequestContext.class),
EasyMock.isA(ServiceBundle.class)))
.andReturn(instanceDiscoveryResponse);

PowerMock.replay(HttpHelper.class, HttpResponse.class);

CompletableFuture<IAuthenticationResult> completableFuture = app.acquireToken(
AuthorizationCodeParameters.builder
("auth_code",
new URI(TestConfiguration.AAD_DEFAULT_REDIRECT_URI))
.scopes(Collections.singleton("default-scope"))
.build());

completableFuture.get();
Assert.assertEquals(capturedHttpRequest.getValues().size(),0);
}

/**
* when instance_discovery flag is set to true (by default), an instance_discovery is NOT performed for adfs.
*/
@Test
public void adfsInstanceDiscoveryTrue() throws Exception{
app = PowerMock.createPartialMock(PublicClientApplication.class,
new String[]{"acquireTokenCommon"},
PublicClientApplication.builder(TestConstants.ADFS_APP_ID)
.authority("https://contoso.com/adfs")
.instanceDiscovery(true));

Capture<MsalRequest> capturedMsalRequest = Capture.newInstance();

PowerMock.expectPrivate(app, "acquireTokenCommon",
EasyMock.capture(capturedMsalRequest), EasyMock.isA(AADAuthority.class)).andReturn(
AuthenticationResult.builder().
accessToken("accessToken").
expiresOn(new Date().getTime() + 100).
refreshToken("refreshToken").
idToken("idToken").environment("environment").build());

PowerMock.mockStatic(HttpHelper.class);

HttpResponse instanceDiscoveryResponse = new HttpResponse();
instanceDiscoveryResponse.statusCode(200);
instanceDiscoveryResponse.body(TestConfiguration.INSTANCE_DISCOVERY_RESPONSE);

Capture<HttpRequest> capturedHttpRequest = Capture.newInstance();

EasyMock.expect(
HttpHelper.executeHttpRequest(
EasyMock.capture(capturedHttpRequest),
EasyMock.isA(RequestContext.class),
EasyMock.isA(ServiceBundle.class)))
.andReturn(instanceDiscoveryResponse);

PowerMock.replay(HttpHelper.class, HttpResponse.class);

CompletableFuture<IAuthenticationResult> completableFuture = app.acquireToken(
AuthorizationCodeParameters.builder
("auth_code",
new URI(TestConfiguration.AAD_DEFAULT_REDIRECT_URI))
.scopes(Collections.singleton("default-scope"))
.build());

completableFuture.get();
Assert.assertEquals(capturedHttpRequest.getValues().size(),0);

}

/**
* when instance_discovery flag is set to true (by default), an instance_discovery is NOT performed for b2c.
*/
@Test
public void b2cInstanceDiscoveryTrue() throws Exception{
app = PowerMock.createPartialMock(PublicClientApplication.class,
new String[]{"acquireTokenCommon"},
PublicClientApplication.builder(TestConstants.ADFS_APP_ID)
.b2cAuthority(TestConstants.B2C_MICROSOFTLOGIN_ROPC)
.instanceDiscovery(true));

Capture<MsalRequest> capturedMsalRequest = Capture.newInstance();

PowerMock.expectPrivate(app, "acquireTokenCommon",
EasyMock.capture(capturedMsalRequest), EasyMock.isA(AADAuthority.class)).andReturn(
AuthenticationResult.builder().
accessToken("accessToken").
expiresOn(new Date().getTime() + 100).
refreshToken("refreshToken").
idToken("idToken").environment("environment").build());

PowerMock.mockStatic(HttpHelper.class);

HttpResponse instanceDiscoveryResponse = new HttpResponse();
instanceDiscoveryResponse.statusCode(200);
instanceDiscoveryResponse.body(TestConfiguration.INSTANCE_DISCOVERY_RESPONSE);

Capture<HttpRequest> capturedHttpRequest = Capture.newInstance();

EasyMock.expect(
HttpHelper.executeHttpRequest(
EasyMock.capture(capturedHttpRequest),
EasyMock.isA(RequestContext.class),
EasyMock.isA(ServiceBundle.class)))
.andReturn(instanceDiscoveryResponse);

PowerMock.replay(HttpHelper.class, HttpResponse.class);

CompletableFuture<IAuthenticationResult> completableFuture = app.acquireToken(
AuthorizationCodeParameters.builder
("auth_code",
new URI(TestConfiguration.AAD_DEFAULT_REDIRECT_URI))
.scopes(Collections.singleton("default-scope"))
.build());

completableFuture.get();
Assert.assertEquals(capturedHttpRequest.getValues().size(),0);

}


}
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ class AadInstanceDiscoveryProvider {
private final static String SOVEREIGN_HOST_TEMPLATE_WITH_REGION = "{region}.{host}";
private final static String REGION_NAME = "REGION_NAME";
private final static int PORT_NOT_SET = -1;

// For information of the current api-version refer: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service#versioning
private final static String DEFAULT_API_VERSION = "2020-06-01";
private final static String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text";
private static final String DEFAULT_API_VERSION = "2020-06-01";
private static final String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text";

final static TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
final static TreeSet<String> TRUSTED_SOVEREIGN_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
static final TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
static final TreeSet<String> TRUSTED_SOVEREIGN_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);

private final static Logger log = LoggerFactory.getLogger(HttpHelper.class);
private static final Logger log = LoggerFactory.getLogger(AadInstanceDiscoveryProvider.class);

static ConcurrentHashMap<String, InstanceDiscoveryMetadataEntry> cache = new ConcurrentHashMap<>();

Expand Down Expand Up @@ -67,10 +68,9 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,

//If region autodetection is enabled and a specific region not already set,
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
if (msalRequest.application().azureRegion() == null && msalRequest.application().autoDetectRegion()) {
if (detectedRegion != null) {
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
&& null != detectedRegion) {
msalRequest.application().azureRegion = detectedRegion;
}
}
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
Expand All @@ -80,7 +80,16 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
InstanceDiscoveryMetadataEntry result = cache.get(host);

if (result == null) {
doInstanceDiscoveryAndCache(authorityUrl, validateAuthority, msalRequest, serviceBundle);
if(msalRequest.application().instanceDiscovery()){
doInstanceDiscoveryAndCache(authorityUrl, validateAuthority, msalRequest, serviceBundle);
} else {
// instanceDiscovery flag is set to False. Do not perform instanceDiscovery.
return InstanceDiscoveryMetadataEntry.builder().
preferredCache(host).
preferredNetwork(host).
aliases(Collections.singleton(host)).
build();
}
}

return cache.get(host);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ public abstract class AbstractClientApplicationBase implements IClientApplicatio
@Getter
protected String azureRegion;

@Accessors(fluent = true)
@Getter
private boolean instanceDiscovery;

@Override
public CompletableFuture<IAuthenticationResult> acquireToken(AuthorizationCodeParameters parameters) {

Expand Down Expand Up @@ -325,6 +329,7 @@ public abstract static class Builder<T extends Builder<T>> {
private String azureRegion;
private Integer connectTimeoutForDefaultHttpClient;
private Integer readTimeoutForDefaultHttpClient;
private boolean instanceDiscovery = true;

/**
* Constructor to create instance of Builder of client application
Expand Down Expand Up @@ -643,6 +648,18 @@ public T azureRegion(String val) {
return self();
}

/** Historically, MSAL would connect to a central endpoint located at
``https://login.microsoftonline.com`` to acquire some metadata, especially when using an unfamiliar authority.
This behavior is known as Instance Discovery.
This parameter defaults to true, which enables the Instance Discovery.
If you do not know some authorities beforehand,
yet still want MSAL to accept any authority that you will provide,
you can use a ``False`` to unconditionally disable Instance Discovery. */
public T instanceDiscovery(boolean val) {
instanceDiscovery = val;
return self();
}

abstract AbstractClientApplicationBase build();
}

Expand Down Expand Up @@ -671,6 +688,7 @@ public T azureRegion(String val) {
clientCapabilities = builder.clientCapabilities;
autoDetectRegion = builder.autoDetectRegion;
azureRegion = builder.azureRegion;
instanceDiscovery = builder.instanceDiscovery;

if (aadAadInstanceDiscoveryResponse != null) {
AadInstanceDiscoveryProvider.cacheInstanceDiscoveryMetadata(
Expand Down

0 comments on commit 0b81ab6

Please sign in to comment.