Skip to content

Commit

Permalink
[Feature/agent_framework] Registers a single agent with multiple tools (
Browse files Browse the repository at this point in the history
#198)

* Initial register agent workflow step

Signed-off-by: Owais Kazi <[email protected]>

* Added tools step

Signed-off-by: Owais Kazi <[email protected]>

* Fixed ClassCastException

Signed-off-by: Owais Kazi <[email protected]>

* Handled exception for Instant

Signed-off-by: Owais Kazi <[email protected]>

* Added type Instant for WorklowNode Parser

Signed-off-by: Owais Kazi <[email protected]>

* Removed created and last updated time

Signed-off-by: Owais Kazi <[email protected]>

* Addressed parsing error

Signed-off-by: Owais Kazi <[email protected]>

* Handled parsing of Long values for Instant

Signed-off-by: Owais Kazi <[email protected]>

* Handled nested object for llm key

Signed-off-by: Owais Kazi <[email protected]>

* Handled parsing error

Signed-off-by: Owais Kazi <[email protected]>

* Another attempt to fix parsing error for llm

Signed-off-by: Owais Kazi <[email protected]>

* Another attemp to fix XContent

Signed-off-by: Owais Kazi <[email protected]>

* Fixed Parsing error

Signed-off-by: Owais Kazi <[email protected]>

* Added tests for toolstep and javadocs

Signed-off-by: Owais Kazi <[email protected]>

* Undo CI changes

Signed-off-by: Owais Kazi <[email protected]>

* Addressing PR comments

Signed-off-by: Owais Kazi <[email protected]>

* Addressing PR comments

Signed-off-by: Owais Kazi <[email protected]>

* Handled interface changes

Signed-off-by: Owais Kazi <[email protected]>

* Addressed conflicts

Signed-off-by: Owais Kazi <[email protected]>

* Added TODO

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 authored and dbwiddis committed Dec 18, 2023
1 parent 860c445 commit 069c907
Show file tree
Hide file tree
Showing 20 changed files with 773 additions and 22 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ dependencies {
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'org.bouncycastle:bcprov-jdk18on:1.77'
implementation "com.google.code.gson:gson:2.10.1"

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,20 @@ private CommonValue() {}
public static final String RESOURCE_ID_FIELD = "resource_id";
/** The field name for the ResourceCreated's resource name */
public static final String WORKFLOW_STEP_NAME = "workflow_step_name";

/** LLM Name for registering an agent */
public static final String LLM_FIELD = "llm";
/** The tools' field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The memory field for an agent */
public static final String MEMORY_FIELD = "memory";
/** The app type field for an agent */
public static final String APP_TYPE_FIELD = "app_type";
/** The agent id of an agent */
public static final String AGENT_ID = "agent_id";
/** To include field for an agent response */
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";
/** The created time field for an agent */
public static final String CREATED_TIME = "created_time";
/** The last updated time field for an agent */
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
for (Entry<String, Workflow> e : workflows.entrySet()) {
xContentBuilder.field(e.getKey(), e.getValue(), params);
}
xContentBuilder.endObject();

if (uiMetadata != null && !uiMetadata.isEmpty()) {
xContentBuilder.field(UI_METADATA_FIELD, uiMetadata);
}

xContentBuilder.endObject();
if (user != null) {
xContentBuilder.field(USER_FIELD, user);
}
Expand Down
35 changes: 32 additions & 3 deletions src/main/java/org/opensearch/flowframework/model/WorkflowNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -24,7 +25,10 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseLLM;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

/**
Expand All @@ -34,7 +38,6 @@
* and its inputs are used to populate the {@link WorkflowData} input.
*/
public class WorkflowNode implements ToXContentObject {

/** The template field name for node id */
public static final String ID_FIELD = "id";
/** The template field name for node type */
Expand Down Expand Up @@ -82,7 +85,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
xContentBuilder.startObject(USER_INPUTS_FIELD);
for (Entry<String, Object> e : userInputs.entrySet()) {
xContentBuilder.field(e.getKey());
if (e.getValue() instanceof String) {
if (e.getValue() instanceof String || e.getValue() instanceof Number) {
xContentBuilder.value(e.getValue());
} else if (e.getValue() instanceof Map<?, ?>) {
buildStringToStringMap(xContentBuilder, (Map<?, ?>) e.getValue());
Expand All @@ -98,6 +101,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
}
xContentBuilder.endArray();
} else if (e.getValue() instanceof LLMSpec) {
if (LLM_FIELD.equals(e.getKey())) {
xContentBuilder.startObject();
buildLLMMap(xContentBuilder, (LLMSpec) e.getValue());
xContentBuilder.endObject();
}
}
}
xContentBuilder.endObject();
Expand Down Expand Up @@ -141,7 +150,11 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
userInputs.put(inputFieldName, parseStringToStringMap(parser));
if (LLM_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, parseLLM(parser));
} else {
userInputs.put(inputFieldName, parseStringToStringMap(parser));
}
break;
case START_ARRAY:
if (PROCESSORS_FIELD.equals(inputFieldName)) {
Expand All @@ -158,6 +171,22 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, mapList.toArray(new Map[0]));
}
break;
case VALUE_NUMBER:
switch (parser.numberType()) {
case INT:
userInputs.put(inputFieldName, parser.intValue());
break;
case LONG:
userInputs.put(inputFieldName, parser.longValue());
break;
case FLOAT:
userInputs.put(inputFieldName, parser.floatValue());
break;
case DOUBLE:
userInputs.put(inputFieldName, parser.doubleValue());
break;
}
break;
default:
throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object.");
}
Expand Down
99 changes: 99 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.util;

import com.google.gson.Gson;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
Expand All @@ -21,21 +22,34 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

/**
* Utility methods for Template parsing
*/
public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

public static final Gson gson;

static {
gson = new Gson();
}

private ParseUtils() {}

/**
Expand Down Expand Up @@ -70,6 +84,21 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a LLMSpec.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param llm LLMSpec
* @throws IOException on a build failure
*/
public static void buildLLMMap(XContentBuilder xContentBuilder, LLMSpec llm) throws IOException {
String modelId = llm.getModelId();
Map<String, String> parameters = llm.getParameters();
xContentBuilder.field(MODEL_ID, modelId);
xContentBuilder.field(PARAMETERS_FIELD);
buildStringToStringMap(xContentBuilder, parameters);
}

/**
* Parses an XContent object representing a map of String keys to String values.
*
Expand All @@ -88,6 +117,37 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
return map;
}

// TODO Figure out a way to use the parse method of LLMSpec of ml-commons
/**
* Parses an XContent object representing the object of LLMSpec
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return instance of {@link org.opensearch.ml.common.agent.LLMSpec}
* @throws IOException parsing error
*/
public static LLMSpec parseLLM(XContentParser parser) throws IOException {
String modelId = null;
Map<String, String> parameters = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_ID:
modelId = parser.text();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
}
}
return LLMSpec.builder().modelId(modelId).parameters(parameters).build();
}

/**
* Parse content parser to {@link java.time.Instant}.
*
Expand Down Expand Up @@ -116,6 +176,31 @@ public static User getUserContext(Client client) {
return User.parse(userStr);
}

/**
* Generates a parameter map required when the parameter is nested within an object
* @param parameterObjs parameters
* @return a parameters map of type String,String
*/
public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

/**
* Creates a XContentParser from a given Registry
*
Expand All @@ -129,4 +214,18 @@ public static XContentParser createXContentParserFromRegistry(NamedXContentRegis
return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON);
}

/**
* Generates a string to string Map
* @param map content map
* @param fieldName fieldName
* @return instance of the map
*/
@SuppressWarnings("unchecked")
public static Map<String, String> getStringToStringMap(Object map, String fieldName) {
if (map instanceof Map) {
return (Map<String, String>) map;
}
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

/**
* Step to create a connector for a remote model
Expand Down Expand Up @@ -210,14 +211,6 @@ public String getName() {
return NAME;
}

@SuppressWarnings("unchecked")
private static Map<String, String> getStringToStringMap(Object map, String fieldName) {
if (map instanceof Map) {
return (Map<String, String>) map;
}
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");
}

private static Map<String, String> getParameterMap(Object parameterMap) throws PrivilegedActionException {
Map<String, String> parameters = new HashMap<>();
for (Entry<String, String> entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
public class CreateIndexStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateIndexStep.class);
private ClusterService clusterService;
private Client client;
private final ClusterService clusterService;
private final Client client;

/** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */
static final String NAME = "create_index";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private MachineLearningNodeClient mlClient;
private final MachineLearningNodeClient mlClient;
static final String NAME = "deploy_model";

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(ModelGroupStep.class);

private MachineLearningNodeClient mlClient;
private final MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

Expand Down
Loading

0 comments on commit 069c907

Please sign in to comment.