diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 2fdbe6ab..cc7b2702 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -61,6 +61,7 @@ public void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100); updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); + updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true); } @SneakyThrows diff --git a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java index e7f54521..2bf0e611 100644 --- a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java +++ b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java @@ -16,6 +16,8 @@ import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; import com.google.gson.JsonParser; import lombok.extern.log4j.Log4j2; @@ -86,7 +88,7 @@ private void prepareVisualization(String title, String id) { } private String extractAdditionalInfo(String responseStr) { - return JsonParser + JsonArray output = JsonParser .parseString(responseStr) .getAsJsonObject() .get("inference_results") @@ -94,14 +96,19 @@ private String extractAdditionalInfo(String responseStr) { .get(0) .getAsJsonObject() .get("output") - .getAsJsonArray() - .get(0) - .getAsJsonObject() - .get("dataAsMap") - .getAsJsonObject() - .get("additional_info") - .getAsJsonObject() - .get(String.format(Locale.ROOT, "%s.output", toolType())) - .getAsString(); + .getAsJsonArray(); + for (JsonElement element : output) { + if ("response".equals(element.getAsJsonObject().get("name").getAsString())) { + return element + .getAsJsonObject() + .get("dataAsMap") + .getAsJsonObject() + .get("additional_info") + .getAsJsonObject() + .get(String.format(Locale.ROOT, "%s.output", toolType())) + .getAsString(); + } + } + return null; } }