Skip to content

Commit

Permalink
Register agent API support for MLClient
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri committed Nov 16, 2023
1 parent d8684ce commit 50e3592
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -337,4 +339,22 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
* @param listener action listener
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

/**
* Registers new agent and returns ActionFuture.
* @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest) {
PlainActionFuture<MLRegisterAgentResponse> actionFuture = PlainActionFuture.newFuture();
registerAgent(mlRegisterAgentRequest, actionFuture);
return actionFuture;
}

/**
* Registers new agent and returns agent ID in response
* @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener<MLRegisterAgentResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand Down Expand Up @@ -253,6 +256,16 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

@Override
public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener<MLRegisterAgentResponse> listener) {
client
.execute(
MLRegisterAgentAction.INSTANCE,
mlRegisterAgentRequest,
ActionListener.wrap(listener::onResponse, listener::onFailure)
);
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -82,6 +84,9 @@ public class MachineLearningClientTest {
@Mock
MLRegisterModelGroupResponse registerModelGroupResponse;

@Mock
MLRegisterAgentResponse registerAgentResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -178,6 +183,11 @@ public void listTools(ActionListener<List<ToolMetadata>> listener) {
public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
listener.onResponse(null);
}

@Override
public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener<MLRegisterAgentResponse> listener) {
listener.onResponse(registerAgentResponse);
}
};
}

Expand Down Expand Up @@ -365,4 +375,10 @@ public void createConnector() {

assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet());
}

@Test
public void testRegisterAgent() {
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(registerAgentRequest).actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand Down Expand Up @@ -152,6 +155,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLRegisterModelGroupResponse> registerModelGroupResponseActionListener;

@Mock
ActionListener<MLRegisterAgentResponse> registerAgentResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -676,6 +682,27 @@ public void createConnector() {

}

@Test
public void testRegisterAgent() {
String agentId = "agentId";

doAnswer(invocation -> {
ActionListener<MLRegisterAgentResponse> actionListener = invocation.getArgument(2);
MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class);
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build();

machineLearningNodeClient.registerAgent(registerAgentRequest, registerAgentResponseActionListener);

verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any());
verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(agentId, (argumentCaptor.getValue()).getAgentId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand All @@ -701,4 +728,5 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti
SearchResponse.Clusters.EMPTY
);
}

}

0 comments on commit 50e3592

Please sign in to comment.