From f6c682587ea710d40af02e7a045d2b5971e92c76 Mon Sep 17 00:00:00 2001 From: Pablo de Roque Date: Thu, 22 Aug 2024 04:02:31 +0200 Subject: [PATCH] Fallback to defaults in `adstock|saturation_from_dict` (#955) * Default saturation_from_dict to default_priors * Default to AdstockTransformation.default priors in adstock_from_dict --- pymc_marketing/mmm/components/adstock.py | 3 ++- pymc_marketing/mmm/components/saturation.py | 7 ++++--- tests/mmm/components/test_adstock.py | 18 ++++++++++++++++++ tests/mmm/components/test_saturation.py | 12 ++++++++++++ 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index 28a4e8ba..81756d1d 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -405,7 +405,8 @@ def adstock_from_dict(data: dict) -> AdstockTransformation: lookup_name = data.pop("lookup_name") cls = ADSTOCK_TRANSFORMATIONS[lookup_name] - data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()} + if "priors" in data: + data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()} return cls(**data) diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index 7cba45aa..767b4d64 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -465,9 +465,10 @@ def saturation_from_dict(data: dict) -> SaturationTransformation: data = data.copy() cls = SATURATION_TRANSFORMATIONS[data.pop("lookup_name")] - data["priors"] = { - key: Prior.from_json(value) for key, value in data["priors"].items() - } + if "priors" in data: + data["priors"] = { + key: Prior.from_json(value) for key, value in data["priors"].items() + } return cls(**data) diff --git a/tests/mmm/components/test_adstock.py b/tests/mmm/components/test_adstock.py index c932d7e5..407a5f79 100644 --- a/tests/mmm/components/test_adstock.py +++ b/tests/mmm/components/test_adstock.py @@ -178,6 +178,24 @@ def test_adstock_from_dict() -> None: ) +@pytest.mark.parametrize( + "adstock", + adstocks(), +) +def test_adstock_from_dict_without_priors(adstock) -> None: + data = { + "lookup_name": adstock.lookup_name, + "l_max": 10, + "prefix": "test", + "mode": "Before", + } + + adstock = adstock_from_dict(data) + assert adstock.default_priors == { + k: Prior.from_json(v) for k, v in adstock.to_dict()["priors"].items() + } + + def test_register_adstock_transformation() -> None: class NewTransformation(AdstockTransformation): lookup_name: str = "new_transformation" diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index cebd9cd6..9768a5e3 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -258,3 +258,15 @@ def test_saturation_from_dict() -> None: "lam": Prior("HalfNormal", sigma=1), } ) + + +@pytest.mark.parametrize("saturation", saturation_functions()) +def test_saturation_from_dict_without_priors(saturation) -> None: + data = { + "lookup_name": saturation.lookup_name, + } + + saturation = saturation_from_dict(data) + assert saturation.default_priors == { + k: Prior.from_json(v) for k, v in saturation.to_dict()["priors"].items() + }