-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add HSGP prior #415
Add HSGP prior #415
Changes from all commits
475d1da
751aaf0
f4b528f
715eaef
81320a9
27f292a
e38155f
1528150
2309c27
91e07be
f7abae8
f92bc54
dc11f91
b3eabb8
47d6b25
f5d5014
73c47b4
b032e63
ec9798b
35abad0
c90b0da
eb05e91
73d930d
b90ce92
f3dcdfd
6540de9
ab7b086
7875206
de45c1c
e61c168
8b484e2
cf95e19
93b1535
0c4e724
c53d0a4
68052de
d043841
7ae82cd
b9e20b4
4d73e65
6a19c1a
b053ecd
5e48e32
7063fd1
f0128b2
067cf89
466a506
e98d9ed
d309f40
d8e959f
1cd321c
08a766e
2f5a27e
b223fee
5541ba0
e0354d9
c599254
c128aad
87a8a8d
fc808b8
a9c84b6
a18a8d1
7098c10
579d7c1
fce7457
c2dac26
6795916
7dc6971
b12f5d3
6115c64
29273c2
3d00e48
14370e6
f1fc9f1
63cb4fe
516d697
57c8839
d7c79f4
04fb0f4
7c1240a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import numpy.typing as npt | ||
import pandas as pd | ||
import pymc as pm | ||
import pytensor.tensor as pt | ||
import seaborn as sns | ||
from xarray import DataArray | ||
|
||
|
@@ -187,32 +188,15 @@ def build_model( | |
dims="date", | ||
) | ||
|
||
intercept = pm.Normal( | ||
name="intercept", | ||
mu=model_config["intercept"]["mu"], | ||
sigma=model_config["intercept"]["sigma"], | ||
) | ||
|
||
beta_channel = pm.HalfNormal( | ||
name="beta_channel", | ||
sigma=model_config["beta_channel"]["sigma"], | ||
dims=model_config["beta_channel"]["dims"], | ||
) | ||
alpha = pm.Beta( | ||
name="alpha", | ||
alpha=model_config["alpha"]["alpha"], | ||
beta=model_config["alpha"]["beta"], | ||
dims=model_config["alpha"]["dims"], | ||
) | ||
#Building the priors | ||
priors = self.create_priors_from_config(self.model_config) | ||
|
||
lam = pm.Gamma( | ||
name="lam", | ||
alpha=model_config["lam"]["alpha"], | ||
beta=model_config["lam"]["beta"], | ||
dims=model_config["lam"]["dims"], | ||
) | ||
|
||
sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"]) | ||
#Specifying the variables | ||
intercept = priors['intercept'] | ||
beta_channel = priors['beta_channel'] | ||
alpha = priors['alpha'] | ||
lam = priors['lam'] | ||
gamma_control = priors['gamma_control'] | ||
|
||
channel_adstock = pm.Deterministic( | ||
name="channel_adstock", | ||
|
@@ -230,6 +214,7 @@ def build_model( | |
var=logistic_saturation(x=channel_adstock, lam=lam), | ||
dims=("date", "channel"), | ||
) | ||
|
||
channel_contributions = pm.Deterministic( | ||
name="channel_contributions", | ||
var=channel_adstock_saturated * beta_channel, | ||
|
@@ -251,12 +236,8 @@ def build_model( | |
dims=("date", "control"), | ||
) | ||
|
||
gamma_control = pm.Normal( | ||
name="gamma_control", | ||
mu=model_config["gamma_control"]["mu"], | ||
sigma=model_config["gamma_control"]["sigma"], | ||
dims=model_config["gamma_control"]["dims"], | ||
) | ||
print("Shape of control_data_:", control_data_.eval().shape) | ||
print("Shape of gamma_control:", gamma_control.eval().shape) | ||
|
||
control_contributions = pm.Deterministic( | ||
name="control_contributions", | ||
|
@@ -299,33 +280,135 @@ def build_model( | |
name="mu", var=mu_var, dims=model_config["mu"]["dims"] | ||
) | ||
|
||
pm.Normal( | ||
name="likelihood", | ||
mu=mu, | ||
sigma=sigma, | ||
observed=target_, | ||
dims=model_config["likelihood"]["dims"], | ||
) | ||
likelihood = self.create_likelihood(self.model_config, target_, mu) | ||
|
||
def create_priors_from_config(self, model_config): | ||
priors, dimensions = {}, {"channel": len(self.channel_columns), "control": len(self.control_columns)} | ||
stacked_priors = {} | ||
|
||
positive_params = {"intercept", "beta_channel", "alpha", "lam", "sigma"} # Set of params that need positive=True | ||
|
||
for param, config in model_config.items(): | ||
if param == "likelihood": continue | ||
|
||
prior_type = config.get("type") | ||
if prior_type is not None: | ||
|
||
# Initial value based on parameter name | ||
is_positive = param in positive_params | ||
|
||
# Override if the config explicitly sets the 'positive' key | ||
if 'positive' in config: | ||
is_positive = config.get('positive') | ||
|
||
if prior_type == "tvp": | ||
if param in ["intercept", "lam", "alpha", "sigma"]: | ||
priors[param] = self.gp_wrapper(name=param, X=np.arange(len(self.X[self.date_column]))[:, None], config=config, positive=is_positive) | ||
continue | ||
|
||
length = dimensions.get(config.get("dims", [None, None])[1], 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needlessly terse. can be simplified to
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my apologies, my suggesion won't work and this is actually decent code. |
||
priors[param] = self.create_tvp_priors(param, config, length, positive=is_positive) | ||
continue | ||
|
||
dist_func = getattr(pm, prior_type, None) | ||
if not dist_func: raise ValueError(f"Invalid distribution type {prior_type}") | ||
config_copy = {k: v for k, v in config.items() if k != "type"} | ||
priors[param] = dist_func(name=param, **config_copy) | ||
|
||
return priors | ||
|
||
def create_likelihood(self, model_config, target_, mu): | ||
likelihood_config = model_config.get("likelihood", {}) | ||
likelihood_type = likelihood_config.get("type") | ||
dims = likelihood_config.get("dims") | ||
|
||
if not likelihood_type: | ||
raise ValueError("Likelihood type must be specified in the model config.") | ||
|
||
likelihood_func = getattr(pm, likelihood_type, None) | ||
if likelihood_func is None: | ||
raise ValueError(f"Invalid likelihood type {likelihood_type}") | ||
|
||
# Transform mu if the likelihood type is Lognormal or HurdleLognormal | ||
if likelihood_type in ['LogNormal', 'HurdleLogNormal']: | ||
mu = pt.log(mu) | ||
|
||
# Create sub-priors | ||
sub_priors = {} | ||
for param, config in likelihood_config.items(): | ||
if param not in ['type', 'dims']: # Skip 'type' and 'dims' | ||
if param == 'params': # Handle nested 'params' | ||
for sub_param, sub_config in config.items(): | ||
sub_priors[sub_param] = self.create_priors_from_config({sub_param: sub_config})[sub_param] | ||
else: | ||
sub_priors[param] = self.create_priors_from_config({param: config})[param] | ||
|
||
return likelihood_func(name="likelihood", mu=mu, observed=target_, dims=dims, **sub_priors) | ||
|
||
def create_tvp_priors(self, param, config, length, positive=False): | ||
dims = config.get("dims", None) # Extracting dims from the config | ||
print(dims) | ||
gp_list = [self.gp_wrapper(name=f"{param}_{i}", X=np.arange(len(self.X[self.date_column]))[:, None], config=config, positive=positive) for i in range(length)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be potentially vectorized with gp.prior_linearized |
||
stacked_gp = pt.stack(gp_list, axis=1) | ||
return pm.Deterministic(f"{param}", stacked_gp, dims=dims) | ||
|
||
|
||
def gp_wrapper(self, name, X, config, positive=False, **kwargs): | ||
return self.gp_coeff(X, name, config=config, positive=positive, **kwargs) | ||
|
||
def gp_coeff(self, X, name, mean=0.0, positive=False, config=None): | ||
params = pm.find_constrained_prior(pm.Gamma, 8, 12, init_guess={"alpha": 1, "beta": 1}, mass=0.8) | ||
ell = pm.Gamma(f"ell_{name}", **params) | ||
eta = pm.Exponential(f"_eta_{name}", lam=1) | ||
# cov = eta ** 2 * pm.gp.cov.ExpQuad(1, ls=ell) | ||
|
||
cov = eta ** 2 * pm.gp.cov.Matern32(1, ls=ell) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, apologies, I've been a bit slow on this one. There's been quite a few changes to the main with the custom priors/likelihood pr & shortly more with the out-of-sample prediction. @bwengals is going to be helping out implementing this one too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries. For context, I'm here because I'm adding TVC to my company's MMM, and thought it would make sense to drop review comments for clarification sake (and to give back, thanks for the beautiful lib). But if any of these comments are out of place, please just ignore them. |
||
|
||
gp = pm.gp.HSGP(m=[40], c=4, cov_func=cov) | ||
f_raw = gp.prior(f"{name}_tvp_raw", X=X) | ||
|
||
|
||
# Inside your gp_coeff function | ||
# Offset | ||
offset_config = config.get('offset', None) if config else None | ||
if offset_config: | ||
offset_type = offset_config.get('type') | ||
offset_params = {k: v for k, v in offset_config.items() if k != 'type'} | ||
offset_prior = getattr(pm, offset_type)(name=f"{name}_offset", **offset_params) | ||
else: | ||
offset_prior = 0 | ||
|
||
if positive: | ||
f_output = pm.Deterministic(f"{name}", (pt.exp(f_raw)) + offset_prior, dims=("date")) | ||
else: | ||
f_output = pm.Deterministic(f"{name}", f_raw + offset_prior, dims=("date")) | ||
|
||
return f_output | ||
|
||
|
||
|
||
@property | ||
def default_model_config(self) -> Dict: | ||
model_config: Dict = { | ||
"intercept": {"mu": 0, "sigma": 2}, | ||
"beta_channel": {"sigma": 2, "dims": ("channel",)}, | ||
"alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)}, | ||
"lam": {"alpha": 3, "beta": 1, "dims": ("channel",)}, | ||
"sigma": {"sigma": 2}, | ||
"gamma_control": { | ||
"mu": 0, | ||
"sigma": 2, | ||
"dims": ("control",), | ||
}, | ||
"intercept": {"type": "Normal", "mu": 0, "sigma": 2}, | ||
"beta_channel": {"type": "HalfNormal", "sigma": 2, "dims": ("channel",)}, | ||
"alpha": {"type": "Beta", "alpha": 1, "beta": 3, "dims": ("channel",)}, | ||
"lam": {"type": "Gamma", "alpha": 3, "beta": 1, "dims": ("channel",)}, | ||
"gamma_control": {'type': 'Gamma', 'alpha': 2, 'beta': 1, 'dims': ('control',)}, | ||
"mu": {"dims": ("date",)}, | ||
"likelihood": {"dims": ("date",)}, | ||
"likelihood": { | ||
"type": "Normal", | ||
"dims": ("date",), | ||
"params": { | ||
"sigma": {"type": "HalfNormal", "sigma": 1, 'dims': ('date',)}, | ||
} | ||
}, | ||
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"}, | ||
} | ||
return model_config | ||
|
||
|
||
|
||
def _get_fourier_models_data(self, X) -> pd.DataFrame: | ||
"""Generates fourier modes to model seasonality. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified to: