Skip to content

Commit

Permalink
Merge pull request #598 from AzureAD/SJAIN/add-2s-timeout-to-IMDS-call
Browse files Browse the repository at this point in the history
add 2 seconds timeout while calling IMDS
  • Loading branch information
siddhijain authored Feb 23, 2023
2 parents 92eace8 + d6ac699 commit 290b543
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import labapi.AzureEnvironment;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.io.IOException;
Expand Down Expand Up @@ -118,13 +119,18 @@ public void acquireTokenClientCredentials_DefaultCacheLookup() throws Exception
Assert.assertNotEquals(result2.accessToken(), result3.accessToken());
}

@Test
public void acquireTokenClientCredentials_Regional() throws Exception {
@DataProvider(name = "regionWithAuthority")
public static Object[][] createData() {
return new Object[][]{{"westus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS},
{"eastus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS}};
}

@Test(dataProvider = "regionWithAuthority")
public void acquireTokenClientCredentials_Regional(String[] regionWithAuthority) throws Exception {
String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e";

assertAcquireTokenCommon_withRegion(clientId, certificate);
assertAcquireTokenCommon_withRegion(clientId, certificate, regionWithAuthority[0], regionWithAuthority[1]);
}

private ClientAssertion getClientAssertion(String clientId) {
return JwtHelper.buildJwt(
clientId,
Expand Down Expand Up @@ -164,15 +170,15 @@ private void assertAcquireTokenCommon_withParameters(String clientId, IClientCre
Assert.assertNotNull(result.accessToken());
}

private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential) throws Exception {
private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential, String region, String regionalAuthority) throws Exception {
ConfidentialClientApplication ccaNoRegion = ConfidentialClientApplication.builder(
clientId, credential).
authority(TestConstants.MICROSOFT_AUTHORITY).
build();

ConfidentialClientApplication ccaRegion = ConfidentialClientApplication.builder(
clientId, credential).
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion("westus").
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion(region).
build();

//Ensure behavior when region not specified
Expand All @@ -193,7 +199,7 @@ private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredent

Assert.assertNotNull(resultRegion);
Assert.assertNotNull(resultRegion.accessToken());
Assert.assertEquals(resultRegion.environment(), TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS);
Assert.assertEquals(resultRegion.environment(), regionalAuthority);

IAuthenticationResult resultRegionCached = ccaRegion.acquireToken(ClientCredentialParameters
.builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ public class TestConstants {
public final static String TENANT_SPECIFIC_AUTHORITY = MICROSOFT_AUTHORITY_HOST + MICROSOFT_AUTHORITY_TENANT;
public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS = "westus.login.microsoft.com";

public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS = "eastus.login.microsoft.com";

public final static String ARLINGTON_ORGANIZATIONS_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "organizations/";
public final static String ARLINGTON_COMMON_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "common/";
public final static String ARLINGTON_TENANT_SPECIFIC_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + ARLINGTON_AUTHORITY_TENANT;
public final static String ARLINGTON_GRAPH_DEFAULT_SCOPE = "https://graph.microsoft.us/.default";


public final static String B2C_AUTHORITY = "https://msidlabb2c.b2clogin.com/msidlabb2c.onmicrosoft.com/";
public final static String B2C_AUTHORITY_LEGACY_FORMAT = "https://msidlabb2c.b2clogin.com/tfp/msidlabb2c.onmicrosoft.com/";

public final static String B2C_ROPC_POLICY = "B2C_1_ROPC_Auth";
public final static String B2C_SIGN_IN_POLICY = "B2C_1_SignInPolicy";
public final static String B2C_AUTHORITY_SIGN_IN = B2C_AUTHORITY + B2C_SIGN_IN_POLICY;
Expand All @@ -49,19 +50,13 @@ public class TestConstants {
public final static String B2C_MICROSOFTLOGIN_ROPC = B2C_MICROSOFTLOGIN_AUTHORITY + B2C_ROPC_POLICY;

public final static String LOCALHOST = "http://localhost:";
public final static String LOCAL_FLAG_ENV_VAR = "MSAL_JAVA_RUN_LOCAL";

public final static String ADFS_AUTHORITY = "https://fs.msidlab8.com/adfs/";
public final static String ADFS_SCOPE = USER_READ_SCOPE;
public final static String ADFS_APP_ID = "PublicClientId";

public final static String CLAIMS = "{\"id_token\":{\"auth_time\":{\"essential\":true}}}";
public final static Set<String> CLIENT_CAPABILITIES_EMPTY = new HashSet<>(Collections.emptySet());
public final static Set<String> CLIENT_CAPABILITIES_LLT = new HashSet<>(Collections.singletonList("llt"));

// cross cloud b2b settings
public final static String AUTHORITY_ARLINGTON = "https://login.microsoftonline.us/" + ARLINGTON_AUTHORITY_TENANT;
public final static String AUTHORITY_MOONCAKE = "https://login.chinacloudapi.cn/mncmsidlab1.partner.onmschina.cn";
public final static String AUTHORITY_PUBLIC_TENANT_SPECIFIC = "https://login.microsoftonline.com/" + MICROSOFT_AUTHORITY_TENANT;

public final static String DEFAULT_ACCESS_TOKEN = "defaultAccessToken";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.TreeSet;
import java.util.Map;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.*;

class AadInstanceDiscoveryProvider {

Expand All @@ -31,6 +31,8 @@ class AadInstanceDiscoveryProvider {
private static final String DEFAULT_API_VERSION = "2020-06-01";
private static final String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text";

private static final int IMDS_TIMEOUT = 2;
private static final TimeUnit IMDS_TIMEOUT_UNIT = TimeUnit.SECONDS;
static final TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
static final TreeSet<String> TRUSTED_SOVEREIGN_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);

Expand Down Expand Up @@ -71,8 +73,8 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
//If region autodetection is enabled and a specific region not already set,
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
&& null != detectedRegion) {
msalRequest.application().azureRegion = detectedRegion;
&& null != detectedRegion) {
msalRequest.application().azureRegion = detectedRegion;
}
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
Expand Down Expand Up @@ -291,33 +293,39 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv
return System.getenv(REGION_NAME);
}

try {
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
Map<String, String> headers = new HashMap<>();
headers.put("Metadata", "true");
IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle);
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
Map<String, String> headers = new HashMap<>();
headers.put("Metadata", "true");

ExecutorService executor = Executors.newSingleThreadExecutor();
Future<IHttpResponse> future = executor.submit(() -> executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle));

try {
log.info("Starting call to IMDS endpoint.");
IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT);
//If call to IMDS endpoint was successful, return region from response body
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
log.info("Region retrieved from IMDS endpoint: " + httpResponse.body());
log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);

return httpResponse.body();
}

log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);

return null;
} catch (Exception e) {
} catch (Exception ex) {
// handle other exceptions
//IMDS call failed, cannot find region
//The IMDS endpoint is only available from within an Azure environment, so the most common cause of this
// exception will likely be java.net.SocketException: Network is unreachable: connect
log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
log.warn(String.format("Exception during call to local IMDS endpoint: %s", ex.getMessage()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
future.cancel(true);

return null;
} finally {
executor.shutdownNow();
}

return null;
}

private static void doInstanceDiscoveryAndCache(URL authorityUrl,
Expand Down

0 comments on commit 290b543

Please sign in to comment.