diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index c2176c334..5dd6b2773 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import java.util.List; @@ -163,15 +164,20 @@ public void onFailure(Exception e) { MLModelConfig modelConfig = modelConfigBuilder.build(); - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder() .modelName(modelName) .version(modelVersion) .modelFormat(modelFormat) .modelGroupId(modelGroupId) .hashValue(modelContentHashValue) .modelConfig(modelConfig) - .url(url) - .build(); + .url(url); + + if (description != null) { + mlInputBuilder.description(description); + } + + MLRegisterModelInput mlInput = mlInputBuilder.build(); mlClient.register(mlInput, actionListener); } else { diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 85e46a8cb..4dedc8bf2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -117,17 +117,19 @@ public void onFailure(Exception e) { if (Stream.of(modelName, functionName, connectorId).allMatch(x -> x != null)) { - MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder(); + MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName(modelName) + .connectorId(connectorId); if (modelGroupId != null) { builder.modelGroupId(modelGroupId); } + if (description != null) { + builder.description(description); + } + MLRegisterModelInput mlInput = builder.build(); - MLRegisterModelInput mlInput = builder.functionName(functionName) - .modelName(modelName) - .description(description) - .connectorId(connectorId) - .build(); mlClient.register(mlInput, actionListener); } else { registerRemoteModelFuture.completeExceptionally(