Skip to content

Commit

Permalink
[Feature/agent_framework] Registers root agent with an agentId in Too…
Browse files Browse the repository at this point in the history
…lSteps (#242)

* Add agentId to parameters map for root agent

Signed-off-by: Owais Kazi <[email protected]>

* Modified ToolStep with new Util method

Signed-off-by: Owais Kazi <[email protected]>

* Integrated RegisterAgentStep with new Util method

Signed-off-by: Owais Kazi <[email protected]>

* Spotless fixes

Signed-off-by: Owais Kazi <[email protected]>

* Removed TODO

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Dec 4, 2023
1 parent 69b035f commit 53daf61
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
Expand All @@ -25,13 +26,12 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.AGENT_ID;
import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD;
Expand Down Expand Up @@ -101,84 +101,55 @@ public void onFailure(Exception e) {
}
};

String name = null;
String type = null;
String description = null;
String llmModelId = null;
Map<String, String> llmParameters = Collections.emptyMap();
List<MLToolSpec> tools = new ArrayList<>();
Map<String, String> parameters = Collections.emptyMap();
MLMemorySpec memory = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
String appType = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case TYPE:
type = (String) entry.getValue();
break;
case LLM_MODEL_ID:
llmModelId = (String) entry.getValue();
break;
case LLM_PARAMETERS:
llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS);
break;
case TOOLS_FIELD:
tools = addTools(entry.getValue());
break;
case PARAMETERS_FIELD:
parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD);
break;
case MEMORY_FIELD:
memory = getMLMemorySpec(entry.getValue());
break;
case CREATED_TIME:
createdTime = Instant.ofEpochMilli((Long) entry.getValue());
break;
case LAST_UPDATED_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue());
break;
case APP_TYPE_FIELD:
appType = (String) entry.getValue();
break;
default:
break;
}
}
}
Set<String> requiredKeys = Set.of(NAME_FIELD, TYPE);
Set<String> optionalKeys = Set.of(
DESCRIPTION_FIELD,
LLM_MODEL_ID,
LLM_PARAMETERS,
TOOLS_FIELD,
PARAMETERS_FIELD,
MEMORY_FIELD,
CREATED_TIME,
LAST_UPDATED_TIME_FIELD,
APP_TYPE_FIELD
);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
}
String type = (String) inputs.get(TYPE);
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String llmModelId = (String) inputs.get(LLM_MODEL_ID);
Map<String, String> llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS);
List<MLToolSpec> tools = getTools(previousNodeInputs, outputs);
Map<String, String> parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD);
MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD));
Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME));
Instant lastUpdateTime = Instant.ofEpochMilli((Long) inputs.get(LAST_UPDATED_TIME_FIELD));
String appType = (String) inputs.get(APP_TYPE_FIELD);

// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
}

// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
);
return registerAgentModelFuture;
}
// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
);
return registerAgentModelFuture;
}

LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);
LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);

if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) {
MLAgentBuilder builder = MLAgent.builder().name(name);

if (description != null) {
Expand All @@ -198,12 +169,9 @@ public void onFailure(Exception e) {

mlClient.registerAgent(mlAgent, actionListener);

} else {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
} catch (FlowFrameworkException e) {
registerAgentModelFuture.completeExceptionally(e);
}

return registerAgentModelFuture;
}

Expand All @@ -212,9 +180,24 @@ public String getName() {
return NAME;
}

private List<MLToolSpec> addTools(Object tools) {
MLToolSpec mlToolSpec = (MLToolSpec) tools;
mlToolSpecList.add(mlToolSpec);
private List<MLToolSpec> getTools(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
List<MLToolSpec> mlToolSpecList = new ArrayList<>();
List<String> previousNodes = previousNodeInputs.entrySet()
.stream()
.filter(e -> TOOLS_FIELD.equals(e.getValue()))
.map(Map.Entry::getKey)
.collect(Collectors.toList());

if (previousNodes != null) {
previousNodes.forEach((previousNode) -> {
WorkflowData previousNodeOutput = outputs.get(previousNode);
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) {
MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD);
logger.info("Tool added {}", mlToolSpec.getType());
mlToolSpecList.add(mlToolSpec);
}
});
}
return mlToolSpecList;
}

Expand All @@ -240,7 +223,7 @@ private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String,

private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters) {
if (llmModelId == null) {
throw new IllegalArgumentException("model id for llm is null");
throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST);
}
LLMSpec.LLMSpecBuilder builder = LLMSpec.builder();
builder.modelId(llmModelId);
Expand Down
94 changes: 43 additions & 51 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.common.agent.MLToolSpec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.AGENT_ID;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
Expand All @@ -47,49 +45,25 @@ public CompletableFuture<WorkflowData> execute(
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
String type = null;
String name = null;
String description = null;
Map<String, String> parameters = Collections.emptyMap();
Boolean includeOutputInAgentResponse = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case TYPE:
type = (String) entry.getValue();
break;
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs);
break;
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = (Boolean) entry.getValue();
break;
default:
break;
}

}
Set<String> requiredKeys = Set.of(TYPE);
Set<String> optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE);

}
try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

String type = (String) inputs.get(TYPE);
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE);
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);

if (type == null) {
toolFuture.completeExceptionally(new FlowFrameworkException("Tool type is not provided", RestStatus.BAD_REQUEST));
} else {
MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();

builder.type(type);
Expand All @@ -115,9 +89,12 @@ public CompletableFuture<WorkflowData> execute(
currentNodeInputs.getNodeId()
)
);
}

logger.info("Tool registered successfully {}", type);
logger.info("Tool registered successfully {}", type);

} catch (FlowFrameworkException e) {
toolFuture.completeExceptionally(e);
}
return toolFuture;
}

Expand All @@ -132,19 +109,34 @@ private Map<String, String> getToolsParametersMap(
Map<String, WorkflowData> outputs
) {
Map<String, String> parametersMap = (Map<String, String>) parameters;
Optional<String> previousNode = previousNodeInputs.entrySet()
Optional<String> previousNodeModel = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

Optional<String> previousNodeAgent = previousNodeInputs.entrySet()
.stream()
.filter(e -> AGENT_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

// Case when modelId is passed through previousSteps and not present already in parameters
if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
return parametersMap;
}
}

// Case when agentId is passed through previousSteps and not present already in parameters
if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) {
parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString());
}
}

// For other cases where modelId is already present in the parameters or not return the parametersMap
return parametersMap;
}
Expand Down

0 comments on commit 53daf61

Please sign in to comment.