Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/agent_framework] Registers a single agent with multiple tools #198

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ dependencies {
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'org.bouncycastle:bcprov-jdk18on:1.77'
implementation "com.google.code.gson:gson:2.10.1"
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved

configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,20 @@ private CommonValue() {}
public static final String RESOURCE_ID_FIELD = "resource_id";
/** The field name for the ResourceCreated's resource name */
public static final String WORKFLOW_STEP_NAME = "workflow_step_name";

/** LLM Name for registering an agent */
public static final String LLM_FIELD = "llm";
/** The tools' field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The memory field for an agent */
public static final String MEMORY_FIELD = "memory";
/** The app type field for an agent */
public static final String APP_TYPE_FIELD = "app_type";
/** The agent id of an agent */
public static final String AGENT_ID = "agent_id";
/** To include field for an agent response */
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";
/** The created time field for an agent */
public static final String CREATED_TIME = "created_time";
/** The last updated time field for an agent */
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
amitgalitz marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
for (Entry<String, Workflow> e : workflows.entrySet()) {
xContentBuilder.field(e.getKey(), e.getValue(), params);
}
xContentBuilder.endObject();

if (uiMetadata != null && !uiMetadata.isEmpty()) {
xContentBuilder.field(UI_METADATA_FIELD, uiMetadata);
}

xContentBuilder.endObject();
if (user != null) {
xContentBuilder.field(USER_FIELD, user);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -24,7 +25,10 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseLLM;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

/**
Expand All @@ -34,7 +38,6 @@
* and its inputs are used to populate the {@link WorkflowData} input.
*/
public class WorkflowNode implements ToXContentObject {

/** The template field name for node id */
public static final String ID_FIELD = "id";
/** The template field name for node type */
Expand Down Expand Up @@ -82,7 +85,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
xContentBuilder.startObject(USER_INPUTS_FIELD);
for (Entry<String, Object> e : userInputs.entrySet()) {
xContentBuilder.field(e.getKey());
if (e.getValue() instanceof String) {
if (e.getValue() instanceof String || e.getValue() instanceof Number) {
xContentBuilder.value(e.getValue());
} else if (e.getValue() instanceof Map<?, ?>) {
buildStringToStringMap(xContentBuilder, (Map<?, ?>) e.getValue());
Expand All @@ -98,6 +101,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
}
xContentBuilder.endArray();
} else if (e.getValue() instanceof LLMSpec) {
if (LLM_FIELD.equals(e.getKey())) {
xContentBuilder.startObject();
buildLLMMap(xContentBuilder, (LLMSpec) e.getValue());
xContentBuilder.endObject();
}
}
}
xContentBuilder.endObject();
Expand Down Expand Up @@ -141,7 +150,11 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
userInputs.put(inputFieldName, parseStringToStringMap(parser));
if (LLM_FIELD.equals(inputFieldName)) {
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
userInputs.put(inputFieldName, parseLLM(parser));
} else {
userInputs.put(inputFieldName, parseStringToStringMap(parser));
}
break;
case START_ARRAY:
if (PROCESSORS_FIELD.equals(inputFieldName)) {
Expand All @@ -158,6 +171,22 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, mapList.toArray(new Map[0]));
}
break;
case VALUE_NUMBER:
switch (parser.numberType()) {
case INT:
userInputs.put(inputFieldName, parser.intValue());
break;
case LONG:
userInputs.put(inputFieldName, parser.longValue());
break;
case FLOAT:
userInputs.put(inputFieldName, parser.floatValue());
break;
case DOUBLE:
userInputs.put(inputFieldName, parser.doubleValue());
break;
}
break;
default:
throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object.");
}
Expand Down
99 changes: 99 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.util;

import com.google.gson.Gson;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
Expand All @@ -21,21 +22,34 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

/**
* Utility methods for Template parsing
*/
public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

public static final Gson gson;

static {
gson = new Gson();
}

private ParseUtils() {}

/**
Expand Down Expand Up @@ -70,6 +84,21 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a LLMSpec.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param llm LLMSpec
* @throws IOException on a build failure
*/
public static void buildLLMMap(XContentBuilder xContentBuilder, LLMSpec llm) throws IOException {
String modelId = llm.getModelId();
Map<String, String> parameters = llm.getParameters();
xContentBuilder.field(MODEL_ID, modelId);
xContentBuilder.field(PARAMETERS_FIELD);
buildStringToStringMap(xContentBuilder, parameters);
}

/**
* Parses an XContent object representing a map of String keys to String values.
*
Expand All @@ -88,6 +117,37 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
return map;
}

// TODO Figure out a way to use the parse method of LLMSpec of ml-commons
/**
* Parses an XContent object representing the object of LLMSpec
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return instance of {@link org.opensearch.ml.common.agent.LLMSpec}
* @throws IOException parsing error
*/
public static LLMSpec parseLLM(XContentParser parser) throws IOException {
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
String modelId = null;
Map<String, String> parameters = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_ID:
modelId = parser.text();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
}
}
return LLMSpec.builder().modelId(modelId).parameters(parameters).build();
}

/**
* Parse content parser to {@link java.time.Instant}.
*
Expand Down Expand Up @@ -116,6 +176,31 @@ public static User getUserContext(Client client) {
return User.parse(userStr);
}

/**
* Generates a parameter map required when the parameter is nested within an object
* @param parameterObjs parameters
* @return a parameters map of type String,String
*/
public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

/**
* Creates a XContentParser from a given Registry
*
Expand All @@ -129,4 +214,18 @@ public static XContentParser createXContentParserFromRegistry(NamedXContentRegis
return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON);
}

/**
* Generates a string to string Map
* @param map content map
* @param fieldName fieldName
* @return instance of the map
*/
@SuppressWarnings("unchecked")
public static Map<String, String> getStringToStringMap(Object map, String fieldName) {
if (map instanceof Map) {
return (Map<String, String>) map;
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
}
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

/**
* Step to create a connector for a remote model
Expand Down Expand Up @@ -210,14 +211,6 @@ public String getName() {
return NAME;
}

@SuppressWarnings("unchecked")
private static Map<String, String> getStringToStringMap(Object map, String fieldName) {
if (map instanceof Map) {
return (Map<String, String>) map;
}
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");
}

private static Map<String, String> getParameterMap(Object parameterMap) throws PrivilegedActionException {
Map<String, String> parameters = new HashMap<>();
for (Entry<String, String> entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
public class CreateIndexStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateIndexStep.class);
private ClusterService clusterService;
private Client client;
private final ClusterService clusterService;
private final Client client;

/** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */
static final String NAME = "create_index";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private MachineLearningNodeClient mlClient;
private final MachineLearningNodeClient mlClient;
static final String NAME = "deploy_model";

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(ModelGroupStep.class);

private MachineLearningNodeClient mlClient;
private final MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

Expand Down
Loading
Loading