Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/agent_framework] Registers root agent with an agentId in ToolSteps #242

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
Expand All @@ -25,13 +26,12 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.AGENT_ID;
import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD;
Expand Down Expand Up @@ -101,84 +101,55 @@ public void onFailure(Exception e) {
}
};

String name = null;
String type = null;
String description = null;
String llmModelId = null;
Map<String, String> llmParameters = Collections.emptyMap();
List<MLToolSpec> tools = new ArrayList<>();
Map<String, String> parameters = Collections.emptyMap();
MLMemorySpec memory = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
String appType = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case TYPE:
type = (String) entry.getValue();
break;
case LLM_MODEL_ID:
llmModelId = (String) entry.getValue();
break;
case LLM_PARAMETERS:
llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS);
break;
case TOOLS_FIELD:
tools = addTools(entry.getValue());
break;
case PARAMETERS_FIELD:
parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD);
break;
case MEMORY_FIELD:
memory = getMLMemorySpec(entry.getValue());
break;
case CREATED_TIME:
createdTime = Instant.ofEpochMilli((Long) entry.getValue());
break;
case LAST_UPDATED_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue());
break;
case APP_TYPE_FIELD:
appType = (String) entry.getValue();
break;
default:
break;
}
}
}
Set<String> requiredKeys = Set.of(NAME_FIELD, TYPE);
Set<String> optionalKeys = Set.of(
DESCRIPTION_FIELD,
LLM_MODEL_ID,
LLM_PARAMETERS,
TOOLS_FIELD,
PARAMETERS_FIELD,
MEMORY_FIELD,
CREATED_TIME,
LAST_UPDATED_TIME_FIELD,
APP_TYPE_FIELD
);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
}
String type = (String) inputs.get(TYPE);
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String llmModelId = (String) inputs.get(LLM_MODEL_ID);
Map<String, String> llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS);
List<MLToolSpec> tools = getTools(previousNodeInputs, outputs);
Map<String, String> parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD);
MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD));
Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME));
Instant lastUpdateTime = Instant.ofEpochMilli((Long) inputs.get(LAST_UPDATED_TIME_FIELD));
String appType = (String) inputs.get(APP_TYPE_FIELD);

// Case when modelId is present in previous node inputs
if (llmModelId == null) {
llmModelId = getLlmModelId(previousNodeInputs, outputs);
}

// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
);
return registerAgentModelFuture;
}
// Case when modelId is not present at all
if (llmModelId == null) {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add workflowID and step (node) ID to this error message. (Can handle in a future PR)

);
return registerAgentModelFuture;
}

LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);
LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters);

if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) {
MLAgentBuilder builder = MLAgent.builder().name(name);

if (description != null) {
Expand All @@ -198,12 +169,9 @@ public void onFailure(Exception e) {

mlClient.registerAgent(mlAgent, actionListener);

} else {
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
} catch (FlowFrameworkException e) {
registerAgentModelFuture.completeExceptionally(e);
}

return registerAgentModelFuture;
}

Expand All @@ -212,9 +180,24 @@ public String getName() {
return NAME;
}

private List<MLToolSpec> addTools(Object tools) {
MLToolSpec mlToolSpec = (MLToolSpec) tools;
mlToolSpecList.add(mlToolSpec);
private List<MLToolSpec> getTools(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
List<MLToolSpec> mlToolSpecList = new ArrayList<>();
List<String> previousNodes = previousNodeInputs.entrySet()
.stream()
.filter(e -> TOOLS_FIELD.equals(e.getValue()))
.map(Map.Entry::getKey)
.collect(Collectors.toList());

if (previousNodes != null) {
previousNodes.forEach((previousNode) -> {
WorkflowData previousNodeOutput = outputs.get(previousNode);
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) {
MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD);
logger.info("Tool added {}", mlToolSpec.getType());
mlToolSpecList.add(mlToolSpec);
}
});
}
return mlToolSpecList;
}

Expand All @@ -240,7 +223,7 @@ private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String,

private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters) {
if (llmModelId == null) {
throw new IllegalArgumentException("model id for llm is null");
throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass workflow ID and step ID to this method to include in the exception? (Can handle in a future PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do this for all the exception message for the workflow steps. I can raise a follow up PR for all of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
LLMSpec.LLMSpecBuilder builder = LLMSpec.builder();
builder.modelId(llmModelId);
Expand Down
94 changes: 43 additions & 51 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.common.agent.MLToolSpec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.AGENT_ID;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
Expand All @@ -47,49 +45,25 @@ public CompletableFuture<WorkflowData> execute(
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
String type = null;
String name = null;
String description = null;
Map<String, String> parameters = Collections.emptyMap();
Boolean includeOutputInAgentResponse = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case TYPE:
type = (String) entry.getValue();
break;
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs);
break;
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = (Boolean) entry.getValue();
break;
default:
break;
}

}
Set<String> requiredKeys = Set.of(TYPE);
Set<String> optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE);

}
try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

String type = (String) inputs.get(TYPE);
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE);
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);

if (type == null) {
toolFuture.completeExceptionally(new FlowFrameworkException("Tool type is not provided", RestStatus.BAD_REQUEST));
} else {
MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();

builder.type(type);
Expand All @@ -115,9 +89,12 @@ public CompletableFuture<WorkflowData> execute(
currentNodeInputs.getNodeId()
)
);
}

logger.info("Tool registered successfully {}", type);
logger.info("Tool registered successfully {}", type);

} catch (FlowFrameworkException e) {
toolFuture.completeExceptionally(e);
}
return toolFuture;
}

Expand All @@ -132,19 +109,34 @@ private Map<String, String> getToolsParametersMap(
Map<String, WorkflowData> outputs
) {
Map<String, String> parametersMap = (Map<String, String>) parameters;
Optional<String> previousNode = previousNodeInputs.entrySet()
Optional<String> previousNodeModel = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

Optional<String> previousNodeAgent = previousNodeInputs.entrySet()
.stream()
.filter(e -> AGENT_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

// Case when modelId is passed through previousSteps and not present already in parameters
if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
return parametersMap;
}
}

// Case when agentId is passed through previousSteps and not present already in parameters
if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) {
parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString());
}
}

// For other cases where modelId is already present in the parameters or not return the parametersMap
return parametersMap;
}
Expand Down
Loading