Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Conversation API in MLClient #2211

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ plugins {
dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-memory")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious to know if we should mark this as shadow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory package doesn't have a publishing task that's why @jngz-es suggested to keep it like this.

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -428,4 +429,22 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @return the result future
*/
default ActionFuture<CreateConversationResponse> createConversation(String name) {
Copy link
Collaborator

@ylwu-amzn ylwu-amzn Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should add this API. We should create a new conversation automatically when conversation id not set.

Copy link
Member

@amitgalitz amitgalitz Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay @ylwu-amzn so if we want users to be able to do one click creation for conversation use case with some defaults for demo/poc purpose, we can just skip this step and not give rag processor the conversation ID?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do want the ability to give conversations human-readable names.

cc: @HenryL27

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason I am currently adding this is so we can have default use cases ready to go for conversational search opensearch-project/flow-framework#588

We want to be able to do something like setting up the whole e2e with just:

http://localhost:9200/_plugins/_flow_framework/workflow?use_case=conversationl_search&provision=true
{
    "create_connector.credential.key" : "123"
}

all fields can be overriden but at least for easy poc/dev I thought it would be useful,

If we have some nice human-readable name in the processor for memory will it even be seen. I was thinking this part creates the memory id to give here, not conversation ID:

GET /my_rag_test_data/_search
{
 "query": {
   "match": {
     "text": "What's the population of NYC metro area in 2023"
   }
 },
 "ext": {
   "generative_qa_parameters": {
     "llm_model": "gpt-3.5-turbo",
     "llm_question": "What's the population of NYC metro area in 2023",
     "memory_id": "znCqcI0BfUsSoeNTntd7",
     "context_size": 5,
     "message_size": 5,
     "timeout": 15
   }
 }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking this part creates the memory id to give here, not conversation ID

memory=conversation. message=interaction. Sorry about the confusion here... we probably need to rename the internal classes to match the API names.

Copy link
Collaborator

@Zhangxunmt Zhangxunmt Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to follow the new names. Use Memory to replace Conversation, and Message to replace Interaction. Can you update based on the API name in the doc https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory?

PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
createConversation(name, actionFuture);
return actionFuture;
}

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @param listener action listener
*/
void createConversation(String name, ActionListener<CreateConversationResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -309,6 +312,11 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down Expand Up @@ -379,6 +387,16 @@ private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseAc
return actionListener;
}

private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
ActionListener<CreateConversationResponse> listener
) {
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
return conversationResponse;
});
return actionListener;
}

private ActionListener<MLRegisterModelGroupResponse> getMlRegisterModelGroupResponseActionListener(
ActionListener<MLRegisterModelGroupResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -98,6 +99,9 @@ public class MachineLearningClientTest {
@Mock
MLRegisterAgentResponse registerAgentResponse;

@Mock
CreateConversationResponse createConversationResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -230,6 +234,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
listener.onResponse(createConversationResponse);
}
};
}

Expand Down Expand Up @@ -502,4 +511,9 @@ public void getTool() {
public void listTools() {
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
}

@Test
public void createConversation() {
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -205,6 +208,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<ToolMetadata> getToolActionListener;

@Mock
ActionListener<CreateConversationResponse> createConversationResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -950,6 +956,26 @@ public void listTools() {
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
}

@Test
public void createConversation() {
String name = "Conversation for a RAG pipeline";
String conversationId = "conversationId";

doAnswer(invocation -> {
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
CreateConversationResponse output = new CreateConversationResponse(conversationId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());

ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);

verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(conversationId, argumentCaptor.getValue().getId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import lombok.AllArgsConstructor;

Expand Down Expand Up @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
return builder;
}

public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test.

if (actionResponse instanceof MLCreateConnectorResponse) {
return (CreateConversationResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new CreateConversationResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
*/
package org.opensearch.ml.memory.action.conversation;

import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -32,8 +38,14 @@

public class CreateConversationResponseTests extends OpenSearchTestCase {

CreateConversationResponse response;

@Before
public void setup() {
response = new CreateConversationResponse("test-id");
}

public void testCreateConversationResponseStreaming() throws IOException {
CreateConversationResponse response = new CreateConversationResponse("test-id");
assert (response.getId().equals("test-id"));
BytesStreamOutput outbytes = new BytesStreamOutput();
StreamOutput osso = new OutputStreamStreamOutput(outbytes);
Expand All @@ -44,11 +56,40 @@ public void testCreateConversationResponseStreaming() throws IOException {
}

public void testToXContent() throws IOException {
CreateConversationResponse response = new CreateConversationResponse("createme");
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
String expected = "{\"memory_id\":\"createme\"}";
String expected = "{\"memory_id\":\"test-id\"}";
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
assert (result.equals(expected));
}

@Test
public void fromActionResponseWithCreateConversationResponseSuccess() {
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(response);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test
public void fromActionResponseSuccess() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(actionResponse);
assertNotSame(response, responseFromActionResponse);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionResponseIOException() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
CreateConversationResponse.fromActionResponse(actionResponse);
}
}
Loading