Skip to content

Commit

Permalink
[ENH] merge sktime BaseEstimator into skbase BaseEstimator (#370
Browse files Browse the repository at this point in the history
)

This merges some minor improvements from `sktime` `BaseEstimator` into
`skbase` `BaseEstimator`:

* docstring formatting
* clearer error messages
* deduplicating logic with `check_is_fitted` calls

Accompanies sktime/sktime#7213, see there for
explanation of merge and deduplication.
  • Loading branch information
fkiraly authored Oct 4, 2024
1 parent 2c2acdb commit c71dfe8
Showing 1 changed file with 37 additions and 23 deletions.
60 changes: 37 additions & 23 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,13 +1445,13 @@ class BaseEstimator(BaseObject):
def __init__(self):
"""Construct BaseEstimator."""
self._is_fitted = False
super(BaseEstimator, self).__init__()
super().__init__()

@property
def is_fitted(self):
"""Whether `fit` has been called.
"""Whether ``fit`` has been called.
Inspects object's `_is_fitted` attribute that should initialize to False
Inspects object's ``_is_fitted` attribute that should initialize to ``False``
during object construction, and be set to True in calls to an object's
`fit` method.
Expand All @@ -1460,25 +1460,43 @@ def is_fitted(self):
bool
Whether the estimator has been `fit`.
"""
return self._is_fitted
if hasattr(self, "_is_fitted"):
return self._is_fitted
else:
return False

def check_is_fitted(self):
def check_is_fitted(self, method_name=None):
"""Check if the estimator has been fitted.
Inspects object's `_is_fitted` attribute that should initialize to False
during object construction, and be set to True in calls to an object's
`fit` method.
Check if ``_is_fitted`` attribute is present and ``True``.
The ``is_fitted``
attribute should be set to ``True`` in calls to an object's ``fit`` method.
If not, raises a ``NotFittedError``.
Parameters
----------
method_name : str, optional
Name of the method that called this function. If provided, the error
message will include this information.
Raises
------
NotFittedError
If the estimator has not been fitted yet.
"""
if not self.is_fitted:
raise NotFittedError(
f"This instance of {self.__class__.__name__} has not been fitted yet. "
f"Please call `fit` first."
)
if method_name is None:
msg = (
f"This instance of {self.__class__.__name__} has not been fitted "
f"yet. Please call `fit` first."
)
else:
msg = (
f"This instance of {self.__class__.__name__} has not been fitted "
f"yet. Please call `fit` before calling `{method_name}`."
)
raise NotFittedError(msg)

def get_fitted_params(self, deep=True):
"""Get fitted parameters.
Expand All @@ -1503,19 +1521,15 @@ def get_fitted_params(self, deep=True):
Dictionary of fitted parameters, paramname : paramvalue
keys-value pairs include:
* always: all fitted parameters of this object, as via `get_param_names`
* always: all fitted parameters of this object, as via ``get_param_names``
values are fitted parameter value for that key, of this object
* if `deep=True`, also contains keys/value pairs of component parameters
parameters of components are indexed as `[componentname]__[paramname]`
all parameters of `componentname` appear as `paramname` with its value
* if `deep=True`, also contains arbitrary levels of component recursion,
e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc
* if ``deep=True``, also contains keys/value pairs of component parameters
parameters of components are indexed as ``[componentname]__[paramname]``
all parameters of ``componentname`` appear as ``paramname`` with its value
* if ``deep=True``, also contains arbitrary levels of component recursion,
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
"""
if not self.is_fitted:
raise NotFittedError(
f"estimator of type {type(self).__name__} has not been "
"fitted yet, please call fit on data before get_fitted_params"
)
self.check_is_fitted(method_name="get_fitted_params")

# collect non-nested fitted params of self
fitted_params = self._get_fitted_params()
Expand Down

0 comments on commit c71dfe8

Please sign in to comment.