Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support numpy arrays in model_config #351

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,17 @@ def channel_contributions_forward_pass(

@property
def _serializable_model_config(self) -> Dict[str, Any]:
def ndarray_to_list(d: Dict) -> Dict:
new_d = d.copy() # Copy the dictionary to avoid mutating the original one
for key, value in new_d.items():
if isinstance(value, np.ndarray):
new_d[key] = value.tolist()
elif isinstance(value, dict):
new_d[key] = ndarray_to_list(value)
return new_d
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

serializable_config = self.model_config.copy()
if type(serializable_config["beta_channel"]["sigma"]) == np.ndarray:
serializable_config["beta_channel"]["sigma"] = serializable_config[
"beta_channel"
]["sigma"].tolist()
return serializable_config
return ndarray_to_list(serializable_config)

@classmethod
def load(cls, fname: str):
Expand Down Expand Up @@ -481,6 +486,28 @@ def _data_setter(
with self.model:
pm.set_data(data)

@classmethod
def _model_config_formatting(cls, model_config: Dict) -> Dict:
"""
Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists.
This function converts them back to tuples and numpy arrays to ensure correct id encoding.
"""

def format_nested_dict(d: Dict) -> Dict:
for key, value in d.items():
if isinstance(value, dict):
d[key] = format_nested_dict(value)
elif isinstance(value, list):
# Check if the key is "dims" to convert it to tuple
if key == "dims":
d[key] = tuple(value)
# Convert all other lists to numpy arrays
else:
d[key] = np.array(value)
return d

return format_nested_dict(model_config.copy())


class DelayedSaturatedMMM(
MaxAbsScaleTarget,
Expand Down
67 changes: 67 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@ def toy_X() -> pd.DataFrame:
)


@pytest.fixture(scope="class")
def model_config_requiring_serialization() -> dict:
model_config = {
"intercept": {"mu": 0, "sigma": 2},
"beta_channel": {
"sigma": np.array([0.4533017, 0.25488063]),
"dims": ("channel",),
},
"alpha": {
"alpha": np.array([3, 3]),
"beta": np.array([3.55001301, 2.87092431]),
"dims": ("channel",),
},
"lam": {
"alpha": np.array([3, 3]),
"beta": np.array([4.12231653, 5.02896872]),
"dims": ("channel",),
},
"sigma": {"sigma": 2},
"gamma_control": {"mu": 0, "sigma": 2, "dims": ("control",)},
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
}
return model_config


@pytest.fixture(scope="class")
def toy_y(toy_X: pd.DataFrame) -> pd.Series:
return pd.Series(data=rng.integers(low=0, high=100, size=toy_X.shape[0]))
Expand All @@ -62,6 +89,46 @@ def mmm_fitted(


class TestDelayedSaturatedMMM:
def test_save_load_with_not_serializable_model_config(
self, model_config_requiring_serialization, toy_X, toy_y
):
def deep_equal(dict1, dict2):
for key, value in dict1.items():
if key not in dict2:
return False
if isinstance(value, dict):
if not deep_equal(value, dict2[key]):
return False
elif isinstance(value, np.ndarray):
if not np.array_equal(value, dict2[key]):
return False
else:
if value != dict2[key]:
return False
return True

model = DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
model_config=model_config_requiring_serialization,
)
model.fit(
toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng
)
model.save("test_save_load")
model2 = DelayedSaturatedMMM.load("test_save_load")
assert model.date_column == model2.date_column
assert model.control_columns == model2.control_columns
assert model.channel_columns == model2.channel_columns
assert model.adstock_max_lag == model2.adstock_max_lag
assert model.validate_data == model2.validate_data
assert model.yearly_seasonality == model2.yearly_seasonality
assert deep_equal(model.model_config, model2.model_config)

assert model.sampler_config == model2.sampler_config
os.remove("test_save_load")

@pytest.mark.parametrize(
argnames="adstock_max_lag",
argvalues=[1, 4],
Expand Down