Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for asymmetric embedding models #710

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x)
### Features
- Add support for asymmetric embedding models ([#710](https://github.com/opensearch-project/neural-search/pull/710))
- Speed up NeuralSparseQuery by two-phase using a custom search pipeline.([#646](https://github.com/opensearch-project/neural-search/issues/646))
- Support batchExecute in TextEmbeddingProcessor and SparseEncodingProcessor ([#743](https://github.com/opensearch-project/neural-search/issues/743))
### Enhancements
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -46,10 +48,15 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentences(
this.modelId,
inferenceList,
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

Expand Down Expand Up @@ -312,10 +314,15 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
inferenceInput.put(INPUT_IMAGE, queryImage());
}
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
((client, actionListener) -> ML_CLIENT.inferenceSentences(
modelId(),
inferenceInput,
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(),
ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)
))
);
return new NeuralQueryBuilder(
fieldName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -59,15 +63,34 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener);
accessor.inferenceSentence(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST.get(0),
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
singleSentenceResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

private void setupMocksForTextEmbeddingModelAsymmetryCheck(boolean isAsymmetric) {
MLModel modelMock = mock(MLModel.class);
TextEmbeddingModelConfig textEmbeddingModelConfigMock = mock(TextEmbeddingModelConfig.class);
Mockito.when(textEmbeddingModelConfigMock.getPassagePrefix()).thenReturn(isAsymmetric ? "passage: " : null);
Mockito.when(textEmbeddingModelConfigMock.getQueryPrefix()).thenReturn(isAsymmetric ? "query: " : null);
Mockito.when(modelMock.getModelConfig()).thenReturn(textEmbeddingModelConfigMock);
Mockito.doAnswer(invocation -> {
final ActionListener<MLModel> actionListener = invocation.getArgument(1);
actionListener.onResponse(modelMock);
return null;
}).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class));
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand All @@ -76,6 +99,8 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -92,6 +117,8 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
actionListener.onResponse(createModelTensorOutput(new Float[] {}));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -107,6 +134,8 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
actionListener.onFailure(exception);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
Expand All @@ -130,6 +159,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
Expand All @@ -149,6 +181,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
actionListener.onFailure(illegalStateException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
Expand All @@ -161,6 +196,62 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
Mockito.verify(resultListener).onFailure(illegalStateException);
}

public void testInferenceSentences_whenModelAsymmetric_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(true);

accessor.inferenceSentence(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST.get(0),
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
singleSentenceResultListener
);

Mockito.verify(client)
.predict(
Mockito.eq(TestCommonConstants.MODEL_ID),
Mockito.argThat((MLInput input) -> input.getParameters() != null),
Mockito.isA(ActionListener.class)
);
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSentences_whenGetModelException_thenFailure() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to test scenario when we're retrying 1-2 times. I see scenario when first request has failed with error that isn't retryable, this isn't a full coverage

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-gaievski thanks for pointing this out.

I'm wondering though, if we should make this retryable at all. Let me elaborate:

In my understanding, inference requests are retried because they tend to fail more often than regular operations in OpenSearch. I don't know the history and complete reasoning behind this, so I speculate it has to do with the fact that the inference is done natively and that many things can go wrong there.

With my change, if fetching the model information fails (mlClient.getModel(modelId, ...), there is no retry. Model information is fetched the first time inference is requested with a particular model. After that, the result is cached and the method behaves exactly as before the PR.

So my argument is: should we really add a retry logic to this relatively simple operation? If getModel fails, it is most likely to fail again, so retrying wouldn't make sense. If so, one could argue that all operations in OpenSearch should be wrapped in a retry logic.

final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
RuntimeException exception = new RuntimeException("Bam!");
setupMocksForTextEmbeddingModelAsymmetryCheck(exception);

accessor.inferenceSentence(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST.get(0),
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
singleSentenceResultListener
);

Mockito.verify(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

private void setupMocksForTextEmbeddingModelAsymmetryCheck(Exception exception) {
Mockito.doAnswer(invocation -> {
final ActionListener<MLModel> actionListener = invocation.getArgument(1);
actionListener.onFailure(exception);
return null;
}).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class));
}

public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
final Map<String, String> map = Map.of("key", "value");
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);
Expand All @@ -169,6 +260,9 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
actionListener.onResponse(createModelTensorOutput(map));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -185,6 +279,9 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -209,6 +306,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand Down Expand Up @@ -236,6 +336,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -255,6 +358,9 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client, times(4))
Expand All @@ -270,6 +376,9 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client, times(1))
Expand All @@ -285,6 +394,8 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
Expand All @@ -300,6 +411,9 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() {
actionListener.onFailure(exception);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
Expand All @@ -318,6 +432,9 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client, times(4))
Expand All @@ -333,6 +450,8 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
Expand All @@ -354,6 +473,8 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
Expand All @@ -378,6 +499,8 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
Expand Down
Loading
Loading