From 540f594d449b6dfc7e85e80431d924a1778712fd Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 4 Dec 2023 09:56:45 -0800 Subject: [PATCH] [Feature/agent_framework] Registers root agent with an agentId in ToolSteps (#242) * Add agentId to parameters map for root agent Signed-off-by: Owais Kazi * Modified ToolStep with new Util method Signed-off-by: Owais Kazi * Integrated RegisterAgentStep with new Util method Signed-off-by: Owais Kazi * Spotless fixes Signed-off-by: Owais Kazi * Removed TODO Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- .../workflow/RegisterAgentStep.java | 155 ++++++++---------- .../flowframework/workflow/ToolStep.java | 94 +++++------ 2 files changed, 112 insertions(+), 137 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 022d46c22..e055433b0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -14,6 +14,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -25,13 +26,12 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; +import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; @@ -101,84 +101,55 @@ public void onFailure(Exception e) { } }; - String name = null; - String type = null; - String description = null; - String llmModelId = null; - Map llmParameters = Collections.emptyMap(); - List tools = new ArrayList<>(); - Map parameters = Collections.emptyMap(); - MLMemorySpec memory = null; - Instant createdTime = null; - Instant lastUpdateTime = null; - String appType = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case TYPE: - type = (String) entry.getValue(); - break; - case LLM_MODEL_ID: - llmModelId = (String) entry.getValue(); - break; - case LLM_PARAMETERS: - llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS); - break; - case TOOLS_FIELD: - tools = addTools(entry.getValue()); - break; - case PARAMETERS_FIELD: - parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD); - break; - case MEMORY_FIELD: - memory = getMLMemorySpec(entry.getValue()); - break; - case CREATED_TIME: - createdTime = Instant.ofEpochMilli((Long) entry.getValue()); - break; - case LAST_UPDATED_TIME_FIELD: - lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue()); - break; - case APP_TYPE_FIELD: - appType = (String) entry.getValue(); - break; - default: - break; - } - } - } + Set requiredKeys = Set.of(NAME_FIELD, TYPE); + Set optionalKeys = Set.of( + DESCRIPTION_FIELD, + LLM_MODEL_ID, + LLM_PARAMETERS, + TOOLS_FIELD, + PARAMETERS_FIELD, + MEMORY_FIELD, + CREATED_TIME, + LAST_UPDATED_TIME_FIELD, + APP_TYPE_FIELD + ); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - // Case when modelId is present in previous node inputs - if (llmModelId == null) { - llmModelId = getLlmModelId(previousNodeInputs, outputs); - } + String type = (String) inputs.get(TYPE); + String name = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String llmModelId = (String) inputs.get(LLM_MODEL_ID); + Map llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS); + List tools = getTools(previousNodeInputs, outputs); + Map parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD); + MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD)); + Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME)); + Instant lastUpdateTime = Instant.ofEpochMilli((Long) inputs.get(LAST_UPDATED_TIME_FIELD)); + String appType = (String) inputs.get(APP_TYPE_FIELD); + + // Case when modelId is present in previous node inputs + if (llmModelId == null) { + llmModelId = getLlmModelId(previousNodeInputs, outputs); + } - // Case when modelId is not present at all - if (llmModelId == null) { - registerAgentModelFuture.completeExceptionally( - new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) - ); - return registerAgentModelFuture; - } + // Case when modelId is not present at all + if (llmModelId == null) { + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) + ); + return registerAgentModelFuture; + } - LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); + LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); - if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) { MLAgentBuilder builder = MLAgent.builder().name(name); if (description != null) { @@ -198,12 +169,9 @@ public void onFailure(Exception e) { mlClient.registerAgent(mlAgent, actionListener); - } else { - registerAgentModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + registerAgentModelFuture.completeExceptionally(e); } - return registerAgentModelFuture; } @@ -212,9 +180,24 @@ public String getName() { return NAME; } - private List addTools(Object tools) { - MLToolSpec mlToolSpec = (MLToolSpec) tools; - mlToolSpecList.add(mlToolSpec); + private List getTools(Map previousNodeInputs, Map outputs) { + List mlToolSpecList = new ArrayList<>(); + List previousNodes = previousNodeInputs.entrySet() + .stream() + .filter(e -> TOOLS_FIELD.equals(e.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + if (previousNodes != null) { + previousNodes.forEach((previousNode) -> { + WorkflowData previousNodeOutput = outputs.get(previousNode); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) { + MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD); + logger.info("Tool added {}", mlToolSpec.getType()); + mlToolSpecList.add(mlToolSpec); + } + }); + } return mlToolSpecList; } @@ -240,7 +223,7 @@ private String getLlmModelId(Map previousNodeInputs, Map llmParameters) { if (llmModelId == null) { - throw new IllegalArgumentException("model id for llm is null"); + throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST); } LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); builder.modelId(llmModelId); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index af8556289..f12d9848e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -10,19 +10,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; @@ -47,49 +45,25 @@ public CompletableFuture execute( Map outputs, Map previousNodeInputs ) throws IOException { - String type = null; - String name = null; - String description = null; - Map parameters = Collections.emptyMap(); - Boolean includeOutputInAgentResponse = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case TYPE: - type = (String) entry.getValue(); - break; - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case PARAMETERS_FIELD: - parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs); - break; - case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: - includeOutputInAgentResponse = (Boolean) entry.getValue(); - break; - default: - break; - } - } + Set requiredKeys = Set.of(TYPE); + Set optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE); - } + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String type = (String) inputs.get(TYPE); + String name = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + Map parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs); - if (type == null) { - toolFuture.completeExceptionally(new FlowFrameworkException("Tool type is not provided", RestStatus.BAD_REQUEST)); - } else { MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); builder.type(type); @@ -115,9 +89,12 @@ public CompletableFuture execute( currentNodeInputs.getNodeId() ) ); - } - logger.info("Tool registered successfully {}", type); + logger.info("Tool registered successfully {}", type); + + } catch (FlowFrameworkException e) { + toolFuture.completeExceptionally(e); + } return toolFuture; } @@ -132,19 +109,34 @@ private Map getToolsParametersMap( Map outputs ) { Map parametersMap = (Map) parameters; - Optional previousNode = previousNodeInputs.entrySet() + Optional previousNodeModel = previousNodeInputs.entrySet() .stream() .filter(e -> MODEL_ID.equals(e.getValue())) .map(Map.Entry::getKey) .findFirst(); + + Optional previousNodeAgent = previousNodeInputs.entrySet() + .stream() + .filter(e -> AGENT_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + // Case when modelId is passed through previousSteps and not present already in parameters - if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) { - WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) { + WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get()); if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString()); - return parametersMap; } } + + // Case when agentId is passed through previousSteps and not present already in parameters + if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) { + WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) { + parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString()); + } + } + // For other cases where modelId is already present in the parameters or not return the parametersMap return parametersMap; }