From 9e570d0cbc470c49450268ca3979ccb4387c2c68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 14 Aug 2023 20:47:42 +0100 Subject: [PATCH 1/2] [BUG] fix for `get_fitted_params` in `_HeterogenousMetaEstimator` (#191) Upstream version of `sktime` bugfix https://github.com/sktime/sktime/pull/4633 This fixes a bug in `get_fitted_params` of `_HeterogenousMetaEstimator`, which is at the root of https://github.com/sktime/sktime/issues/4574. `get_fitted_params` of `_HeterogenousMetaEstimator` was accidentally calling the private interface point of components rather than the public interface point that it should have. For reviewers: the first call should be to `_get_fitted_params` of the base class, which is correct. Subsequent calls to components should be to the public interface, `get_fitted_params`. --- skbase/base/_meta.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index 485c406e..57681a1d 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -188,14 +188,16 @@ def _get_params( """ # Set variables that let us use same code for retrieving params or fitted params if fitted: - method = "_get_fitted_params" + method_shallow = "_get_fitted_params" + method_public = "get_fitted_params" deepkw = {} else: - method = "get_params" + method_shallow = "get_params" + method_public = "get_params" deepkw = {"deep": deep} # Get the direct params/fitted params - out = getattr(super(), method)(**deepkw) + out = getattr(super(), method_shallow)(**deepkw) if deep and hasattr(self, attr): named_objects = getattr(self, attr) @@ -207,8 +209,15 @@ def _get_params( ] out.update(named_objects_) for name, obj in named_objects_: - if hasattr(obj, method): - for key, value in getattr(obj, method)(**deepkw).items(): + # checks estimator has the method we want to call + cond1 = hasattr(obj, method_public) + # checks estimator is fitted if calling get_fitted_params + is_fitted = hasattr(obj, "is_fitted") and obj.is_fitted + # if we call get_params and not get_fitted_params, this is True + cond2 = not fitted or is_fitted + # check both conditions together + if cond1 and cond2: + for key, value in getattr(obj, method_public)(**deepkw).items(): out["%s__%s" % (name, key)] = value return out From 9210fdb9dd9e4f441dda42aa2f61db6d5bd8da28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 14 Aug 2023 21:12:54 +0100 Subject: [PATCH 2/2] [ENH] `_HeterogenousMetaObject` to accept list of tuples of any length (#206) Mirror PR of https://github.com/sktime/sktime/pull/4793 from `sktime`. This improves the `_HeterogenousMetaObject` by widening its functionality. `_HeterogenousMetaObject` now allows tuples of any length in the `_steps_attr`, as long as the zeroth elements are str names, and the first elements are estimators --- skbase/base/_meta.py | 11 +++++--- skbase/tests/test_meta.py | 57 ++++++++++++++++++++++++++++++++------- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index 57681a1d..a5042484 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -243,8 +243,8 @@ def _set_params(self, attr: str, **params): # 2. Step replacement items = getattr(self, attr) names = [] - if items: - names, _ = zip(*items) + if items and isinstance(items, (list, tuple)): + names = list(zip(*items))[0] for name in list(params.keys()): if "__" not in name and name in names: self._replace_object(attr, name, params.pop(name)) @@ -256,9 +256,12 @@ def _replace_object(self, attr: str, name: str, new_val: Any) -> None: """Replace an object in attribute that contains named objects.""" # assumes `name` is a valid object name new_objects = list(getattr(self, attr)) - for i, (object_name, _) in enumerate(new_objects): + for i, obj_tpl in enumerate(new_objects): + object_name = obj_tpl[0] if object_name == name: - new_objects[i] = (name, new_val) + new_tpl = list(obj_tpl) + new_tpl[1] = new_val + new_objects[i] = tuple(new_tpl) break setattr(self, attr, new_objects) diff --git a/skbase/tests/test_meta.py b/skbase/tests/test_meta.py index 75414992..f2c6c8fa 100644 --- a/skbase/tests/test_meta.py +++ b/skbase/tests/test_meta.py @@ -1,13 +1,8 @@ # -*- coding: utf-8 -*- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) -"""Tests for BaseMetaObject and BaseMetaEstimator mixins. +"""Tests for BaseMetaObject and BaseMetaEstimator mixins.""" -tests in this module: - - -""" - -__author__ = ["RNKuhns"] +__author__ = ["RNKuhns", "fkiraly"] import inspect import pytest @@ -23,37 +18,51 @@ class MetaObjectTester(BaseMetaObject): - """Class to test meta object functionality.""" + """Class to test meta-object functionality.""" def __init__(self, a=7, b="something", c=None, steps=None): self.a = a self.b = b self.c = c self.steps = steps + super().__init__() class MetaEstimatorTester(BaseMetaEstimator): - """Class to test meta estimator functionality.""" + """Class to test meta-estimator functionality.""" def __init__(self, a=7, b="something", c=None, steps=None): self.a = a self.b = b self.c = c self.steps = steps + super().__init__() + + +class ComponentDummy(BaseObject): + """Class to use as components in meta-estimator.""" + + def __init__(self, a=7, b="something"): + self.a = a + self.b = b + super().__init__() @pytest.fixture def fixture_metaestimator_instance(): + """BaseMetaEstimator instance fixture.""" return BaseMetaEstimator() @pytest.fixture def fixture_meta_object(): + """MetaObjectTester instance fixture.""" return MetaObjectTester() @pytest.fixture def fixture_meta_estimator(): + """MetaEstimatorTester instance fixture.""" return MetaEstimatorTester() @@ -129,3 +138,33 @@ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted( fixture_metaestimator_instance._is_fitted = True assert fixture_metaestimator_instance.check_is_fitted() is None + + +@pytest.mark.parametrize("long_steps", (True, False)) +def test_metaestimator_composite(long_steps): + """Test composite meta-estimator functionality.""" + if long_steps: + steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))] + else: + steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)] + + meta_est = MetaEstimatorTester(steps=steps) + + meta_est_params = meta_est.get_params() + assert isinstance(meta_est_params, dict) + expected_keys = [ + "a", + "b", + "c", + "steps", + "foo", + "bar", + "foo__a", + "foo__b", + "bar__a", + "bar__b", + ] + assert set(meta_est_params.keys()) == set(expected_keys) + + meta_est.set_params(bar__b="something else") + assert meta_est.get_params()["bar__b"] == "something else"