Skip to content

Commit

Permalink
SigV4: Add host header only when not already provided (#5608)
Browse files Browse the repository at this point in the history
* SigV4: Add host header only when not already provided

* http-client: respect user host header

---------

Co-authored-by: Vladimir Sudilovsky <[email protected]>
Co-authored-by: Daniel Cullen <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2024
1 parent 97ee691 commit 7f0dddf
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-b3fbc61.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "vsudilov",
"description": "SigV4: Add host header only when not already provided"
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ public static void addHostHeader(SdkHttpRequest.Builder requestBuilder) {
// AWS4 requires that we sign the Host header, so we
// have to have it in the request by the time we sign.

// If the SdkHttpRequest has an associated Host header
// already set, prefer to use that.

if (requestBuilder.headers().get(SignerConstant.HOST) != null) {
return;
}

String host = requestBuilder.host();
if (!SdkHttpUtils.isUsingStandardPort(requestBuilder.protocol(), requestBuilder.port())) {
StringBuilder hostHeaderBuilder = new StringBuilder(host);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant;
import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
import software.amazon.awssdk.identity.spi.AwsSessionCredentialsIdentity;

Expand Down Expand Up @@ -58,6 +59,7 @@ public void sign_computesSigningResult() {
assertEquals(expectedCanonicalRequestString, result.getCanonicalRequest().getCanonicalRequestString());
}


@Test
public void sign_withHeader_addsAuthHeaders() {
String expectedAuthorization = "AWS4-HMAC-SHA256 Credential=access/19700101/us-east-1/demo/aws4_request, " +
Expand All @@ -82,6 +84,21 @@ public void sign_withHeaderAndSessionCredentials_addsAuthHeadersAndTokenHeader()
assertThat(result.getSignedRequest().firstMatchingHeader("X-Amz-Security-Token")).hasValue("token");
}

@Test
public void sign_withHeaderAndSessionCredentials_correctSigningUsingProvidedHostHeader() {
String expectedAuthorization = "AWS4-HMAC-SHA256 Credential=access/19700101/us-east-1/demo/aws4_request, " +
"SignedHeaders=host;x-amz-archive-description;x-amz-content-sha256;x-amz-date;"
+ "x-amz-security-token, " +
"Signature=c8228e7bef8a72a450df38e6e935ce61fdb8989670b41d97cfc20d04bb76b10a";
SdkHttpRequest.Builder request = getRequest().putHeader(SignerConstant.HOST, "virtual-host.localhost");
V4RequestSigningResult result = header(getProperties(sessionCreds)).sign(request);

assertThat(result.getSignedRequest().firstMatchingHeader("Host")).hasValue("virtual-host.localhost");
assertThat(result.getSignedRequest().firstMatchingHeader("X-Amz-Date")).hasValue("19700101T000000Z");
assertThat(result.getSignedRequest().firstMatchingHeader("Authorization")).hasValue(expectedAuthorization);
assertThat(result.getSignedRequest().firstMatchingHeader("X-Amz-Security-Token")).hasValue("token");
}

@Test
public void sign_withQuery_addsAuthQueryParams() {
V4RequestSigningResult result = query(getProperties(creds)).sign(getRequest());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.client.config.RequestConfig;
Expand Down Expand Up @@ -55,7 +56,6 @@ public HttpRequestBase create(final HttpExecuteRequest request, final ApacheHttp
HttpRequestBase base = createApacheRequest(request, sanitizeUri(request.httpRequest()));
addHeadersToRequest(base, request.httpRequest());
addRequestConfig(base, request.httpRequest(), requestConfig);

return base;
}

Expand Down Expand Up @@ -172,7 +172,7 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req
// it's already present, so we skip it here. We also skip the Host
// header to avoid sending it twice, which will interfere with some
// signing schemes.
if (!IGNORE_HEADERS.contains(name)) {
if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) {
for (String headerValue : value) {
httpRequest.addHeader(name, headerValue);
}
Expand All @@ -181,6 +181,11 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req
}

private String getHostHeaderValue(SdkHttpRequest request) {
// Respect any user-specified Host header when present
Optional<String> existingHostHeader = request.firstMatchingHeader(HttpHeaders.HOST);
if (existingHostHeader.isPresent()) {
return existingHostHeader.get();
}
// Apache doesn't allow us to include the port in the host header if it's a standard port for that protocol. For that
// reason, we don't include the port when we sign the message. See {@link SdkHttpRequest#port()}.
return !SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,46 @@ public void createSetsHostHeaderByDefault() {
assertEquals("localhost:12345", hostHeaders[0].getValue());
}

@Test
public void createRespectsUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("Host", hostOverride)
.build();
HttpExecuteRequest request = HttpExecuteRequest.builder()
.request(sdkRequest)
.build();

HttpRequestBase result = instance.create(request, requestConfig);

Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST);
assertNotNull(hostHeaders);
assertEquals(1, hostHeaders.length);
assertEquals(hostOverride, hostHeaders[0].getValue());
}

@Test
public void createRespectsLowercaseUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("host", hostOverride)
.build();
HttpExecuteRequest request = HttpExecuteRequest.builder()
.request(sdkRequest)
.build();

HttpRequestBase result = instance.create(request, requestConfig);

Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST);
assertNotNull(hostHeaders);
assertEquals(1, hostHeaders.length);
assertEquals(hostOverride, hostHeaders[0].getValue());
}

@Test
public void putRequest_withTransferEncodingChunked_isChunkedAndDoesNotIncludeHeader() {
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.Protocol;
import software.amazon.awssdk.http.SdkHttpMethod;
Expand Down Expand Up @@ -87,13 +88,19 @@ private void addHeadersToRequest(DefaultHttpRequest httpRequest, SdkHttpRequest
// Copy over any other headers already in our request
request.forEachHeader((name, value) -> {
// Skip the Host header to avoid sending it twice, which will interfere with some signing schemes.
if (!IGNORE_HEADERS.contains(name)) {
if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) {
value.forEach(h -> httpRequest.headers().add(name, h));
}
});
}

private String getHostHeaderValue(SdkHttpRequest request) {
// Respect any user-specified Host header when present
Optional<String> existingHostHeader = request.firstMatchingHeader(HOST);
if (existingHostHeader.isPresent()) {
return existingHostHeader.get();
}

return SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port())
? request.host()
: request.host() + ":" + request.port();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,34 @@ public void adapt_hostHeaderSet() {
assertThat(hostHeaders).containsExactly("localhost:12345");
}

@Test
public void adapt_keepsUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("Host", hostOverride)
.build();
HttpRequest result = h1Adapter.adapt(sdkRequest);
List<String> hostHeaders = result.headers()
.getAll(HttpHeaderNames.HOST.toString());
assertThat(hostHeaders).containsExactly(hostOverride);
}

@Test
public void adapt_keepsLowercaseUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("host", hostOverride)
.build();
HttpRequest result = h1Adapter.adapt(sdkRequest);
List<String> hostHeaders = result.headers()
.getAll(HttpHeaderNames.HOST.toString());
assertThat(hostHeaders).containsExactly(hostOverride);
}

@Test
public void adapt_standardHttpsPort_omittedInHeader() {
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
Expand Down

0 comments on commit 7f0dddf

Please sign in to comment.