Skip to content

Commit

Permalink
update tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 21, 2023
1 parent 8eeb4b8 commit 66d655e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
10 changes: 6 additions & 4 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib
import os
import os.path as osp
Expand Down Expand Up @@ -43,7 +45,7 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config: Dict[str, Any]):
self.config = config

def update_config(self, config: Dict[str, Any]):
def update_config(self, config: Dict[str, Any]) -> Task:
if self.config is None:
self.config = config
else:
Expand All @@ -67,7 +69,7 @@ def supports_discrete_actions(self) -> bool:
def max_steps(self, env: EnvBase) -> int:
raise NotImplementedError

def has_render(self) -> bool:
def has_render(self, env: EnvBase) -> bool:
raise NotImplementedError

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
Expand All @@ -90,7 +92,7 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:

@staticmethod
def env_name() -> str:
return "vmas"
raise NotImplementedError

@staticmethod
def log_info(batch: TensorDictBase) -> Dict:
Expand All @@ -108,7 +110,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = Path(__file__).parent.parent / "conf" / "task" / f"{name}.yaml"
return read_yaml_config(str(yaml_path.resolve()))

def get_from_yaml(self, path: Optional[str] = None):
def get_from_yaml(self, path: Optional[str] = None) -> Task:
if path is None:
task_name = self.name.lower()
return self.update_config(
Expand Down
9 changes: 1 addition & 8 deletions benchmarl/environments/smacv2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def get_env_fun(
continuous_actions: bool,
seed: Optional[int],
) -> Callable[[], EnvBase]:

return lambda: SMACv2Env(categorical_actions=True, seed=seed, **self.config)

def supports_continuous_actions(self) -> bool:
Expand All @@ -27,7 +26,7 @@ def supports_continuous_actions(self) -> bool:
def supports_discrete_actions(self) -> bool:
return True

def has_render(self) -> bool:
def has_render(self, env: EnvBase) -> bool:
return True

def max_steps(self, env: EnvBase) -> bool:
Expand Down Expand Up @@ -84,9 +83,3 @@ def log_info(batch: TensorDictBase) -> Dict:
@staticmethod
def env_name() -> str:
return "smacv2"


if __name__ == "__main__":
print(Smacv2Task.protoss_5_vs_5.get_from_yaml())
env = Smacv2Task.protoss_5_vs_5.get_env_fun(0, False, 0)()
print(env.render(mode="rgb_array"))
6 changes: 1 addition & 5 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def supports_continuous_actions(self) -> bool:
def supports_discrete_actions(self) -> bool:
return True

def has_render(self) -> bool:
def has_render(self, env: EnvBase) -> bool:
return True

def max_steps(self, env: EnvBase) -> bool:
Expand Down Expand Up @@ -64,7 +64,3 @@ def action_spec(self, env: EnvBase) -> CompositeSpec:
@staticmethod
def env_name() -> str:
return "vmas"


if __name__ == "__main__":
print(VmasTask.BALANCE.get_from_yaml())

0 comments on commit 66d655e

Please sign in to comment.