Skip to content

Commit

Permalink
dev: metaclass wrap use_state based on a class attr
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Mar 26, 2024
1 parent ffa5d3e commit fb710cb
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
1 change: 0 additions & 1 deletion src/evox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .core.algorithm import Algorithm
from .core.module import *
from .core.operator import Operator
from .core.problem import Problem
from .core.state import State
from .core.monitor import Monitor
Expand Down
1 change: 1 addition & 0 deletions src/evox/core/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class Algorithm(Stateful):
"""Base class for all algorithms"""
stateful_functions = ['init_ask', 'init_tell', 'ask', 'tell']

def init_ask(self, state: State) -> Tuple[jax.Array, State]:
"""Ask the algorithm for the initial population
Expand Down
28 changes: 18 additions & 10 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,24 @@ def __new__(
name,
bases,
class_dict,
force_wrap=["__call__"],
ignore=["init", "setup"],
ignore_prefix="_",
):
wrapped = {}
wrapped = class_dict
stateful_functions = []
# stateful_functions from parent classes
for base in bases:
if hasattr(base, "stateful_functions"):
stateful_functions += base.stateful_functions

# stateful_functions from current class
if "stateful_functions" in class_dict:
stateful_functions += class_dict["stateful_functions"]

for key, value in class_dict.items():
if key in force_wrap:
wrapped[key] = use_state(value)
elif key.startswith(ignore_prefix) or key in ignore:
wrapped[key] = value
elif callable(value):
if key in stateful_functions:
wrapped[key] = use_state(value)

wrapped["stateful_functions"] = stateful_functions

return super().__new__(cls, name, bases, wrapped)


Expand All @@ -158,6 +162,8 @@ class Stateful(metaclass=MetaStatefulModule):
and recursively call ``setup`` methods of all submodules.
"""

stateful_functions = []

def __init__(self) -> None:
self._node_id = None
self._cache_override = set()
Expand Down Expand Up @@ -201,7 +207,9 @@ def _recursive_init(self, key, node_id, module_name) -> Tuple[State, int]:
subkey = None
else:
key, subkey = jax.random.split(key)
submodule_state, node_id = attr._recursive_init(subkey, node_id + 1, attr_name)
submodule_state, node_id = attr._recursive_init(
subkey, node_id + 1, attr_name
)
assert isinstance(
submodule_state, State
), "setup method must return a State"
Expand Down
12 changes: 0 additions & 12 deletions src/evox/core/operator.py

This file was deleted.

1 change: 1 addition & 0 deletions src/evox/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class Problem(Stateful):
"""Base class for all algorithms"""
stateful_functions = ["evalutate"]

def evaluate(self, state: State, pop: jax.Array) -> Tuple[jax.Array, State]:
"""Evaluate the fitness at given points
Expand Down
4 changes: 4 additions & 0 deletions src/evox/core/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .module import *

class Workflow(Stateful):
stateful_functions = ["step"]
1 change: 0 additions & 1 deletion src/evox/operators/selection/find_pbest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from evox import jit_class, Operator
from evox.core.state import State
import jax
import jax.numpy as jnp
Expand Down

0 comments on commit fb710cb

Please sign in to comment.