Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] option to return BaseObject.get_param_names in the same order as in the __init__ #335

Merged
merged 7 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from collections import defaultdict
from copy import deepcopy
from typing import List
from warnings import warn

from skbase._exceptions import NotFittedError
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
Expand Down Expand Up @@ -205,17 +206,43 @@
)
return parameters

# todo 0.10.0: changed sort default to False
# update docstring, and remove warning
@classmethod
def get_param_names(cls):
def get_param_names(cls, sort=None):
"""Get object's parameter names.

Parameters
----------
sort : bool, default=True
Whether to return the parameter names sorted in alphabetical order (True),
or in the order they appear in the class ``__init__`` (False).

Returns
-------
param_names: list[str]
Alphabetically sorted list of parameter names of cls.
List of parameter names of cls.
If ``sort=False``, in same order as they appear in the class ``__init__``.
If ``sort=True``, alphabetically ordered.
"""
if sort is None:
sort = True
Fixed Show fixed Hide fixed
warn(
"In scikit-base BaseObject.get_param_names, the default of parameter"
" 'sorted' will change from True to False in 0.10.0. "
"This will change the order of the output, returning the parameters"
" in the order they appear in the class __init__, "
"rather than alphabetically ordered. "
"To retain previous behaviour in direct calls, "
"set 'sort=True'. To silence this warning, set 'sorted' to "
"either True or False.",
FutureWarning,
)

parameters = cls._get_init_signature()
param_names = sorted([p.name for p in parameters])
param_names = [p.name for p in parameters]
if sorted:
param_names = sorted(param_names)
Fixed Show fixed Hide fixed
return param_names

@classmethod
Expand Down Expand Up @@ -586,7 +613,7 @@
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params_with_defaults = set(cls.get_param_defaults().keys())
all_params = set(cls.get_param_names())
all_params = set(cls.get_param_names(sort=False))
params_without_defaults = all_params - params_with_defaults

# if non-default parameters are required, but none have been found, raise error
Expand Down
11 changes: 8 additions & 3 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,21 @@ def test_get_init_signature_raises_error_for_invalid_signature(
fixture_invalid_init._get_init_signature()


@pytest.mark.parametrize("sorted", [True, False])
def test_get_param_names(
fixture_object: Type[BaseObject],
fixture_class_parent: Type[Parent],
fixture_class_parent_expected_params: Dict[str, Any],
sorted: bool,
):
"""Test that get_param_names returns list of string parameter names."""
param_names = fixture_class_parent.get_param_names()
assert param_names == sorted([*fixture_class_parent_expected_params])
param_names = fixture_class_parent.get_param_names(sort=sorted)
if sorted:
assert param_names == sorted([*fixture_class_parent_expected_params])
else:
assert param_names == [*fixture_class_parent_expected_params]

param_names = fixture_object.get_param_names()
param_names = fixture_object.get_param_names(sort=sorted)
assert param_names == []


Expand Down
Loading