Skip to content

Commit

Permalink
[ENH] remove sklearn dependency in test_get_params (#212)
Browse files Browse the repository at this point in the history
This PR removes the `sklearn` dependency in `get_test_params`, by
replacing `_check_get_params_invariance` with an `skbase`-native
implementation.

The call to `clone` is replaced with the object's own `clone`method.
  • Loading branch information
fkiraly authored Aug 11, 2023
1 parent e83bcf0 commit 9fbe066
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions skbase/testing/test_all_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,21 +659,17 @@ def test_no_between_test_case_side_effects(self, object_instance, a):
assert not hasattr(object_instance, "test__attr")
object_instance.test__attr = 42

@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_get_params(self, object_instance):
"""Check that get_params works correctly, against sklearn interface."""
from sklearn.utils.estimator_checks import (
check_get_params_invariance as _check_get_params_invariance,
)

params = object_instance.get_params()
assert isinstance(params, dict)
_check_get_params_invariance(
object_instance.__class__.__name__, object_instance
)

e = object_instance.clone()

shallow_params = e.get_params(deep=False)
deep_params = e.get_params(deep=True)

assert all(item in deep_params.items() for item in shallow_params.items())

def test_set_params(self, object_instance):
"""Check that set_params works correctly."""
Expand Down

0 comments on commit 9fbe066

Please sign in to comment.