Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMM.load issue matching ids with tvp #814

Closed
AlfredoJF opened this issue Jul 8, 2024 · 4 comments · Fixed by #815
Closed

MMM.load issue matching ids with tvp #814

AlfredoJF opened this issue Jul 8, 2024 · 4 comments · Fixed by #815
Labels
bug Something isn't working MMM

Comments

@AlfredoJF
Copy link

Just encountered this issue when loading a trained model with time-varying set to True (either intercept or media params). It looks like model.id and idata.attrs["id"] from load method line 624 is different after we build the model on line 622

I do not know exactly what is going on under the hood but issue 757 also happens when loading a model with time-varying settings.

I made a test where I commented out the below lines from load method to workaround the loading error but when I ran the allocate_budget_to_maximize_response method still got the same error as in issue 757.

        # All previously used data is in idata.
        if model.id != idata.attrs["id"]:
            error_msg = f"""The file '{fname}' does not contain an inference data of the same model
        or configuration as '{cls._model_type}'"""
            raise ValueError(error_msg)

Do you know if there is a workaround or another way to load a trained model to avoid these 2 issues?

Thanks!

@wd60622 wd60622 added bug Something isn't working MMM labels Jul 8, 2024
@wd60622
Copy link
Contributor

wd60622 commented Jul 8, 2024

Hi @AlfredoJF

This seems like a bug. Thank you for providing the information. I have a bug fix and will post a workaround shortly

@wd60622
Copy link
Contributor

wd60622 commented Jul 8, 2024

I have two solutions for you @AlfredoJF

  1. If you are fitting new models, then this class can be used
class FixedMMM(MMM):
    def _save_input_params(self, idata) -> None:
        """Saves input parameters to the attrs of idata."""
        idata.attrs["date_column"] = json.dumps(self.date_column)
        idata.attrs["adstock"] = json.dumps(self.adstock.lookup_name)
        idata.attrs["saturation"] = json.dumps(self.saturation.lookup_name)
        idata.attrs["adstock_first"] = json.dumps(self.adstock_first)
        idata.attrs["control_columns"] = json.dumps(self.control_columns)
        idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
        idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
        idata.attrs["validate_data"] = json.dumps(self.validate_data)
        idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
        # These were missing
        idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
        idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media)

    @classmethod
    def load(cls, fname: str, **kwargs):
        filepath = Path(fname)
        idata = az.from_netcdf(filepath)
        model_config = cls._model_config_formatting(
            json.loads(idata.attrs["model_config"])
        )
        model = cls(
            date_column=json.loads(idata.attrs["date_column"]),
            control_columns=json.loads(idata.attrs["control_columns"]),
            # Media Transformations
            channel_columns=json.loads(idata.attrs["channel_columns"]),
            adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
            adstock=json.loads(idata.attrs.get("adstock", "geometric")),
            saturation=json.loads(idata.attrs.get("saturation", "logistic")),
            adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
            # Seasonality
            yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
            # TVP
            time_varying_intercept=json.loads(
                idata.attrs.get("time_varying_intercept", False)
            ),
            time_varying_media=json.loads(idata.attrs.get("time_varying_media", False)),
            # Configurations
            validate_data=json.loads(idata.attrs["validate_data"]),
            model_config=model_config,
            sampler_config=json.loads(idata.attrs["sampler_config"]),
            **kwargs,
        )
        model.idata = idata
        dataset = idata.fit_data.to_dataframe()
        X = dataset.drop(columns=[model.output_var])
        y = dataset[model.output_var].values
        model.build_model(X, y)
        # All previously used data is in idata.
        if model.id != idata.attrs["id"]:
            error_msg = (
                f"The file '{fname}' does not contain "
                "an inference data of the same model or "
                f"configuration as '{cls._model_type}'"
            )
            raise ValueError(error_msg)

        return model

# Usage
mmm = FixedMMM(...)
mmm.fit(X, y)
mmm.save("saved-model")
loaded_mmm = FixedMMM.load("saved-model")
  1. If you already have some saved off models with MMM class then you can use this function
def load_already_fit(fname: str, cls=MMM, **missing_init_kwargs):
    filepath = Path(fname)
    idata = az.from_netcdf(filepath)
    model_config = cls._model_config_formatting(json.loads(idata.attrs["model_config"]))
    model = cls(
        date_column=json.loads(idata.attrs["date_column"]),
        control_columns=json.loads(idata.attrs["control_columns"]),
        # Media Transformations
        channel_columns=json.loads(idata.attrs["channel_columns"]),
        adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
        adstock=json.loads(idata.attrs.get("adstock", "geometric")),
        saturation=json.loads(idata.attrs.get("saturation", "logistic")),
        adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
        # Seasonality
        yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
        # Configurations
        validate_data=json.loads(idata.attrs["validate_data"]),
        model_config=model_config,
        sampler_config=json.loads(idata.attrs["sampler_config"]),
        **missing_init_kwargs,
    )
    model.idata = idata
    dataset = idata.fit_data.to_dataframe()
    X = dataset.drop(columns=[model.output_var])
    y = dataset[model.output_var].values
    model.build_model(X, y)
    # All previously used data is in idata.
    if model.id != idata.attrs["id"]:
        error_msg = (
            f"The file '{fname}' does not contain "
            "an inference data of the same model or "
            f"configuration as '{cls._model_type}'"
        )
        raise ValueError(error_msg)

    return model

# Usage
mmm = MMM(..., time_varying_intercept=True, time_varying_media=False)
mmm.fit(X, y)
mmm.save("saved-model")
loaded_mmm = load_already_fit("saved-model", time_varying_intercept=True, time_varying_media=False)

Give these a try and if you have any issues with them, let me know

@AlfredoJF
Copy link
Author

Amazing @wd60622! I'll give it a try and let you how it goes

@AlfredoJF
Copy link
Author

Thanks @wd60622! I can confirm the first option fixed this issue and #757 too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working MMM
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants