diff --git a/jaxmarl/environments/smax/smax_env.py b/jaxmarl/environments/smax/smax_env.py index 51243393..89c199a1 100644 --- a/jaxmarl/environments/smax/smax_env.py +++ b/jaxmarl/environments/smax/smax_env.py @@ -262,12 +262,12 @@ def _get_obs_size(self): def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: """Environment-specific reset.""" key, team_0_key, team_1_key = jax.random.split(key, num=3) - team_0_start = jnp.stack([jnp.array([8.0, 16.0])] * self.num_allies) + team_0_start = jnp.stack([jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies) team_0_start_noise = jax.random.uniform( team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2 ) team_0_start = team_0_start + team_0_start_noise - team_1_start = jnp.stack([jnp.array([24.0, 16.0])] * self.num_enemies) + team_1_start = jnp.stack([jnp.array([self.map_width / 4 * 3, self.map_height / 2])] * self.num_enemies) team_1_start_noise = jax.random.uniform( team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2 )