Skip to content

Commit

Permalink
fix: enable_multi_devicesenable_multi_devices
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 14, 2024
1 parent 3ec2b72 commit 528871f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def enable_multi_devices(
if devices is None:
devices = jax.local_devices()
device_count = len(devices)
dummy_pop, _ = jax.eval_shape(self.algorithm.ask, state)
dummy_pop, _ = jax.eval_shape(use_state(self.algorithm.ask), state)
pop_size, dim = dummy_pop.shape
sharding = PositionalSharding(devices).reshape(1, device_count)
state_sharding = self._auto_shard(state, sharding, pop_size, dim)
Expand Down

0 comments on commit 528871f

Please sign in to comment.