From 088ee992ca5b2aef07005817b347526256ec81db Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 29 Nov 2023 19:34:57 -0800 Subject: [PATCH] [Feature/agent_framework] Registers a single agent with multiple tools (#198) * Initial register agent workflow step Signed-off-by: Owais Kazi * Added tools step Signed-off-by: Owais Kazi * Fixed ClassCastException Signed-off-by: Owais Kazi * Handled exception for Instant Signed-off-by: Owais Kazi * Added type Instant for WorklowNode Parser Signed-off-by: Owais Kazi * Removed created and last updated time Signed-off-by: Owais Kazi * Addressed parsing error Signed-off-by: Owais Kazi * Handled parsing of Long values for Instant Signed-off-by: Owais Kazi * Handled nested object for llm key Signed-off-by: Owais Kazi * Handled parsing error Signed-off-by: Owais Kazi * Another attempt to fix parsing error for llm Signed-off-by: Owais Kazi * Another attemp to fix XContent Signed-off-by: Owais Kazi * Fixed Parsing error Signed-off-by: Owais Kazi * Added tests for toolstep and javadocs Signed-off-by: Owais Kazi * Undo CI changes Signed-off-by: Owais Kazi * Addressing PR comments Signed-off-by: Owais Kazi * Addressing PR comments Signed-off-by: Owais Kazi * Handled interface changes Signed-off-by: Owais Kazi * Addressed conflicts Signed-off-by: Owais Kazi * Added TODO Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- build.gradle | 1 + .../flowframework/common/CommonValue.java | 17 +- .../flowframework/model/Template.java | 2 +- .../flowframework/model/WorkflowNode.java | 35 ++- .../flowframework/util/ParseUtils.java | 99 +++++++ .../workflow/CreateConnectorStep.java | 9 +- .../workflow/CreateIndexStep.java | 4 +- .../workflow/DeployModelStep.java | 2 +- .../workflow/ModelGroupStep.java | 2 +- .../workflow/RegisterAgentStep.java | 241 ++++++++++++++++++ .../workflow/RegisterLocalModelStep.java | 2 +- .../workflow/RegisterRemoteModelStep.java | 2 +- .../flowframework/workflow/ToolStep.java | 127 +++++++++ .../flowframework/workflow/WorkflowData.java | 2 +- .../workflow/WorkflowStepFactory.java | 2 + .../resources/mappings/workflow-steps.json | 24 ++ .../model/WorkflowNodeTests.java | 22 +- .../flowframework/util/ParseUtilsTests.java | 19 ++ .../workflow/RegisterAgentTests.java | 130 ++++++++++ .../flowframework/workflow/ToolStepTests.java | 53 ++++ 20 files changed, 773 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/ToolStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java diff --git a/build.gradle b/build.gradle index 20a9837ca..d6e7afbd8 100644 --- a/build.gradle +++ b/build.gradle @@ -155,6 +155,7 @@ dependencies { implementation "org.opensearch:common-utils:${common_utils_version}" implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'org.bouncycastle:bcprov-jdk18on:1.77' + implementation "com.google.code.gson:gson:2.10.1" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 90f208c8d..dbfc17891 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -162,5 +162,20 @@ private CommonValue() {} public static final String RESOURCE_ID_FIELD = "resource_id"; /** The field name for the ResourceCreated's resource name */ public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; - + /** LLM Name for registering an agent */ + public static final String LLM_FIELD = "llm"; + /** The tools' field for an agent */ + public static final String TOOLS_FIELD = "tools"; + /** The memory field for an agent */ + public static final String MEMORY_FIELD = "memory"; + /** The app type field for an agent */ + public static final String APP_TYPE_FIELD = "app_type"; + /** The agent id of an agent */ + public static final String AGENT_ID = "agent_id"; + /** To include field for an agent response */ + public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + /** The created time field for an agent */ + public static final String CREATED_TIME = "created_time"; + /** The last updated time field for an agent */ + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; } diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index bfb40b696..f4a8b1958 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -229,12 +229,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (Entry e : workflows.entrySet()) { xContentBuilder.field(e.getKey(), e.getValue(), params); } + xContentBuilder.endObject(); if (uiMetadata != null && !uiMetadata.isEmpty()) { xContentBuilder.field(UI_METADATA_FIELD, uiMetadata); } - xContentBuilder.endObject(); if (user != null) { xContentBuilder.field(USER_FIELD, user); } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 7d04a5a3f..999ba460f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -14,6 +14,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; import java.util.ArrayList; @@ -24,7 +25,10 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; +import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseLLM; import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** @@ -34,7 +38,6 @@ * and its inputs are used to populate the {@link WorkflowData} input. */ public class WorkflowNode implements ToXContentObject { - /** The template field name for node id */ public static final String ID_FIELD = "id"; /** The template field name for node type */ @@ -82,7 +85,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.startObject(USER_INPUTS_FIELD); for (Entry e : userInputs.entrySet()) { xContentBuilder.field(e.getKey()); - if (e.getValue() instanceof String) { + if (e.getValue() instanceof String || e.getValue() instanceof Number) { xContentBuilder.value(e.getValue()); } else if (e.getValue() instanceof Map) { buildStringToStringMap(xContentBuilder, (Map) e.getValue()); @@ -98,6 +101,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } xContentBuilder.endArray(); + } else if (e.getValue() instanceof LLMSpec) { + if (LLM_FIELD.equals(e.getKey())) { + xContentBuilder.startObject(); + buildLLMMap(xContentBuilder, (LLMSpec) e.getValue()); + xContentBuilder.endObject(); + } } } xContentBuilder.endObject(); @@ -141,7 +150,11 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - userInputs.put(inputFieldName, parseStringToStringMap(parser)); + if (LLM_FIELD.equals(inputFieldName)) { + userInputs.put(inputFieldName, parseLLM(parser)); + } else { + userInputs.put(inputFieldName, parseStringToStringMap(parser)); + } break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { @@ -158,6 +171,22 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, mapList.toArray(new Map[0])); } break; + case VALUE_NUMBER: + switch (parser.numberType()) { + case INT: + userInputs.put(inputFieldName, parser.intValue()); + break; + case LONG: + userInputs.put(inputFieldName, parser.longValue()); + break; + case FLOAT: + userInputs.put(inputFieldName, parser.floatValue()); + break; + case DOUBLE: + userInputs.put(inputFieldName, parser.doubleValue()); + break; + } + break; default: throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object."); } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 0f725f687..14d113e34 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.util; +import com.google.gson.Gson; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -21,14 +22,21 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; /** * Utility methods for Template parsing @@ -36,6 +44,12 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); + public static final Gson gson; + + static { + gson = new Gson(); + } + private ParseUtils() {} /** @@ -70,6 +84,21 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map parameters = llm.getParameters(); + xContentBuilder.field(MODEL_ID, modelId); + xContentBuilder.field(PARAMETERS_FIELD); + buildStringToStringMap(xContentBuilder, parameters); + } + /** * Parses an XContent object representing a map of String keys to String values. * @@ -88,6 +117,37 @@ public static Map parseStringToStringMap(XContentParser parser) return map; } + // TODO Figure out a way to use the parse method of LLMSpec of ml-commons + /** + * Parses an XContent object representing the object of LLMSpec + * @param parser An XContent parser whose position is at the start of the map object to parse + * @return instance of {@link org.opensearch.ml.common.agent.LLMSpec} + * @throws IOException parsing error + */ + public static LLMSpec parseLLM(XContentParser parser) throws IOException { + String modelId = null; + Map parameters = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID: + modelId = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + return LLMSpec.builder().modelId(modelId).parameters(parameters).build(); + } + /** * Parse content parser to {@link java.time.Instant}. * @@ -116,6 +176,31 @@ public static User getUserContext(Client client) { return User.parse(userStr); } + /** + * Generates a parameter map required when the parameter is nested within an object + * @param parameterObjs parameters + * @return a parameters map of type String,String + */ + public static Map getParameterMap(Map parameterObjs) { + Map parameters = new HashMap<>(); + for (String key : parameterObjs.keySet()) { + Object value = parameterObjs.get(key); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + if (value instanceof String) { + parameters.put(key, (String) value); + } else { + parameters.put(key, gson.toJson(value)); + } + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + return parameters; + } + /** * Creates a XContentParser from a given Registry * @@ -129,4 +214,18 @@ public static XContentParser createXContentParserFromRegistry(NamedXContentRegis return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); } + /** + * Generates a string to string Map + * @param map content map + * @param fieldName fieldName + * @return instance of the map + */ + @SuppressWarnings("unchecked") + public static Map getStringToStringMap(Object map, String fieldName) { + if (map instanceof Map) { + return (Map) map; + } + throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); + } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index b00857ff6..bc4132087 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -49,6 +49,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** * Step to create a connector for a remote model @@ -210,14 +211,6 @@ public String getName() { return NAME; } - @SuppressWarnings("unchecked") - private static Map getStringToStringMap(Object map, String fieldName) { - if (map instanceof Map) { - return (Map) map; - } - throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); - } - private static Map getParameterMap(Object parameterMap) throws PrivilegedActionException { Map parameters = new HashMap<>(); for (Entry entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index f443e9c2c..07246134a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -35,8 +35,8 @@ public class CreateIndexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIndexStep.class); - private ClusterService clusterService; - private Client client; + private final ClusterService clusterService; + private final Client client; /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ static final String NAME = "create_index"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index aa6768605..81409ef77 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -30,7 +30,7 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "deploy_model"; /** diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 325c3edb8..22c6ae810 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -40,7 +40,7 @@ public class ModelGroupStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ModelGroupStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "model_group"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java new file mode 100644 index 000000000..44270d8e6 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLAgent.MLAgentBuilder; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; +import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; + +/** + * Step to register an agent + */ +public class RegisterAgentStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterAgentStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "register_agent"; + + private List mlToolSpecList; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public RegisterAgentStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + this.mlToolSpecList = new ArrayList<>(); + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + + CompletableFuture registerAgentModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { + logger.info("Remote Agent registration successful"); + registerAgentModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register the agent"); + registerAgentModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String name = null; + String type = null; + String description = null; + LLMSpec llm = null; + List tools = new ArrayList<>(); + Map parameters = Collections.emptyMap(); + MLMemorySpec memory = null; + Instant createdTime = null; + Instant lastUpdateTime = null; + String appType = null; + + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + name = (String) entry.getValue(); + break; + case DESCRIPTION_FIELD: + description = (String) entry.getValue(); + break; + case TYPE: + type = (String) entry.getValue(); + break; + case LLM_FIELD: + llm = getLLMSpec(entry.getValue()); + break; + case TOOLS_FIELD: + tools = addTools(entry.getValue()); + break; + case PARAMETERS_FIELD: + parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD); + break; + case MEMORY_FIELD: + memory = getMLMemorySpec(entry.getValue()); + break; + case CREATED_TIME: + createdTime = Instant.ofEpochMilli((Long) entry.getValue()); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue()); + break; + case APP_TYPE_FIELD: + appType = (String) entry.getValue(); + break; + default: + break; + } + } + } + + if (Stream.of(name, type, llm, tools, parameters, memory, appType).allMatch(x -> x != null)) { + MLAgentBuilder builder = MLAgent.builder().name(name); + + if (description != null) { + builder.description(description); + } + + builder.type(type) + .llm(llm) + .tools(tools) + .parameters(parameters) + .memory(memory) + .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) + .appType(appType); + + MLAgent mlAgent = builder.build(); + + mlClient.registerAgent(mlAgent, actionListener); + + } else { + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); + } + + return registerAgentModelFuture; + } + + @Override + public String getName() { + return NAME; + } + + private List addTools(Object tools) { + MLToolSpec mlToolSpec = (MLToolSpec) tools; + mlToolSpecList.add(mlToolSpec); + return mlToolSpecList; + } + + private LLMSpec getLLMSpec(Object llm) { + if (llm instanceof LLMSpec) { + return (LLMSpec) llm; + } + throw new IllegalArgumentException("[" + LLM_FIELD + "] must be of type LLMSpec."); + } + + private MLMemorySpec getMLMemorySpec(Object mlMemory) { + + Map map = (Map) mlMemory; + String type = null; + String sessionId = null; + Integer windowSize = null; + type = (String) map.get(MLMemorySpec.MEMORY_TYPE_FIELD); + if (type == null) { + throw new IllegalArgumentException("agent name is null"); + } + sessionId = (String) map.get(MLMemorySpec.SESSION_ID_FIELD); + windowSize = (Integer) map.get(MLMemorySpec.WINDOW_SIZE_FIELD); + + @SuppressWarnings("unchecked") + MLMemorySpec.MLMemorySpecBuilder builder = MLMemorySpec.builder(); + + builder.type(type); + if (sessionId != null) { + builder.sessionId(sessionId); + } + if (windowSize != null) { + builder.windowSize(windowSize); + } + + MLMemorySpec mlMemorySpec = builder.build(); + return mlMemorySpec; + + } + + private Instant getInstant(Object instant, String fieldName) { + if (instant instanceof Instant) { + return (Instant) instant; + } + throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant."); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 19229efd1..cc5645306 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -56,7 +56,7 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "register_local_model"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index e41323a14..27a77cb98 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -43,7 +43,7 @@ public class RegisterRemoteModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterRemoteModelStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "register_remote_model"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java new file mode 100644 index 000000000..339142139 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.common.agent.MLToolSpec; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; + +/** + * Step to register a tool for an agent + */ +public class ToolStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(ToolStep.class); + CompletableFuture toolFuture = new CompletableFuture<>(); + static final String NAME = "create_tool"; + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + String type = null; + String name = null; + String description = null; + Map parameters = Collections.emptyMap(); + Boolean includeOutputInAgentResponse = null; + + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case TYPE: + type = (String) content.get(TYPE); + break; + case NAME_FIELD: + name = (String) content.get(NAME_FIELD); + break; + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); + break; + case PARAMETERS_FIELD: + parameters = getStringToStringMap(content.get(PARAMETERS_FIELD), PARAMETERS_FIELD); + break; + case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: + includeOutputInAgentResponse = (Boolean) content.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + break; + default: + break; + } + + } + + } + + if (type == null) { + toolFuture.completeExceptionally(new FlowFrameworkException("Tool type is not provided", RestStatus.BAD_REQUEST)); + } else { + MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); + + builder.type(type); + if (name != null) { + builder.name(name); + } + if (description != null) { + builder.description(description); + } + if (parameters != null) { + builder.parameters(parameters); + } + if (includeOutputInAgentResponse != null) { + builder.includeOutputInAgentResponse(includeOutputInAgentResponse); + } + + MLToolSpec mlToolSpec = builder.build(); + + toolFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + logger.info("Tool registered successfully {}", type); + return toolFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index a0d901f74..ba19823a7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -66,7 +66,7 @@ public WorkflowData(Map content, Map params, @Nu */ public Map getContent() { return this.content; - }; + } /** * Returns a map represents the params associated with a Rest API request, parsed from the URI. diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index b95a0449d..c9e565bba 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -62,6 +62,8 @@ private void populateMap( stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); + stepMap.put(ToolStep.NAME, new ToolStep()); + stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 6256189c1..c9794d4ea 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -83,5 +83,29 @@ "model_group_id", "model_group_status" ] + }, + "register_agent": { + "inputs":[ + "name", + "type", + "llm", + "tools", + "parameters", + "memory", + "created_time", + "last_updated_time", + "app_type" + ], + "outputs":[ + "agent_id" + ] + }, + "create_tool": { + "inputs": [ + "type" + ], + "outputs": [ + "tools" + ] } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index 700e1d0d2..c0011f7ae 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -8,9 +8,11 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.util.HashMap; import java.util.Map; public class WorkflowNodeTests extends OpenSearchTestCase { @@ -21,6 +23,12 @@ public void setUp() throws Exception { } public void testNode() throws IOException { + Map parameters = new HashMap<>(); + parameters.put("stop", "true"); + parameters.put("max", "5"); + + LLMSpec llmSpec = new LLMSpec("modelId", parameters); + WorkflowNode nodeA = new WorkflowNode( "A", "a-type", @@ -29,7 +37,9 @@ public void testNode() throws IOException { Map.entry("foo", "a string"), Map.entry("bar", Map.of("key", "value")), Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), - Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }) + Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }), + Map.entry("llm", llmSpec), + Map.entry("created_time", 1689793598499L) ) ); assertEquals("A", nodeA.id()); @@ -43,6 +53,7 @@ public void testNode() throws IOException { assertEquals(1, pp.length); assertEquals("test-type", pp[0].type()); assertEquals(Map.of("key2", "value2"), pp[0].params()); + assertEquals(1689793598499L, map.get("created_time")); // node equality is based only on ID WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz")); @@ -52,13 +63,17 @@ public void testNode() throws IOException { assertNotEquals(nodeA, nodeB); String json = TemplateTestJsonUtil.parseToJson(nodeA); - logger.info("TESTING : " + json); + logger.info("JSON : " + json); assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"previous_node_inputs\":{\"foo\":\"field\"},")); assertTrue(json.contains("\"user_inputs\":{")); assertTrue(json.contains("\"foo\":\"a string\"")); assertTrue(json.contains("\"baz\":[{\"A\":\"a\"},{\"B\":\"b\"}]")); assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); + assertTrue(json.contains("\"created_time\":1689793598499")); + assertTrue(json.contains("llm\":{")); + assertTrue(json.contains("\"parameters\":{\"stop\":\"true\",\"max\":\"5\"")); + assertTrue(json.contains("\"model_id\":\"modelId\"")); WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); @@ -73,6 +88,9 @@ public void testNode() throws IOException { assertEquals(1, ppX.length); assertEquals("test-type", ppX[0].type()); assertEquals(Map.of("key2", "value2"), ppX[0].params()); + LLMSpec llm = (LLMSpec) mapX.get("llm"); + assertEquals("modelId", llm.getModelId()); + assertEquals(parameters, llm.getParameters()); } public void testExceptions() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index a5c4253b3..76334b52b 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -12,9 +12,12 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import org.junit.Assert; import java.io.IOException; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; public class ParseUtilsTests extends OpenSearchTestCase { public void testToInstant() throws IOException { @@ -54,4 +57,20 @@ public void testToInstantWithNotValue() throws IOException { Instant instant = ParseUtils.parseInstant(parser); assertNull(instant); } + + public void testGetParameterMap() { + Map parameters = new HashMap<>(); + parameters.put("key1", "value1"); + parameters.put("key2", 2); + parameters.put("key3", 2.1); + parameters.put("key4", new int[] { 10, 20 }); + parameters.put("key5", new Object[] { 1.01, "abc" }); + Map parameterMap = ParseUtils.getParameterMap(parameters); + Assert.assertEquals(5, parameterMap.size()); + Assert.assertEquals("value1", parameterMap.get("key1")); + Assert.assertEquals("2", parameterMap.get("key2")); + Assert.assertEquals("2.1", parameterMap.get("key3")); + Assert.assertEquals("[10,20]", parameterMap.get("key4")); + Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java new file mode 100644 index 000000000..0f4b33471 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class RegisterAgentTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); + + LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap()); + + Map mlMemorySpec = Map.ofEntries( + Map.entry(MLMemorySpec.MEMORY_TYPE_FIELD, "type"), + Map.entry(MLMemorySpec.SESSION_ID_FIELD, "abc"), + Map.entry(MLMemorySpec.WINDOW_SIZE_FIELD, 2) + ); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("type", "type"), + Map.entry("llm", llmSpec), + Map.entry("tools", tools), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("memory", mlMemorySpec), + Map.entry("created_time", 1689793598499L), + Map.entry("last_updated_time", 1689793598499L), + Map.entry("app_type", "app") + ), + "test-id", + "test-node-id" + ); + } + + public void testRegisterAgent() throws IOException, ExecutionException, InterruptedException { + String agentId = "agent_id"; + RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + assertTrue(future.isDone()); + assertEquals(agentId, future.get().getContent().get("agent_id")); + } + + public void testRegisterAgentFailure() throws IOException { + String agentId = "agent_id"; + RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register the agent", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register the agent", ex.getCause().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java new file mode 100644 index 000000000..c7e8df2d8 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +public class ToolStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("type", "type"), + Map.entry("name", "name"), + Map.entry("description", "description"), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("include_output_in_agent_response", false) + ), + "test-id", + "test-node-id" + ); + } + + public void testTool() throws IOException, ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + CompletableFuture future = toolStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); + } +}