Skip to content

Commit

Permalink
Fallback to defaults in adstock|saturation_from_dict (#955)
Browse files Browse the repository at this point in the history
* Default saturation_from_dict to default_priors

* Default to AdstockTransformation.default priors in adstock_from_dict
  • Loading branch information
PabloRoque authored Aug 22, 2024
1 parent 3126ae0 commit f6c6825
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
3 changes: 2 additions & 1 deletion pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def adstock_from_dict(data: dict) -> AdstockTransformation:
lookup_name = data.pop("lookup_name")
cls = ADSTOCK_TRANSFORMATIONS[lookup_name]

data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
if "priors" in data:
data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
return cls(**data)


Expand Down
7 changes: 4 additions & 3 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,10 @@ def saturation_from_dict(data: dict) -> SaturationTransformation:
data = data.copy()
cls = SATURATION_TRANSFORMATIONS[data.pop("lookup_name")]

data["priors"] = {
key: Prior.from_json(value) for key, value in data["priors"].items()
}
if "priors" in data:
data["priors"] = {
key: Prior.from_json(value) for key, value in data["priors"].items()
}
return cls(**data)


Expand Down
18 changes: 18 additions & 0 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ def test_adstock_from_dict() -> None:
)


@pytest.mark.parametrize(
"adstock",
adstocks(),
)
def test_adstock_from_dict_without_priors(adstock) -> None:
data = {
"lookup_name": adstock.lookup_name,
"l_max": 10,
"prefix": "test",
"mode": "Before",
}

adstock = adstock_from_dict(data)
assert adstock.default_priors == {
k: Prior.from_json(v) for k, v in adstock.to_dict()["priors"].items()
}


def test_register_adstock_transformation() -> None:
class NewTransformation(AdstockTransformation):
lookup_name: str = "new_transformation"
Expand Down
12 changes: 12 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,15 @@ def test_saturation_from_dict() -> None:
"lam": Prior("HalfNormal", sigma=1),
}
)


@pytest.mark.parametrize("saturation", saturation_functions())
def test_saturation_from_dict_without_priors(saturation) -> None:
data = {
"lookup_name": saturation.lookup_name,
}

saturation = saturation_from_dict(data)
assert saturation.default_priors == {
k: Prior.from_json(v) for k, v in saturation.to_dict()["priors"].items()
}

0 comments on commit f6c6825

Please sign in to comment.