Skip to content

Commit

Permalink
consolidate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Dec 20, 2023
2 parents 2cd0b31 + 890c469 commit f99600b
Show file tree
Hide file tree
Showing 14 changed files with 2,020 additions and 1,618 deletions.
37 changes: 37 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!-- !! Thank your for opening a PR !! -->

<!--- Provide a self-contained summary of your changes in the Title above -->
<!--- This is what will be shown in the automatic release notes: https://github.com/pymc-labs/pymc-marketing/releases -->

## Description
<!--- Describe your changes in detail -->

## Related Issue
<!--- It is good practice to first open an issue explaining the bug / new feature that is addressed by this PR -->
<!--- Please type an `x` in one of the boxes below and provide the issue number after the # sign: -->
- [ ] Closes issue: #
- [ ] Related issue (not closed by this PR): #

## Checklist
<!--- Make sure you have completed the following steps before submitting your PR -->
<!--- Feel free to type an `x` in all the boxes below to let us know you have completed the steps: -->
- [ ] Checked that [the pre-commit linting/style checks pass](https://docs.pymc.io/en/latest/contributing/python_style.html)
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a [relevant logical change](https://wiki.openstack.org/wiki/GitCommitMessages#Structural_split_of_changes)
<!--- You may find this guide helpful: https://mainmatter.com/blog/2021/05/26/keeping-a-clean-git-history/ -->

## Modules affected
<!--- Please list the modules that are affected by this PR by typing an `x` in the boxes below: -->
- [ ] MMM
- [ ] CLV
<!--- Additionally, if you are a maintainer or reviewer, please make sure that the appropriate labels are added to this PR -->

## Type of change
<!--- Select one of the categories below by typing an `x` in the box -->
- [ ] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
<!--- Additionally, if you are a maintainer or reviewer, please make sure that the appropriate labels are added to this PR -->
2,733 changes: 1,375 additions & 1,358 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

158 changes: 61 additions & 97 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import types
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, cast

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
from pymc import str_for_dist
from pymc import Model, str_for_dist
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pytensor.tensor import TensorVariable
Expand All @@ -22,14 +21,27 @@ class CLVModel(ModelBuilder):

def __init__(
self,
data: Optional[pd.DataFrame] = None,
*,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
):
super().__init__(model_config, sampler_config)
self.data = data

def __repr__(self):
return f"{self._model_type}\n{self.model.str_repr()}"

def _add_fit_data_group(self, data: pd.DataFrame) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
assert self.idata is not None
self.idata.add_groups(fit_data=data.to_xarray())

def fit( # type: ignore
self,
fit_method: str = "mcmc",
Expand All @@ -50,21 +62,18 @@ def fit( # type: ignore
self.build_model() # type: ignore

if fit_method == "mcmc":
self._fit_mcmc(**kwargs)
idata = self._fit_mcmc(**kwargs)
elif fit_method == "map":
self._fit_MAP(**kwargs)
idata = self._fit_MAP(**kwargs)
else:
raise ValueError(
f"Fit method options are ['mcmc', 'map'], got: {fit_method}"
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=self.data.to_xarray()) # type: ignore
self.idata = idata
self.set_idata_attrs(self.idata)
if self.data is not None:
self._add_fit_data_group(self.data)

return self.idata

Expand Down Expand Up @@ -92,58 +101,22 @@ def _fit_mcmc(self, **kwargs) -> az.InferenceData:
if self.sampler_config is not None:
sampler_config = self.sampler_config.copy()
sampler_config.update(**kwargs)
self.idata = self.sample_model(**sampler_config)
return self.idata

def sample_model(self, **kwargs):
"""
Sample from the PyMC model.
Parameters
----------
**kwargs : dict
Additional keyword arguments to pass to the PyMC sampler.
Returns
-------
xarray.Dataset
The PyMC samples dataset.
Raises
------
RuntimeError
If the PyMC model hasn't been built yet.
"""
if self.model is None:
raise RuntimeError(
"The model hasn't been built yet, call .build_model() first or call .fit() instead."
)

with self.model:
sampler_args = {**self.sampler_config, **kwargs}
idata = pm.sample(**sampler_args)

self.set_idata_attrs(idata)
return idata
return pm.sample(**sampler_config, model=self.model)

def _fit_MAP(self, **kwargs):
def _fit_MAP(self, **kwargs) -> az.InferenceData:
"""Find model maximum a posteriori using scipy optimizer"""
model = self.model
map_res = pm.find_MAP(model=model, **kwargs)
# Filter non-value variables
value_vars_names = set(v.name for v in model.value_vars)
value_vars_names = set(v.name for v in cast(Model, model).value_vars)
map_res = {k: v for k, v in map_res.items() if k in value_vars_names}
# Convert map result to InferenceData
map_strace = NDArray(model=model)
map_strace.setup(draws=1, chain=0)
map_strace.record(map_res)
map_strace.close()
trace = MultiTrace([map_strace])
idata = pm.to_inference_data(trace, model=model)
self.set_idata_attrs(idata)
self.idata = idata
return self.idata
return pm.to_inference_data(trace, model=model)

@classmethod
def load(cls, fname: str):
Expand Down Expand Up @@ -173,24 +146,52 @@ def load(cls, fname: str):
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
return cls._build_with_idata(idata)

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata

model.build_model() # type: ignore

if model.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
raise ValueError(f"Inference data not compatible with {cls._model_type}")
return model

def thin_fit_result(self, keep_every: int):
"""Return a copy of the model with a thinned fit result.
This is useful when computing summary statistics that may require too much memory per posterior draw.
Examples
--------
.. code-block:: python
fitted_gg = ...
fitted bg = ...
fitted_gg_thinned = fitted_gg.thin_fit_result(keep_every=10)
fitted_bg_thinned = fitted_bg.thin_fit_result(keep_every=10)
clv_thinned = fitted_gg_thinned.expected_customer_lifetime_value(
transaction_model=fitted_bg_thinned,
customer_id=t.index,
frequency=t["frequency"],
recency=t["recency"],
T=t["T"],
mean_transaction_value=t["monetary_value"],
)
# All previously used data is in idata.
return model
"""
self.fit_result # Raise Error if fit didn't happen yet
assert self.idata is not None
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
return type(self)._build_with_idata(new_idata)

@staticmethod
def _check_prior_ndim(prior, ndim: int = 0):
Expand Down Expand Up @@ -228,39 +229,6 @@ def default_sampler_config(self) -> Dict:
def _serializable_model_config(self) -> Dict:
return self.model_config

def sample_prior_predictive( # type: ignore
self,
samples: int = 1000,
extend_idata: bool = True,
combined: bool = True,
**kwargs,
):
if self.model is not None:
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(
samples, **kwargs
)
self.set_idata_attrs(prior_pred)
if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred)
else:
self.idata = prior_pred

prior_predictive_samples = az.extract(
prior_pred, "prior_predictive", combined=combined
)

return prior_predictive_samples

@property
def prior_predictive(self) -> az.InferenceData:
if self.idata is None or "prior_predictive" not in self.idata:
raise RuntimeError(
"No prior predictive samples available, call sample_prior_predictive() first"
)
return self.idata["prior_predictive"]

@property
def fit_result(self) -> Dataset:
if self.idata is None or "posterior" not in self.idata:
Expand Down Expand Up @@ -293,11 +261,7 @@ def fit_summary(self, **kwargs):
def output_var(self):
pass

def _generate_and_preprocess_model_data(
self,
X: Union[pd.DataFrame, pd.Series],
y: Union[pd.Series, np.ndarray[Any, Any]],
) -> None:
def _generate_and_preprocess_model_data(self, *args, **kwargs):
pass

def _data_setter(self):
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def __init__(
except KeyError:
raise KeyError("T column is missing from data")
super().__init__(
data=data,
model_config=model_config,
sampler_config=sampler_config,
)
self.data = data
self.a_prior = self._create_distribution(self.model_config["a_prior"])
self.b_prior = self._create_distribution(self.model_config["b_prior"])
self.alpha_prior = self._create_distribution(self.model_config["alpha_prior"])
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/clv/models/gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ class BaseGammaGammaModel(CLVModel):
def __init__(
self,
data: pd.DataFrame,
*,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
):
super().__init__(model_config, sampler_config)
self.data = data
super().__init__(
data=data, model_config=model_config, sampler_config=sampler_config
)
self.p_prior = self._create_distribution(self.model_config["p_prior"])
self.q_prior = self._create_distribution(self.model_config["q_prior"])
self.v_prior = self._create_distribution(self.model_config["v_prior"])
Expand Down
Loading

0 comments on commit f99600b

Please sign in to comment.