Skip to content

Commit

Permalink
Add support for OpenSearch Serverless collections to the opensearch s…
Browse files Browse the repository at this point in the history
…ource (opensearch-project#3288)

Signed-off-by: Taylor Gray <[email protected]>
  • Loading branch information
graytaylor0 authored and asifsmohammed committed Sep 27, 2023
1 parent 5f29a30 commit be68221
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public class AwsAuthenticationConfiguration {
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

@JsonProperty("serverless")
private Boolean serverless = false;

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}
Expand All @@ -44,5 +47,9 @@ public Region getAwsRegion() {
public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}

public Boolean isServerlessCollection() {
return serverless;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ public class SearchConfiguration {
private static final ObjectMapper objectMapper = new ObjectMapper();
private static final Logger LOG = LoggerFactory.getLogger(SearchConfiguration.class);

// TODO: Should we default this to NONE and remove the version lookup to determine scroll or point-in-time as the default behavior?
@JsonProperty("search_context_type")
private SearchContextType searchContextType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class OpenSearchClientFactory {
private static final Logger LOG = LoggerFactory.getLogger(OpenSearchClientFactory.class);

private static final String AOS_SERVICE_NAME = "es";
private static final String AOSS_SERVICE_NAME = "aoss";

private final AwsCredentialsSupplier awsCredentialsSupplier;

Expand Down Expand Up @@ -96,9 +97,13 @@ private OpenSearchTransport createOpenSearchTransportForAws(final OpenSearchSour
.withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides())
.build());

final boolean isServerlessCollection = Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions()) &&
openSearchSourceConfiguration.getAwsAuthenticationOptions().isServerlessCollection();

return new AwsSdk2Transport(createSdkHttpClient(openSearchSourceConfiguration),
HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(),
AOS_SERVICE_NAME, openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(),
isServerlessCollection ? AOSS_SERVICE_NAME : AOS_SERVICE_NAME,
openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(),
AwsSdk2TransportOptions.builder()
.setCredentials(awsCredentialsProvider)
.setMapper(new JacksonJsonpMapper())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.client.opensearch._types.OpenSearchException;
import org.opensearch.client.opensearch.core.InfoResponse;
import org.opensearch.client.util.MissingRequiredPropertyException;
import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException;
import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType;
import org.slf4j.Logger;
Expand All @@ -31,7 +32,6 @@ public class SearchAccessorStrategy {
static final String OPENSEARCH_DISTRIBUTION = "opensearch";
static final String ELASTICSEARCH_DISTRIBUTION = "elasticsearch";
static final String ELASTICSEARCH_OSS_BUILD_FLAVOR = "oss";
static final String OPENDISTRO_DISTRIUBTION = "opendistro";

private static final String OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF = "2.5.0";
private static final String ELASTICSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF = "7.10.0";
Expand Down Expand Up @@ -60,6 +60,11 @@ public SearchAccessor getSearchAccessor() {

final OpenSearchClient openSearchClient = openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration);

if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions()) &&
openSearchSourceConfiguration.getAwsAuthenticationOptions().isServerlessCollection()) {
return createSearchAccessorForServerlessCollection(openSearchClient);
}

InfoResponse infoResponse = null;

ElasticsearchClient elasticsearchClient = null;
Expand Down Expand Up @@ -102,6 +107,20 @@ public SearchAccessor getSearchAccessor() {
return new ElasticsearchAccessor(elasticsearchClient, searchContextType);
}

private SearchAccessor createSearchAccessorForServerlessCollection(final OpenSearchClient openSearchClient) {
if (Objects.isNull(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType())) {
LOG.info("Configured with AOS serverless flag as true, defaulting to search_context_type as 'none', which uses search_after");
return new OpenSearchAccessor(openSearchClient, SearchContextType.NONE);
} else {
if (SearchContextType.POINT_IN_TIME.equals(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType())) {
throw new InvalidPluginConfigurationException("A search_context_type of point_in_time is not supported for serverless collections");
}

LOG.info("Using search_context_type set in the config: '{}'", openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType().toString().toLowerCase());
return new OpenSearchAccessor(openSearchClient, openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType());
}
}

private void validateSearchContextTypeOverride(final SearchContextType searchContextType, final String distribution, final String version) {

if (searchContextType.equals(SearchContextType.POINT_IN_TIME) && !versionSupportsPointInTime(distribution, version)) {
Expand Down Expand Up @@ -142,9 +161,9 @@ private Pair<String, String> getDistributionAndVersionNumber(final InfoResponse
}

private void validateDistribution(final String distribution) {
if (!distribution.equals(OPENSEARCH_DISTRIBUTION) && !distribution.startsWith(ELASTICSEARCH_DISTRIBUTION) && !distribution.equals(OPENDISTRO_DISTRIUBTION)) {
throw new IllegalArgumentException(String.format("Only %s, %s, or %s distributions are supported at this time. The cluster distribution being used is '%s'",
OPENSEARCH_DISTRIBUTION, OPENDISTRO_DISTRIUBTION, ELASTICSEARCH_DISTRIBUTION, distribution));
if (!distribution.equals(OPENSEARCH_DISTRIBUTION) && !distribution.startsWith(ELASTICSEARCH_DISTRIBUTION)) {
throw new IllegalArgumentException(String.format("Only %s or %s distributions are supported at this time. The cluster distribution being used is '%s'",
OPENSEARCH_DISTRIBUTION, ELASTICSEARCH_DISTRIBUTION, distribution));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void provideOpenSearchClient_with_aws_auth() {
final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role";
when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap());
when(awsAuthenticationConfiguration.isServerlessCollection()).thenReturn(false);
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final ArgumentCaptor<AwsCredentialsOptions> awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
Expand Down Expand Up @@ -155,4 +156,32 @@ void provideOpenSearchClient_with_auth_disabled() {
verify(openSearchSourceConfiguration, never()).getUsername();
verify(openSearchSourceConfiguration, never()).getPassword();
}

@Test
void provideOpenSearchClient_with_aws_auth_and_serverless_flag_true() {
when(connectionConfiguration.getCertPath()).thenReturn(null);
when(connectionConfiguration.getSocketTimeout()).thenReturn(null);
when(connectionConfiguration.getConnectTimeout()).thenReturn(null);

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1);
final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role";
when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap());
when(awsAuthenticationConfiguration.isServerlessCollection()).thenReturn(true);
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final ArgumentCaptor<AwsCredentialsOptions> awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider);

final OpenSearchClient openSearchClient = createObjectUnderTest().provideOpenSearchClient(openSearchSourceConfiguration);
assertThat(openSearchClient, notNullValue());

final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue();
assertThat(awsCredentialsOptions, notNullValue());
assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1));
assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap()));
assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.ElasticsearchVersionInfo;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
Expand All @@ -17,7 +18,9 @@
import org.opensearch.client.opensearch._types.OpenSearchVersionInfo;
import org.opensearch.client.opensearch.core.InfoResponse;
import org.opensearch.client.util.MissingRequiredPropertyException;
import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException;
import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.SearchConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType;

Expand Down Expand Up @@ -189,4 +192,54 @@ void search_context_type_set_to_none_uses_that_search_context_regardless_of_vers
assertThat(searchAccessor, notNullValue());
assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.NONE));
}

@Test
void serverless_flag_true_defaults_to_search_context_type_none() {

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.isServerlessCollection()).thenReturn(true);
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class);
when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration);

final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor();

assertThat(searchAccessor, notNullValue());
assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.NONE));
}

@Test
void serverless_flag_true_throws_InvalidPluginConfiguration_if_search_context_type_is_point_in_time() {

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.isServerlessCollection()).thenReturn(true);
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class);
when(searchConfiguration.getSearchContextType()).thenReturn(SearchContextType.POINT_IN_TIME);
when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration);

final SearchAccessorStrategy objectUnderTest = createObjectUnderTest();

assertThrows(InvalidPluginConfigurationException.class, objectUnderTest::getSearchAccessor);
}

@ParameterizedTest
@ValueSource(strings = {"NONE", "SCROLL"})
void serverless_flag_true_uses_search_context_type_from_config(final String searchContextType) {

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.isServerlessCollection()).thenReturn(true);
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class);
when(searchConfiguration.getSearchContextType()).thenReturn(SearchContextType.valueOf(searchContextType));
when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration);

final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor();

assertThat(searchAccessor, notNullValue());
assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.valueOf(searchContextType)));
}
}

0 comments on commit be68221

Please sign in to comment.