Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 31, 2024
1 parent 1191274 commit 72392b9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
12 changes: 12 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import contextlib
from typing import List

import packaging
import pytest
import torch
import torch_geometric.nn
import torchrl

from benchmarl.hydra_config import load_model_config_from_hydra
from benchmarl.models import GnnConfig, model_config_registry
Expand Down Expand Up @@ -155,6 +157,11 @@ def test_models_forward_shape(
or (isinstance(model_name, list) and model_name[0] != "gnn")
):
pytest.skip("gnn model needs agent dim as input")
if (
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
and "gru" in model_name
):
pytest.skip("gru model needs torchrl > 0.5.0")

torch.manual_seed(0)

Expand Down Expand Up @@ -237,6 +244,11 @@ def test_share_params_between_models(
or (isinstance(model_name, list) and model_name[0] != "gnn")
):
pytest.skip("gnn model needs agent dim as input")
if (
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
and "gru" in model_name
):
pytest.skip("gru model needs torchrl > 0.5.0")
torch.manual_seed(1)

input_spec, output_spec = _get_input_and_output_specs(
Expand Down
6 changes: 5 additions & 1 deletion test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#


import packaging
import pytest

import torchrl
from benchmarl.algorithms import (
algorithm_config_registry,
IddpgConfig,
Expand Down Expand Up @@ -110,6 +111,9 @@ def test_gnn(
"algo_config", [IddpgConfig, MaddpgConfig, IppoConfig, MappoConfig, QmixConfig]
)
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
)
def test_gru(
self,
algo_config: AlgorithmConfig,
Expand Down
6 changes: 5 additions & 1 deletion test/test_smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
#


import packaging
import pytest
import torchrl

from benchmarl.algorithms import algorithm_config_registry, MappoConfig, QmixConfig
from benchmarl.algorithms.common import AlgorithmConfig
Expand Down Expand Up @@ -80,6 +81,9 @@ def test_gnn(

@pytest.mark.parametrize("algo_config", [QmixConfig])
@pytest.mark.parametrize("task", [Smacv2Task.PROTOSS_5_VS_5])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
)
def test_gru(
self,
algo_config,
Expand Down
6 changes: 5 additions & 1 deletion test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
#

import packaging
import pytest

import torchrl
from benchmarl.algorithms import (
algorithm_config_registry,
IddpgConfig,
Expand Down Expand Up @@ -119,6 +120,9 @@ def test_gnn(
)
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("task", [VmasTask.NAVIGATION])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
)
def test_gru(
self,
algo_config: AlgorithmConfig,
Expand Down

0 comments on commit 72392b9

Please sign in to comment.