Skip to content

Commit

Permalink
Deepcopy of posterior to allow second fit call (#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 6fd2a47 commit 53e2f6d
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 370 deletions.
680 changes: 312 additions & 368 deletions docs/source/notebooks/mmm/mmm_lift_test.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,7 @@ def add_lift_test_measurements(
time_varying_var_name=time_varying_var_name,
model=self.model,
dist=dist,
name=name,
)

def _create_synth_dataset(
Expand Down
5 changes: 5 additions & 0 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def fit(
idata = pm.sample(**sampler_args)

if self.idata:
self.idata = self.idata.copy()
self.idata.extend(idata, join="right")
else:
self.idata = idata
Expand All @@ -479,6 +480,10 @@ def fit(
combined_data = pd.concat([X_df, y_df], axis=1)
if not all(combined_data.columns):
raise ValueError("All columns must have non-empty names")

if "fit_data" in self.idata:
del self.idata.fit_data

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
2 changes: 0 additions & 2 deletions tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ def transform(self, X):


def test_validate_and_preprocess(toy_X, toy_y, test_mmm):
test_mmm

test_mmm.validate("X", toy_X)
test_mmm.mock_method1.assert_called_once_with(test_mmm, toy_X)

Expand Down
14 changes: 14 additions & 0 deletions tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,17 @@ def test_fit_after_prior_keeps_prior(toy_X, toy_y):
model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100)
assert "prior" in model.idata
assert "prior_predictive" in model.idata


def test_second_fit(toy_X, toy_y):
model = ModelBuilderTest()

model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100)
assert "posterior" in model.idata
id_before = id(model.idata)
assert "fit_data" in model.idata

model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100)
id_after = id(model.idata)

assert id_before != id_after

0 comments on commit 53e2f6d

Please sign in to comment.