Skip to content

Commit

Permalink
Merge branch 'main' into release-0.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Aug 14, 2023
2 parents 180ac57 + 9210fdb commit cb12356
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 18 deletions.
30 changes: 21 additions & 9 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,16 @@ def _get_params(
"""
# Set variables that let us use same code for retrieving params or fitted params
if fitted:
method = "_get_fitted_params"
method_shallow = "_get_fitted_params"
method_public = "get_fitted_params"
deepkw = {}
else:
method = "get_params"
method_shallow = "get_params"
method_public = "get_params"
deepkw = {"deep": deep}

# Get the direct params/fitted params
out = getattr(super(), method)(**deepkw)
out = getattr(super(), method_shallow)(**deepkw)

if deep and hasattr(self, attr):
named_objects = getattr(self, attr)
Expand All @@ -207,8 +209,15 @@ def _get_params(
]
out.update(named_objects_)
for name, obj in named_objects_:
if hasattr(obj, method):
for key, value in getattr(obj, method)(**deepkw).items():
# checks estimator has the method we want to call
cond1 = hasattr(obj, method_public)
# checks estimator is fitted if calling get_fitted_params
is_fitted = hasattr(obj, "is_fitted") and obj.is_fitted
# if we call get_params and not get_fitted_params, this is True
cond2 = not fitted or is_fitted
# check both conditions together
if cond1 and cond2:
for key, value in getattr(obj, method_public)(**deepkw).items():
out["%s__%s" % (name, key)] = value
return out

Expand All @@ -234,8 +243,8 @@ def _set_params(self, attr: str, **params):
# 2. Step replacement
items = getattr(self, attr)
names = []
if items:
names, _ = zip(*items)
if items and isinstance(items, (list, tuple)):
names = list(zip(*items))[0]
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
Expand All @@ -247,9 +256,12 @@ def _replace_object(self, attr: str, name: str, new_val: Any) -> None:
"""Replace an object in attribute that contains named objects."""
# assumes `name` is a valid object name
new_objects = list(getattr(self, attr))
for i, (object_name, _) in enumerate(new_objects):
for i, obj_tpl in enumerate(new_objects):
object_name = obj_tpl[0]
if object_name == name:
new_objects[i] = (name, new_val)
new_tpl = list(obj_tpl)
new_tpl[1] = new_val
new_objects[i] = tuple(new_tpl)
break
setattr(self, attr, new_objects)

Expand Down
57 changes: 48 additions & 9 deletions skbase/tests/test_meta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Tests for BaseMetaObject and BaseMetaEstimator mixins.
"""Tests for BaseMetaObject and BaseMetaEstimator mixins."""

tests in this module:
"""

__author__ = ["RNKuhns"]
__author__ = ["RNKuhns", "fkiraly"]
import inspect

import pytest
Expand All @@ -23,37 +18,51 @@


class MetaObjectTester(BaseMetaObject):
"""Class to test meta object functionality."""
"""Class to test meta-object functionality."""

def __init__(self, a=7, b="something", c=None, steps=None):
self.a = a
self.b = b
self.c = c
self.steps = steps
super().__init__()


class MetaEstimatorTester(BaseMetaEstimator):
"""Class to test meta estimator functionality."""
"""Class to test meta-estimator functionality."""

def __init__(self, a=7, b="something", c=None, steps=None):
self.a = a
self.b = b
self.c = c
self.steps = steps
super().__init__()


class ComponentDummy(BaseObject):
"""Class to use as components in meta-estimator."""

def __init__(self, a=7, b="something"):
self.a = a
self.b = b
super().__init__()


@pytest.fixture
def fixture_metaestimator_instance():
"""BaseMetaEstimator instance fixture."""
return BaseMetaEstimator()


@pytest.fixture
def fixture_meta_object():
"""MetaObjectTester instance fixture."""
return MetaObjectTester()


@pytest.fixture
def fixture_meta_estimator():
"""MetaEstimatorTester instance fixture."""
return MetaEstimatorTester()


Expand Down Expand Up @@ -129,3 +138,33 @@ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(

fixture_metaestimator_instance._is_fitted = True
assert fixture_metaestimator_instance.check_is_fitted() is None


@pytest.mark.parametrize("long_steps", (True, False))
def test_metaestimator_composite(long_steps):
"""Test composite meta-estimator functionality."""
if long_steps:
steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
else:
steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)]

meta_est = MetaEstimatorTester(steps=steps)

meta_est_params = meta_est.get_params()
assert isinstance(meta_est_params, dict)
expected_keys = [
"a",
"b",
"c",
"steps",
"foo",
"bar",
"foo__a",
"foo__b",
"bar__a",
"bar__b",
]
assert set(meta_est_params.keys()) == set(expected_keys)

meta_est.set_params(bar__b="something else")
assert meta_est.get_params()["bar__b"] == "something else"

0 comments on commit cb12356

Please sign in to comment.