diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 779438bac..49f3bcce9 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -121,8 +121,6 @@ public Collection createComponents( ); WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( threadPool, - clusterService, - client, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 7f54b17cc..a580c41f7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -10,8 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; @@ -61,37 +59,47 @@ public class WorkflowStepFactory { private final Map> stepMap = new HashMap<>(); private static final Logger logger = LogManager.getLogger(WorkflowStepFactory.class); - private static ThreadPool threadPool; - private static MachineLearningNodeClient mlClient; - private static FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - private static FlowFrameworkSettings flowFrameworkSettings; /** * Instantiate this class. * * @param threadPool The OpenSearch thread pool - * @param clusterService The OpenSearch cluster service - * @param client The OpenSearch client steps can use * @param mlClient Machine Learning client to perform ml operations * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices * @param flowFrameworkSettings common settings of the plugin */ public WorkflowStepFactory( ThreadPool threadPool, - ClusterService clusterService, - Client client, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, FlowFrameworkSettings flowFrameworkSettings ) { - this.threadPool = threadPool; - this.mlClient = mlClient; - this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; - this.flowFrameworkSettings = flowFrameworkSettings; - // Initialize the WorkflowSteps enum inside the constructor - for (WorkflowSteps workflowStep : WorkflowSteps.values()) { - stepMap.put(workflowStep.getWorkflowStepName(), workflowStep.step()); - } + stepMap.put(NoOpStep.NAME, NoOpStep::new); + stepMap.put( + RegisterLocalCustomModelStep.NAME, + () -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + ); + stepMap.put( + RegisterLocalSparseEncodingModelStep.NAME, + () -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + ); + stepMap.put( + RegisterLocalPretrainedModelStep.NAME, + () -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + ); + stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); + stepMap.put( + DeployModelStep.NAME, + () -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + ); + stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); + stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(ToolStep.NAME, ToolStep::new); + stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); } /** @@ -101,7 +109,7 @@ public WorkflowStepFactory( public enum WorkflowSteps { /** Noop Step */ - NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null, NoOpStep::new), + NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null), /** Create Connector Step */ CREATE_CONNECTOR( @@ -109,8 +117,7 @@ public enum WorkflowSteps { List.of(NAME_FIELD, DESCRIPTION_FIELD, VERSION_FIELD, PROTOCOL_FIELD, PARAMETERS_FIELD, CREDENTIAL_FIELD, ACTIONS_FIELD), List.of(CONNECTOR_ID), List.of(OPENSEARCH_ML), - TimeValue.timeValueSeconds(60), - () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler) + TimeValue.timeValueSeconds(60) ), /** Register Local Custom Model Step */ @@ -129,8 +136,7 @@ public enum WorkflowSteps { ), List.of(MODEL_ID, REGISTER_MODEL_STATUS), List.of(OPENSEARCH_ML), - TimeValue.timeValueSeconds(60), - () -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + TimeValue.timeValueSeconds(60) ), /** Register Local Sparse Encoding Model Step */ @@ -139,8 +145,7 @@ public enum WorkflowSteps { List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT, FUNCTION_NAME, MODEL_CONTENT_HASH_VALUE, URL), List.of(MODEL_ID, REGISTER_MODEL_STATUS), List.of(OPENSEARCH_ML), - TimeValue.timeValueSeconds(60), - () -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + TimeValue.timeValueSeconds(60) ), /** Register Local Pretrained Model Step */ @@ -149,8 +154,7 @@ public enum WorkflowSteps { List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT), List.of(MODEL_ID, REGISTER_MODEL_STATUS), List.of(OPENSEARCH_ML), - TimeValue.timeValueSeconds(60), - () -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) + TimeValue.timeValueSeconds(60) ), /** Register Remote Model Step */ @@ -159,8 +163,7 @@ public enum WorkflowSteps { List.of(NAME_FIELD, CONNECTOR_ID), List.of(MODEL_ID, REGISTER_MODEL_STATUS), List.of(OPENSEARCH_ML), - null, - () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler) + null ), /** Register Model Group Step */ @@ -169,94 +172,42 @@ public enum WorkflowSteps { List.of(NAME_FIELD), List.of(MODEL_GROUP_ID, MODEL_GROUP_STATUS), List.of(OPENSEARCH_ML), - null, - () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler) + null ), /** Deploy Model Step */ - DEPLOY_MODEL( - DeployModelStep.NAME, - List.of(MODEL_ID), - List.of(MODEL_ID), - List.of(OPENSEARCH_ML), - TimeValue.timeValueSeconds(15), - () -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ), + DEPLOY_MODEL(DeployModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), TimeValue.timeValueSeconds(15)), /** Undeploy Model Step */ - UNDEPLOY_MODEL( - UndeployModelStep.NAME, - List.of(MODEL_ID), - List.of(SUCCESS), - List.of(OPENSEARCH_ML), - null, - () -> new UndeployModelStep(mlClient) - ), + UNDEPLOY_MODEL(UndeployModelStep.NAME, List.of(MODEL_ID), List.of(SUCCESS), List.of(OPENSEARCH_ML), null), /** Delete Model Step */ - DELETE_MODEL( - DeleteModelStep.NAME, - List.of(MODEL_ID), - List.of(MODEL_ID), - List.of(OPENSEARCH_ML), - null, - () -> new DeleteModelStep(mlClient) - ), + DELETE_MODEL(DeleteModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), null), /** Delete Connector Step */ - DELETE_CONNECTOR( - DeleteConnectorStep.NAME, - List.of(CONNECTOR_ID), - List.of(CONNECTOR_ID), - List.of(OPENSEARCH_ML), - null, - () -> new DeleteConnectorStep(mlClient) - ), + DELETE_CONNECTOR(DeleteConnectorStep.NAME, List.of(CONNECTOR_ID), List.of(CONNECTOR_ID), List.of(OPENSEARCH_ML), null), /** Register Agent Step */ - REGISTER_AGENT( - RegisterAgentStep.NAME, - List.of(NAME_FIELD, TYPE), - List.of(AGENT_ID), - List.of(OPENSEARCH_ML), - null, - () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler) - ), + REGISTER_AGENT(RegisterAgentStep.NAME, List.of(NAME_FIELD, TYPE), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null), /** Delete Agent Step */ - DELETE_AGENT( - DeleteAgentStep.NAME, - List.of(AGENT_ID), - List.of(AGENT_ID), - List.of(OPENSEARCH_ML), - null, - () -> new DeleteAgentStep(mlClient) - ), + DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null), /** Create Tool Step */ - CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null, ToolStep::new); + CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null); private final String workflowStepName; private final List inputs; private final List outputs; private final List requiredPlugins; private final TimeValue timeout; - private final Supplier workflowStep; - - WorkflowSteps( - String workflowStepName, - List inputs, - List outputs, - List requiredPlugins, - TimeValue timeout, - Supplier workflowStep - ) { + + WorkflowSteps(String workflowStepName, List inputs, List outputs, List requiredPlugins, TimeValue timeout) { this.workflowStepName = workflowStepName; this.inputs = List.copyOf(inputs); this.outputs = List.copyOf(outputs); this.requiredPlugins = requiredPlugins; this.timeout = timeout; - this.workflowStep = workflowStep; } /** @@ -299,14 +250,6 @@ public TimeValue timeout() { return timeout; } - /** - * Get the step - * @return the step - */ - public Supplier step() { - return workflowStep; - } - /** * Get the workflow step validator object * @return the WorkflowStepValidator diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 73fb3186c..678435707 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -8,13 +8,6 @@ */ package org.opensearch.flowframework.model; -import org.opensearch.client.AdminClient; -import org.opensearch.client.Client; -import org.opensearch.client.ClusterAdminClient; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -27,14 +20,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; -import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; + import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -174,26 +160,11 @@ public void testParseWorkflowValidator() throws IOException { public void testWorkflowStepFactoryHasValidators() throws IOException { ThreadPool threadPool = mock(ThreadPool.class); - ClusterService clusterService = mock(ClusterService.class); - ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); - AdminClient adminClient = mock(AdminClient.class); - Client client = mock(Client.class); - when(client.admin()).thenReturn(adminClient); - when(adminClient.cluster()).thenReturn(clusterAdminClient); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - final Set> settingsSet = Stream.concat( - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION) - ).collect(Collectors.toSet()); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( threadPool, - clusterService, - client, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 7a2250ca5..6e06f252c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -118,14 +118,7 @@ public static void setup() throws IOException { FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) ); - WorkflowStepFactory factory = new WorkflowStepFactory( - testThreadPool, - clusterService, - client, - mlClient, - flowFrameworkIndicesHandler, - flowFrameworkSettings - ); + WorkflowStepFactory factory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings); }