Skip to content

Commit

Permalink
explicitly set workflow state dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ZaberKo committed Jul 9, 2024
1 parent f8d328e commit 3c6f7ba
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
1 change: 0 additions & 1 deletion src/evox/core/pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from jax.tree_util import register_pytree_node
import dataclasses
from typing import Annotated, Any, Callable, Optional, Tuple, TypeVar, get_type_hints

from typing_extensions import (
dataclass_transform, # pytype: disable=not-supported-yet
Expand Down
20 changes: 12 additions & 8 deletions src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(
problem: Problem,
monitors: Sequence[Monitor] = (),
opt_direction: Union[str, Sequence[str]] = "min",
candidate_transforms: Sequence[Callable[[jax.Array],jax.Array]] = (),
fitness_transforms: Sequence[Callable[[jax.Array],jax.Array]] = (),
candidate_transforms: Sequence[Callable[[jax.Array], jax.Array]] = (),
fitness_transforms: Sequence[Callable[[jax.Array], jax.Array]] = (),
jit_step: bool = True,
external_problem: bool = False,
num_objectives: Optional[int] = None,
Expand Down Expand Up @@ -116,9 +116,7 @@ def __init__(
self.external_problem = external_problem
self.num_objectives = num_objectives
if self.external_problem is True and self.num_objectives is None:
raise ValueError(
("Using external problem, but num_objectives isn't set ")
)
raise ValueError("Using external problem, but num_objectives isn't set ")

def _ask(self, state):
if has_init_ask(self.algorithm) and state.first_step:
Expand Down Expand Up @@ -248,7 +246,12 @@ def _step(self, state):

def setup(self, key):
return State(
StdWorkflowState(generation=0, first_step=True, rank=0, world_size=1)
StdWorkflowState(
generation=jnp.zeros((), dtype=jnp.uint32),
first_step=True,
rank=jnp.zeros((), dtype=jnp.int32),
world_size=1,
)
)

def step(self, state):
Expand Down Expand Up @@ -277,7 +280,7 @@ def enable_multi_devices(self, state: State, pmap_axis_name=POP_AXIS_NAME) -> St

self.pmap_axis_name = pmap_axis_name
self._step = jax.pmap(
self._step, axis_name=pmap_axis_name, static_broadcasted_argnums=0
self._step, axis_name=pmap_axis_name, static_broadcasted_argnums=(0,)
)

# multi-node case
Expand All @@ -288,7 +291,8 @@ def enable_multi_devices(self, state: State, pmap_axis_name=POP_AXIS_NAME) -> St

state = jax.device_put_replicated(state, self.devices)
state = state.replace(
rank=jax.device_put_sharded(tuple(ranks), self.devices), world_size=num_devices
rank=jax.device_put_sharded(tuple(ranks), self.devices),
world_size=num_devices,
)

return state
Expand Down

0 comments on commit 3c6f7ba

Please sign in to comment.