Skip to content

Commit

Permalink
Update the MsalRequest object flow
Browse files Browse the repository at this point in the history
  • Loading branch information
neha-bhargava committed Aug 9, 2023
1 parent 60bbf7c commit 6688149
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ abstract class AbstractManagedIdentitySource {
private static final String MANAGED_IDENTITY_NO_RESPONSE_RECEIVED = "[Managed Identity] Authentication unavailable. No response received from the managed identity endpoint.";
public static final String MANAGED_IDENTITY_REQUEST_FAILED = "managed_identity_request_failed";

protected final RequestContext requestContext;
protected final ManagedIdentityRequest managedIdentityRequest;
private ServiceBundle serviceBundle;
private ManagedIdentitySourceType managedIdentitySourceType;

Expand All @@ -37,22 +37,22 @@ abstract class AbstractManagedIdentitySource {
@Setter
private String managedIdentityUserAssignedResourceId;

public AbstractManagedIdentitySource(RequestContext requestContext, ServiceBundle serviceBundle,
public AbstractManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle,
ManagedIdentitySourceType sourceType) {
this.requestContext = requestContext;
this.managedIdentityRequest = (ManagedIdentityRequest) msalRequest;
this.managedIdentitySourceType = sourceType;
this.serviceBundle = serviceBundle;
}

public ManagedIdentityResponse getManagedIdentityResponse(
ManagedIdentityParameters parameters) {

ManagedIdentityRequest request = createManagedIdentityRequest(parameters.resource);
createManagedIdentityRequest(parameters.resource);
IHttpResponse response;

try {
HttpRequest httpRequest = new HttpRequest(HttpMethod.GET, request.computeURI().toString());
response = HttpHelper.executeHttpRequest(httpRequest, requestContext, serviceBundle);
HttpRequest httpRequest = new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString());
response = HttpHelper.executeHttpRequest(httpRequest, managedIdentityRequest.requestContext(), serviceBundle);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -90,7 +90,7 @@ public ManagedIdentityResponse handleResponse(
}
}

public abstract ManagedIdentityRequest createManagedIdentityRequest(String resource);
public abstract void createManagedIdentityRequest(String resource);

protected ManagedIdentityResponse getSuccessfulResponse(IHttpResponse response) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ class AppServiceManagedIdentity extends AbstractManagedIdentitySource{
private static URI endpointUri;

@Override
public ManagedIdentityRequest createManagedIdentityRequest(String resource) {
ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.GET, endpoint);

public void createManagedIdentityRequest(String resource) {
Map<String, String> headers = new HashMap<>();
headers.put(SecretHeaderName, secret);
request.headers = headers;
managedIdentityRequest.headers = headers;

Map<String, String> queryParameters = new HashMap<>();
queryParameters.put("api-version", APP_SERVICE_MSI_API_VERSION );
Expand All @@ -48,26 +46,24 @@ public ManagedIdentityRequest createManagedIdentityRequest(String resource) {
queryParameters.put(Constants.MANAGED_IDENTITY_RESOURCE_ID, getManagedIdentityUserAssignedResourceId());
}

request.queryParameters = queryParameters;

return request;
managedIdentityRequest.queryParameters = queryParameters;
}

private AppServiceManagedIdentity(RequestContext requestContext, ServiceBundle serviceBundle, URI endpoint, String secret)
private AppServiceManagedIdentity(MsalRequest msalRequest, ServiceBundle serviceBundle, URI endpoint, String secret)
{
super(requestContext, serviceBundle, ManagedIdentitySourceType.AppService);
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AppService);
this.endpoint = endpoint;
this.secret = secret;
}

protected static AbstractManagedIdentitySource create(RequestContext requestContext, ServiceBundle serviceBundle) {
protected static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) requestContext.apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String msiSecret = environmentVariables.getEnvironmentVariable(IEnvironmentVariables.IDENTITY_HEADER);
String msiEndpoint = environmentVariables.getEnvironmentVariable(IEnvironmentVariables.IDENTITY_ENDPOINT);

return validateEnvironmentVariables(msiEndpoint, msiSecret)
? new AppServiceManagedIdentity(requestContext, serviceBundle, endpointUri, msiSecret)
? new AppServiceManagedIdentity(msalRequest, serviceBundle, endpointUri, msiSecret)
: null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ public class IMDSManagedIdentity extends AbstractManagedIdentitySource{

private URI imdsEndpoint;

public IMDSManagedIdentity(RequestContext requestContext,
public IMDSManagedIdentity(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
super(requestContext, serviceBundle, ManagedIdentitySourceType.Imds);
ManagedIdentityParameters parameters = (ManagedIdentityParameters) requestContext.apiParameters();
IEnvironmentVariables environmentVariables = ((ManagedIdentityParameters) requestContext.apiParameters()).environmentVariables == null ?
super(msalRequest, serviceBundle, ManagedIdentitySourceType.Imds);
ManagedIdentityParameters parameters = (ManagedIdentityParameters) msalRequest.requestContext().apiParameters();
IEnvironmentVariables environmentVariables = ((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()).environmentVariables == null ?
new EnvironmentVariables() :
parameters.environmentVariables;
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(IEnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST))){
Expand All @@ -53,7 +53,7 @@ public IMDSManagedIdentity(RequestContext requestContext,
StringBuilder builder = new StringBuilder(environmentVariables.getEnvironmentVariable(IEnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST));
builder.append("/" + imdsTokenPath);
try {
URI imdsEndpoint = new URI(builder.toString());
imdsEndpoint = new URI(builder.toString());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
Expand All @@ -68,12 +68,10 @@ public IMDSManagedIdentity(RequestContext requestContext,
}

@Override
public ManagedIdentityRequest createManagedIdentityRequest(String resource) {
ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.GET, imdsEndpoint);

public void createManagedIdentityRequest(String resource) {
Map<String, String> headers = new HashMap<>();
headers.put("Metadata", "true");
request.headers = headers;
managedIdentityRequest.headers = headers;

Map<String, String> queryParameters = new HashMap<>();
queryParameters.put("api-version",imdsApiVersion);
Expand All @@ -93,8 +91,7 @@ public ManagedIdentityRequest createManagedIdentityRequest(String resource) {
queryParameters.put(Constants.MANAGED_IDENTITY_RESOURCE_ID, resourceId);
}

request.queryParameters = queryParameters;
return request;
managedIdentityRequest.queryParameters = queryParameters;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ManagedIdentityClient {
private AbstractManagedIdentitySource managedIdentitySource;

public ManagedIdentityClient(MsalRequest msalRequest, ServiceBundle serviceBundle) throws Exception {
managedIdentitySource = createManagedIdentitySource(msalRequest.requestContext(), serviceBundle);
managedIdentitySource = createManagedIdentitySource(msalRequest, serviceBundle);

ManagedIdentityApplication managedIdentityApplication = (ManagedIdentityApplication) msalRequest.application();
ManagedIdentityIdType identityIdType = managedIdentityApplication.getManagedIdentityId().getIdType();
Expand All @@ -31,13 +31,13 @@ public ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParamet
}

// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
private static AbstractManagedIdentitySource createManagedIdentitySource(RequestContext requestContext,
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) throws Exception {
AbstractManagedIdentitySource managedIdentitySource;
if ((managedIdentitySource = AppServiceManagedIdentity.create(requestContext, serviceBundle)) != null) {
if ((managedIdentitySource = AppServiceManagedIdentity.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentity(requestContext, serviceBundle);
return new IMDSManagedIdentity(msalRequest, serviceBundle);
}
}
}

0 comments on commit 6688149

Please sign in to comment.