Skip to content

Commit

Permalink
Adds Register and Deploy Model Step for remote model (#52)
Browse files Browse the repository at this point in the history
* Initial UploadModel integration

Signed-off-by: Owais Kazi <[email protected]>

* Implemented Register Model Step

Signed-off-by: Owais Kazi <[email protected]>

* Integrated register for remote model

Signed-off-by: Owais Kazi <[email protected]>

* Integrated deploy model

Signed-off-by: Owais Kazi <[email protected]>

* Separated Register and Deploy Steps

Signed-off-by: Owais Kazi <[email protected]>

* Added tests

Signed-off-by: Owais Kazi <[email protected]>

* Added NodeClient

Signed-off-by: Owais Kazi <[email protected]>

* Added javadocs

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR comments

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR comments

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR comments - 2

Signed-off-by: Owais Kazi <[email protected]>

* Fixed test failure

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR comments

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 12, 2023
1 parent 446c32d commit 9b10b23
Show file tree
Hide file tree
Showing 16 changed files with 686 additions and 102 deletions.
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ publishing {
allprojects {
group = opensearch_group
version = "${opensearch_build}"
}

java {
targetCompatibility = JavaVersion.VERSION_11
sourceCompatibility = JavaVersion.VERSION_11
}
Expand Down
38 changes: 38 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#


# Enable build caching
org.gradle.caching=true
org.gradle.warning.mode=none
org.gradle.parallel=true
# Workaround for https://github.com/diffplug/spotless/issues/834
org.gradle.jvmargs=-Xmx3g -XX:+HeapDumpOnOutOfMemoryError -Xss2m \
--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
options.forkOptions.memoryMaximumSize=2g

# Disable duplicate project id detection
# See https://docs.gradle.org/current/userguide/upgrading_version_6.html#duplicate_project_names_may_cause_publication_to_fail
systemProp.org.gradle.dependency.duplicate.project.detection=false

# Enforce the build to fail on deprecated gradle api usage
systemProp.org.gradle.warning.mode=fail

# forcing to use TLS1.2 to avoid failure in vault
# see https://github.com/hashicorp/vault/issues/8750#issuecomment-631236121
systemProp.jdk.tls.client.protocols=TLSv1.2

# jvm args for faster test execution by default
systemProp.tests.jvm.argline=-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m
2 changes: 1 addition & 1 deletion src/main/java/demo/DemoWorkflowStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture.runAsync(() -> {
try {
Thread.sleep(this.delay);
future.complete(null);
future.complete(WorkflowData.EMPTY);
} catch (InterruptedException e) {
future.completeExceptionally(e);
}
Expand Down
1 change: 0 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);

WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/
package org.opensearch.flowframework.client;

import org.opensearch.client.node.NodeClient;
import org.opensearch.client.Client;
import org.opensearch.ml.client.MachineLearningNodeClient;

/**
Expand All @@ -22,12 +22,12 @@ private MLClient() {}
/**
* Creates machine learning client.
*
* @param nodeClient node client of OpenSearch.
* @param client client of OpenSearch.
* @return machine learning client from ml-commons.
*/
public static MachineLearningNodeClient createMLClient(NodeClient nodeClient) {
public static MachineLearningNodeClient createMLClient(Client client) {
if (INSTANCE == null) {
INSTANCE = new MachineLearningNodeClient(nodeClient);
INSTANCE = new MachineLearningNodeClient(client);
}
return INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,13 @@ public class CommonValue {
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context";
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
public static final String MODEL_ID = "model_id";
public static final String FUNCTION_NAME = "function_name";
public static final String MODEL_NAME = "name";
public static final String MODEL_VERSION = "model_version";
public static final String MODEL_GROUP_ID = "model_group_id";
public static final String DESCRIPTION = "description";
public static final String CONNECTOR_ID = "connector_id";
public static final String MODEL_FORMAT = "model_format";
public static final String MODEL_CONFIG = "model_config";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;

/**
* Step to deploy a model
*/
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private Client client;
static final String NAME = "deploy_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
*/
public DeployModelStep(Client client) {
this.client = client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())))
);
}

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
}
};

String modelId = null;

for (WorkflowData workflowData : data) {
if (workflowData.getContent().containsKey(MODEL_ID)) {
modelId = (String) workflowData.getContent().get(MODEL_ID);
break;
}
}
machineLearningNodeClient.deploy(modelId, actionListener);
return deployModelFuture;
}

@Override
public String getName() {
return NAME;
}
}
65 changes: 65 additions & 0 deletions src/main/java/org/opensearch/flowframework/workflow/GetTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;

/**
* Step to get modelID of a registered local model
*/
@SuppressForbidden(reason = "This class is for the future work of registering local model")
public class GetTask {

private static final Logger logger = LogManager.getLogger(GetTask.class);
private MachineLearningNodeClient machineLearningNodeClient;
private String taskId;

/**
* Instantiate this class
* @param machineLearningNodeClient client to instantiate ml-commons APIs
* @param taskId taskID of the model
*/
public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) {
this.machineLearningNodeClient = machineLearningNodeClient;
this.taskId = taskId;
}

/**
* Invokes get task API of ml-commons
*/
public void getTask() {

ActionListener<MLTask> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLTask mlTask) {
if (mlTask.getState() == MLTaskState.COMPLETED) {
logger.info("Model registration successful");
MLTaskGetResponse response = MLTaskGetResponse.builder().mlTask(mlTask).build();
logger.info("Response from task {}", response);
}
}

@Override
public void onFailure(Exception e) {
logger.error("Model registration failed");
}
};

machineLearningNodeClient.getTask(taskId, actionListener);

}

}
Loading

0 comments on commit 9b10b23

Please sign in to comment.