Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed static fields initialization in WorkflowStepFactory #532

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ public Collection<Object> createComponents(
);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
threadPool,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -61,37 +59,47 @@

private final Map<String, Supplier<WorkflowStep>> stepMap = new HashMap<>();
private static final Logger logger = LogManager.getLogger(WorkflowStepFactory.class);
private static ThreadPool threadPool;
private static MachineLearningNodeClient mlClient;
private static FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private static FlowFrameworkSettings flowFrameworkSettings;

/**
* Instantiate this class.
*
* @param threadPool The OpenSearch thread pool
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
* @param flowFrameworkSettings common settings of the plugin
*/
public WorkflowStepFactory(
ThreadPool threadPool,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkSettings flowFrameworkSettings
) {
this.threadPool = threadPool;
this.mlClient = mlClient;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.flowFrameworkSettings = flowFrameworkSettings;
// Initialize the WorkflowSteps enum inside the constructor
for (WorkflowSteps workflowStep : WorkflowSteps.values()) {
stepMap.put(workflowStep.getWorkflowStepName(), workflowStep.step());
}
stepMap.put(NoOpStep.NAME, NoOpStep::new);
stepMap.put(
RegisterLocalCustomModelStep.NAME,
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
);
stepMap.put(
RegisterLocalSparseEncodingModelStep.NAME,
() -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)

Check warning on line 84 in src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java#L84

Added line #L84 was not covered by tests
);
stepMap.put(
RegisterLocalPretrainedModelStep.NAME,
() -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)

Check warning on line 88 in src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java#L88

Added line #L88 was not covered by tests
);
stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient));
stepMap.put(
DeployModelStep.NAME,
() -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
);
stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient));
stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ToolStep.NAME, ToolStep::new);
stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient));
}

/**
Expand All @@ -101,16 +109,15 @@
public enum WorkflowSteps {

/** Noop Step */
NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null, NoOpStep::new),
NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null),

/** Create Connector Step */
CREATE_CONNECTOR(
CreateConnectorStep.NAME,
List.of(NAME_FIELD, DESCRIPTION_FIELD, VERSION_FIELD, PROTOCOL_FIELD, PARAMETERS_FIELD, CREDENTIAL_FIELD, ACTIONS_FIELD),
List.of(CONNECTOR_ID),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)
TimeValue.timeValueSeconds(60)
),

/** Register Local Custom Model Step */
Expand All @@ -129,8 +136,7 @@
),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
TimeValue.timeValueSeconds(60)
),

/** Register Local Sparse Encoding Model Step */
Expand All @@ -139,8 +145,7 @@
List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT, FUNCTION_NAME, MODEL_CONTENT_HASH_VALUE, URL),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
TimeValue.timeValueSeconds(60)
),

/** Register Local Pretrained Model Step */
Expand All @@ -149,8 +154,7 @@
List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(60),
() -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
TimeValue.timeValueSeconds(60)
),

/** Register Remote Model Step */
Expand All @@ -159,8 +163,7 @@
List.of(NAME_FIELD, CONNECTOR_ID),
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
List.of(OPENSEARCH_ML),
null,
() -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)
null
),

/** Register Model Group Step */
Expand All @@ -169,94 +172,42 @@
List.of(NAME_FIELD),
List.of(MODEL_GROUP_ID, MODEL_GROUP_STATUS),
List.of(OPENSEARCH_ML),
null,
() -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)
null
),

/** Deploy Model Step */
DEPLOY_MODEL(
DeployModelStep.NAME,
List.of(MODEL_ID),
List.of(MODEL_ID),
List.of(OPENSEARCH_ML),
TimeValue.timeValueSeconds(15),
() -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
),
DEPLOY_MODEL(DeployModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), TimeValue.timeValueSeconds(15)),

/** Undeploy Model Step */
UNDEPLOY_MODEL(
UndeployModelStep.NAME,
List.of(MODEL_ID),
List.of(SUCCESS),
List.of(OPENSEARCH_ML),
null,
() -> new UndeployModelStep(mlClient)
),
UNDEPLOY_MODEL(UndeployModelStep.NAME, List.of(MODEL_ID), List.of(SUCCESS), List.of(OPENSEARCH_ML), null),

/** Delete Model Step */
DELETE_MODEL(
DeleteModelStep.NAME,
List.of(MODEL_ID),
List.of(MODEL_ID),
List.of(OPENSEARCH_ML),
null,
() -> new DeleteModelStep(mlClient)
),
DELETE_MODEL(DeleteModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), null),

/** Delete Connector Step */
DELETE_CONNECTOR(
DeleteConnectorStep.NAME,
List.of(CONNECTOR_ID),
List.of(CONNECTOR_ID),
List.of(OPENSEARCH_ML),
null,
() -> new DeleteConnectorStep(mlClient)
),
DELETE_CONNECTOR(DeleteConnectorStep.NAME, List.of(CONNECTOR_ID), List.of(CONNECTOR_ID), List.of(OPENSEARCH_ML), null),

/** Register Agent Step */
REGISTER_AGENT(
RegisterAgentStep.NAME,
List.of(NAME_FIELD, TYPE),
List.of(AGENT_ID),
List.of(OPENSEARCH_ML),
null,
() -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)
),
REGISTER_AGENT(RegisterAgentStep.NAME, List.of(NAME_FIELD, TYPE), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),

/** Delete Agent Step */
DELETE_AGENT(
DeleteAgentStep.NAME,
List.of(AGENT_ID),
List.of(AGENT_ID),
List.of(OPENSEARCH_ML),
null,
() -> new DeleteAgentStep(mlClient)
),
DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),

/** Create Tool Step */
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null, ToolStep::new);
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null);

private final String workflowStepName;
private final List<String> inputs;
private final List<String> outputs;
private final List<String> requiredPlugins;
private final TimeValue timeout;
private final Supplier<WorkflowStep> workflowStep;

WorkflowSteps(
String workflowStepName,
List<String> inputs,
List<String> outputs,
List<String> requiredPlugins,
TimeValue timeout,
Supplier<WorkflowStep> workflowStep
) {

WorkflowSteps(String workflowStepName, List<String> inputs, List<String> outputs, List<String> requiredPlugins, TimeValue timeout) {
this.workflowStepName = workflowStepName;
this.inputs = List.copyOf(inputs);
this.outputs = List.copyOf(outputs);
this.requiredPlugins = requiredPlugins;
this.timeout = timeout;
this.workflowStep = workflowStep;
}

/**
Expand Down Expand Up @@ -299,14 +250,6 @@
return timeout;
}

/**
* Get the step
* @return the step
*/
public Supplier<WorkflowStep> step() {
return workflowStep;
}

/**
* Get the workflow step validator object
* @return the WorkflowStepValidator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@
*/
package org.opensearch.flowframework.model;

import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -27,14 +20,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -174,26 +160,11 @@ public void testParseWorkflowValidator() throws IOException {
public void testWorkflowStepFactoryHasValidators() throws IOException {

ThreadPool threadPool = mock(ThreadPool.class);
ClusterService clusterService = mock(ClusterService.class);
ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class);
AdminClient adminClient = mock(AdminClient.class);
Client client = mock(Client.class);
when(client.admin()).thenReturn(adminClient);
when(adminClient.cluster()).thenReturn(clusterAdminClient);
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
threadPool,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,7 @@ public static void setup() throws IOException {
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL
)
);
WorkflowStepFactory factory = new WorkflowStepFactory(
testThreadPool,
clusterService,
client,
mlClient,
flowFrameworkIndicesHandler,
flowFrameworkSettings
);
WorkflowStepFactory factory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings);
workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings);
}

Expand Down
Loading