Skip to content

Commit

Permalink
dev: introduce dataclass support and reduce magic
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Mar 28, 2024
1 parent 29841a3 commit 3419d6f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 95 deletions.
64 changes: 8 additions & 56 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import jax
import jax.numpy as jnp
import dataclasses

from .state import State

Expand Down Expand Up @@ -85,17 +86,8 @@ def default_cond_fun(name: str):
return True


def _class_decorator(cls, wrapper, cond_fun=default_cond_fun):
"""A helper function used to add decorators to methods of a class
Parameters
----------
wrapper
The decorator
ignore
Ignore methods in this list
ignore_prefix
Ignore methods with certain prefix
def jit_class(cls):
"""A helper function used to jit decorators to methods of a class
Returns
-------
Expand All @@ -105,52 +97,15 @@ def _class_decorator(cls, wrapper, cond_fun=default_cond_fun):
for attr_name in dir(cls):
func = getattr(cls, attr_name)
if callable(func) and cond_fun(attr_name):
wrapped = wrapper(func)
if dataclasses.is_dataclass(cls):
wrapped = wrapper(jax.jit(func))
else:
wrapped = wrapper(jit_method(func))
setattr(cls, attr_name, wrapped)
return cls


def jit_class(cls, cond_fun=default_cond_fun):
return _class_decorator(cls, jit_method, cond_fun)


class MetaStatefulModule(type):
"""Meta class used by Module
This meta class will try to wrap methods with use_state,
which allows easy managing of states.
It is recommended to use a single underscore as prefix to prevent a method from being wrapped.
Still, this behavior can be configured by passing ``force_wrap``, ``ignore`` and ``ignore_prefix``.
"""

def __new__(
cls,
name,
bases,
class_dict,
):
wrapped = class_dict
stateful_functions = []
# stateful_functions from parent classes
for base in bases:
if hasattr(base, "stateful_functions"):
stateful_functions += base.stateful_functions

# stateful_functions from current class
if "stateful_functions" in class_dict:
stateful_functions += class_dict["stateful_functions"]

for key, value in class_dict.items():
if key in stateful_functions:
wrapped[key] = use_state(value)

wrapped["stateful_functions"] = stateful_functions

return super().__new__(cls, name, bases, wrapped)


class Stateful(metaclass=MetaStatefulModule):
class Stateful:
"""Base class for all evox modules.
This module allow easy managing of states.
Expand All @@ -162,11 +117,8 @@ class Stateful(metaclass=MetaStatefulModule):
and recursively call ``setup`` methods of all submodules.
"""

stateful_functions = []

def __init__(self) -> None:
self._node_id = None
self._cache_override = set()

def setup(self, key: jax.Array) -> State:
"""Setup mutable state here
Expand Down
46 changes: 17 additions & 29 deletions src/evox/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class State:

EMPTY: dict = {}

def __init__(
self, state_dict: dict = EMPTY, child_states: dict[str, State] = EMPTY, **kwargs
) -> None:
def __init__(self, _dataclass, /, **kwargs) -> None:
"""Construct a ``State`` from dict or keyword arguments
Example::
Expand All @@ -35,7 +33,10 @@ def __init__(
>>> State(x=1, y=2) # from keyword arguments
State ({'x': 1, 'y': 2}, [])
"""
self.__dict__["_state_dict"] = kwargs
if _dataclass is not None:
self.__dict__["_state_dict"] = _dataclass
else:
self.__dict__["_state_dict"] = kwargs
self.__dict__["_child_states"] = State.EMPTY
self.__dict__["_state_id"] = None

Expand Down Expand Up @@ -66,7 +67,7 @@ def _set_state_id_mut(self, state_id) -> State:
self.__dict__["_state_id"] = state_id
return self

def update(self, other: Optional[Union[State, dict]] = None, **kwargs) -> State:
def update(self, **kwargs) -> State:
"""Update the current State with another State or dict and return new State.
This method also accept keyword arguments.
Expand All @@ -78,23 +79,14 @@ def update(self, other: Optional[Union[State, dict]] = None, **kwargs) -> State:
State ({'x': 1, 'y': 3}, [])
>>> state # note that State is immutable, so state isn't modified
State ({'x': 1, 'y': 2}, [])
>>> state | {"y": 4} # use the | operator
State ({'x': 1, 'y': 4}, [])
"""
if other is None:
if dataclass.is_dataclass(self._state_dict):
return copy(self)._set_state_dict_mut(self._state_dict.replace(**kwargs))
else:
return copy(self)._set_state_dict_mut({**self._state_dict, **kwargs})

if isinstance(other, State):
return (
copy(self)
._set_state_dict_mut({**self._state_dict, **other._state_dict})
._set_child_states_mut({**self._child_states, **other._child_states})
)

if isinstance(other, dict):
return copy(self)._set_state_dict_mut({**self._state_dict, **other})

raise ValueError(f"other must be either State or dict, but got {type(other)}.")
def replace(self, **kwargs) -> State:
return self.update(**kwargs)

def has_child(self, name: str) -> bool:
return name in self._child_states
Expand Down Expand Up @@ -140,25 +132,21 @@ def update_path(self, path, new_state):
else:
raise ValueError("Path must be either tuple or int")

def __or__(self, *args, **kwargs) -> State:
"""| operator
Same as the update method.
"""
return self.update(*args, **kwargs)

def __getattr__(self, key: str) -> Any:
if is_magic_method(key):
return super().__getattr__(key)
return self._state_dict[key]

def __getitem__(self, index: Union[str, int]) -> State:
def __getitem__(self, key: str) -> Any:
return getattr(self, key)

def index(self, index: Union[str, int]) -> State:
"""
PyTree index, apply the index to every element in the state.
"""
return tree_map(lambda x: x[index], self)

def __getslice__(self, begin: int, end: int) -> State:
def slice(self, begin: int, end: int) -> State:
"""
PyTree index, apply the index to every element in the state.
"""
Expand Down Expand Up @@ -188,7 +176,7 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return f"State {pformat(self.sprint_tree())}"

def sprint_tree(self) -> Union[dict,str]:
def sprint_tree(self) -> Union[dict, str]:
if self is State.EMPTY:
return "State.empty"
children = {
Expand Down
12 changes: 6 additions & 6 deletions src/evox/workflows/non_jit_workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Callable, List, Optional, Union

from evox import Algorithm, Monitor, Problem, State, Workflow
from evox import Algorithm, Monitor, Problem, State, Workflow, use_state
from evox.utils import algorithm_has_init_ask, parse_opt_direction


Expand Down Expand Up @@ -95,9 +95,9 @@ def step(self, state):
is_init = False

if is_init:
cand_sol, state = self.algorithm.init_ask(state)
cand_sol, state = use_state(self.algorithm.init_ask)(state)
else:
cand_sol, state = self.algorithm.ask(state)
cand_sol, state = use_state(self.algorithm.ask)(state)

for monitor in self.registered_hooks["post_ask"]:
monitor.post_ask(state, cand_sol)
Expand All @@ -109,7 +109,7 @@ def step(self, state):
for monitor in self.registered_hooks["pre_eval"]:
monitor.pre_eval(state, cand_sol, transformed_cand_sol)

fitness, state = self.problem.evaluate(state, transformed_cand_sol)
fitness, state = use_state(self.problem.evaluate)(state, transformed_cand_sol)

fitness = fitness * self.opt_direction

Expand All @@ -126,9 +126,9 @@ def step(self, state):
)

if is_init:
state = self.algorithm.init_tell(state, fitness)
state = use_state(self.algorithm.init_tell)(state, fitness)
else:
state = self.algorithm.tell(state, fitness)
state = use_state(self.algorithm.tell)(state, fitness)

for monitor in self.registered_hooks["post_tell"]:
monitor.post_tell(state)
Expand Down
8 changes: 4 additions & 4 deletions src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jax.sharding import PositionalSharding
from jax.tree_util import tree_map

from evox import Algorithm, Problem, State, Workflow, Monitor, jit_method
from evox import Algorithm, Problem, State, Workflow, Monitor, jit_method, use_state
from evox.utils import parse_opt_direction, algorithm_has_init_ask


Expand Down Expand Up @@ -139,7 +139,7 @@ def _proto_step(self, is_init, state):
tell = self.algorithm.tell

# candidate solution
cand_sol, state = ask(state)
cand_sol, state = use_state(ask)(state)
cand_sol_size = cand_sol.shape[0]

for monitor in self.registered_hooks["post_ask"]:
Expand All @@ -159,7 +159,7 @@ def _proto_step(self, is_init, state):

# if the function is jitted
if self.jit_problem:
fitness, state = self.problem.evaluate(state, transformed_cand_sol)
fitness, state = use_state(self.problem.evaluate)(state, transformed_cand_sol)
else:
if self.num_objectives == 1:
fit_shape = (cand_sol_size,)
Expand Down Expand Up @@ -192,7 +192,7 @@ def _proto_step(self, is_init, state):
state, cand_sol, transformed_cand_sol, fitness, transformed_fitness
)

state = tell(state, transformed_fitness)
state = use_state(tell)(state, transformed_fitness)

for monitor in self.registered_hooks["post_tell"]:
monitor.post_tell(state)
Expand Down

0 comments on commit 3419d6f

Please sign in to comment.