Skip to content

Commit

Permalink
Fixed error for case when mltensor has data as null
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 6, 2023
1 parent 841f280 commit e592ad7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -40,6 +42,7 @@
public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;
private static final String EXCEPTION_MESSAGE_MODEL_PROCESSING_FAILED = "the system encountered an unexpected error during processing";

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
Expand Down Expand Up @@ -187,6 +190,13 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
String exceptionMessage = EXCEPTION_MESSAGE_MODEL_PROCESSING_FAILED;
if (Objects.isNull(tensor.getData())) {
if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) {
exceptionMessage = (String) tensor.getDataAsMap().get("message");
}
throw new IllegalStateException(exceptionMessage);
}
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

Expand Down Expand Up @@ -328,6 +329,21 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() {
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createEmptyModelTensorOutput());
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(any());
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down Expand Up @@ -355,4 +371,22 @@ private ModelTensorOutput createModelTensorOutput(final Map<String, String> map)
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}

private ModelTensorOutput createEmptyModelTensorOutput() {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
final ModelTensor tensor = new ModelTensor(
"someValue",
null,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.")
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}
}

0 comments on commit e592ad7

Please sign in to comment.