diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index f91367cf0..5f99b7289 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -25,11 +25,12 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; @@ -415,13 +416,11 @@ public static List getInputByWorkflowType(String workflowStep) throws Fl * @return WorkflowValidator */ public WorkflowValidator getWorkflowValidator() { - Map workflowStepValidators = new HashMap<>(); - - for (WorkflowSteps mapping : WorkflowSteps.values()) { - workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); - } - - return new WorkflowValidator(workflowStepValidators); + return new WorkflowValidator( + Stream.of(WorkflowSteps.values()) + .filter(w -> !WorkflowProcessSorter.WORKFLOW_STEP_DENYLIST.contains(w.getWorkflowStepName())) + .collect(Collectors.toMap(WorkflowSteps::getWorkflowStepName, WorkflowSteps::getWorkflowStepValidator)) + ); } /** @@ -430,22 +429,20 @@ public WorkflowValidator getWorkflowValidator() { * @return WorkflowValidator */ public WorkflowValidator getWorkflowValidatorByStep(List steps) { - Map workflowStepValidators = new HashMap<>(); - Set invalidSteps = new HashSet<>(steps); - - for (WorkflowSteps mapping : WorkflowSteps.values()) { - String step = mapping.getWorkflowStepName(); - if (steps.contains(step)) { - workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); - invalidSteps.remove(step); - } - } - + Set validSteps = Stream.of(WorkflowSteps.values()) + .map(WorkflowSteps::getWorkflowStepName) + .filter(name -> !WorkflowProcessSorter.WORKFLOW_STEP_DENYLIST.contains(name)) + .filter(steps::contains) + .collect(Collectors.toSet()); + Set invalidSteps = steps.stream().filter(name -> !validSteps.contains(name)).collect(Collectors.toSet()); if (!invalidSteps.isEmpty()) { throw new FlowFrameworkException("Invalid step name: " + invalidSteps, RestStatus.BAD_REQUEST); } - - return new WorkflowValidator(workflowStepValidators); + return new WorkflowValidator( + Stream.of(WorkflowSteps.values()) + .filter(w -> validSteps.contains(w.getWorkflowStepName())) + .collect(Collectors.toMap(WorkflowSteps::getWorkflowStepName, WorkflowSteps::getWorkflowStepValidator)) + ); } /** diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index a9df22a82..e685e07b9 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -11,6 +11,7 @@ import org.opensearch.client.Client; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -70,6 +71,7 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { // Get all registered workflow step types in the workflow step factory List registeredWorkflowStepTypes = new ArrayList(workflowStepFactory.getStepMap().keySet()); + registeredWorkflowStepTypes.removeAll(WorkflowProcessSorter.WORKFLOW_STEP_DENYLIST); // Check if each registered step has a corresponding validator definition assertTrue(registeredWorkflowStepTypes.containsAll(registeredWorkflowValidatorTypes));