diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 1c09f5996..d2c4c6029 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -12,12 +12,14 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.util.Strings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -40,6 +42,7 @@ public class MLCommonsClientAccessor { private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); private final MachineLearningNodeClient mlClient; + private static final String EXCEPTION_MESSAGE_MODEL_PROCESSING_FAILED = "the system encountered an unexpected error during processing"; /** * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating @@ -187,6 +190,13 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { for (final ModelTensors tensors : tensorOutputList) { final List tensorsList = tensors.getMlModelTensors(); for (final ModelTensor tensor : tensorsList) { + String exceptionMessage = EXCEPTION_MESSAGE_MODEL_PROCESSING_FAILED; + if (Objects.isNull(tensor.getData())) { + if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) { + exceptionMessage = (String) tensor.getDataAsMap().get("message"); + } + throw new IllegalStateException(exceptionMessage); + } vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList())); } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index ce2773f2f..d093b9e71 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -328,6 +329,21 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } + public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() { + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createEmptyModelTensorOutput()); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(any()); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -355,4 +371,22 @@ private ModelTensorOutput createModelTensorOutput(final Map map) tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createEmptyModelTensorOutput() { + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor( + "someValue", + null, + new long[] { 1, 2 }, + MLResultDataType.FLOAT64, + ByteBuffer.wrap(new byte[12]), + "mockResult", + ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.") + ); + mlModelTensorList.add(tensor); + final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); + tensorsList.add(modelTensors); + return new ModelTensorOutput(tensorsList); + } }