Skip to content

Commit

Permalink
dev: workflow adopt the new stateful module
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Mar 26, 2024
1 parent ef95eee commit 4eed085
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/evox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .core.workflow import Workflow
from .core.algorithm import Algorithm
from .core.module import *
from .core.problem import Problem
Expand Down
10 changes: 4 additions & 6 deletions src/evox/workflows/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@

import jax
import jax.numpy as jnp
import numpy as np
import ray
from jax import jit
from jax.tree_util import tree_flatten

from evox import Algorithm, Monitor, Problem, State, Stateful, jit_class
from evox import Algorithm, Problem, State, Workflow
from evox.utils import algorithm_has_init_ask, parse_opt_direction


class WorkerWorkflow(Stateful):
class WorkerWorkflow(Workflow):
stateful_functions = ["step1", "step2"]
def __init__(
self,
algorithm: Algorithm,
Expand Down Expand Up @@ -228,7 +226,7 @@ def get_call_queue(self):
return call_queue


class RayDistributedWorkflow(Stateful):
class RayDistributedWorkflow(Workflow):
def __init__(
self,
algorithm: Algorithm,
Expand Down
6 changes: 3 additions & 3 deletions src/evox/workflows/non_jit_workflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import warnings
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, List, Optional, Union

from evox import Algorithm, Monitor, Problem, State, Stateful
from evox import Algorithm, Monitor, Problem, State, Workflow
from evox.utils import algorithm_has_init_ask, parse_opt_direction


class NonJitWorkflow(Stateful):
class NonJitWorkflow(Workflow):
def __init__(
self,
algorithm: Algorithm,
Expand Down
6 changes: 3 additions & 3 deletions src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, List, Optional, Union

import jax
import jax.numpy as jnp
from jax import jit, lax, pmap, pure_callback
from jax.sharding import PositionalSharding
from jax.tree_util import tree_map

from evox import Algorithm, Problem, State, Stateful, Monitor, jit_method
from evox import Algorithm, Problem, State, Workflow, Monitor, jit_method
from evox.utils import parse_opt_direction, algorithm_has_init_ask


class StdWorkflow(Stateful):
class StdWorkflow(Workflow):
"""Experimental unified workflow,
designed to provide unparallel performance for EC workflow.
Expand Down

0 comments on commit 4eed085

Please sign in to comment.