diff --git a/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/ClientCredentialsIT.java b/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/ClientCredentialsIT.java index 8c1f5256..1e9b9ceb 100644 --- a/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/ClientCredentialsIT.java +++ b/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/ClientCredentialsIT.java @@ -7,6 +7,7 @@ import labapi.AzureEnvironment; import org.testng.Assert; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.io.IOException; @@ -118,13 +119,18 @@ public void acquireTokenClientCredentials_DefaultCacheLookup() throws Exception Assert.assertNotEquals(result2.accessToken(), result3.accessToken()); } - @Test - public void acquireTokenClientCredentials_Regional() throws Exception { + @DataProvider(name = "regionWithAuthority") + public static Object[][] createData() { + return new Object[][]{{"westus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS}, + {"eastus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS}}; + } + + @Test(dataProvider = "regionWithAuthority") + public void acquireTokenClientCredentials_Regional(String[] regionWithAuthority) throws Exception { String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e"; - assertAcquireTokenCommon_withRegion(clientId, certificate); + assertAcquireTokenCommon_withRegion(clientId, certificate, regionWithAuthority[0], regionWithAuthority[1]); } - private ClientAssertion getClientAssertion(String clientId) { return JwtHelper.buildJwt( clientId, @@ -164,7 +170,7 @@ private void assertAcquireTokenCommon_withParameters(String clientId, IClientCre Assert.assertNotNull(result.accessToken()); } - private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential) throws Exception { + private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential, String region, String regionalAuthority) throws Exception { ConfidentialClientApplication ccaNoRegion = ConfidentialClientApplication.builder( clientId, credential). authority(TestConstants.MICROSOFT_AUTHORITY). @@ -172,7 +178,7 @@ private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredent ConfidentialClientApplication ccaRegion = ConfidentialClientApplication.builder( clientId, credential). - authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion("westus"). + authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion(region). build(); //Ensure behavior when region not specified @@ -193,7 +199,7 @@ private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredent Assert.assertNotNull(resultRegion); Assert.assertNotNull(resultRegion.accessToken()); - Assert.assertEquals(resultRegion.environment(), TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS); + Assert.assertEquals(resultRegion.environment(), regionalAuthority); IAuthenticationResult resultRegionCached = ccaRegion.acquireToken(ClientCredentialParameters .builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE)) diff --git a/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/TestConstants.java b/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/TestConstants.java index bd81b076..b9603d52 100644 --- a/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/TestConstants.java +++ b/msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/TestConstants.java @@ -32,14 +32,15 @@ public class TestConstants { public final static String TENANT_SPECIFIC_AUTHORITY = MICROSOFT_AUTHORITY_HOST + MICROSOFT_AUTHORITY_TENANT; public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS = "westus.login.microsoft.com"; + public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS = "eastus.login.microsoft.com"; + public final static String ARLINGTON_ORGANIZATIONS_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "organizations/"; - public final static String ARLINGTON_COMMON_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "common/"; public final static String ARLINGTON_TENANT_SPECIFIC_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + ARLINGTON_AUTHORITY_TENANT; public final static String ARLINGTON_GRAPH_DEFAULT_SCOPE = "https://graph.microsoft.us/.default"; - public final static String B2C_AUTHORITY = "https://msidlabb2c.b2clogin.com/msidlabb2c.onmicrosoft.com/"; public final static String B2C_AUTHORITY_LEGACY_FORMAT = "https://msidlabb2c.b2clogin.com/tfp/msidlabb2c.onmicrosoft.com/"; + public final static String B2C_ROPC_POLICY = "B2C_1_ROPC_Auth"; public final static String B2C_SIGN_IN_POLICY = "B2C_1_SignInPolicy"; public final static String B2C_AUTHORITY_SIGN_IN = B2C_AUTHORITY + B2C_SIGN_IN_POLICY; @@ -49,7 +50,6 @@ public class TestConstants { public final static String B2C_MICROSOFTLOGIN_ROPC = B2C_MICROSOFTLOGIN_AUTHORITY + B2C_ROPC_POLICY; public final static String LOCALHOST = "http://localhost:"; - public final static String LOCAL_FLAG_ENV_VAR = "MSAL_JAVA_RUN_LOCAL"; public final static String ADFS_AUTHORITY = "https://fs.msidlab8.com/adfs/"; public final static String ADFS_SCOPE = USER_READ_SCOPE; @@ -57,11 +57,6 @@ public class TestConstants { public final static String CLAIMS = "{\"id_token\":{\"auth_time\":{\"essential\":true}}}"; public final static Set CLIENT_CAPABILITIES_EMPTY = new HashSet<>(Collections.emptySet()); - public final static Set CLIENT_CAPABILITIES_LLT = new HashSet<>(Collections.singletonList("llt")); - - // cross cloud b2b settings - public final static String AUTHORITY_ARLINGTON = "https://login.microsoftonline.us/" + ARLINGTON_AUTHORITY_TENANT; - public final static String AUTHORITY_MOONCAKE = "https://login.chinacloudapi.cn/mncmsidlab1.partner.onmschina.cn"; public final static String AUTHORITY_PUBLIC_TENANT_SPECIFIC = "https://login.microsoftonline.com/" + MICROSOFT_AUTHORITY_TENANT; public final static String DEFAULT_ACCESS_TOKEN = "defaultAccessToken"; diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java index c548fb73..b4d61b27 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java @@ -14,7 +14,7 @@ import java.util.TreeSet; import java.util.Map; import java.util.HashMap; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.*; class AadInstanceDiscoveryProvider { @@ -31,6 +31,8 @@ class AadInstanceDiscoveryProvider { private static final String DEFAULT_API_VERSION = "2020-06-01"; private static final String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text"; + private static final int IMDS_TIMEOUT = 2; + private static final TimeUnit IMDS_TIMEOUT_UNIT = TimeUnit.SECONDS; static final TreeSet TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); static final TreeSet TRUSTED_SOVEREIGN_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); @@ -71,8 +73,8 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl, //If region autodetection is enabled and a specific region not already set, // set the application's region to the discovered region so that future requests can skip the IMDS endpoint call if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion() - && null != detectedRegion) { - msalRequest.application().azureRegion = detectedRegion; + && null != detectedRegion) { + msalRequest.application().azureRegion = detectedRegion; } cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion()); serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome( @@ -291,33 +293,39 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv return System.getenv(REGION_NAME); } - try { - //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM) - Map headers = new HashMap<>(); - headers.put("Metadata", "true"); - IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle); + //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM) + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future future = executor.submit(() -> executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle)); + try { + log.info("Starting call to IMDS endpoint."); + IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT); //If call to IMDS endpoint was successful, return region from response body if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) { - log.info("Region retrieved from IMDS endpoint: " + httpResponse.body()); + log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body())); currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue); return httpResponse.body(); } - log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode())); currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue); - - return null; - } catch (Exception e) { + } catch (Exception ex) { + // handle other exceptions //IMDS call failed, cannot find region //The IMDS endpoint is only available from within an Azure environment, so the most common cause of this // exception will likely be java.net.SocketException: Network is unreachable: connect - log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage())); + log.warn(String.format("Exception during call to local IMDS endpoint: %s", ex.getMessage())); currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue); + future.cancel(true); - return null; + } finally { + executor.shutdownNow(); } + + return null; } private static void doInstanceDiscoveryAndCache(URL authorityUrl,