Skip to content

Commit

Permalink
move register random vars and priors to ModelBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSagen committed Oct 23, 2023
1 parent 439115c commit 0a9da8c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
43 changes: 36 additions & 7 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ def __init__(
adstock_max_lag=adstock_max_lag,
)

# Define custom priors
self.intercept = self._create_distribution(self.model_config["intercept"])
self.beta_channel = self._create_distribution(self.model_config["beta_channel"])
self.alpha = self._create_distribution(self.model_config["alpha"])
self.lam = self._create_distribution(self.model_config["lam"])
self.sigma = self._create_distribution(self.model_config["sigma"])
self.gamma_control = self._create_distribution(
self.model_config["gamma_control"]
)
self.gamma_fourier = self._create_distribution(
self.model_config["gamma_fourier"]
)
self._process_priors(
self.intercept,
self.beta_channel,
self.alpha,
self.lam,
self.sigma,
self.gamma_control,
self.gamma_fourier,
)

@property
def default_sampler_config(self) -> Dict:
return {}
Expand Down Expand Up @@ -231,11 +253,14 @@ def build_model(
dims="date",
)

intercept = self.register_rv(name="intercept")
beta_channel = self.register_rv(name="beta_channel")
alpha = self.register_rv(name="alpha")
lam = self.register_rv(name="lam")
sigma = self.register_rv(name="sigma")
# FIXME: Need to add the correct dims to `beta_channel`, `alpha`, `lam`,
intercept = self.model.register_rv(self.intercept, name="intercept")
beta_channel = self.model.register_rv(
self.beta_channel, name="beta_channel"
)
alpha = self.model.register_rv(self.alpha, name="alpha")
lam = self.model.register_rv(self.lam, name="lam")
sigma = self.model.register_rv(self.sigma, name="sigma")

# TODO: register the adstock transforms
channel_adstock = pm.Deterministic(
Expand Down Expand Up @@ -269,7 +294,9 @@ def build_model(
for column in self.control_columns
)
):
gamma_control = self.register_rv(name="gamma_control")
gamma_control = self.model.register_rv(
self.gamma_control, name="gamma_control"
)

control_data_ = pm.MutableData(
name="control_data",
Expand All @@ -293,7 +320,9 @@ def build_model(
for column in self.fourier_columns
)
):
gamma_fourier = self.register_rv(name="gamma_fourier")
gamma_fourier = self.model.register_rv(
self.gamma_fourier, name="gamma_fourier"
)

fourier_data_ = pm.MutableData(
name="fourier_data",
Expand Down
17 changes: 0 additions & 17 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,23 +447,6 @@ def _process_priors(
prior.str_repr = types.MethodType(str_for_dist, prior) # type: ignore
return priors

def register_rv(
self,
name: str,
observed: Any | None = None,
total_size: Any | None = None,
dims: Any | None = None,
transform: None = None,
):
"""Register random variables and priors from model_config to the pm.model object."""
rv_var = self._create_distribution(self.model_config[name])
self._process_priors(rv_var)
setattr(self, name, rv_var)

dims = self.model_config[name].get("dims")
rv = self.model.register_rv(rv_var, name=name, dims=dims)
return rv

def fit(
self,
X: pd.DataFrame,
Expand Down

0 comments on commit 0a9da8c

Please sign in to comment.