-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding RegisterLocalModelStep, fixing tests, adding input/ouput defin…
…itions to workflow step json Signed-off-by: Joshua Palis <[email protected]>
- Loading branch information
Showing
10 changed files
with
266 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
198 changes: 198 additions & 0 deletions
198
src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.model.MLModelConfig; | ||
import org.opensearch.ml.common.model.MLModelFormat; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelInput; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Stream; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; | ||
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; | ||
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; | ||
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; | ||
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; | ||
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; | ||
import static org.opensearch.flowframework.common.CommonValue.URL; | ||
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; | ||
|
||
/** | ||
* Step to register a local model | ||
*/ | ||
public class RegisterLocalModelStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class); | ||
|
||
private MachineLearningNodeClient mlClient; | ||
|
||
static final String NAME = "register_local_model"; | ||
|
||
/** | ||
* Instantiate this class | ||
* @param mlClient client to instantiate MLClient | ||
*/ | ||
public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { | ||
this.mlClient = mlClient; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) { | ||
|
||
CompletableFuture<WorkflowData> registerLocalModelFuture = new CompletableFuture<>(); | ||
|
||
ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() { | ||
@Override | ||
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { | ||
logger.info("Local Model registration successful"); | ||
registerLocalModelFuture.complete( | ||
new WorkflowData( | ||
Map.ofEntries( | ||
Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), | ||
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) | ||
) | ||
) | ||
); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
logger.error("Failed to register local model"); | ||
registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); | ||
} | ||
}; | ||
|
||
String modelName = null; | ||
String modelVersion = null; | ||
String description = null; | ||
MLModelFormat modelFormat = null; | ||
String modelGroupId = null; | ||
String modelContentHashValue = null; | ||
String modelType = null; | ||
String embeddingDimension = null; | ||
FrameworkType frameworkType = null; | ||
String allConfig = null; | ||
String url = null; | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, Object> content = workflowData.getContent(); | ||
|
||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case NAME_FIELD: | ||
modelName = (String) content.get(NAME_FIELD); | ||
break; | ||
case VERSION_FIELD: | ||
modelVersion = (String) content.get(VERSION_FIELD); | ||
break; | ||
case DESCRIPTION_FIELD: | ||
description = (String) content.get(DESCRIPTION_FIELD); | ||
break; | ||
case MODEL_FORMAT: | ||
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); | ||
break; | ||
case MODEL_GROUP_ID: | ||
modelGroupId = (String) content.get(MODEL_GROUP_ID); | ||
break; | ||
case MODEL_TYPE: | ||
modelType = (String) content.get(MODEL_TYPE); | ||
break; | ||
case EMBEDDING_DIMENSION: | ||
embeddingDimension = (String) content.get(EMBEDDING_DIMENSION); | ||
break; | ||
case FRAMEWORK_TYPE: | ||
frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE)); | ||
break; | ||
case ALL_CONFIG: | ||
allConfig = (String) content.get(ALL_CONFIG); | ||
break; | ||
Check warning on line 134 in src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java Codecov / codecov/patchsrc/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java#L133-L134
|
||
case MODEL_CONTENT_HASH_VALUE: | ||
modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE); | ||
break; | ||
case URL: | ||
url = (String) content.get(URL); | ||
break; | ||
default: | ||
break; | ||
|
||
} | ||
} | ||
} | ||
|
||
if (Stream.of( | ||
modelName, | ||
modelVersion, | ||
description, | ||
modelFormat, | ||
modelGroupId, | ||
embeddingDimension, | ||
frameworkType, | ||
modelContentHashValue, | ||
url | ||
).allMatch(x -> x != null)) { | ||
|
||
// Create model configuration, assuming null pooling mode, null model max length, normalize results set to false | ||
TextEmbeddingModelConfigBuilder builder = TextEmbeddingModelConfig.builder(); | ||
if (allConfig != null) { | ||
builder.allConfig(allConfig); | ||
} | ||
|
||
MLModelConfig modelConfig = builder.modelType(modelType) | ||
.embeddingDimension(Integer.valueOf(embeddingDimension)) | ||
.frameworkType(frameworkType) | ||
.poolingMode(null) | ||
.modelMaxLength(null) | ||
.normalizeResult(false) | ||
.build(); | ||
|
||
MLRegisterModelInput mlInput = MLRegisterModelInput.builder() | ||
.modelName(modelName) | ||
.version(modelVersion) | ||
.modelFormat(modelFormat) | ||
.modelGroupId(modelGroupId) | ||
.hashValue(modelContentHashValue) | ||
.modelConfig(modelConfig) | ||
.url(url) | ||
.build(); | ||
|
||
mlClient.register(mlInput, actionListener); | ||
} else { | ||
registerLocalModelFuture.completeExceptionally( | ||
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) | ||
); | ||
} | ||
|
||
return registerLocalModelFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
} |
154 changes: 0 additions & 154 deletions
154
src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.