-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature/agent_framework] Registers root agent with an agentId in ToolSteps #242
Changes from all commits
7cf2126
7100663
a8c5cb9
9fd114a
4f1d0f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.flowframework.util.ParseUtils; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.agent.LLMSpec; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
|
@@ -25,13 +26,12 @@ | |
import java.io.IOException; | ||
import java.time.Instant; | ||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.Optional; | ||
import java.util.Set; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Stream; | ||
import java.util.stream.Collectors; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; | ||
|
@@ -101,84 +101,55 @@ public void onFailure(Exception e) { | |
} | ||
}; | ||
|
||
String name = null; | ||
String type = null; | ||
String description = null; | ||
String llmModelId = null; | ||
Map<String, String> llmParameters = Collections.emptyMap(); | ||
List<MLToolSpec> tools = new ArrayList<>(); | ||
Map<String, String> parameters = Collections.emptyMap(); | ||
MLMemorySpec memory = null; | ||
Instant createdTime = null; | ||
Instant lastUpdateTime = null; | ||
String appType = null; | ||
|
||
// TODO: Recreating the list to get this compiling | ||
// Need to refactor the below iteration to pull directly from the maps | ||
List<WorkflowData> data = new ArrayList<>(); | ||
data.add(currentNodeInputs); | ||
data.addAll(outputs.values()); | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, Object> content = workflowData.getContent(); | ||
|
||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case NAME_FIELD: | ||
name = (String) entry.getValue(); | ||
break; | ||
case DESCRIPTION_FIELD: | ||
description = (String) entry.getValue(); | ||
break; | ||
case TYPE: | ||
type = (String) entry.getValue(); | ||
break; | ||
case LLM_MODEL_ID: | ||
llmModelId = (String) entry.getValue(); | ||
break; | ||
case LLM_PARAMETERS: | ||
llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS); | ||
break; | ||
case TOOLS_FIELD: | ||
tools = addTools(entry.getValue()); | ||
break; | ||
case PARAMETERS_FIELD: | ||
parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD); | ||
break; | ||
case MEMORY_FIELD: | ||
memory = getMLMemorySpec(entry.getValue()); | ||
break; | ||
case CREATED_TIME: | ||
createdTime = Instant.ofEpochMilli((Long) entry.getValue()); | ||
break; | ||
case LAST_UPDATED_TIME_FIELD: | ||
lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue()); | ||
break; | ||
case APP_TYPE_FIELD: | ||
appType = (String) entry.getValue(); | ||
break; | ||
default: | ||
break; | ||
} | ||
} | ||
} | ||
Set<String> requiredKeys = Set.of(NAME_FIELD, TYPE); | ||
Set<String> optionalKeys = Set.of( | ||
DESCRIPTION_FIELD, | ||
LLM_MODEL_ID, | ||
LLM_PARAMETERS, | ||
TOOLS_FIELD, | ||
PARAMETERS_FIELD, | ||
MEMORY_FIELD, | ||
CREATED_TIME, | ||
LAST_UPDATED_TIME_FIELD, | ||
APP_TYPE_FIELD | ||
); | ||
|
||
try { | ||
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps( | ||
requiredKeys, | ||
optionalKeys, | ||
currentNodeInputs, | ||
outputs, | ||
previousNodeInputs | ||
); | ||
|
||
// Case when modelId is present in previous node inputs | ||
if (llmModelId == null) { | ||
llmModelId = getLlmModelId(previousNodeInputs, outputs); | ||
} | ||
String type = (String) inputs.get(TYPE); | ||
String name = (String) inputs.get(NAME_FIELD); | ||
String description = (String) inputs.get(DESCRIPTION_FIELD); | ||
String llmModelId = (String) inputs.get(LLM_MODEL_ID); | ||
Map<String, String> llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS); | ||
List<MLToolSpec> tools = getTools(previousNodeInputs, outputs); | ||
Map<String, String> parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD); | ||
MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD)); | ||
Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME)); | ||
Instant lastUpdateTime = Instant.ofEpochMilli((Long) inputs.get(LAST_UPDATED_TIME_FIELD)); | ||
String appType = (String) inputs.get(APP_TYPE_FIELD); | ||
|
||
// Case when modelId is present in previous node inputs | ||
if (llmModelId == null) { | ||
llmModelId = getLlmModelId(previousNodeInputs, outputs); | ||
} | ||
|
||
// Case when modelId is not present at all | ||
if (llmModelId == null) { | ||
registerAgentModelFuture.completeExceptionally( | ||
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) | ||
); | ||
return registerAgentModelFuture; | ||
} | ||
// Case when modelId is not present at all | ||
if (llmModelId == null) { | ||
registerAgentModelFuture.completeExceptionally( | ||
new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) | ||
); | ||
return registerAgentModelFuture; | ||
} | ||
|
||
LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); | ||
LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); | ||
|
||
if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) { | ||
MLAgentBuilder builder = MLAgent.builder().name(name); | ||
|
||
if (description != null) { | ||
|
@@ -198,12 +169,9 @@ public void onFailure(Exception e) { | |
|
||
mlClient.registerAgent(mlAgent, actionListener); | ||
|
||
} else { | ||
registerAgentModelFuture.completeExceptionally( | ||
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) | ||
); | ||
} catch (FlowFrameworkException e) { | ||
registerAgentModelFuture.completeExceptionally(e); | ||
} | ||
|
||
return registerAgentModelFuture; | ||
} | ||
|
||
|
@@ -212,9 +180,24 @@ public String getName() { | |
return NAME; | ||
} | ||
|
||
private List<MLToolSpec> addTools(Object tools) { | ||
MLToolSpec mlToolSpec = (MLToolSpec) tools; | ||
mlToolSpecList.add(mlToolSpec); | ||
private List<MLToolSpec> getTools(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) { | ||
List<MLToolSpec> mlToolSpecList = new ArrayList<>(); | ||
List<String> previousNodes = previousNodeInputs.entrySet() | ||
.stream() | ||
.filter(e -> TOOLS_FIELD.equals(e.getValue())) | ||
.map(Map.Entry::getKey) | ||
.collect(Collectors.toList()); | ||
|
||
if (previousNodes != null) { | ||
previousNodes.forEach((previousNode) -> { | ||
WorkflowData previousNodeOutput = outputs.get(previousNode); | ||
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) { | ||
MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD); | ||
logger.info("Tool added {}", mlToolSpec.getType()); | ||
mlToolSpecList.add(mlToolSpec); | ||
} | ||
}); | ||
} | ||
return mlToolSpecList; | ||
} | ||
|
||
|
@@ -240,7 +223,7 @@ private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String, | |
|
||
private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters) { | ||
if (llmModelId == null) { | ||
throw new IllegalArgumentException("model id for llm is null"); | ||
throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we pass workflow ID and step ID to this method to include in the exception? (Can handle in a future PR) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to do this for all the exception message for the workflow steps. I can raise a follow up PR for all of them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); | ||
builder.modelId(llmModelId); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add workflowID and step (node) ID to this error message. (Can handle in a future PR)