Skip to content

Commit

Permalink
integration test: create agent with connector tool
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Aug 26, 2024
1 parent 44cf0be commit 11e8197
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,38 @@ protected SearchResponse searchWorkflowState(RestClient client, String query) th
}
}

/**
* Helper method to invoke the Search Agent Rest Action
* @param client the rest client
* @param query the search query
* @return
* @throws Exception
*/
protected SearchResponse searchAgent(RestClient client, String query) throws Exception {
Response restSearchResponse = TestHelpers.makeRequest(
client,
"GET",
"/_plugins/_ml/agents/_search",
Collections.emptyMap(),
query,
null
);
assertEquals(RestStatus.OK, TestHelpers.restStatus(restSearchResponse));

// Parse entity content into SearchResponse
MediaType mediaType = MediaType.fromMediaType(restSearchResponse.getEntity().getContentType());
try (
XContentParser parser = mediaType.xContent()
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
restSearchResponse.getEntity().getContent()
)
) {
return SearchResponse.fromXContent(parser);
}
}

/**
* Helper method to invoke the Get Workflow Status Rest Action and assert the provisioning and state status
* @param client the rest client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.search.SearchHit;
import org.junit.Before;
import org.junit.ComparisonFailure;

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
Expand All @@ -56,7 +58,6 @@ public void waitToStart() throws Exception {
}

public void testSearchWorkflows() throws Exception {

// Create a Workflow that has a credential 12345
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");
Response response = createWorkflow(client(), template);
Expand Down Expand Up @@ -228,7 +229,6 @@ public void testCreateAndProvisionCyclicalTemplate() throws Exception {
}

public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {

// Using a 3 step template to create a connector, register remote model and deploy model
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");

Expand Down Expand Up @@ -331,6 +331,91 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception {
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
}

public void testCreateAndProvisionConnectorToolAgentFrameworkWorkflow() throws Exception {
// Create a Workflow that has a credential 12345
Template template = TestHelpers.createTemplateFromFile("createconnector-createconnectortool-createflowagent.json");

// Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter
Response response = createWorkflowWithProvision(client(), template);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
Map<String, Object> responseMap = entityAsMap(response);
String workflowId = (String) responseMap.get(WORKFLOW_ID);
// wait and ensure state is completed/done
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); },
120,
TimeUnit.SECONDS
);

// Hit Search State API with the workflow id created above
String query = "{\"query\":{\"ids\":{\"values\":[\"" + workflowId + "\"]}}}";
SearchResponse searchResponse = searchWorkflowState(client(), query);
assertEquals(1, searchResponse.getHits().getTotalHits().value);
String searchHitSource = searchResponse.getHits().getAt(0).getSourceAsString();
WorkflowState searchHitWorkflowState = WorkflowState.parse(searchHitSource);

// Assert based on the agent-framework template
List<ResourceCreated> resourcesCreated = searchHitWorkflowState.resourcesCreated();
Set<String> expectedStepNames = new HashSet<>();
expectedStepNames.add("create_connector");
expectedStepNames.add("create_flow_agent");
Set<String> stepNames = resourcesCreated.stream().map(ResourceCreated::workflowStepId).collect(Collectors.toSet());

assertEquals(2, resourcesCreated.size());
assertEquals(stepNames, expectedStepNames);
String connectorId = resourcesCreated.getFirst().resourceId();
String agentId = resourcesCreated.get(1).resourceId();
assertNotNull(connectorId);
assertNotNull(agentId);

query = "{\"query\":{\"ids\":{\"values\":[\"" + agentId + "\"]}}}";
searchResponse = searchAgent(client(), query);
assertEquals(1, searchResponse.getHits().getTotalHits().value);
SearchHit searchHit = searchResponse.getHits().getAt(0);
Map<String, Object> searchHitSourceMap = searchHit.getSourceAsMap();
assertTrue(searchHitSourceMap.containsKey("tools"));

@SuppressWarnings("unchecked")
ArrayList<Map<String, Object>> tools = (ArrayList<Map<String, Object>>) searchHitSourceMap.get("tools");
assertEquals(1, tools.size());
Map<String, Object> tool = tools.getFirst();
assertTrue(tool.containsKey("parameters"));
@SuppressWarnings("unchecked")
Map<String, String> toolParameters = (Map<String, String>) tool.get("parameters");
assertEquals(toolParameters, Map.of("connector_id", connectorId));

// Hit Deprovision API
// By design, this may not completely deprovision the first time if it takes >2s to process removals
Response deprovisionResponse = deprovisionWorkflow(client(), workflowId);
try {
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
30,
TimeUnit.SECONDS
);
} catch (ComparisonFailure e) {
// 202 return if still processing
assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse));
}
if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) {
// Short wait before we try again
Thread.sleep(10000);
deprovisionResponse = deprovisionWorkflow(client(), workflowId);
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
30,
TimeUnit.SECONDS
);
}
assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse));
// Hit Delete API
Response deleteResponse = deleteWorkflow(client(), workflowId);
assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse));

// Verify state doc is deleted
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
}

public void testReprovisionWorkflow() throws Exception {
// Begin with a template to register a local pretrained model
Template template = TestHelpers.createTemplateFromFile("registerremotemodel.json");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"name": "createconnector-createconnectortool-createflowagent",
"description": "test case",
"use_case": "TEST_CASE",
"version": {
"template": "1.0.0",
"compatibility": [
"2.15.0",
"3.0.0"
]
},
"workflows": {
"provision": {
"nodes": [
{
"id": "create_connector",
"type": "create_connector",
"user_inputs": {
"name": "OpenAI Chat Connector",
"description": "The connector to public OpenAI model service for GPT 3.5",
"version": "1",
"protocol": "http",
"parameters": {
"endpoint": "api.openai.com",
"model": "gpt-3.5-turbo"
},
"credential": {
"openAI_key": "12345"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://${parameters.endpoint}/v1/chat/completions"
}
]
}
},
{
"id": "create_tool",
"type": "create_tool",
"previous_node_inputs": {
"create_connector": "connector_id"
},
"user_inputs": {
"parameters": {},
"name": "ConnectorTool",
"type": "ConnectorTool"
}
},
{
"id": "create_flow_agent",
"type": "register_agent",
"previous_node_inputs": {
"create_tool": "tools"
},
"user_inputs": {
"parameters": {},
"type": "flow",
"name": "OpenAI Chat Agent"
}
}
],
"edges": [
{
"source": "create_connector",
"dest": "create_tool"
},
{
"source": "create_tool",
"dest": "create_flow_agent"
}
]
}
}
}
Original file line number Diff line number Diff line change
@@ -1,71 +1,71 @@
{
"name": "createconnector-registerremotemodel-deploymodel",
"description": "test case",
"use_case": "TEST_CASE",
"version": {
"template": "1.0.0",
"compatibility": [
"2.12.0",
"3.0.0"
]
},
"workflows": {
"provision": {
"nodes": [
{
"id": "workflow_step_1",
"type": "create_connector",
"user_inputs": {
"name": "OpenAI Chat Connector",
"description": "The connector to public OpenAI model service for GPT 3.5",
"version": "1",
"protocol": "http",
"parameters": {
"endpoint": "api.openai.com",
"model": "gpt-3.5-turbo"
},
"credential": {
"openAI_key": "12345"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://${parameters.endpoint}/v1/chat/completions"
}
]
}
},
{
"id": "workflow_step_2",
"type": "register_remote_model",
"previous_node_inputs": {
"workflow_step_1": "connector_id"
"name": "createconnector-registerremotemodel-deploymodel",
"description": "test case",
"use_case": "TEST_CASE",
"version": {
"template": "1.0.0",
"compatibility": [
"2.12.0",
"3.0.0"
]
},
"workflows": {
"provision": {
"nodes": [
{
"id": "workflow_step_1",
"type": "create_connector",
"user_inputs": {
"name": "OpenAI Chat Connector",
"description": "The connector to public OpenAI model service for GPT 3.5",
"version": "1",
"protocol": "http",
"parameters": {
"endpoint": "api.openai.com",
"model": "gpt-3.5-turbo"
},
"user_inputs": {
"name": "openAI-gpt-3.5-turbo",
"function_name": "remote",
"description": "test model"
}
},
{
"id": "workflow_step_3",
"type": "deploy_model",
"previous_node_inputs": {
"workflow_step_2": "model_id"
}
"credential": {
"openAI_key": "12345"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://${parameters.endpoint}/v1/chat/completions"
}
]
}
],
"edges": [
{
"source": "workflow_step_1",
"dest": "workflow_step_2"
},
{
"id": "workflow_step_2",
"type": "register_remote_model",
"previous_node_inputs": {
"workflow_step_1": "connector_id"
},
{
"source": "workflow_step_2",
"dest": "workflow_step_3"
"user_inputs": {
"name": "openAI-gpt-3.5-turbo",
"function_name": "remote",
"description": "test model"
}
},
{
"id": "workflow_step_3",
"type": "deploy_model",
"previous_node_inputs": {
"workflow_step_2": "model_id"
}
]
}
}
],
"edges": [
{
"source": "workflow_step_1",
"dest": "workflow_step_2"
},
{
"source": "workflow_step_2",
"dest": "workflow_step_3"
}
]
}
}
}

0 comments on commit 11e8197

Please sign in to comment.