diff --git a/src/evox/__init__.py b/src/evox/__init__.py index 37dd5420..82ea8a75 100644 --- a/src/evox/__init__.py +++ b/src/evox/__init__.py @@ -1,3 +1,4 @@ +from .core.workflow import Workflow from .core.algorithm import Algorithm from .core.module import * from .core.problem import Problem diff --git a/src/evox/workflows/distributed.py b/src/evox/workflows/distributed.py index adb9ee4f..8c820b3d 100644 --- a/src/evox/workflows/distributed.py +++ b/src/evox/workflows/distributed.py @@ -4,16 +4,14 @@ import jax import jax.numpy as jnp -import numpy as np import ray -from jax import jit -from jax.tree_util import tree_flatten -from evox import Algorithm, Monitor, Problem, State, Stateful, jit_class +from evox import Algorithm, Problem, State, Workflow from evox.utils import algorithm_has_init_ask, parse_opt_direction -class WorkerWorkflow(Stateful): +class WorkerWorkflow(Workflow): + stateful_functions = ["step1", "step2"] def __init__( self, algorithm: Algorithm, @@ -228,7 +226,7 @@ def get_call_queue(self): return call_queue -class RayDistributedWorkflow(Stateful): +class RayDistributedWorkflow(Workflow): def __init__( self, algorithm: Algorithm, diff --git a/src/evox/workflows/non_jit_workflow.py b/src/evox/workflows/non_jit_workflow.py index 424c04b4..e6956ae3 100644 --- a/src/evox/workflows/non_jit_workflow.py +++ b/src/evox/workflows/non_jit_workflow.py @@ -1,11 +1,11 @@ import warnings -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union -from evox import Algorithm, Monitor, Problem, State, Stateful +from evox import Algorithm, Monitor, Problem, State, Workflow from evox.utils import algorithm_has_init_ask, parse_opt_direction -class NonJitWorkflow(Stateful): +class NonJitWorkflow(Workflow): def __init__( self, algorithm: Algorithm, diff --git a/src/evox/workflows/std_workflow.py b/src/evox/workflows/std_workflow.py index 406b570c..2f0b3ca1 100644 --- a/src/evox/workflows/std_workflow.py +++ b/src/evox/workflows/std_workflow.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union import jax import jax.numpy as jnp @@ -8,11 +8,11 @@ from jax.sharding import PositionalSharding from jax.tree_util import tree_map -from evox import Algorithm, Problem, State, Stateful, Monitor, jit_method +from evox import Algorithm, Problem, State, Workflow, Monitor, jit_method from evox.utils import parse_opt_direction, algorithm_has_init_ask -class StdWorkflow(Stateful): +class StdWorkflow(Workflow): """Experimental unified workflow, designed to provide unparallel performance for EC workflow.