From 1e74babc5272a6ed5548d3b66c05b0fedb252ba9 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 18 Mar 2024 12:15:15 -0700 Subject: [PATCH] Added create conversation API in MLClient Signed-off-by: Owais Kazi --- client/build.gradle | 1 + .../ml/client/MachineLearningClient.java | 19 ++++++++++++++ .../ml/client/MachineLearningNodeClient.java | 18 +++++++++++++ .../ml/client/MachineLearningClientTest.java | 14 ++++++++++ .../client/MachineLearningNodeClientTest.java | 26 +++++++++++++++++++ .../CreateConversationResponse.java | 22 ++++++++++++++++ 6 files changed, 100 insertions(+) diff --git a/client/build.gradle b/client/build.gradle index 40e0743910..2b67c4d8c7 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -16,6 +16,7 @@ plugins { dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') + implementation project(path: ":${rootProject.name}-memory") compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index b115eb91c9..eaf07b4a1d 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -33,6 +33,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; /** * A client to provide interfaces for machine learning jobs. This will be used by other plugins. @@ -428,4 +429,22 @@ default ActionFuture getTool(String toolName) { */ void getTool(String toolName, ActionListener listener); + /** + * Create conversational memory for conversation + * @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/ + * @return the result future + */ + default ActionFuture createConversation(String name) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + createConversation(name, actionFuture); + return actionFuture; + } + + /** + * Create conversational memory for conversation + * @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/ + * @param listener action listener + */ + void createConversation(String name, ActionListener listener); + } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index acf171872d..6bc7d77b76 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -85,6 +85,9 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import lombok.AccessLevel; import lombok.RequiredArgsConstructor; @@ -309,6 +312,11 @@ public void getTool(String toolName, ActionListener listener) { client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener)); } + public void createConversation(String name, ActionListener listener) { + CreateConversationRequest createConversationRequest = new CreateConversationRequest(name); + client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener)); + } + private ActionListener getMlListToolsResponseActionListener(ActionListener> listener) { ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { listener.onResponse(mlModelListResponse.getToolMetadataList()); @@ -379,6 +387,16 @@ private ActionListener getMlCreateConnectorResponseAc return actionListener; } + private ActionListener getCreateConversationResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response); + return conversationResponse; + }); + return actionListener; + } + private ActionListener getMlRegisterModelGroupResponseActionListener( ActionListener listener ) { diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index ccc0e050e9..141e431c5c 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -54,6 +54,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; public class MachineLearningClientTest { @@ -98,6 +99,9 @@ public class MachineLearningClientTest { @Mock MLRegisterAgentResponse registerAgentResponse; + @Mock + CreateConversationResponse createConversationResponse; + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -230,6 +234,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener listener) { listener.onResponse(deleteResponse); } + + @Override + public void createConversation(String name, ActionListener listener) { + listener.onResponse(createConversationResponse); + } }; } @@ -502,4 +511,9 @@ public void getTool() { public void listTools() { assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0)); } + + @Test + public void createConversation() { + assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index f81b20747f..219157d750 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -130,6 +130,9 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -205,6 +208,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener getToolActionListener; + @Mock + ActionListener createConversationResponseActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -950,6 +956,26 @@ public void listTools() { assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription()); } + @Test + public void createConversation() { + String name = "Conversation for a RAG pipeline"; + String conversationId = "conversationId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + CreateConversationResponse output = new CreateConversationResponse(conversationId); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); + machineLearningNodeClient.createConversation(name, createConversationResponseActionListener); + + verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any()); + verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(conversationId, argumentCaptor.getValue().getId()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java index 79f6fb6bf0..9ba60558a0 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java @@ -17,15 +17,21 @@ */ package org.opensearch.ml.memory.action.conversation; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import lombok.AllArgsConstructor; @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } + public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateConnectorResponse) { + return (CreateConversationResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new CreateConversationResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e); + } + + } + }