Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support numpy arrays in model_config #351

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,17 @@

@property
def _serializable_model_config(self) -> Dict[str, Any]:
def ndarray_to_list(d: Dict) -> Dict:
new_d = d.copy() # Copy the dictionary to avoid mutating the original one
for key, value in new_d.items():
if isinstance(value, np.ndarray):
new_d[key] = value.tolist()
elif isinstance(value, dict):
new_d[key] = ndarray_to_list(value)
return new_d
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

serializable_config = self.model_config.copy()
if type(serializable_config["beta_channel"]["sigma"]) == np.ndarray:
serializable_config["beta_channel"]["sigma"] = serializable_config[
"beta_channel"
]["sigma"].tolist()
return serializable_config
return ndarray_to_list(serializable_config)

@classmethod
def load(cls, fname: str):
Expand Down Expand Up @@ -481,6 +486,30 @@
with self.model:
pm.set_data(data)

@classmethod
def _model_config_formatting(cls, model_config: Dict) -> Dict:
"""
Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists.
This function converts them back to tuples and numpy arrays to ensure correct id encoding.
"""
for key in model_config:
if isinstance(model_config[key], dict):
michaelraczycki marked this conversation as resolved.
Show resolved Hide resolved
for sub_key in model_config[key]:
if isinstance(model_config[key][sub_key], list):
# Check if "dims" key to convert it to tuple
if sub_key == "dims":
model_config[key][sub_key] = tuple(
model_config[key][sub_key]
)
# Convert all other lists to numpy arrays
else:
model_config[key][sub_key] = np.array(
model_config[key][sub_key]
)
elif isinstance(model_config[key], list):
model_config[key] = np.array(model_config[key])

Check warning on line 510 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L509-L510

Added lines #L509 - L510 were not covered by tests
return model_config


class DelayedSaturatedMMM(
MaxAbsScaleTarget,
Expand Down
67 changes: 67 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@ def toy_X() -> pd.DataFrame:
)


@pytest.fixture(scope="class")
def model_config_requiring_serialization() -> dict:
model_config = {
"intercept": {"mu": 0, "sigma": 2},
"beta_channel": {
"sigma": np.array([0.4533017, 0.25488063]),
"dims": ("channel",),
},
"alpha": {
"alpha": np.array([3, 3]),
"beta": np.array([3.55001301, 2.87092431]),
"dims": ("channel",),
},
"lam": {
"alpha": np.array([3, 3]),
"beta": np.array([4.12231653, 5.02896872]),
"dims": ("channel",),
},
"sigma": {"sigma": 2},
"gamma_control": {"mu": 0, "sigma": 2, "dims": ("control",)},
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
}
return model_config


@pytest.fixture(scope="class")
def toy_y(toy_X: pd.DataFrame) -> pd.Series:
return pd.Series(data=rng.integers(low=0, high=100, size=toy_X.shape[0]))
Expand All @@ -62,6 +89,46 @@ def mmm_fitted(


class TestDelayedSaturatedMMM:
def test_save_load_with_not_serializable_model_config(
self, model_config_requiring_serialization, toy_X, toy_y
):
def deep_equal(dict1, dict2):
for key, value in dict1.items():
if key not in dict2:
return False
if isinstance(value, dict):
if not deep_equal(value, dict2[key]):
return False
elif isinstance(value, np.ndarray):
if not np.array_equal(value, dict2[key]):
return False
else:
if value != dict2[key]:
return False
return True

model = DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
model_config=model_config_requiring_serialization,
)
model.fit(
toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng
)
model.save("test_save_load")
model2 = DelayedSaturatedMMM.load("test_save_load")
assert model.date_column == model2.date_column
assert model.control_columns == model2.control_columns
assert model.channel_columns == model2.channel_columns
assert model.adstock_max_lag == model2.adstock_max_lag
assert model.validate_data == model2.validate_data
assert model.yearly_seasonality == model2.yearly_seasonality
assert deep_equal(model.model_config, model2.model_config)

assert model.sampler_config == model2.sampler_config
os.remove("test_save_load")

@pytest.mark.parametrize(
argnames="adstock_max_lag",
argvalues=[1, 4],
Expand Down