Skip to content

Commit

Permalink
Merge branch 'main' into pyhon-313rc2
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Sep 29, 2024
2 parents 45bf64e + 49ff3dc commit 4b1bf2b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
7 changes: 7 additions & 0 deletions skbase/base/_pretty_printing/tests/test_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Tests for skbase pretty printing functionality."""

import pytest

from skbase.base import BaseObject
from skbase.utils.dependencies import _check_soft_dependencies


class CompositionDummy(BaseObject):
Expand All @@ -15,6 +18,10 @@ def __init__(self, foo, bar=84):
super(CompositionDummy, self).__init__()


@pytest.mark.skipif(
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_sklearn_compatibility():
"""Test that the pretty printing functions are compatible with sklearn."""
from sklearn.ensemble import RandomForestRegressor
Expand Down
12 changes: 6 additions & 6 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def __init__(self, obj, obj_iterable):


@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
Expand All @@ -1037,7 +1037,7 @@ def test_clone_param_is_none(fixture_class_parent: Type[Parent]):


@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_empty_array(fixture_class_parent: Type[Parent]):
Expand All @@ -1057,7 +1057,7 @@ def test_clone_empty_array(fixture_class_parent: Type[Parent]):


@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
Expand All @@ -1076,7 +1076,7 @@ def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):


@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_nan(fixture_class_parent: Type[Parent]):
Expand Down Expand Up @@ -1105,7 +1105,7 @@ def test_clone_estimator_types(fixture_class_parent: Type[Parent]):


@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_class_rather_than_instance_raises_error(
Expand All @@ -1120,7 +1120,7 @@ def test_clone_class_rather_than_instance_raises_error(


@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
Expand Down
12 changes: 3 additions & 9 deletions skbase/utils/tests/test_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@

EXAMPLES += [X]

if _check_soft_dependencies(
"scikit-learn", package_import_alias={"scikit-learn": "sklearn"}, severity="none"
):
if _check_soft_dependencies("scikit-learn", severity="none"):
from sklearn.ensemble import RandomForestRegressor

EXAMPLES += [RandomForestRegressor()]
Expand Down Expand Up @@ -115,16 +113,12 @@ def test_deep_equals_negative(fixture1, fixture2):
def copy_except_if_sklearn(obj):
"""Copy obj if it is not a scikit-learn estimator.
We use this functoin as deep_copy should return True for
We use this function as deep_copy should return True for
identical sklearn estimators, but False for different copies.
This is the current status quo, possibly we want to change this in the future.
"""
if not _check_soft_dependencies(
"scikit-learn",
package_import_alias={"scikit-learn": "sklearn"},
severity="none",
):
if not _check_soft_dependencies("scikit-learn", severity="none"):
return deepcopy(obj)
else:
from sklearn.base import BaseEstimator
Expand Down

0 comments on commit 4b1bf2b

Please sign in to comment.