diff --git a/test/test_algorithms.py b/test/test_algorithms.py index fcacdceb..b925a124 100644 --- a/test/test_algorithms.py +++ b/test/test_algorithms.py @@ -1,5 +1,4 @@ import pytest - from benchmarl.algorithms import algorithm_config_registry from benchmarl.environments import VmasTask from benchmarl.experiment import Experiment, ExperimentConfig @@ -7,6 +6,7 @@ from benchmarl.models.common import SequenceModelConfig from benchmarl.models.mlp import MlpConfig from hydra import compose, initialize +from torch import nn @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) @@ -15,8 +15,8 @@ def test_all_algos_balance(algo_config, continuous): task = VmasTask.BALANCE.get_from_yaml() model_config = SequenceModelConfig( model_configs=[ - MlpConfig(num_cells=[8]), - MlpConfig(num_cells=[4]), + MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear), + MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear), ], intermediate_sizes=[5], )