From abdcf41a2cef9ab0dbebcb7ba1f06e16143757ce Mon Sep 17 00:00:00 2001 From: Pablo de Roque Date: Mon, 9 Sep 2024 19:05:53 +0200 Subject: [PATCH] Remove `combined` flag from `apply_sklearn_transformer_across_dim` (#1010) * Remove combined flag from apply_sklearn_transformer_across_dim * Add on_missing_core_dim to allow lift_measurements * Comment. Re-trigger pipeline --- pymc_marketing/mmm/mmm.py | 2 -- pymc_marketing/mmm/utils.py | 22 ++++++++++------------ tests/mmm/test_utils.py | 19 ++----------------- 3 files changed, 12 insertions(+), 31 deletions(-) diff --git a/pymc_marketing/mmm/mmm.py b/pymc_marketing/mmm/mmm.py index b9522a1e..af6149ce 100644 --- a/pymc_marketing/mmm/mmm.py +++ b/pymc_marketing/mmm/mmm.py @@ -1403,7 +1403,6 @@ def new_spend_contributions( data=channel_contributions, func=self.get_target_transformer().inverse_transform, dim_name="time_since_spend", - combined=False, ) return channel_contributions @@ -1882,7 +1881,6 @@ def sample_posterior_predictive( data=posterior_predictive_samples, func=self.get_target_transformer().inverse_transform, dim_name="date", - combined=combined, ) return posterior_predictive_samples diff --git a/pymc_marketing/mmm/utils.py b/pymc_marketing/mmm/utils.py index 69920115..32638aef 100644 --- a/pymc_marketing/mmm/utils.py +++ b/pymc_marketing/mmm/utils.py @@ -197,7 +197,6 @@ def apply_sklearn_transformer_across_dim( data: xr.DataArray, func: Callable[[np.ndarray], np.ndarray], dim_name: str, - combined: bool = False, ) -> xr.DataArray: """Apply a scikit-learn transformer across a dimension of an xarray DataArray. @@ -211,8 +210,6 @@ def apply_sklearn_transformer_across_dim( scikit-learn method to apply to the data dim_name : str Name of the dimension to apply the function to - combined : bool, default False - Flag to indicate if the data coords have been combined or not Returns ------- @@ -221,20 +218,21 @@ def apply_sklearn_transformer_across_dim( """ # These are lost during the ufunc attrs = data.attrs + # Cache dims to restore them after the ufunc + dims = data.dims - if combined: - data = xr.apply_ufunc( + data = ( + xr.apply_ufunc( func, - data, - ) - else: - data = xr.apply_ufunc( - func, - data.expand_dims(dim={"_": 1}, axis=1), + data.expand_dims("_"), input_core_dims=[[dim_name, "_"]], output_core_dims=[[dim_name, "_"]], vectorize=True, - ).squeeze(dim="_") + on_missing_core_dim="copy", + ) + .squeeze(dim="_") + .transpose(*dims) + ) data.attrs = attrs diff --git a/tests/mmm/test_utils.py b/tests/mmm/test_utils.py index c76f89b6..04650e0f 100644 --- a/tests/mmm/test_utils.py +++ b/tests/mmm/test_utils.py @@ -123,11 +123,12 @@ def create_mock_mmm_return_data(): def _create_mock_mm_return_data(combined: bool) -> xr.DataArray: dates = pd.date_range(start="2020-01-01", end="2020-01-31", freq="W-MON") data = xr.DataArray( - np.ones(shape=(1, 3, len(dates))), + np.ones(shape=(1, 3, len(dates), 2)), coords={ "chain": [1], "draw": [1, 2, 3], "date": dates, + "channel": ["channel1", "channel2"], }, ) @@ -149,27 +150,11 @@ def test_apply_sklearn_function_across_dim( data, mock_method, dim_name="date", - combined=combined, ) xr.testing.assert_allclose(result, data * 2) -def test_apply_sklearn_function_across_dim_error( - mock_method, - create_mock_mmm_return_data, -) -> None: - data = create_mock_mmm_return_data(combined=False) - - with pytest.raises(ValueError, match="x must be 2-dimensional"): - apply_sklearn_transformer_across_dim( - data, - mock_method, - dim_name="date", - combined=True, - ) - - @pytest.mark.parametrize("constructor", [pd.Series, np.array]) def test_transform_1d_array(constructor): transform = MaxAbsScaler()