diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 1c09f5996..2ccef0746 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -11,13 +11,16 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.util.Strings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -40,6 +43,8 @@ public class MLCommonsClientAccessor { private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); private final MachineLearningNodeClient mlClient; + private static final String EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED = "failed while calling model, check error log for details"; + private static final String EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED = "encountered following error while calling a model"; /** * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating @@ -187,6 +192,20 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { for (final ModelTensors tensors : tensorOutputList) { final List tensorsList = tensors.getMlModelTensors(); for (final ModelTensor tensor : tensorsList) { + if (Objects.isNull(tensor.getData())) { + if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) { + String errorFromModel = (String) tensor.getDataAsMap().get("message"); + throw new IllegalStateException( + String.format(Locale.ROOT, "%s: %s", EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED, errorFromModel) + ); + } else { + log.error( + "Received following output tensor from a model, there is no detailed error message: {}", + tensor.toString() + ); + throw new IllegalStateException(EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED); + } + } vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList())); } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index ce2773f2f..eeb528823 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -328,6 +330,65 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } + public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() { + List tensorsList = new ArrayList<>(); + List mlModelTensorList = List.of( + new ModelTensor( + "someValue", + null, + new long[] { 1, 2 }, + MLResultDataType.FLOAT64, + ByteBuffer.wrap(new byte[12]), + "mockResult", + ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.") + ) + ); + final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); + ModelTensorOutput outputWithErrorMessage = new ModelTensorOutput(List.of(modelTensors)); + + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(outputWithErrorMessage); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(any()); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + + clearInvocations(client, singleSentenceResultListener); + + List mlModelTensorList2 = List.of( + new ModelTensor( + "someValue", + null, + new long[] { 1, 2 }, + MLResultDataType.FLOAT64, + ByteBuffer.wrap(new byte[12]), + "mockResult", + ImmutableMap.of("test_key", "test_value") + ) + ); + final ModelTensors modelTensors2 = new ModelTensors(mlModelTensorList2); + ModelTensorOutput outputWithErrorMessage2 = new ModelTensorOutput(List.of(modelTensors2)); + + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(outputWithErrorMessage2); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(any()); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>();