Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] option to return BaseObject.get_param_names in the same order as in the __init__ #335

Merged
merged 7 commits into from
Jun 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from collections import defaultdict
from copy import deepcopy
from typing import List
from warnings import warn

from skbase._exceptions import NotFittedError
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
Expand Down Expand Up @@ -205,17 +206,43 @@
)
return parameters

# todo 0.10.0: changed sorted default to False
# update docstring, and remove warning
@classmethod
def get_param_names(cls):
def get_param_names(cls, sorted=None):
"""Get object's parameter names.

Parameters
----------
sorted : bool, default=True
Whether to return the parameter names sorted in alphabetical order (True),
or in the order they appear in the class ``__init__`` (False).

Returns
-------
param_names: list[str]
Alphabetically sorted list of parameter names of cls.
List of parameter names of cls.
If ``sorted=False``, in same order as they appear in the class ``__init__``.
If ``sorted=True``, alphabetically ordered.
"""
if sorted is None:
sorted = True
warn(
"In scikit-base BaseObject.get_param_names, the default of parameter"
" 'sorted' will change from True to False in 0.10.0. "
"This will change the order of the output, returning the parameters"
" in the order they appear in the class __init__, "
"rather than alphabetically ordered. "
"To retain previous behaviour in direct calls, "
"set 'sorted=True'. To silence this warning, set 'sorted' to "
"either True or False.",
FutureWarning,
)

parameters = cls._get_init_signature()
param_names = sorted([p.name for p in parameters])
param_names = [p.name for p in parameters]
if sorted:
param_names = sorted(param_names)
Fixed Show fixed Hide fixed
return param_names

@classmethod
Expand Down Expand Up @@ -586,7 +613,7 @@
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params_with_defaults = set(cls.get_param_defaults().keys())
all_params = set(cls.get_param_names())
all_params = set(cls.get_param_names(sorted=False))
params_without_defaults = all_params - params_with_defaults

# if non-default parameters are required, but none have been found, raise error
Expand Down
Loading