From 0a9da8c33f4b942d25d01b7be84766d7d5394163 Mon Sep 17 00:00:00 2001 From: Markus Sagen Date: Mon, 23 Oct 2023 09:56:16 +0200 Subject: [PATCH] move register random vars and priors to ModelBuilder --- pymc_marketing/mmm/delayed_saturated_mmm.py | 43 +++++++++++++++++---- pymc_marketing/model_builder.py | 17 -------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 2af3ebae..c9748efc 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -118,6 +118,28 @@ def __init__( adstock_max_lag=adstock_max_lag, ) + # Define custom priors + self.intercept = self._create_distribution(self.model_config["intercept"]) + self.beta_channel = self._create_distribution(self.model_config["beta_channel"]) + self.alpha = self._create_distribution(self.model_config["alpha"]) + self.lam = self._create_distribution(self.model_config["lam"]) + self.sigma = self._create_distribution(self.model_config["sigma"]) + self.gamma_control = self._create_distribution( + self.model_config["gamma_control"] + ) + self.gamma_fourier = self._create_distribution( + self.model_config["gamma_fourier"] + ) + self._process_priors( + self.intercept, + self.beta_channel, + self.alpha, + self.lam, + self.sigma, + self.gamma_control, + self.gamma_fourier, + ) + @property def default_sampler_config(self) -> Dict: return {} @@ -231,11 +253,14 @@ def build_model( dims="date", ) - intercept = self.register_rv(name="intercept") - beta_channel = self.register_rv(name="beta_channel") - alpha = self.register_rv(name="alpha") - lam = self.register_rv(name="lam") - sigma = self.register_rv(name="sigma") + # FIXME: Need to add the correct dims to `beta_channel`, `alpha`, `lam`, + intercept = self.model.register_rv(self.intercept, name="intercept") + beta_channel = self.model.register_rv( + self.beta_channel, name="beta_channel" + ) + alpha = self.model.register_rv(self.alpha, name="alpha") + lam = self.model.register_rv(self.lam, name="lam") + sigma = self.model.register_rv(self.sigma, name="sigma") # TODO: register the adstock transforms channel_adstock = pm.Deterministic( @@ -269,7 +294,9 @@ def build_model( for column in self.control_columns ) ): - gamma_control = self.register_rv(name="gamma_control") + gamma_control = self.model.register_rv( + self.gamma_control, name="gamma_control" + ) control_data_ = pm.MutableData( name="control_data", @@ -293,7 +320,9 @@ def build_model( for column in self.fourier_columns ) ): - gamma_fourier = self.register_rv(name="gamma_fourier") + gamma_fourier = self.model.register_rv( + self.gamma_fourier, name="gamma_fourier" + ) fourier_data_ = pm.MutableData( name="fourier_data", diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index e9b77c04..1a034611 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -447,23 +447,6 @@ def _process_priors( prior.str_repr = types.MethodType(str_for_dist, prior) # type: ignore return priors - def register_rv( - self, - name: str, - observed: Any | None = None, - total_size: Any | None = None, - dims: Any | None = None, - transform: None = None, - ): - """Register random variables and priors from model_config to the pm.model object.""" - rv_var = self._create_distribution(self.model_config[name]) - self._process_priors(rv_var) - setattr(self, name, rv_var) - - dims = self.model_config[name].get("dims") - rv = self.model.register_rv(rv_var, name=name, dims=dims) - return rv - def fit( self, X: pd.DataFrame,