diff --git a/tests/test_common.py b/tests/test_common.py index e8dbef39..090a24f0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -15,9 +15,10 @@ ) MODEL_CONFIG_SAVE_TEST = [ - CategoryEmbeddingModelConfig, - AutoIntConfig, - TabNetModelConfig, + (CategoryEmbeddingModelConfig, dict(layers="10-20")), + (AutoIntConfig, dict(num_heads=1,num_attn_blocks=1,)), + (NodeConfig, dict(num_trees=100, depth=2)), + (TabNetModelConfig, dict(n_a=2, n_d=2)), ] MODEL_CONFIG_FEATURE_EXT_TEST = [ @@ -67,7 +68,8 @@ def test_save_load( continuous_cols=continuous_cols, categorical_cols=categorical_cols, ) - model_config_params = dict(task="regression") + model_config_class, model_config_params = model_config_class + model_config_params['task']="regression" model_config = model_config_class(**model_config_params) trainer_config = TrainerConfig( max_epochs=3, checkpoints=None, early_stopping=None, gpus=0, fast_dev_run=True