Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 31, 2024
1 parent 72392b9 commit da733d5
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 13 deletions.
8 changes: 4 additions & 4 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def test_models_forward_shape(
):
pytest.skip("gnn model needs agent dim as input")
if (
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
packaging.version.parse(torchrl.__version__).local is None
and "gru" in model_name
):
pytest.skip("gru model needs torchrl > 0.5.0")
pytest.skip("gru model needs torchrl from github")

torch.manual_seed(0)

Expand Down Expand Up @@ -245,10 +245,10 @@ def test_share_params_between_models(
):
pytest.skip("gnn model needs agent dim as input")
if (
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
packaging.version.parse(torchrl.__version__).local is None
and "gru" in model_name
):
pytest.skip("gru model needs torchrl > 0.5.0")
pytest.skip("gru model needs torchrl from github")
torch.manual_seed(1)

input_spec, output_spec = _get_input_and_output_specs(
Expand Down
4 changes: 1 addition & 3 deletions test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ def test_gnn(
"algo_config", [IddpgConfig, MaddpgConfig, IppoConfig, MappoConfig, QmixConfig]
)
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
)
@pytest.mark.skipif(packaging.version.parse(torchrl.__version__).local is None)
def test_gru(
self,
algo_config: AlgorithmConfig,
Expand Down
4 changes: 1 addition & 3 deletions test/test_smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def test_gnn(

@pytest.mark.parametrize("algo_config", [QmixConfig])
@pytest.mark.parametrize("task", [Smacv2Task.PROTOSS_5_VS_5])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
)
@pytest.mark.skipif(packaging.version.parse(torchrl.__version__).local is None)
def test_gru(
self,
algo_config,
Expand Down
4 changes: 1 addition & 3 deletions test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def test_gnn(
)
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("task", [VmasTask.NAVIGATION])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).base_version <= "0.5.0"
)
@pytest.mark.skipif(packaging.version.parse(torchrl.__version__).local is None)
def test_gru(
self,
algo_config: AlgorithmConfig,
Expand Down

0 comments on commit da733d5

Please sign in to comment.