Skip to content

Commit

Permalink
-- added testcase for saving NODE
Browse files Browse the repository at this point in the history
  • Loading branch information
manujosephv committed Mar 18, 2021
1 parent 6b892b6 commit 82a30fe
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
)

MODEL_CONFIG_SAVE_TEST = [
CategoryEmbeddingModelConfig,
AutoIntConfig,
TabNetModelConfig,
(CategoryEmbeddingModelConfig, dict(layers="10-20")),
(AutoIntConfig, dict(num_heads=1,num_attn_blocks=1,)),
(NodeConfig, dict(num_trees=100, depth=2)),
(TabNetModelConfig, dict(n_a=2, n_d=2)),
]

MODEL_CONFIG_FEATURE_EXT_TEST = [
Expand Down Expand Up @@ -67,7 +68,8 @@ def test_save_load(
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
)
model_config_params = dict(task="regression")
model_config_class, model_config_params = model_config_class
model_config_params['task']="regression"
model_config = model_config_class(**model_config_params)
trainer_config = TrainerConfig(
max_epochs=3, checkpoints=None, early_stopping=None, gpus=0, fast_dev_run=True
Expand Down

0 comments on commit 82a30fe

Please sign in to comment.