diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 4d81448362..4577cab7b7 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -19,6 +19,8 @@ import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -337,4 +339,22 @@ default ActionFuture getTool(String toolName) { * @param listener action listener */ void getTool(String toolName, ActionListener listener); + + /** + * Registers new agent and returns ActionFuture. + * @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent + * @return the result future + */ + default ActionFuture registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + registerAgent(mlRegisterAgentRequest, actionFuture); + return actionFuture; + } + + /** + * Registers new agent and returns agent ID in response + * @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent + * @return the result future + */ + void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 5d4d868b61..760fa5b6b9 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -32,6 +32,9 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; @@ -253,6 +256,16 @@ public void getTool(String toolName, ActionListener listener) { client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener)); } + @Override + public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener listener) { + client + .execute( + MLRegisterAgentAction.INSTANCE, + mlRegisterAgentRequest, + ActionListener.wrap(listener::onResponse, listener::onFailure) + ); + } + private ActionListener getMlListToolsResponseActionListener(ActionListener> listener) { ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { listener.onResponse(mlModelListResponse.getToolMetadataList()); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 4b137ac685..cd3bb4f04c 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -40,6 +40,8 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -82,6 +84,9 @@ public class MachineLearningClientTest { @Mock MLRegisterModelGroupResponse registerModelGroupResponse; + @Mock + MLRegisterAgentResponse registerAgentResponse; + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -178,6 +183,11 @@ public void listTools(ActionListener> listener) { public void getTool(String toolName, ActionListener listener) { listener.onResponse(null); } + + @Override + public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener listener) { + listener.onResponse(registerAgentResponse); + } }; } @@ -365,4 +375,10 @@ public void createConnector() { assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet()); } + + @Test + public void testRegisterAgent() { + MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build(); + assertEquals(registerAgentResponse, machineLearningClient.registerAgent(registerAgentRequest).actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index ccdf812195..4f77ffa4ae 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -66,6 +66,9 @@ import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; @@ -152,6 +155,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener registerModelGroupResponseActionListener; + @Mock + ActionListener registerAgentResponseActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -676,6 +682,27 @@ public void createConnector() { } + @Test + public void testRegisterAgent() { + String agentId = "agentId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build(); + + machineLearningNodeClient.registerAgent(registerAgentRequest, registerAgentResponseActionListener); + + verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any()); + verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(agentId, (argumentCaptor.getValue()).getAgentId()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); @@ -701,4 +728,5 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti SearchResponse.Clusters.EMPTY ); } + }