Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method thin_fit_result to CLV models #393

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 61 additions & 97 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import types
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, cast

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
from pymc import str_for_dist
from pymc import Model, str_for_dist
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pytensor.tensor import TensorVariable
Expand All @@ -22,14 +21,27 @@ class CLVModel(ModelBuilder):

def __init__(
self,
data: Optional[pd.DataFrame] = None,
*,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
):
super().__init__(model_config, sampler_config)
self.data = data

def __repr__(self):
return f"{self._model_type}\n{self.model.str_repr()}"

def _add_fit_data_group(self, data: pd.DataFrame) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
assert self.idata is not None
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
self.idata.add_groups(fit_data=data.to_xarray())

def fit( # type: ignore
self,
fit_method: str = "mcmc",
Expand All @@ -50,21 +62,18 @@ def fit( # type: ignore
self.build_model() # type: ignore

if fit_method == "mcmc":
self._fit_mcmc(**kwargs)
idata = self._fit_mcmc(**kwargs)
elif fit_method == "map":
self._fit_MAP(**kwargs)
idata = self._fit_MAP(**kwargs)
else:
raise ValueError(
f"Fit method options are ['mcmc', 'map'], got: {fit_method}"
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=self.data.to_xarray()) # type: ignore
self.idata = idata
self.set_idata_attrs(self.idata)
if self.data is not None:
self._add_fit_data_group(self.data)

return self.idata

Expand Down Expand Up @@ -92,58 +101,22 @@ def _fit_mcmc(self, **kwargs) -> az.InferenceData:
if self.sampler_config is not None:
sampler_config = self.sampler_config.copy()
sampler_config.update(**kwargs)
self.idata = self.sample_model(**sampler_config)
return self.idata

def sample_model(self, **kwargs):
"""
Sample from the PyMC model.

Parameters
----------
**kwargs : dict
Additional keyword arguments to pass to the PyMC sampler.

Returns
-------
xarray.Dataset
The PyMC samples dataset.

Raises
------
RuntimeError
If the PyMC model hasn't been built yet.

"""
if self.model is None:
raise RuntimeError(
"The model hasn't been built yet, call .build_model() first or call .fit() instead."
)

with self.model:
sampler_args = {**self.sampler_config, **kwargs}
idata = pm.sample(**sampler_args)

self.set_idata_attrs(idata)
return idata
return pm.sample(**sampler_config, model=self.model)

def _fit_MAP(self, **kwargs):
def _fit_MAP(self, **kwargs) -> az.InferenceData:
"""Find model maximum a posteriori using scipy optimizer"""
model = self.model
map_res = pm.find_MAP(model=model, **kwargs)
# Filter non-value variables
value_vars_names = set(v.name for v in model.value_vars)
value_vars_names = set(v.name for v in cast(Model, model).value_vars)
map_res = {k: v for k, v in map_res.items() if k in value_vars_names}
# Convert map result to InferenceData
map_strace = NDArray(model=model)
map_strace.setup(draws=1, chain=0)
map_strace.record(map_res)
map_strace.close()
trace = MultiTrace([map_strace])
idata = pm.to_inference_data(trace, model=model)
self.set_idata_attrs(idata)
self.idata = idata
return self.idata
return pm.to_inference_data(trace, model=model)

@classmethod
def load(cls, fname: str):
Expand Down Expand Up @@ -173,24 +146,52 @@ def load(cls, fname: str):
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
return cls._build_with_idata(idata)

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata

model.build_model() # type: ignore

if model.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
raise ValueError(f"Inference data not compatible with {cls._model_type}")
return model

def thin_fit_result(self, keep_every: int):
"""Return a copy of the model with a thinned fit result.

This is useful when computing summary statistics that may require too much memory per posterior draw.

Examples
--------

.. code-block:: python

fitted_gg = ...
fitted bg = ...

fitted_gg_thinned = fitted_gg.thin_fit_result(keep_every=10)
fitted_bg_thinned = fitted_bg.thin_fit_result(keep_every=10)

clv_thinned = fitted_gg_thinned.expected_customer_lifetime_value(
transaction_model=fitted_bg_thinned,
customer_id=t.index,
frequency=t["frequency"],
recency=t["recency"],
T=t["T"],
mean_transaction_value=t["monetary_value"],
)
# All previously used data is in idata.

return model
"""
self.fit_result # Raise Error if fit didn't happen yet
assert self.idata is not None
ColtAllen marked this conversation as resolved.
Show resolved Hide resolved
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
return type(self)._build_with_idata(new_idata)

@staticmethod
def _check_prior_ndim(prior, ndim: int = 0):
Expand Down Expand Up @@ -228,39 +229,6 @@ def default_sampler_config(self) -> Dict:
def _serializable_model_config(self) -> Dict:
return self.model_config

def sample_prior_predictive( # type: ignore
self,
samples: int = 1000,
extend_idata: bool = True,
combined: bool = True,
**kwargs,
):
if self.model is not None:
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(
samples, **kwargs
)
self.set_idata_attrs(prior_pred)
if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred)
else:
self.idata = prior_pred

prior_predictive_samples = az.extract(
prior_pred, "prior_predictive", combined=combined
)

return prior_predictive_samples

@property
def prior_predictive(self) -> az.InferenceData:
if self.idata is None or "prior_predictive" not in self.idata:
raise RuntimeError(
"No prior predictive samples available, call sample_prior_predictive() first"
)
return self.idata["prior_predictive"]

ColtAllen marked this conversation as resolved.
Show resolved Hide resolved
@property
def fit_result(self) -> Dataset:
if self.idata is None or "posterior" not in self.idata:
Expand Down Expand Up @@ -293,11 +261,7 @@ def fit_summary(self, **kwargs):
def output_var(self):
pass

def _generate_and_preprocess_model_data(
self,
X: Union[pd.DataFrame, pd.Series],
y: Union[pd.Series, np.ndarray[Any, Any]],
) -> None:
def _generate_and_preprocess_model_data(self, *args, **kwargs):
pass

def _data_setter(self):
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def __init__(
except KeyError:
raise KeyError("T column is missing from data")
super().__init__(
data=data,
model_config=model_config,
sampler_config=sampler_config,
)
self.data = data
self.a_prior = self._create_distribution(self.model_config["a_prior"])
self.b_prior = self._create_distribution(self.model_config["b_prior"])
self.alpha_prior = self._create_distribution(self.model_config["alpha_prior"])
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/clv/models/gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ class BaseGammaGammaModel(CLVModel):
def __init__(
self,
data: pd.DataFrame,
*,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
):
super().__init__(model_config, sampler_config)
self.data = data
super().__init__(
data=data, model_config=model_config, sampler_config=sampler_config
)
self.p_prior = self._create_distribution(self.model_config["p_prior"])
self.q_prior = self._create_distribution(self.model_config["q_prior"])
self.v_prior = self._create_distribution(self.model_config["v_prior"])
Expand Down
Loading