diff --git a/src/evox/__init__.py b/src/evox/__init__.py index 712b6bcb..37dd5420 100644 --- a/src/evox/__init__.py +++ b/src/evox/__init__.py @@ -1,6 +1,5 @@ from .core.algorithm import Algorithm from .core.module import * -from .core.operator import Operator from .core.problem import Problem from .core.state import State from .core.monitor import Monitor diff --git a/src/evox/core/algorithm.py b/src/evox/core/algorithm.py index 61c145c4..273ca899 100644 --- a/src/evox/core/algorithm.py +++ b/src/evox/core/algorithm.py @@ -9,6 +9,7 @@ class Algorithm(Stateful): """Base class for all algorithms""" + stateful_functions = ['init_ask', 'init_tell', 'ask', 'tell'] def init_ask(self, state: State) -> Tuple[jax.Array, State]: """Ask the algorithm for the initial population diff --git a/src/evox/core/module.py b/src/evox/core/module.py index 6ad9f2d1..359804b8 100644 --- a/src/evox/core/module.py +++ b/src/evox/core/module.py @@ -129,20 +129,24 @@ def __new__( name, bases, class_dict, - force_wrap=["__call__"], - ignore=["init", "setup"], - ignore_prefix="_", ): - wrapped = {} + wrapped = class_dict + stateful_functions = [] + # stateful_functions from parent classes + for base in bases: + if hasattr(base, "stateful_functions"): + stateful_functions += base.stateful_functions + + # stateful_functions from current class + if "stateful_functions" in class_dict: + stateful_functions += class_dict["stateful_functions"] for key, value in class_dict.items(): - if key in force_wrap: - wrapped[key] = use_state(value) - elif key.startswith(ignore_prefix) or key in ignore: - wrapped[key] = value - elif callable(value): + if key in stateful_functions: wrapped[key] = use_state(value) + wrapped["stateful_functions"] = stateful_functions + return super().__new__(cls, name, bases, wrapped) @@ -158,6 +162,8 @@ class Stateful(metaclass=MetaStatefulModule): and recursively call ``setup`` methods of all submodules. """ + stateful_functions = [] + def __init__(self) -> None: self._node_id = None self._cache_override = set() @@ -201,7 +207,9 @@ def _recursive_init(self, key, node_id, module_name) -> Tuple[State, int]: subkey = None else: key, subkey = jax.random.split(key) - submodule_state, node_id = attr._recursive_init(subkey, node_id + 1, attr_name) + submodule_state, node_id = attr._recursive_init( + subkey, node_id + 1, attr_name + ) assert isinstance( submodule_state, State ), "setup method must return a State" diff --git a/src/evox/core/operator.py b/src/evox/core/operator.py deleted file mode 100644 index 419fe47b..00000000 --- a/src/evox/core/operator.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Tuple - -import jax -import jax.numpy as jnp - -from .module import * -from .state import State - - -class Operator(Stateful): - def __call__(self, state: State, pop: jax.Array) -> Tuple[jax.Array, State]: - return jnp.empty(0), State() diff --git a/src/evox/core/problem.py b/src/evox/core/problem.py index 8c2606e9..918de677 100644 --- a/src/evox/core/problem.py +++ b/src/evox/core/problem.py @@ -8,6 +8,7 @@ class Problem(Stateful): """Base class for all algorithms""" + stateful_functions = ["evalutate"] def evaluate(self, state: State, pop: jax.Array) -> Tuple[jax.Array, State]: """Evaluate the fitness at given points diff --git a/src/evox/core/workflow.py b/src/evox/core/workflow.py new file mode 100644 index 00000000..adac8474 --- /dev/null +++ b/src/evox/core/workflow.py @@ -0,0 +1,4 @@ +from .module import * + +class Workflow(Stateful): + stateful_functions = ["step"] \ No newline at end of file diff --git a/src/evox/operators/selection/find_pbest.py b/src/evox/operators/selection/find_pbest.py index 6f37470b..cb3873b1 100644 --- a/src/evox/operators/selection/find_pbest.py +++ b/src/evox/operators/selection/find_pbest.py @@ -1,4 +1,3 @@ -from evox import jit_class, Operator from evox.core.state import State import jax import jax.numpy as jnp