Skip to content

Commit

Permalink
Remove combined flag from apply_sklearn_transformer_across_dim (#…
Browse files Browse the repository at this point in the history
…1010)

* Remove combined flag from apply_sklearn_transformer_across_dim

* Add on_missing_core_dim to allow lift_measurements

* Comment. Re-trigger pipeline
  • Loading branch information
PabloRoque authored and twiecki committed Sep 10, 2024
1 parent 98f2921 commit abdcf41
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 31 deletions.
2 changes: 0 additions & 2 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,6 @@ def new_spend_contributions(
data=channel_contributions,
func=self.get_target_transformer().inverse_transform,
dim_name="time_since_spend",
combined=False,
)

return channel_contributions
Expand Down Expand Up @@ -1882,7 +1881,6 @@ def sample_posterior_predictive(
data=posterior_predictive_samples,
func=self.get_target_transformer().inverse_transform,
dim_name="date",
combined=combined,
)

return posterior_predictive_samples
Expand Down
22 changes: 10 additions & 12 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def apply_sklearn_transformer_across_dim(
data: xr.DataArray,
func: Callable[[np.ndarray], np.ndarray],
dim_name: str,
combined: bool = False,
) -> xr.DataArray:
"""Apply a scikit-learn transformer across a dimension of an xarray DataArray.
Expand All @@ -211,8 +210,6 @@ def apply_sklearn_transformer_across_dim(
scikit-learn method to apply to the data
dim_name : str
Name of the dimension to apply the function to
combined : bool, default False
Flag to indicate if the data coords have been combined or not
Returns
-------
Expand All @@ -221,20 +218,21 @@ def apply_sklearn_transformer_across_dim(
"""
# These are lost during the ufunc
attrs = data.attrs
# Cache dims to restore them after the ufunc
dims = data.dims

if combined:
data = xr.apply_ufunc(
data = (
xr.apply_ufunc(
func,
data,
)
else:
data = xr.apply_ufunc(
func,
data.expand_dims(dim={"_": 1}, axis=1),
data.expand_dims("_"),
input_core_dims=[[dim_name, "_"]],
output_core_dims=[[dim_name, "_"]],
vectorize=True,
).squeeze(dim="_")
on_missing_core_dim="copy",
)
.squeeze(dim="_")
.transpose(*dims)
)

data.attrs = attrs

Expand Down
19 changes: 2 additions & 17 deletions tests/mmm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ def create_mock_mmm_return_data():
def _create_mock_mm_return_data(combined: bool) -> xr.DataArray:
dates = pd.date_range(start="2020-01-01", end="2020-01-31", freq="W-MON")
data = xr.DataArray(
np.ones(shape=(1, 3, len(dates))),
np.ones(shape=(1, 3, len(dates), 2)),
coords={
"chain": [1],
"draw": [1, 2, 3],
"date": dates,
"channel": ["channel1", "channel2"],
},
)

Expand All @@ -149,27 +150,11 @@ def test_apply_sklearn_function_across_dim(
data,
mock_method,
dim_name="date",
combined=combined,
)

xr.testing.assert_allclose(result, data * 2)


def test_apply_sklearn_function_across_dim_error(
mock_method,
create_mock_mmm_return_data,
) -> None:
data = create_mock_mmm_return_data(combined=False)

with pytest.raises(ValueError, match="x must be 2-dimensional"):
apply_sklearn_transformer_across_dim(
data,
mock_method,
dim_name="date",
combined=True,
)


@pytest.mark.parametrize("constructor", [pd.Series, np.array])
def test_transform_1d_array(constructor):
transform = MaxAbsScaler()
Expand Down

0 comments on commit abdcf41

Please sign in to comment.