From a84540180c37f48b73f54743c5d3cbc2d650191a Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Thu, 16 Nov 2023 14:40:12 -0800 Subject: [PATCH] Addressed feedback Signed-off-by: Arjun kumar Giri --- .../ml/client/MachineLearningClient.java | 13 ++++++------ .../ml/client/MachineLearningNodeClient.java | 14 ++++++++++++- .../ml/client/MachineLearningClientTest.java | 8 +++---- .../client/MachineLearningNodeClientTest.java | 5 +++-- .../agent/MLRegisterAgentResponse.java | 21 +++++++++++++++++++ 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 4577cab7b7..07b8b20b22 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -17,9 +17,9 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; -import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -342,19 +342,18 @@ default ActionFuture getTool(String toolName) { /** * Registers new agent and returns ActionFuture. - * @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent + * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent * @return the result future */ - default ActionFuture registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest) { + default ActionFuture registerAgent(MLAgent mlAgent) { PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - registerAgent(mlRegisterAgentRequest, actionFuture); + registerAgent(mlAgent, actionFuture); return actionFuture; } /** * Registers new agent and returns agent ID in response - * @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent - * @return the result future + * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent */ - void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener listener); + void registerAgent(MLAgent mlAgent, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 760fa5b6b9..14e0ecd235 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -28,6 +28,7 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; @@ -257,7 +258,8 @@ public void getTool(String toolName, ActionListener listener) { } @Override - public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener listener) { + public void registerAgent(MLAgent mlAgent, ActionListener listener) { + MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build(); client .execute( MLRegisterAgentAction.INSTANCE, @@ -266,6 +268,16 @@ public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionL ); } + private ActionListener getMLRegisterAgentResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, res -> { + MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res); + return mlRegisterAgentResponse; + }); + return actionListener; + } + private ActionListener getMlListToolsResponseActionListener(ActionListener> listener) { ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { listener.onResponse(mlModelListResponse.getToolMetadataList()); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index cd3bb4f04c..984d8dd332 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -31,6 +31,7 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.input.MLInput; @@ -40,7 +41,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLTrainingOutput; -import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -185,7 +185,7 @@ public void getTool(String toolName, ActionListener listener) { } @Override - public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener listener) { + public void registerAgent(MLAgent mlAgent, ActionListener listener) { listener.onResponse(registerAgentResponse); } }; @@ -378,7 +378,7 @@ public void createConnector() { @Test public void testRegisterAgent() { - MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build(); - assertEquals(registerAgentResponse, machineLearningClient.registerAgent(registerAgentRequest).actionGet()); + MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet()); } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 4f77ffa4ae..1b32c47d38 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -56,6 +56,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.MLInput; @@ -694,9 +695,9 @@ public void testRegisterAgent() { }).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); - MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build(); + MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); - machineLearningNodeClient.registerAgent(registerAgentRequest, registerAgentResponseActionListener); + machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener); verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any()); verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture()); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java index e4fc073f9c..90b1401e4d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject { @@ -41,4 +46,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterAgentResponse) { + return (MLRegisterAgentResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterAgentResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e); + } + } }