-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MMM.load
issue matching ids with tvp
#814
Comments
Hi @AlfredoJF This seems like a bug. Thank you for providing the information. I have a bug fix and will post a workaround shortly |
13 tasks
I have two solutions for you @AlfredoJF
class FixedMMM(MMM):
def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
idata.attrs["date_column"] = json.dumps(self.date_column)
idata.attrs["adstock"] = json.dumps(self.adstock.lookup_name)
idata.attrs["saturation"] = json.dumps(self.saturation.lookup_name)
idata.attrs["adstock_first"] = json.dumps(self.adstock_first)
idata.attrs["control_columns"] = json.dumps(self.control_columns)
idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
# These were missing
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media)
@classmethod
def load(cls, fname: str, **kwargs):
filepath = Path(fname)
idata = az.from_netcdf(filepath)
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
# Media Transformations
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
# Seasonality
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
# TVP
time_varying_intercept=json.loads(
idata.attrs.get("time_varying_intercept", False)
),
time_varying_media=json.loads(idata.attrs.get("time_varying_media", False)),
# Configurations
validate_data=json.loads(idata.attrs["validate_data"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
**kwargs,
)
model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model or "
f"configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)
return model
# Usage
mmm = FixedMMM(...)
mmm.fit(X, y)
mmm.save("saved-model")
loaded_mmm = FixedMMM.load("saved-model")
def load_already_fit(fname: str, cls=MMM, **missing_init_kwargs):
filepath = Path(fname)
idata = az.from_netcdf(filepath)
model_config = cls._model_config_formatting(json.loads(idata.attrs["model_config"]))
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
# Media Transformations
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
# Seasonality
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
# Configurations
validate_data=json.loads(idata.attrs["validate_data"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
**missing_init_kwargs,
)
model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model or "
f"configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)
return model
# Usage
mmm = MMM(..., time_varying_intercept=True, time_varying_media=False)
mmm.fit(X, y)
mmm.save("saved-model")
loaded_mmm = load_already_fit("saved-model", time_varying_intercept=True, time_varying_media=False) Give these a try and if you have any issues with them, let me know |
Amazing @wd60622! I'll give it a try and let you how it goes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Just encountered this issue when loading a trained model with time-varying set to True (either intercept or media params). It looks like
model.id
andidata.attrs["id"]
from load method line 624 is different after we build the model on line 622I do not know exactly what is going on under the hood but issue 757 also happens when loading a model with time-varying settings.
I made a test where I commented out the below lines from load method to workaround the loading error but when I ran the allocate_budget_to_maximize_response method still got the same error as in issue 757.
Do you know if there is a workaround or another way to load a trained model to avoid these 2 issues?
Thanks!
The text was updated successfully, but these errors were encountered: