Skip to content

Commit

Permalink
[Feature] VMAS group map
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 5, 2023
1 parent 615b3ba commit bb04829
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def max_steps(self, env: EnvBase) -> int:
return self.config["max_steps"]

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
if hasattr(env, "group_map"):
return env.group_map
return {"agents": [agent.name for agent in env.agents]}

def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
Expand All @@ -63,15 +65,18 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:

def observation_spec(self, env: EnvBase) -> CompositeSpec:
observation_spec = env.unbatched_observation_spec.clone()
if "info" in observation_spec["agents"]:
del observation_spec[("agents", "info")]
for group in self.group_map(env):
if "info" in observation_spec[group]:
del observation_spec[(group, "info")]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
info_spec = env.unbatched_observation_spec.clone()
del info_spec[("agents", "observation")]
if "info" in info_spec["agents"]:
return info_spec
for group in self.group_map(env):
del info_spec[(group, "observation")]
for group in self.group_map(env):
if "info" in info_spec[group]:
return info_spec
else:
return None

Expand Down

0 comments on commit bb04829

Please sign in to comment.