Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Environment] SMACv2 #13

Merged
merged 16 commits into from
Sep 19, 2023
59 changes: 59 additions & 0 deletions benchmarl/conf/task/smacv2/protoss_5_vs_5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
continuing_episode: False
difficulty: "7"
game_version: null
map_name: "10gen_protoss"
move_amount: 2
obs_all_health: True
obs_instead_of_state: False
obs_last_action: False
obs_own_health: True
obs_pathing_grid: False
obs_terrain_height: False
obs_timestep_number: False
reward_death_value: 10
reward_defeat: 0
reward_negative_scale: 0.5
reward_only_positive: True
reward_scale: True
reward_scale_rate: 20
reward_sparse: False
reward_win: 200
replay_dir: ""
replay_prefix: ""
conic_fov: False
use_unit_ranges: True
min_attack_range: 2
obs_own_pos: True
num_fov_actions: 12
capability_config:
n_units: 5
n_enemies: 5
team_gen:
dist_type: "weighted_teams"
unit_types:
- "stalker"
- "zealot"
- "colossus"
weights:
- 0.45
- 0.45
- 0.1
observe: True
start_positions:
dist_type: "surrounded_and_reflect"
p: 0.5
map_x: 32
map_y: 32

# enemy_mask:
# dist_type: "mask"
# mask_probability: 0.5
# n_enemies: 5
state_last_action: True
state_timestep_number: False
step_mul: 8
heuristic_ai: False
# heuristic_rest: False
debug: False
prob_obs_enemy: 1.0
action_mask: True
3 changes: 3 additions & 0 deletions benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .common import Task

from .smacv2.common import Smacv2Task
from .vmas.balance import TaskConfig as BalanceConfig

# Environments
Expand All @@ -11,6 +13,7 @@
"vmas/balance": VmasTask.BALANCE,
"vmas/sampling": VmasTask.SAMPLING,
"vmas/navigation": VmasTask.NAVIGATION,
"smacv2/protoss_5_vs_5": Smacv2Task.protoss_5_vs_5,
}


Expand Down
Empty file.
93 changes: 93 additions & 0 deletions benchmarl/environments/smacv2/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Callable, Dict, List, Optional

import torch

from tensordict import TensorDictBase
from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase
from torchrl.envs.libs.smacv2 import SMACv2Env

from benchmarl.environments.common import Task


class Smacv2Task(Task):
protoss_5_vs_5 = None

def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
) -> Callable[[], EnvBase]:

return lambda: SMACv2Env(categorical_actions=True, seed=seed, **self.config)

def supports_continuous_actions(self) -> bool:
return False

def supports_discrete_actions(self) -> bool:
return True

def has_render(self) -> bool:
return True

def max_steps(self, env: EnvBase) -> bool:
return env.episode_limit

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map

def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
observation_spec = env.observation_spec.clone()
del observation_spec["info"]
del observation_spec["agents"]
return observation_spec

def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
observation_spec = env.observation_spec.clone()
del observation_spec["info"]
del observation_spec["state"]
del observation_spec[("agents", "observation")]
return observation_spec

def observation_spec(self, env: EnvBase) -> CompositeSpec:
observation_spec = env.observation_spec.clone()
del observation_spec["info"]
del observation_spec["state"]
del observation_spec[("agents", "action_mask")]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
observation_spec = env.observation_spec.clone()
del observation_spec["state"]
del observation_spec["agents"]
return observation_spec

def action_spec(self, env: EnvBase) -> CompositeSpec:
return env.input_spec["full_action_spec"]

@staticmethod
def log_info(batch: TensorDictBase) -> Dict:
done = batch.get(("next", "done")).squeeze(-1)
return {
"collection/info/win_rate": batch.get(("next", "info", "battle_won"))[done]
.to(torch.float)
.mean()
.item(),
"collection/info/episode_limit_rate": batch.get(
("next", "info", "episode_limit")
)[done]
.to(torch.float)
.mean()
.item(),
}

@staticmethod
def env_name() -> str:
return "smacv2"


if __name__ == "__main__":
print(Smacv2Task.protoss_5_vs_5.get_from_yaml())
env = Smacv2Task.protoss_5_vs_5.get_env_fun(0, False, 0)()
print(env.render(mode="rgb_array"))
32 changes: 31 additions & 1 deletion test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import pytest
from benchmarl.algorithms import algorithm_config_registry

from benchmarl.environments import VmasTask
from benchmarl.algorithms.common import AlgorithmConfig

from benchmarl.environments import Smacv2Task, VmasTask
from benchmarl.experiment import Experiment
from benchmarl.models.common import SequenceModelConfig
from benchmarl.models.mlp import MlpConfig
from torch import nn


_has_vmas = importlib.util.find_spec("vmas") is not None
_has_smacv2 = importlib.util.find_spec("smacv2") is not None


@pytest.mark.skipif(not _has_vmas, reason="VMAS not found")
Expand All @@ -37,6 +40,33 @@ def test_all_algos_vmas(algo_config, continuous, experiment_config):
experiment.run()


@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found")
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
def test_all_algos_smac(algo_config: AlgorithmConfig, experiment_config):
if algo_config.supports_discrete_actions():
task = Smacv2Task.protoss_5_vs_5.get_from_yaml()
model_config = SequenceModelConfig(
model_configs=[
MlpConfig(
num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear
),
MlpConfig(
num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear
),
],
intermediate_sizes=[5],
)

experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=model_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()


# @pytest.mark.parametrize("algo_config", algorithm_config_registry.keys())
# def test_all_algos_hydra(algo_config):
# with initialize(version_base=None, config_path="../benchmarl/conf"):
Expand Down