Skip to content

Commit

Permalink
Adding RegisterLocalModelStep, fixing tests, adding input/ouput defin…
Browse files Browse the repository at this point in the history
…itions to workflow step json

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed Nov 7, 2023
1 parent 9bb74eb commit d93a581
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 201 deletions.
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,20 @@ private CommonValue() {}
public static final String CONNECTOR_ID = "connector_id";
/** Model format field */
public static final String MODEL_FORMAT = "model_format";
/** Model content hash value field */
public static final String MODEL_CONTENT_HASH_VALUE = "model_content_hash_value";
/** URL field */
public static final String URL = "url";
/** Model config field */
public static final String MODEL_CONFIG = "model_config";
/** Model type field */
public static final String MODEL_TYPE = "model_type";
/** Embedding dimension field */
public static final String EMBEDDING_DIMENSION = "embedding_dimension";
/** Framework type field */
public static final String FRAMEWORK_TYPE = "framework_type";
/** All config field */
public static final String ALL_CONFIG = "all_config";
/** Version field */
public static final String VERSION_FIELD = "version";
/** Connector protocol field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);
private static final Logger logger = LogManager.getLogger(ModelGroupStep.class);

private MachineLearningNodeClient mlClient;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;

/**
* Step to register a local model
*/
public class RegisterLocalModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "register_local_model";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public RegisterLocalModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> registerLocalModelFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
logger.info("Local Model registration successful");
registerLocalModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register local model");
registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelName = null;
String modelVersion = null;
String description = null;
MLModelFormat modelFormat = null;
String modelGroupId = null;
String modelContentHashValue = null;
String modelType = null;
String embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
String url = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelName = (String) content.get(NAME_FIELD);
break;
case VERSION_FIELD:
modelVersion = (String) content.get(VERSION_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case MODEL_FORMAT:
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT));
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case MODEL_TYPE:
modelType = (String) content.get(MODEL_TYPE);
break;
case EMBEDDING_DIMENSION:
embeddingDimension = (String) content.get(EMBEDDING_DIMENSION);
break;
case FRAMEWORK_TYPE:
frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE));
break;
case ALL_CONFIG:
allConfig = (String) content.get(ALL_CONFIG);
break;

Check warning on line 134 in src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java#L133-L134

Added lines #L133 - L134 were not covered by tests
case MODEL_CONTENT_HASH_VALUE:
modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE);
break;
case URL:
url = (String) content.get(URL);
break;
default:
break;

}
}
}

if (Stream.of(
modelName,
modelVersion,
description,
modelFormat,
modelGroupId,
embeddingDimension,
frameworkType,
modelContentHashValue,
url
).allMatch(x -> x != null)) {

// Create model configuration, assuming null pooling mode, null model max length, normalize results set to false
TextEmbeddingModelConfigBuilder builder = TextEmbeddingModelConfig.builder();
if (allConfig != null) {
builder.allConfig(allConfig);

Check warning on line 163 in src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java#L163

Added line #L163 was not covered by tests
}

MLModelConfig modelConfig = builder.modelType(modelType)
.embeddingDimension(Integer.valueOf(embeddingDimension))
.frameworkType(frameworkType)
.poolingMode(null)
.modelMaxLength(null)
.normalizeResult(false)
.build();

MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.modelName(modelName)
.version(modelVersion)
.modelFormat(modelFormat)
.modelGroupId(modelGroupId)
.hashValue(modelContentHashValue)
.modelConfig(modelConfig)
.url(url)
.build();

mlClient.register(mlInput, actionListener);
} else {
registerLocalModelFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
}

return registerLocalModelFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 196 in src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java#L196

Added line #L196 was not covered by tests
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ public void onFailure(Exception e) {
String description = null;
String connectorId = null;

// TODO : Handle inline connector configuration : https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/149

for (WorkflowData workflowData : data) {

Map<String, Object> content = workflowData.getContent();
Expand Down
Loading

0 comments on commit d93a581

Please sign in to comment.