Skip to content

Commit

Permalink
improve tests mmm utils (#738)
Browse files Browse the repository at this point in the history
* improve tests

* empty commit

* remove reduntant function
  • Loading branch information
juanitorduz authored and twiecki committed Sep 10, 2024
1 parent bbf7adf commit 95bf0c3
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 65 deletions.
28 changes: 0 additions & 28 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,34 +663,6 @@ def compute_channel_contribution_original_scale(self) -> DataArray:
coords=channel_contribution.coords,
)

def _get_distribution_from_dict(self, dist: dict) -> Callable:
"""
Retrieve a PyMC distribution callable based on the provided dictionary.
Parameters
----------
dist : Dict
A dictionary containing the key 'dist' which should correspond to the
name of a PyMC distribution.
Returns
-------
Callable
A PyMC distribution callable that can be used to instantiate a random
variable.
Raises
------
ValueError
If the specified distribution name in the dictionary does not correspond
to any distribution in PyMC.
"""
try:
prior_distribution = getattr(pm, dist["dist"])
except AttributeError:
raise ValueError(f"Distribution {dist['dist']} does not exist in PyMC")
return prior_distribution

def compute_mean_contributions_over_time(
self, original_scale: bool = False
) -> pd.DataFrame:
Expand Down
27 changes: 0 additions & 27 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Utility functions for the Marketing Mix Modeling module."""

import re
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -227,32 +226,6 @@ def find_sigmoid_inflection_point(
return x_inflection, y_inflection


def standardize_scenarios_dict_keys(d: dict, keywords: list[str]):
"""
Standardize the keys in a dictionary based on a list of keywords.
This function iterates over the keys in the dictionary and the keywords.
If a keyword is found in a key (case-insensitive), the key is replaced with the keyword.
Parameters
----------
d : dict
The dictionary whose keys are to be standardized.
keywords : list
The list of keywords to standardize the keys to.
Returns
-------
None
The function modifies the given dictionary in-place and doesn't return any object.
"""
for keyword in keywords:
for key in list(d.keys()):
if re.search(keyword, key, re.IGNORECASE):
d[keyword] = d.pop(key)
break


def apply_sklearn_transformer_across_dim(
data: xr.DataArray,
func: Callable[[np.ndarray], np.ndarray],
Expand Down
10 changes: 0 additions & 10 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,16 +863,6 @@ def test_new_data_predict_method(
# assert lower < toy_y.mean() < upper


def test_get_valid_distribution(mmm):
normal_dist = mmm._get_distribution_from_dict({"dist": "Normal"})
assert normal_dist is pm.Normal


def test_get_invalid_distribution(mmm):
with pytest.raises(ValueError, match="does not exist in PyMC"):
mmm._get_distribution_from_dict({"dist": "NonExistentDist"})


def test_invalid_likelihood_type(mmm):
with pytest.raises(
ValueError,
Expand Down
32 changes: 32 additions & 0 deletions tests/mmm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import numpy as np
import pandas as pd
import pymc as pm
import pytest
import xarray as xr
from sklearn.preprocessing import MaxAbsScaler

from pymc_marketing.mmm.utils import (
_get_distribution_from_dict,
apply_sklearn_transformer_across_dim,
compute_sigmoid_second_derivative,
create_new_spend_data,
Expand Down Expand Up @@ -344,3 +346,33 @@ def test_create_new_spend_data(
new_spend_data,
np.array(expected_result),
)


def test_create_new_spend_data_value_errors() -> None:
with pytest.raises(
ValueError, match="spend_leading_up must be the same length as the spend"
):
create_new_spend_data(
spend=np.array([1, 2]),
adstock_max_lag=2,
one_time=True,
spend_leading_up=np.array([3, 4, 5]),
)


@pytest.mark.parametrize(
argnames="distribution_dict, expected",
argvalues=[
({"dist": "Normal"}, pm.Normal),
({"dist": "Gamma"}, pm.Gamma),
({"dist": "StudentT"}, pm.StudentT),
],
ids=["Normal", "Gamma", "StudentT"],
)
def test_get_distribution_from_dict(distribution_dict, expected):
assert _get_distribution_from_dict(distribution_dict) == expected


def test_get_distribution_from_dict_value_error():
with pytest.raises(ValueError):
_get_distribution_from_dict({"dist": "InvalidDistribution"})

0 comments on commit 95bf0c3

Please sign in to comment.