Skip to content

Commit

Permalink
[Backport 2.x] Fix MLModelTool returns null if the response of LLM is…
Browse files Browse the repository at this point in the history
… a pure json object (#2675) (#2685)

* Fix MLModelTool returns null if the response of LLM is a pure json object (#2655)

* Fix MLModelTool returns null if the response of LLM is a pure json object

Signed-off-by: Heng Qian <[email protected]>

* Fix UT failure

Signed-off-by: Heng Qian <[email protected]>

* Avoid NPE

Signed-off-by: Heng Qian <[email protected]>

* spotlessApply

Signed-off-by: Heng Qian <[email protected]>

---------

Signed-off-by: Heng Qian <[email protected]>
(cherry picked from commit 007b914)

* remove java21 API for backporting to 2.x

Signed-off-by: Heng Qian <[email protected]>

---------

Signed-off-by: Heng Qian <[email protected]>
(cherry picked from commit 0a6a2b0)

Co-authored-by: qianheng <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and qianheng-aws committed Jul 19, 2024
1 parent dcfe439 commit 51cba97
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;

import lombok.Getter;
Expand Down Expand Up @@ -54,6 +55,7 @@ public class MLModelTool implements Tool {
private Parser inputParser;
@Setter
@Getter
@VisibleForTesting
private Parser outputParser;
@Setter
@Getter
Expand All @@ -65,8 +67,18 @@ public MLModelTool(Client client, String modelId, String responseField) {
this.responseField = responseField;

outputParser = o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField);
try {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
Map<String, ?> dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap();
// Return the response field if it exists, otherwise return the whole response as json string.
if (dataAsMap.containsKey(responseField)) {
return dataAsMap.get(responseField);
} else {
return StringUtils.toJson(dataAsMap);
}
} catch (Exception e) {
throw new IllegalStateException("LLM returns wrong or empty tensors", e);
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throw
tool.run(null, listener);

future.join();
assertEquals(null, future.get());
assertEquals("{\"response\":\"response 1\",\"action\":\"action1\"}", future.get());
}

@Test
Expand Down Expand Up @@ -170,6 +170,26 @@ public void testOutputParserLambda() {
assertEquals("testResponse", result);
}

@Test
public void testOutputParserWithJsonResponse() {
Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser();
String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}";

// Create a mock ModelTensors with json object
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("key1", "value1", "key2", "value2")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
assertEquals(expectedJson, result);

// Create a mock ModelTensors with response string
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "{\"key1\":\"value1\",\"key2\":\"value2\"}")).build();
modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
assertEquals(expectedJson, result);
}

@Test
public void testRunWithError() {
// Mocking the client.execute to simulate an error
Expand Down

0 comments on commit 51cba97

Please sign in to comment.