Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HSGP prior #415

Closed
wants to merge 80 commits into from
Closed
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
475d1da
Option to set the prior type for variables
nialloulton Oct 24, 2023
751aaf0
A fix to the prior type option
nialloulton Oct 24, 2023
f4b528f
fix to prior type 2
nialloulton Oct 24, 2023
715eaef
fix 3 prior type
nialloulton Oct 24, 2023
81320a9
flexible likelihood
nialloulton Oct 24, 2023
27f292a
fixing likelihood
nialloulton Oct 24, 2023
e38155f
fixing likelihood attempt 2
nialloulton Oct 24, 2023
1528150
Time Varying Prior
nialloulton Oct 25, 2023
2309c27
TVP Fix
nialloulton Oct 25, 2023
91e07be
Time Varying Prior fix 2
nialloulton Oct 25, 2023
f7abae8
TVP Fix 3
nialloulton Oct 25, 2023
f92bc54
TVP Fix 4
nialloulton Oct 25, 2023
dc11f91
TVP Fix 5
nialloulton Oct 25, 2023
b3eabb8
TVP Fix 6
nialloulton Oct 25, 2023
47d6b25
TVP Fix 7
nialloulton Oct 25, 2023
f5d5014
TVP Fix Channel
nialloulton Oct 25, 2023
73c47b4
TVP Fix Channel 2
nialloulton Oct 25, 2023
b032e63
tvp fix
nialloulton Oct 25, 2023
ec9798b
tvp control checks
nialloulton Oct 25, 2023
35abad0
control checks
nialloulton Oct 25, 2023
c90b0da
control fix 2
nialloulton Oct 25, 2023
eb05e91
control fix 3
nialloulton Oct 25, 2023
73d930d
control fix 4
nialloulton Oct 25, 2023
b90ce92
control fix 5
nialloulton Oct 25, 2023
f3dcdfd
fix 6
nialloulton Oct 25, 2023
6540de9
fix 7
nialloulton Oct 25, 2023
ab7b086
TVP option for all variables
nialloulton Oct 26, 2023
7875206
tvp fix 8
nialloulton Oct 26, 2023
de45c1c
TVP FIX 9
nialloulton Oct 26, 2023
e61c168
TVP_fix 10
nialloulton Oct 26, 2023
8b484e2
tvp fix 11
nialloulton Oct 26, 2023
cf95e19
tvp fix 12
nialloulton Oct 26, 2023
93b1535
tvp fix 13
nialloulton Oct 26, 2023
0c4e724
tvp fix 14
nialloulton Oct 26, 2023
c53d0a4
tvp fix 15
nialloulton Oct 26, 2023
68052de
tvp positive constraint
nialloulton Oct 26, 2023
d043841
tvp positive constraint fix
nialloulton Oct 26, 2023
7ae82cd
tvp intercept fix
nialloulton Oct 26, 2023
b9e20b4
tvp fix 2
nialloulton Oct 27, 2023
4d73e65
tvp fix 3
nialloulton Oct 27, 2023
6a19c1a
tvp fix 4
nialloulton Oct 27, 2023
b053ecd
tvp fix 5
nialloulton Oct 27, 2023
5e48e32
tvp fix 6
nialloulton Oct 27, 2023
7063fd1
tvp fix 7
nialloulton Oct 27, 2023
f0128b2
tvp fix 8
nialloulton Oct 28, 2023
067cf89
tvp fix 9
nialloulton Oct 28, 2023
466a506
tvp fix 10
nialloulton Oct 28, 2023
e98d9ed
tvp fix 11
nialloulton Oct 28, 2023
d309f40
tvp fix 12
nialloulton Oct 28, 2023
d8e959f
tvp fix 14
nialloulton Oct 28, 2023
1cd321c
tvp fix 15
nialloulton Oct 28, 2023
08a766e
tvp fix 16
nialloulton Oct 28, 2023
2f5a27e
fix tvp 17
nialloulton Oct 28, 2023
b223fee
fix tvp 18
nialloulton Oct 28, 2023
5541ba0
fix tvp 19
nialloulton Oct 28, 2023
e0354d9
fix tvp 20
nialloulton Oct 28, 2023
c599254
resetting tvp
nialloulton Oct 28, 2023
c128aad
tvp fix 21
nialloulton Oct 28, 2023
87a8a8d
tvp fix 22
nialloulton Oct 28, 2023
fc808b8
tvp fix 23
nialloulton Oct 28, 2023
a9c84b6
tvp fix 24
nialloulton Oct 28, 2023
a18a8d1
tvp fix 25
nialloulton Oct 28, 2023
7098c10
loglikelihood
nialloulton Oct 28, 2023
579d7c1
loglikelihood2
nialloulton Oct 28, 2023
fce7457
tvp int 3
nialloulton Oct 28, 2023
c2dac26
tvp int 4
nialloulton Oct 28, 2023
6795916
tvp int 4
nialloulton Oct 28, 2023
7dc6971
tvp int 5
nialloulton Oct 28, 2023
b12f5d3
tvp offset addition
nialloulton Oct 29, 2023
6115c64
tvp offset addition fix
nialloulton Oct 29, 2023
29273c2
Update delayed_saturated_mmm.py
nialloulton Nov 13, 2023
3d00e48
Update delayed_saturated_mmm.py
nialloulton Nov 24, 2023
14370e6
Update delayed_saturated_mmm.py
nialloulton Nov 24, 2023
f1fc9f1
Update delayed_saturated_mmm.py
nialloulton Nov 24, 2023
63cb4fe
Update delayed_saturated_mmm.py
nialloulton Nov 24, 2023
516d697
Update delayed_saturated_mmm.py
nialloulton Nov 24, 2023
57c8839
Update delayed_saturated_mmm.py
nialloulton Nov 24, 2023
d7c79f4
Update delayed_saturated_mmm.py
nialloulton Nov 25, 2023
04fb0f4
Update delayed_saturated_mmm.py
nialloulton Nov 25, 2023
7c1240a
Update delayed_saturated_mmm.py
nialloulton Nov 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 132 additions & 49 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import seaborn as sns
from xarray import DataArray

Expand Down Expand Up @@ -187,32 +188,15 @@ def build_model(
dims="date",
)

intercept = pm.Normal(
name="intercept",
mu=model_config["intercept"]["mu"],
sigma=model_config["intercept"]["sigma"],
)

beta_channel = pm.HalfNormal(
name="beta_channel",
sigma=model_config["beta_channel"]["sigma"],
dims=model_config["beta_channel"]["dims"],
)
alpha = pm.Beta(
name="alpha",
alpha=model_config["alpha"]["alpha"],
beta=model_config["alpha"]["beta"],
dims=model_config["alpha"]["dims"],
)
#Building the priors
priors = self.create_priors_from_config(self.model_config)

lam = pm.Gamma(
name="lam",
alpha=model_config["lam"]["alpha"],
beta=model_config["lam"]["beta"],
dims=model_config["lam"]["dims"],
)

sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"])
#Specifying the variables
intercept = priors['intercept']
beta_channel = priors['beta_channel']
alpha = priors['alpha']
lam = priors['lam']
gamma_control = priors['gamma_control']

channel_adstock = pm.Deterministic(
name="channel_adstock",
Expand All @@ -230,6 +214,7 @@ def build_model(
var=logistic_saturation(x=channel_adstock, lam=lam),
dims=("date", "channel"),
)

channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated * beta_channel,
Expand All @@ -251,12 +236,8 @@ def build_model(
dims=("date", "control"),
)

gamma_control = pm.Normal(
name="gamma_control",
mu=model_config["gamma_control"]["mu"],
sigma=model_config["gamma_control"]["sigma"],
dims=model_config["gamma_control"]["dims"],
)
print("Shape of control_data_:", control_data_.eval().shape)
print("Shape of gamma_control:", gamma_control.eval().shape)

control_contributions = pm.Deterministic(
name="control_contributions",
Expand Down Expand Up @@ -299,33 +280,135 @@ def build_model(
name="mu", var=mu_var, dims=model_config["mu"]["dims"]
)

pm.Normal(
name="likelihood",
mu=mu,
sigma=sigma,
observed=target_,
dims=model_config["likelihood"]["dims"],
)
likelihood = self.create_likelihood(self.model_config, target_, mu)

def create_priors_from_config(self, model_config):
priors, dimensions = {}, {"channel": len(self.channel_columns), "control": len(self.control_columns)}
stacked_priors = {}

positive_params = {"intercept", "beta_channel", "alpha", "lam", "sigma"} # Set of params that need positive=True

for param, config in model_config.items():
if param == "likelihood": continue

prior_type = config.get("type")
if prior_type is not None:

# Initial value based on parameter name
is_positive = param in positive_params

# Override if the config explicitly sets the 'positive' key
if 'positive' in config:
is_positive = config.get('positive')
Comment on lines +297 to +302

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified to:

is_positive = config.get('positive', param in positive_params)


if prior_type == "tvp":
if param in ["intercept", "lam", "alpha", "sigma"]:
priors[param] = self.gp_wrapper(name=param, X=np.arange(len(self.X[self.date_column]))[:, None], config=config, positive=is_positive)
continue

length = dimensions.get(config.get("dims", [None, None])[1], 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needlessly terse. can be simplified to

dimensions.get(config.get("dims", None), 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my apologies, my suggesion won't work and this is actually decent code.

priors[param] = self.create_tvp_priors(param, config, length, positive=is_positive)
continue

dist_func = getattr(pm, prior_type, None)
if not dist_func: raise ValueError(f"Invalid distribution type {prior_type}")
config_copy = {k: v for k, v in config.items() if k != "type"}
priors[param] = dist_func(name=param, **config_copy)

return priors

def create_likelihood(self, model_config, target_, mu):
likelihood_config = model_config.get("likelihood", {})
likelihood_type = likelihood_config.get("type")
dims = likelihood_config.get("dims")

if not likelihood_type:
raise ValueError("Likelihood type must be specified in the model config.")

likelihood_func = getattr(pm, likelihood_type, None)
if likelihood_func is None:
raise ValueError(f"Invalid likelihood type {likelihood_type}")

# Transform mu if the likelihood type is Lognormal or HurdleLognormal
if likelihood_type in ['LogNormal', 'HurdleLogNormal']:
mu = pt.log(mu)

# Create sub-priors
sub_priors = {}
for param, config in likelihood_config.items():
if param not in ['type', 'dims']: # Skip 'type' and 'dims'
if param == 'params': # Handle nested 'params'
for sub_param, sub_config in config.items():
sub_priors[sub_param] = self.create_priors_from_config({sub_param: sub_config})[sub_param]
else:
sub_priors[param] = self.create_priors_from_config({param: config})[param]

return likelihood_func(name="likelihood", mu=mu, observed=target_, dims=dims, **sub_priors)

def create_tvp_priors(self, param, config, length, positive=False):
dims = config.get("dims", None) # Extracting dims from the config
print(dims)
gp_list = [self.gp_wrapper(name=f"{param}_{i}", X=np.arange(len(self.X[self.date_column]))[:, None], config=config, positive=positive) for i in range(length)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be potentially vectorized with gp.prior_linearized

stacked_gp = pt.stack(gp_list, axis=1)
return pm.Deterministic(f"{param}", stacked_gp, dims=dims)


def gp_wrapper(self, name, X, config, positive=False, **kwargs):
return self.gp_coeff(X, name, config=config, positive=positive, **kwargs)

def gp_coeff(self, X, name, mean=0.0, positive=False, config=None):
params = pm.find_constrained_prior(pm.Gamma, 8, 12, init_guess={"alpha": 1, "beta": 1}, mass=0.8)
ell = pm.Gamma(f"ell_{name}", **params)
eta = pm.Exponential(f"_eta_{name}", lam=1)
# cov = eta ** 2 * pm.gp.cov.ExpQuad(1, ls=ell)

cov = eta ** 2 * pm.gp.cov.Matern32(1, ls=ell)
Copy link

@ulfaslakprecis ulfaslakprecis Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this. eta ** 2 is going to blow up for negative x. Plotting eta ** 2 * pm.gp.cov.Matern32(1, ls=ell) for x' = 0 and lam = 1:
Screenshot 2024-01-16 at 16 50 51

Is the point to only let coef value at t = 0 be influenced by coef values at t < 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, apologies, I've been a bit slow on this one. There's been quite a few changes to the main with the custom priors/likelihood pr & shortly more with the out-of-sample prediction.

@bwengals is going to be helping out implementing this one too

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. For context, I'm here because I'm adding TVC to my company's MMM, and thought it would make sense to drop review comments for clarification sake (and to give back, thanks for the beautiful lib). But if any of these comments are out of place, please just ignore them.


gp = pm.gp.HSGP(m=[40], c=4, cov_func=cov)
f_raw = gp.prior(f"{name}_tvp_raw", X=X)


# Inside your gp_coeff function
# Offset
offset_config = config.get('offset', None) if config else None
if offset_config:
offset_type = offset_config.get('type')
offset_params = {k: v for k, v in offset_config.items() if k != 'type'}
offset_prior = getattr(pm, offset_type)(name=f"{name}_offset", **offset_params)
else:
offset_prior = 0

if positive:
f_output = pm.Deterministic(f"{name}", (pt.exp(f_raw)) + offset_prior, dims=("date"))
else:
f_output = pm.Deterministic(f"{name}", f_raw + offset_prior, dims=("date"))

return f_output



@property
def default_model_config(self) -> Dict:
model_config: Dict = {
"intercept": {"mu": 0, "sigma": 2},
"beta_channel": {"sigma": 2, "dims": ("channel",)},
"alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)},
"lam": {"alpha": 3, "beta": 1, "dims": ("channel",)},
"sigma": {"sigma": 2},
"gamma_control": {
"mu": 0,
"sigma": 2,
"dims": ("control",),
},
"intercept": {"type": "Normal", "mu": 0, "sigma": 2},
"beta_channel": {"type": "HalfNormal", "sigma": 2, "dims": ("channel",)},
"alpha": {"type": "Beta", "alpha": 1, "beta": 3, "dims": ("channel",)},
"lam": {"type": "Gamma", "alpha": 3, "beta": 1, "dims": ("channel",)},
"gamma_control": {'type': 'Gamma', 'alpha': 2, 'beta': 1, 'dims': ('control',)},
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
"likelihood": {
"type": "Normal",
"dims": ("date",),
"params": {
"sigma": {"type": "HalfNormal", "sigma": 1, 'dims': ('date',)},
}
},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
}
return model_config



def _get_fourier_models_data(self, X) -> pd.DataFrame:
"""Generates fourier modes to model seasonality.

Expand Down
Loading