Skip to content

Commit

Permalink
model.fit doesn't remove prior samples (#741)
Browse files Browse the repository at this point in the history
* type hint only

* more informative errors

* check for attr

* remove type ignore

* check for attr

* check for attr

* reduce indentation

* new error names
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 95bf0c3 commit 1ad7ed3
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _validate_cols(
raise ValueError(f"Column {required_col} has duplicate entries")

def __repr__(self):
if self.model is None:
if not hasattr(self, "model"):
return self._model_type
else:
return f"{self._model_type}\n{self.model.str_repr()}"
Expand Down
12 changes: 8 additions & 4 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,16 @@ def get_target_transformer(self) -> Pipeline:
def prior(self) -> Dataset:
if self.idata is None or "prior" not in self.idata:
raise RuntimeError(
"The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first"
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
)
return self.idata["prior"]

@property
def prior_predictive(self) -> az.InferenceData:
def prior_predictive(self) -> Dataset:
if self.idata is None or "prior_predictive" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
raise RuntimeError(
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
)
return self.idata["prior_predictive"]

@property
Expand All @@ -286,7 +288,9 @@ def fit_result(self) -> Dataset:
@property
def posterior_predictive(self) -> Dataset:
if self.idata is None or "posterior_predictive" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
raise RuntimeError(
"The model hasn't been fit yet, call .sample_posterior_predictive() first"
)
return self.idata["posterior_predictive"]

def plot_prior_predictive(
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,7 @@ def add_lift_test_measurements(
model.add_lift_test_measurements(df_lift_test)
"""
if self.model is None:
if not hasattr(self, "model"):
raise RuntimeError(
"The model has not been built yet. Please, build the model first."
)
Expand Down
70 changes: 40 additions & 30 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.model_config = (
self.default_model_config | model_config
) # parameters for priors etc.
self.model: pm.Model | None = None # Set by build_model
self.model: pm.Model
self.idata: az.InferenceData | None = None # idata is generated during fitting
self.is_fitted_ = False

Expand Down Expand Up @@ -458,19 +458,22 @@ def fit(
if self.X is None or self.y is None:
raise ValueError("X and y must be set before calling build_model!")

if self.model is None:
if not hasattr(self, "model"):
self.build_model(self.X, self.y)

sampler_config = self.sampler_config.copy()
sampler_config["progressbar"] = progressbar
sampler_config["random_seed"] = random_seed
sampler_config.update(**kwargs)

sampler_config.update(**kwargs)
if self.model is not None:
with self.model:
sampler_args = {**self.sampler_config, **kwargs}
self.idata = pm.sample(**sampler_args)
sampler_args = {**self.sampler_config, **kwargs}
with self.model:
idata = pm.sample(**sampler_args)

if self.idata:
self.idata.extend(idata, join="right")
else:
self.idata = idata

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y_df], axis=1)
Expand Down Expand Up @@ -537,7 +540,7 @@ def sample_prior_predictive(
X_pred,
y_pred=None,
samples: int | None = None,
extend_idata: bool = False,
extend_idata: bool = True,
combined: bool = True,
**kwargs,
):
Expand All @@ -552,7 +555,7 @@ def sample_prior_predictive(
Number of samples from the prior parameter distributions to generate.
If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to False.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_prior_predictive
Expand All @@ -567,21 +570,19 @@ def sample_prior_predictive(
if samples is None:
samples = self.sampler_config.get("draws", 500)

if self.model is None:
if not hasattr(self, "model"):
self.build_model(X_pred, y_pred)

self._data_setter(X_pred, y_pred)
if self.model is not None:
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(
samples, **kwargs
)
self.set_idata_attrs(prior_pred)
if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred, join="right")
else:
self.idata = prior_pred
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs)
self.set_idata_attrs(prior_pred)

if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred, join="right")
else:
self.idata = prior_pred

prior_predictive_samples = az.extract(
prior_pred, "prior_predictive", combined=combined
Expand All @@ -590,7 +591,11 @@ def sample_prior_predictive(
return prior_predictive_samples

def sample_posterior_predictive(
self, X_pred, extend_idata: bool = True, combined: bool = True, **kwargs
self,
X_pred,
extend_idata: bool = True,
combined: bool = True,
**sample_posterior_predictive_kwargs,
):
"""
Sample from the model's posterior predictive distribution.
Expand All @@ -603,7 +608,7 @@ def sample_posterior_predictive(
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
**sample_posterior_predictive_kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
Returns
-------
Expand All @@ -612,16 +617,21 @@ def sample_posterior_predictive(
"""
self._data_setter(X_pred)

with self.model: # type: ignore
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
if extend_idata:
self.idata.extend(post_pred, join="right") # type: ignore
with self.model:
post_pred = pm.sample_posterior_predictive(
self.idata, **sample_posterior_predictive_kwargs
)

if extend_idata:
self.idata.extend(post_pred, join="right") # type: ignore

posterior_predictive_samples = az.extract(
post_pred, "posterior_predictive", combined=combined
variable_name = (
"predictions"
if sample_posterior_predictive_kwargs.get("predictions")
else "posterior_predictive"
)

return posterior_predictive_samples
return az.extract(post_pred, variable_name, combined=combined)

def get_params(self, deep=True):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def set_model_fit(model: CLVModel, fit: InferenceData | Dataset):
assert "posterior" in fit.groups()
else:
fit = InferenceData(posterior=fit)
if model.model is None:
if not hasattr(model, "model"):
model.build_model()
model.idata = fit
model.idata.add_groups(fit_data=model.data.to_xarray())
Expand Down
6 changes: 4 additions & 2 deletions tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def test_calling_prior_predictive_before_fit_raises_error(test_mmm, toy_X, toy_y
test_mmm.idata = None
with pytest.raises(
RuntimeError,
match=re.escape("The model hasn't been fit yet, call .fit() first"),
match=re.escape(
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
),
):
test_mmm.prior_predictive

Expand All @@ -297,7 +299,7 @@ def test_calling_prior_before_sample_prior_predictive_raises_error(
with pytest.raises(
RuntimeError,
match=re.escape(
"The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first"
"The model hasn't been sampled yet, call .sample_prior_predictive() first",
),
):
test_mmm.prior
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _save_input_params(self, idata):
def output_var(self):
return "output"

def _data_setter(self, X: pd.Series, y: pd.Series = None):
def _data_setter(self, X: pd.DataFrame, y: pd.Series = None):
with self.model:
pm.set_data({"x": X.values})
pm.set_data({"x": X["input"].values})
if y is not None:
y = y.values if isinstance(y, pd.Series) else y
pm.set_data({"y_data": y})
Expand Down Expand Up @@ -195,8 +195,8 @@ def test_save_load(fitted_model_instance):
assert fitted_model_instance.id == test_builder2.id
x_pred = rng.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
pred1 = fitted_model_instance.predict(prediction_data)
pred2 = test_builder2.predict(prediction_data)
assert pred1.shape == pred2.shape
temp.close()

Expand Down Expand Up @@ -230,9 +230,9 @@ def test_fit(fitted_model_instance):
assert fitted_model_instance.idata.posterior.dims["draw"] == 100

prediction_data = pd.DataFrame({"input": rng.uniform(low=0, high=1, size=100)})
fitted_model_instance.predict(prediction_data["input"])
fitted_model_instance.predict(prediction_data)
post_pred = fitted_model_instance.sample_posterior_predictive(
prediction_data["input"], extend_idata=True, combined=True
prediction_data, extend_idata=True, combined=True
)
assert (
post_pred[fitted_model_instance.output_var].shape[0]
Expand All @@ -256,7 +256,7 @@ def test_predict(fitted_model_instance):
rng = np.random.default_rng(42)
x_pred = rng.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred = fitted_model_instance.predict(prediction_data["input"])
pred = fitted_model_instance.predict(prediction_data)
# Perform elementwise comparison using numpy
assert type(pred) == np.ndarray
assert len(pred) > 0
Expand All @@ -269,7 +269,7 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
x_pred = rng.uniform(low=0, high=1, size=n_pred)
prediction_data = pd.DataFrame({"input": x_pred})
pred = fitted_model_instance.sample_posterior_predictive(
prediction_data["input"], combined=combined, extend_idata=True
prediction_data, combined=combined, extend_idata=True
)
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_sample_xxx_predictive_keeps_second(
method_name = f"sample_{name}"
method = getattr(fitted_model_instance, method_name)

X_pred = toy_X["input"]
X_pred = toy_X

kwargs = {
"X_pred": X_pred,
Expand All @@ -329,3 +329,26 @@ def test_sample_xxx_predictive_keeps_second(

sample = getattr(fitted_model_instance.idata, name)
xr.testing.assert_allclose(sample, second_sample)


def test_prediction_kwarg(fitted_model_instance, toy_X):
result = fitted_model_instance.sample_posterior_predictive(
toy_X,
extend_idata=True,
predictions=True,
)
assert "predictions" in fitted_model_instance.idata
assert "predictions_constant_data" in fitted_model_instance.idata

assert isinstance(result, xr.Dataset)


def test_fit_after_prior_keeps_prior(toy_X, toy_y):
model = ModelBuilderTest()
model.sample_prior_predictive(toy_X)
assert "prior" in model.idata
assert "prior_predictive" in model.idata

model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100)
assert "prior" in model.idata
assert "prior_predictive" in model.idata

0 comments on commit 1ad7ed3

Please sign in to comment.