Skip to content

Commit

Permalink
new vmas tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 5, 2023
1 parent a4a76ca commit 5160df0
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 4 deletions.
11 changes: 11 additions & 0 deletions benchmarl/conf/task/vmas/transport.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- _self_
- vmas_transport_config


max_steps: 100
n_agents: 4
n_packages: 1
package_width: 0.15
package_length: 0.15
package_mass: 50
9 changes: 9 additions & 0 deletions benchmarl/conf/task/vmas/wheel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- _self_
- vmas_wheel_config

max_steps: 100
n_agents: 4
line_length: 1
line_mass: 30
desired_velocity: 0.05
4 changes: 4 additions & 0 deletions benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
from .vmas.balance import TaskConfig as BalanceConfig
from .vmas.navigation import TaskConfig as NavigationConfig
from .vmas.sampling import TaskConfig as SamplingConfig
from .vmas.transport import TaskConfig as TransportConfig
from .vmas.wheel import TaskConfig as WheelConfig

_task_class_registry = {
"vmas_balance_config": BalanceConfig,
"vmas_sampling_config": SamplingConfig,
"vmas_navigation_config": NavigationConfig,
"vmas_transport_config": TransportConfig,
"vmas_wheel_config": WheelConfig,
"pettingzoo_multiwalker_config": MultiwalkerConfig,
"pettingzoo_simple_tag_config": SimpleTagConfig,
}
10 changes: 8 additions & 2 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class VmasTask(Task):
BALANCE = None
SAMPLING = None
NAVIGATION = None
TRANSPORT = None
WHEEL = None

def get_env_fun(
self,
Expand Down Expand Up @@ -53,13 +55,17 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:

def observation_spec(self, env: EnvBase) -> CompositeSpec:
observation_spec = env.unbatched_observation_spec.clone()
del observation_spec[("agents", "info")]
if "info" in observation_spec["agents"]:
del observation_spec[("agents", "info")]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
info_spec = env.unbatched_observation_spec.clone()
del info_spec[("agents", "observation")]
return info_spec
if "info" in info_spec["agents"]:
return info_spec
else:
return None

def action_spec(self, env: EnvBase) -> CompositeSpec:
return env.unbatched_action_spec
Expand Down
11 changes: 11 additions & 0 deletions benchmarl/environments/vmas/transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass, MISSING


@dataclass
class TaskConfig:
max_steps: int = MISSING
n_agents: int = MISSING
n_packages: int = MISSING
package_width: float = MISSING
package_length: float = MISSING
package_mass: float = MISSING
10 changes: 10 additions & 0 deletions benchmarl/environments/vmas/wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass, MISSING


@dataclass
class TaskConfig:
max_steps: int = MISSING
n_agents: int = MISSING
line_length: float = MISSING
line_mass: float = MISSING
desired_velocity: float = MISSING
25 changes: 23 additions & 2 deletions test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
IppoConfig,
IsacConfig,
MaddpgConfig,
MappoConfig,
MasacConfig,
QmixConfig,
VdnConfig,
Expand All @@ -25,8 +26,8 @@
class TestVmas:
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("prefer_continuous", [True, False])
@pytest.mark.parametrize("task", list(VmasTask))
def test_all_algos_all_tasks(
@pytest.mark.parametrize("task", [VmasTask.BALANCE])
def test_all_algos(
self,
algo_config: AlgorithmConfig,
task: Task,
Expand All @@ -51,6 +52,26 @@ def test_all_algos_all_tasks(
)
experiment.run()

@pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig])
@pytest.mark.parametrize("task", list(VmasTask))
def test_all_tasks(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
mlp_sequence_config,
):

task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=mlp_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("task", [VmasTask.BALANCE])
def test_reloading_trainer(
Expand Down

0 comments on commit 5160df0

Please sign in to comment.