Skip to content

Commit

Permalink
Addressed feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri committed Nov 16, 2023
1 parent 50e3592 commit a845401
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
Expand Down Expand Up @@ -342,19 +342,18 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {

/**
* Registers new agent and returns ActionFuture.
* @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest) {
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLAgent mlAgent) {
PlainActionFuture<MLRegisterAgentResponse> actionFuture = PlainActionFuture.newFuture();
registerAgent(mlRegisterAgentRequest, actionFuture);
registerAgent(mlAgent, actionFuture);
return actionFuture;
}

/**
* Registers new agent and returns agent ID in response
* @param mlRegisterAgentRequest Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
*/
void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener<MLRegisterAgentResponse> listener);
void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
Expand Down Expand Up @@ -257,7 +258,8 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
}

@Override
public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener<MLRegisterAgentResponse> listener) {
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build();
client
.execute(
MLRegisterAgentAction.INSTANCE,
Expand All @@ -266,6 +268,16 @@ public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionL
);
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
return mlRegisterAgentResponse;
});
return actionListener;
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -40,7 +41,6 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
Expand Down Expand Up @@ -185,7 +185,7 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
}

@Override
public void registerAgent(MLRegisterAgentRequest mlRegisterAgentRequest, ActionListener<MLRegisterAgentResponse> listener) {
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
listener.onResponse(registerAgentResponse);
}
};
Expand Down Expand Up @@ -378,7 +378,7 @@ public void createConnector() {

@Test
public void testRegisterAgent() {
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(registerAgentRequest).actionGet());
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -694,9 +695,9 @@ public void testRegisterAgent() {
}).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class);
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.builder().build();
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();

machineLearningNodeClient.registerAgent(registerAgentRequest, registerAgentResponseActionListener);
machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener);

verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any());
verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -41,4 +46,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterAgentResponse) {
return (MLRegisterAgentResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e);
}
}
}

0 comments on commit a845401

Please sign in to comment.