From b30219845d743a20654d1e90947259349a237c13 Mon Sep 17 00:00:00 2001 From: Noah Syrkis Date: Wed, 12 Jun 2024 07:58:37 -0300 Subject: [PATCH] Update smax_env.py Start positions are currently hard-coded to assume that the map_width and height is 32. If the map width is set to 128, the units will this start in the top-left corner. Make team_0 and team_1 start y coordinates halfway up the map_height (instead of hard-coded 16). The x coordinate will be a fourth of the map_width, and three fourth of the map_width respectively. --- jaxmarl/environments/smax/smax_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxmarl/environments/smax/smax_env.py b/jaxmarl/environments/smax/smax_env.py index 51243393..89c199a1 100644 --- a/jaxmarl/environments/smax/smax_env.py +++ b/jaxmarl/environments/smax/smax_env.py @@ -262,12 +262,12 @@ def _get_obs_size(self): def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: """Environment-specific reset.""" key, team_0_key, team_1_key = jax.random.split(key, num=3) - team_0_start = jnp.stack([jnp.array([8.0, 16.0])] * self.num_allies) + team_0_start = jnp.stack([jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies) team_0_start_noise = jax.random.uniform( team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2 ) team_0_start = team_0_start + team_0_start_noise - team_1_start = jnp.stack([jnp.array([24.0, 16.0])] * self.num_enemies) + team_1_start = jnp.stack([jnp.array([self.map_width / 4 * 3, self.map_height / 2])] * self.num_enemies) team_1_start_noise = jax.random.uniform( team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2 )