-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update Signed-off-by: Matteo Bettini <[email protected]> * add abstract and headers Signed-off-by: Matteo Bettini <[email protected]> * amend Signed-off-by: Matteo Bettini <[email protected]> * improve examples Signed-off-by: Matteo Bettini <[email protected]> * remove requirements Signed-off-by: Matteo Bettini <[email protected]> * remove nightly Signed-off-by: Matteo Bettini <[email protected]> * install Signed-off-by: Matteo Bettini <[email protected]> * install Signed-off-by: Matteo Bettini <[email protected]> * install Signed-off-by: Matteo Bettini <[email protected]> * install Signed-off-by: Matteo Bettini <[email protected]> * install Signed-off-by: Matteo Bettini <[email protected]> * amend Signed-off-by: Matteo Bettini <[email protected]> * amend Signed-off-by: Matteo Bettini <[email protected]> * amend Signed-off-by: Matteo Bettini <[email protected]> --------- Signed-off-by: Matteo Bettini <[email protected]>
- Loading branch information
1 parent
0e01302
commit 7afff7d
Showing
9 changed files
with
260 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python benchmarl/run.py -m algorithm=mappo,qmix,masac task=vmas/balance,vmas/sampling seed=0,1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,34 @@ | ||
from benchmarl.algorithms import MaddpgConfig, MappoConfig, MasacConfig, QmixConfig | ||
from benchmarl.algorithms import MappoConfig, MasacConfig, QmixConfig | ||
from benchmarl.benchmark import Benchmark | ||
from benchmarl.environments import VmasTask | ||
from benchmarl.experiment import ExperimentConfig | ||
from benchmarl.models.common import SequenceModelConfig | ||
from benchmarl.models.mlp import MlpConfig | ||
from torch import nn | ||
|
||
if __name__ == "__main__": | ||
|
||
# Loads from "benchmarl/conf/experiment/base_experiment.yaml" | ||
experiment_config = ExperimentConfig.get_from_yaml() | ||
tasks = [VmasTask.BALANCE.get_from_yaml()] | ||
|
||
# Loads from "benchmarl/conf/task" | ||
tasks = [VmasTask.BALANCE.get_from_yaml(), VmasTask.SAMPLING.get_from_yaml()] | ||
|
||
# Loads from "benchmarl/conf/algorithm" | ||
algorithm_configs = [ | ||
MappoConfig.get_from_yaml(), | ||
MaddpgConfig.get_from_yaml(), | ||
QmixConfig.get_from_yaml(), | ||
MasacConfig.get_from_yaml(), | ||
] | ||
seeds = {0} | ||
|
||
# Model still need to be refactored for hydra loading | ||
model_config = SequenceModelConfig( | ||
model_configs=[ | ||
MlpConfig.get_from_yaml(), | ||
MlpConfig(num_cells=[256], layer_class=nn.Linear, activation_class=nn.Tanh), | ||
], | ||
intermediate_sizes=[128], | ||
) | ||
# Loads from "benchmarl/conf/model/layers" | ||
model_config = MlpConfig.get_from_yaml() | ||
critic_model_config = MlpConfig.get_from_yaml() | ||
|
||
benchmark = Benchmark( | ||
algorithm_configs=algorithm_configs, | ||
tasks=tasks, | ||
seeds=seeds, | ||
seeds={0, 1}, | ||
experiment_config=experiment_config, | ||
model_config=model_config, | ||
critic_model_config=critic_model_config, | ||
) | ||
benchmark.run_sequential() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python benchmarl/run.py algorithm=mappo task=vmas/balance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from benchmarl.algorithms import MappoConfig | ||
|
||
from benchmarl.environments import VmasTask | ||
|
||
from benchmarl.experiment import Experiment, ExperimentConfig | ||
|
||
from benchmarl.models.mlp import MlpConfig | ||
|
||
if __name__ == "__main__": | ||
|
||
# Loads from "benchmarl/conf/experiment/base_experiment.yaml" | ||
experiment_config = ExperimentConfig.get_from_yaml() | ||
|
||
# Loads from "benchmarl/conf/task/vmas/balance.yaml" | ||
task = VmasTask.BALANCE.get_from_yaml() | ||
|
||
# Loads from "benchmarl/conf/algorithm/mappo.yaml" | ||
algorithm_config = MappoConfig.get_from_yaml() | ||
|
||
# Loads from "benchmarl/conf/model/layers/mlp.yaml" | ||
model_config = MlpConfig.get_from_yaml() | ||
critic_model_config = MlpConfig.get_from_yaml() | ||
|
||
experiment = Experiment( | ||
task=task, | ||
algorithm_config=algorithm_config, | ||
model_config=model_config, | ||
critic_model_config=critic_model_config, | ||
seed=0, | ||
config=experiment_config, | ||
) | ||
experiment.run() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,9 @@ | |
author="Matteo Bettini", | ||
author_email="[email protected]", | ||
packages=find_packages(), | ||
install_requires=["torchrl", "tqdm", "hydra-core"], | ||
install_requires=["tqdm", "hydra-core"], | ||
extras_require={ | ||
"tasks": ["vmas>=1.2.10", "pettingzoo[all]>=1.24.1"], | ||
"vmas": ["vmas>=1.2.10"], | ||
"pettingzoo": ["pettingzoo[all]>=1.24.1"], | ||
}, | ||
include_package_data=True, | ||
) |