From 8c26f4ccaaa6d302d64e1d927344c4edefbfcb05 Mon Sep 17 00:00:00 2001 From: siddhijain Date: Wed, 22 Feb 2023 13:13:36 -0600 Subject: [PATCH] Fix failing tests --- .../msal4j/AadInstanceDiscoveryProvider.java | 113 +++++++++--------- 1 file changed, 59 insertions(+), 54 deletions(-) diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java index ade8a784..6722787b 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java @@ -62,27 +62,23 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl, ServiceBundle serviceBundle) { String host = authorityUrl.getHost(); - ExecutorService executor = Executors.newSingleThreadExecutor(); - - Future future = executor.submit(() -> performRegionalDiscovery(authorityUrl, msalRequest, serviceBundle)); + if (shouldUseRegionalEndpoint(msalRequest)) { + //Server side telemetry requires the result from region discovery when any part of the region API is used + String detectedRegion = discoverRegion(msalRequest, serviceBundle); - try { - log.info("Starting call to IMDS endpoint."); - host = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT); - } catch (TimeoutException ex) { - log.info("Cancelled call to IMDS endpoint after waiting for 2 seconds"); - future.cancel(true); if (msalRequest.application().azureRegion() != null) { host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion()); } - } catch (Exception ex) { - // handle other exceptions - log.info("Exception while calling IMDS endpoint" + ex.getMessage()); - if (msalRequest.application().azureRegion() != null) { - host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion()); + + //If region autodetection is enabled and a specific region not already set, + // set the application's region to the discovered region so that future requests can skip the IMDS endpoint call + if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion() + && null != detectedRegion) { + msalRequest.application().azureRegion = detectedRegion; } - } finally { - executor.shutdownNow(); + cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion()); + serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome( + determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion())); } InstanceDiscoveryMetadataEntry result = cache.get(host); @@ -103,32 +99,6 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl, return cache.get(host); } - private static String performRegionalDiscovery(URL authorityUrl, MsalRequest msalRequest, ServiceBundle serviceBundle){ - - String host = authorityUrl.getHost(); - - if (shouldUseRegionalEndpoint(msalRequest)) { - //Server side telemetry requires the result from region discovery when any part of the region API is used - String detectedRegion = discoverRegion(msalRequest, serviceBundle); - - if (msalRequest.application().azureRegion() != null) { - host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion()); - } - - //If region autodetection is enabled and a specific region not already set, - // set the application's region to the discovered region so that future requests can skip the IMDS endpoint call - if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion() - && null != detectedRegion) { - msalRequest.application().azureRegion = detectedRegion; - } - cacheRegionInstanceMetadata(host, authorityUrl.getHost()); - serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome( - determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion())); - } - - return host; - } - static Set getAliases(String host) { if (cache.containsKey(host)) { return cache.get(host).aliases(); @@ -192,10 +162,11 @@ private static boolean shouldUseRegionalEndpoint(MsalRequest msalRequest){ return false; } - static void cacheRegionInstanceMetadata(String regionalHost, String host) { + static void cacheRegionInstanceMetadata(String host, String region) { Set aliases = new HashSet<>(); aliases.add(host); + String regionalHost = getRegionalizedHost(host, region); cache.putIfAbsent(regionalHost, InstanceDiscoveryMetadataEntry.builder(). preferredCache(host). @@ -322,12 +293,44 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv return System.getenv(REGION_NAME); } - try { - //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM) - Map headers = new HashMap<>(); - headers.put("Metadata", "true"); - IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle); +// try { +// //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM) +// Map headers = new HashMap<>(); +// headers.put("Metadata", "true"); +// IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle); +// +// //If call to IMDS endpoint was successful, return region from response body +// if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) { +// log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body())); +// currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue); +// +// return httpResponse.body(); +// } +// +// log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode())); +// currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue); +// +// return null; +// } catch (Exception e) { +// //IMDS call failed, cannot find region +// //The IMDS endpoint is only available from within an Azure environment, so the most common cause of this +// // exception will likely be java.net.SocketException: Network is unreachable: connect +// log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage())); +// currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue); +// +// return null; +// } + + //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM) + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future future = executor.submit(() -> executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle)); + try { + log.info("Starting call to IMDS endpoint."); + IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT); //If call to IMDS endpoint was successful, return region from response body if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) { log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body())); @@ -335,20 +338,22 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv return httpResponse.body(); } - log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode())); currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue); - - return null; - } catch (Exception e) { + } catch (Exception ex) { + // handle other exceptions //IMDS call failed, cannot find region //The IMDS endpoint is only available from within an Azure environment, so the most common cause of this // exception will likely be java.net.SocketException: Network is unreachable: connect - log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage())); + log.warn(String.format("Exception during call to local IMDS endpoint: %s", ex.getMessage())); currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue); + future.cancel(true); - return null; + } finally { + executor.shutdownNow(); } + + return null; } private static void doInstanceDiscoveryAndCache(URL authorityUrl,