diff --git a/docs/source/conf.py b/docs/source/conf.py index 3dea2adf..9a9eeba1 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +"""Sphinx configuration for PyMC Marketing Docs.""" import os diff --git a/docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb b/docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb index c6c4e166..daefab7f 100644 --- a/docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb +++ b/docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb @@ -4550,8 +4550,7 @@ "def plot_posterior(\n", " posterior, figsize=(15, 8), path_color=\"blue\", hist_color=\"blue\", **kwargs\n", "):\n", - " \"\"\"\n", - " Plot the posterior distribution of a stochastic process.\n", + " \"\"\"Plot the posterior distribution of a stochastic process.\n", "\n", " Parameters\n", " ----------\n", @@ -4565,8 +4564,8 @@ " Color of the histogram.\n", " **kwargs\n", " Additional keyword arguments to pass to the plotting functions.\n", - " \"\"\"\n", "\n", + " \"\"\"\n", " # Calculate the expected value (mean) across all draws and chains for each date\n", " expected_value = posterior.mean(dim=(\"draw\", \"chain\"))\n", "\n", diff --git a/docs/source/uml/classes_clv.png b/docs/source/uml/classes_clv.png index 9bd99752..38a15be5 100644 Binary files a/docs/source/uml/classes_clv.png and b/docs/source/uml/classes_clv.png differ diff --git a/docs/source/uml/classes_mmm.png b/docs/source/uml/classes_mmm.png index 591a9f76..617b30e6 100644 Binary files a/docs/source/uml/classes_mmm.png and b/docs/source/uml/classes_mmm.png differ diff --git a/pymc_marketing/__init__.py b/pymc_marketing/__init__.py index 9e31fd8c..51d9a613 100644 --- a/pymc_marketing/__init__.py +++ b/pymc_marketing/__init__.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""PyMC Marketing.""" + from pymc_marketing import clv, mmm from pymc_marketing.version import __version__ diff --git a/pymc_marketing/clv/__init__.py b/pymc_marketing/clv/__init__.py index 57a4f482..f80ce9a8 100644 --- a/pymc_marketing/clv/__init__.py +++ b/pymc_marketing/clv/__init__.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""CLV models and utilities.""" + from pymc_marketing.clv.models import ( BetaGeoModel, GammaGammaModel, diff --git a/pymc_marketing/clv/distributions.py b/pymc_marketing/clv/distributions.py index 196f9e2d..e8a6f031 100644 --- a/pymc_marketing/clv/distributions.py +++ b/pymc_marketing/clv/distributions.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Distributions for the CLV module.""" + import numpy as np import pymc as pm import pytensor.tensor as pt @@ -80,9 +82,9 @@ def _supp_shape_from_params(*args, **kwargs): class ContNonContract(PositiveContinuous): - r""" - Individual-level model for the customer lifetime value. See equation (3) - from Fader et al. (2005) [1]_. + r"""Individual-level model for the customer lifetime value. + + See equation (3) from Fader et al. (2005) [1]_. .. math:: @@ -100,15 +102,18 @@ class ContNonContract(PositiveContinuous): .. [1] Fader, Peter S., Bruce GS Hardie, and Ka Lok Lee. "“Counting your customers” the easy way: An alternative to the Pareto/NBD model." Marketing science 24.2 (2005): 275-284. + """ rv_op = continuous_non_contractual @classmethod def dist(cls, lam, p, T, **kwargs): + """Get the distribution from the parameters.""" return super().dist([lam, p, T], **kwargs) def logp(value, lam, p, T): + """Log-likelihood of the distribution.""" t_x = value[..., 0] x = value[..., 1] @@ -206,10 +211,10 @@ def _supp_shape_from_params(*args, **kwargs): class ContContract(PositiveContinuous): - r""" - Distribution class of a continuous contractual data-generating process, - that is where purchases can occur at any time point (continuous) and - churning/dropping out is explicit (contractual). + r"""Distribution class of a continuous contractual data-generating process. + + That is where purchases can occur at any time point (continuous) and churning/dropping + out is explicit (contractual). .. math:: @@ -228,9 +233,11 @@ class ContContract(PositiveContinuous): @classmethod def dist(cls, lam, p, T, **kwargs): + """Get the distribution from the parameters.""" return super().dist([lam, p, T], **kwargs) def logp(value, lam, p, T): + """Log-likelihood of the distribution.""" t_x = value[..., 0] x = value[..., 1] churn = value[..., 2] @@ -358,9 +365,9 @@ def _supp_shape_from_params(*args, **kwargs): class ParetoNBD(PositiveContinuous): - r""" - Population-level distribution class for a continuous, non-contractual, Pareto/NBD process, - based on Schmittlein, et al. in [2]_. + r"""Population-level distribution class for a continuous, non-contractual, Pareto/NBD process. + + It is based on Schmittlein, et al. in [2]_. The likelihood function is derived from equations (22) and (23) of [3]_, with terms rearranged for numerical stability. @@ -403,15 +410,18 @@ class ParetoNBD(PositiveContinuous): .. [3] Fader, Peter & G. S. Hardie, Bruce (2005). "A Note on Deriving the Pareto/NBD Model and Related Expressions." http://brucehardie.com/notes/009/pareto_nbd_derivations_2005-11-05.pdf + """ # noqa: E501 rv_op = pareto_nbd @classmethod def dist(cls, r, alpha, s, beta, T, **kwargs): + """Get the distribution from the parameters.""" return super().dist([r, alpha, s, beta, T], **kwargs) def logp(value, r, alpha, s, beta, T): + """Log-likelihood of the distribution.""" t_x = value[..., 0] x = value[..., 1] @@ -555,9 +565,9 @@ def _supp_shape_from_params(*args, **kwargs): class BetaGeoBetaBinom(Discrete): - r""" - Population-level distribution class for a discrete, non-contractual, Beta-Geometric/Beta-Binomial process, - based on equation(5) from Fader, et al. in [1]_. + r"""Population-level distribution class for a discrete, non-contractual, Beta-Geometric/Beta-Binomial process. + + It is based on equation(5) from Fader, et al. in [1]_. .. math:: @@ -584,9 +594,11 @@ class BetaGeoBetaBinom(Discrete): @classmethod def dist(cls, alpha, beta, gamma, delta, T, **kwargs): + """Get the distribution from the parameters.""" return super().dist([alpha, beta, gamma, delta, T], **kwargs) def logp(value, alpha, beta, gamma, delta, T): + """Log-likelihood of the distribution.""" t_x = pt.atleast_1d(value[..., 0]) x = pt.atleast_1d(value[..., 1]) scalar_case = t_x.type.broadcastable == (True,) diff --git a/pymc_marketing/clv/models/__init__.py b/pymc_marketing/clv/models/__init__.py index e7fa1bbf..2c3a5a37 100644 --- a/pymc_marketing/clv/models/__init__.py +++ b/pymc_marketing/clv/models/__init__.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""CLV models.""" + from pymc_marketing.clv.models.basic import CLVModel from pymc_marketing.clv.models.beta_geo import BetaGeoModel from pymc_marketing.clv.models.gamma_gamma import ( diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 83394b90..cd8ddeda 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""CLV Model base class.""" + import json import warnings from collections.abc import Sequence @@ -32,6 +34,8 @@ class CLVModel(ModelBuilder): + """CLV Model base class.""" + _model_type = "CLVModel" @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -69,6 +73,7 @@ def _validate_cols( raise ValueError(f"Column {required_col} has duplicate entries") def __repr__(self) -> str: + """Representation of the model.""" if not hasattr(self, "model"): return self._model_type else: @@ -89,7 +94,7 @@ def fit( # type: ignore fit_method: str = "mcmc", **kwargs, ) -> az.InferenceData: - """Infer model posterior + """Infer model posterior. Parameters ---------- @@ -99,8 +104,8 @@ def fit( # type: ignore - "map": Finds maximum a posteriori via `pymc.find_MAP` kwargs: Other keyword arguments passed to the underlying PyMC routines - """ + """ self.build_model() # type: ignore if fit_method == "mcmc": @@ -120,8 +125,8 @@ def fit( # type: ignore return self.idata def _fit_mcmc(self, **kwargs) -> az.InferenceData: - """ - Fit a model using the data passed as a parameter. + """Fit a model using the data passed as a parameter. + Sets attrs to inference data of the model. @@ -138,6 +143,7 @@ def _fit_mcmc(self, **kwargs) -> az.InferenceData: ------- self : az.InferenceData returns inference data of the fitted model. + """ sampler_config = {} if self.sampler_config is not None: @@ -146,7 +152,7 @@ def _fit_mcmc(self, **kwargs) -> az.InferenceData: return pm.sample(**sampler_config, model=self.model) def _fit_MAP(self, **kwargs) -> az.InferenceData: - """Find model maximum a posteriori using scipy optimizer""" + """Find model maximum a posteriori using scipy optimizer.""" model = self.model map_res = pm.find_MAP(model=model, **kwargs) # Filter non-value variables @@ -162,8 +168,8 @@ def _fit_MAP(self, **kwargs) -> az.InferenceData: @classmethod def load(cls, fname: str): - """ - Creates a ModelBuilder instance from a file, + """Create a ModelBuilder instance from a file. + Loads inference data for the model. Parameters @@ -179,12 +185,14 @@ def load(cls, fname: str): ------ ValueError If the inference data that is loaded doesn't match with the model. + Examples -------- >>> class MyModel(ModelBuilder): >>> ... >>> name = './mymodel.nc' >>> imported_model = MyModel.load(name) + """ filepath = Path(str(fname)) idata = from_netcdf(filepath) @@ -242,6 +250,7 @@ def thin_fit_result(self, keep_every: int): @property def default_sampler_config(self) -> dict: + """Default sampler configuration.""" return {} @property @@ -250,6 +259,7 @@ def _serializable_model_config(self) -> dict: @property def fit_result(self) -> Dataset: + """Get the fit result.""" if self.idata is None or "posterior" not in self.idata: raise RuntimeError("The model hasn't been fit yet, call .fit() first") return self.idata["posterior"] @@ -265,6 +275,7 @@ def fit_result(self, res: az.InferenceData) -> None: self.idata.posterior = res def fit_summary(self, **kwargs): + """Compute the summary of the fit result.""" res = self.fit_result # Map fitting only gives one value, so we return it. We use arviz # just to get it nicely into a DataFrame @@ -278,10 +289,13 @@ def fit_summary(self, **kwargs): @property def output_var(self): + """Output variable of the model.""" pass def _generate_and_preprocess_model_data(self, *args, **kwargs): + """Generate and preprocess model data.""" pass def _data_setter(self): + """Set the data for the model.""" pass diff --git a/pymc_marketing/clv/models/beta_geo.py b/pymc_marketing/clv/models/beta_geo.py index 464134b7..ac7c1027 100644 --- a/pymc_marketing/clv/models/beta_geo.py +++ b/pymc_marketing/clv/models/beta_geo.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Beta-Geometric Negative Binomial Distribution (BG/NBD) model for a non-contractual customer population across continuous time.""" # noqa: E501 + import warnings from collections.abc import Sequence @@ -31,8 +33,9 @@ class BetaGeoModel(CLVModel): - r"""Beta-Geometric Negative Binomial Distribution (BG/NBD) model for a non-contractual customer population - across continuous time. First introduced by Fader, Hardie & Lee [1]_, with additional predictive methods + r"""Beta-Geometric Negative Binomial Distribution (BG/NBD) model for a non-contractual customer population across continuous time. + + First introduced by Fader, Hardie & Lee [1]_, with additional predictive methods and enhancements in [2]_ and [3]_. The BG/NBD model assumes dropout probabilities for the customer population are Beta distributed, @@ -133,7 +136,8 @@ class BetaGeoModel(CLVModel): P (alive) using the BG/NBD model." http://www.brucehardie.com/notes/021/palive_for_BGNBD.pdf. .. [3] Fader, P. S. & Hardie, B. G. (2013) "Overcoming the BG/NBD Model's #NUM! Error Problem." http://brucehardie.com/notes/027/bgnbd_num_error.pdf. - """ + + """ # noqa: E501 _model_type = "BG/NBD" # Beta-Geometric Negative Binomial Distribution @@ -156,6 +160,7 @@ def __init__( @property def default_model_config(self) -> ModelConfig: + """Default model configuration.""" return { "a_prior": Prior("HalfFlat"), "b_prior": Prior("HalfFlat"), @@ -164,6 +169,7 @@ def default_model_config(self) -> ModelConfig: } def build_model(self) -> None: # type: ignore[override] + """Build the model.""" coords = {"customer_id": self.data["customer_id"]} with pm.Model(coords=coords) as self.model: a = self.model_config["a_prior"].create_variable("a") @@ -173,6 +179,8 @@ def build_model(self) -> None: # type: ignore[override] def logp(t_x, x, a, b, r, alpha, T): """ + Compute the log-likelihood of the BG/NBD model. + The log-likelihood expression here aligns with expression (4) from [3] due to the possible numerical instability of expression (3). """ @@ -231,8 +239,10 @@ def _extract_predictive_variables( data: pd.DataFrame, customer_varnames: Sequence[str] = (), ) -> xarray.Dataset: - """Utility function assigning default customer arguments - for predictive methods and converting to xarrays. + """ + Extract predictive variables from the data. + + Utility function assigning default customer arguments for predictive methods and converting to xarrays. """ self._validate_cols( data, @@ -273,7 +283,8 @@ def expected_num_purchases( recency: np.ndarray | pd.Series | TensorVariable, T: np.ndarray | pd.Series | TensorVariable, ) -> xarray.DataArray: - r""" + r"""Compute the expected number of purchases for a customer. + This is a deprecated method and will be removed in a future release. Please use `BetaGeoModel.expected_purchases` instead. """ @@ -317,9 +328,9 @@ def expected_purchases( *, future_t: int | np.ndarray | pd.Series | None = None, ) -> xarray.DataArray: - r""" - Predict the expected number of future purchases across *future_t* time periods given *recency*, *frequency*, - and *T* for each customer. *data* parameter is only required for out-of-sample customers. + r"""Compute the expected number of future purchases across *future_t* time periods given *recency*, *frequency*, and *T* for each customer. + + The *data* parameter is only required for out-of-sample customers. Adapted from equation (10) in [1]_, and *lifetimes* package: https://github.com/CamDavidsonPilon/lifetimes/blob/41e394923ad72b17b5da93e88cfabab43f51abe2/lifetimes/fitters/beta_geo_fitter.py#L201 @@ -342,7 +353,8 @@ def expected_purchases( "Counting Your Customers the Easy Way: An Alternative to the Pareto/NBD Model," Marketing Science, 24 (2), 275-84. https://www.brucehardie.com/papers/bgnbd_2004-04-20.pdf - """ + + """ # noqa: E501 if data is None: data = self.data @@ -380,9 +392,9 @@ def expected_probability_alive( self, data: pd.DataFrame | None = None, ) -> xarray.DataArray: - r""" - Estimate probability a customer with history *frequency*, *recency*, and *T* - is currently active. *data* parameter is only required for out-of-sample customers. + r"""Compute the probability a customer with history *frequency*, *recency*, and *T* is currently active. + + The *data* parameter is only required for out-of-sample customers. Adapted from page (2) in Bruce Hardie's notes [1]_, and *lifetimes* package: https://github.com/CamDavidsonPilon/lifetimes/blob/41e394923ad72b17b5da93e88cfabab43f51abe2/lifetimes/fitters/beta_geo_fitter.py#L260 @@ -401,6 +413,7 @@ def expected_probability_alive( ---------- .. [1] Fader, P. S., Hardie, B. G., & Lee, K. L. (2008). Computing P (alive) using the BG/NBD model. http://www.brucehardie.com/notes/021/palive_for_BGNBD.pdf. + """ if data is None: data = self.data @@ -425,7 +438,8 @@ def expected_probability_alive( ) def expected_num_purchases_new_customer(self, *args, **kwargs) -> xarray.DataArray: - """ + """Compute the expected number of purchases for a new customer. + This is a deprecated method and will be removed in a future release. Please use `BetaGeoModel.expected_purchases_new_customer` instead. """ @@ -442,8 +456,7 @@ def expected_purchases_new_customer( *, t: np.ndarray | pd.Series, ) -> xarray.DataArray: - r""" - Expected number of purchases for a new customer across *t* time periods. + r"""Compute the expected number of purchases for a new customer across *t* time periods. Adapted from equation (9) in [1]_, and `lifetimes` library: https://github.com/CamDavidsonPilon/lifetimes/blob/41e394923ad72b17b5da93e88cfabab43f51abe2/lifetimes/fitters/beta_geo_fitter.py#L328 @@ -452,12 +465,14 @@ def expected_purchases_new_customer( ---------- t : array_like Number of time periods over which to estimate purchases. + References ---------- .. [1] Fader, Peter S., Bruce G.S. Hardie, and Ka Lok Lee (2005a), "Counting Your Customers the Easy Way: An Alternative to the Pareto/NBD Model," Marketing Science, 24 (2), 275-84. http://www.brucehardie.com/notes/021/palive_for_BGNBD.pdf + """ # TODO: This is extraneous now, but needed for future covariate support. if data is None: @@ -526,6 +541,7 @@ def distribution_new_customer_dropout( ------- xarray.Dataset Dataset containing the posterior samples for the population-level dropout rate. + """ return self._distribution_new_customers( random_seed=random_seed, @@ -550,6 +566,7 @@ def distribution_new_customer_purchase_rate( ------- xarray.Dataset Dataset containing the posterior samples for the population-level purchase rate. + """ return self._distribution_new_customers( random_seed=random_seed, diff --git a/pymc_marketing/clv/models/gamma_gamma.py b/pymc_marketing/clv/models/gamma_gamma.py index b9d44c74..bd64f143 100644 --- a/pymc_marketing/clv/models/gamma_gamma.py +++ b/pymc_marketing/clv/models/gamma_gamma.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Gamma-Gamma Model for expected future monetary value.""" + import numpy as np import pandas import pymc as pm @@ -25,13 +27,14 @@ class BaseGammaGammaModel(CLVModel): + """Base class for Gamma-Gamma models.""" + def distribution_customer_spend( self, data: pandas.DataFrame, random_seed: RandomState | None = None, ) -> xarray.DataArray: - """ - Posterior distribution of mean spend values for each customer. + """Posterior distribution of mean spend values for each customer. Parameters ---------- @@ -44,8 +47,8 @@ def distribution_customer_spend( random_seed : ~RandomState, optional Optional random seed to fix sampling results. - """ + """ x = data["frequency"] z_mean = data["monetary_value"] @@ -69,7 +72,9 @@ def expected_customer_spend( self, data: pandas.DataFrame, ) -> xarray.DataArray: - """Expected future mean spend value per customer. Based on Eq 5 from [1], p.3. + """Compute the expected future mean spend value per customer. + + The computations are based on Eq 5 from [1], p.3. Adapted from: https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/lifetimes/fitters/gamma_gamma_fitter.py#L117 @@ -84,8 +89,8 @@ def expected_customer_spend( ---------- .. [1] Fader, P. S., & Hardie, B. G. (2013). "The Gamma-Gamma model of monetary value". February, 2, 1-9. https://www.brucehardie.com/notes/025/gamma_gamma.pdf - """ + """ mean_transaction_value, frequency = to_xarray( data["customer_id"], data["monetary_value"], @@ -106,8 +111,7 @@ def expected_customer_spend( def distribution_new_customer_spend( self, n: int = 1, random_seed: RandomState | None = None ) -> xarray.DataArray: - """ - Posterior distribution of mean spend values for new customers. + """Posterior distribution of mean spend values for new customers. Parameters ---------- @@ -116,6 +120,7 @@ def distribution_new_customer_spend( random_seed : ~RandomState, optional Optional random seed to fix sampling results. + """ coords = {"new_customer_id": range(n)} with pm.Model(coords=coords): @@ -133,8 +138,7 @@ def distribution_new_customer_spend( ).posterior_predictive["mean_spend"] def expected_new_customer_spend(self) -> xarray.DataArray: - """Expected mean spend value for a new customer.""" - + """Compute the expected mean spend value for a new customer.""" posterior = self.fit_result p_mean = posterior["p"] q_mean = posterior["q"] @@ -156,9 +160,10 @@ def expected_customer_lifetime_value( discount_rate: float = 0.00, time_unit: str = "D", ) -> xarray.DataArray: - """ - Compute the average lifetime value for a group of one or more customers, - and apply a discount rate for net present value estimations. + """Compute the average lifetime value for a group of one or more customers. + + In addition, it applies a discount rate for net present value estimations. + Note `future_t` is measured in months regardless of `time_unit` specified. Adapted from lifetimes package @@ -190,8 +195,8 @@ def expected_customer_lifetime_value( ------- xarray DataArray containing estimated customer lifetime values - """ + """ # Use Gamma-Gamma estimates for the expected_spend values predicted_monetary_value = self.expected_customer_spend(data=data) data.loc[:, "future_spend"] = predicted_monetary_value.mean( @@ -289,6 +294,7 @@ class GammaGammaModel(BaseGammaGammaModel): Using iso-value curves for customer base analysis”, Journal of Marketing Research, 42 (November), 415-430. https://journals.sagepub.com/doi/pdf/10.1509/jmkr.2005.42.4.415 + """ _model_type = "Gamma-Gamma Model (Mean Transactions)" @@ -310,13 +316,15 @@ def __init__( @property def default_model_config(self) -> ModelConfig: + """Default model configuration.""" return { "p_prior": Prior("HalfFlat"), "q_prior": Prior("HalfFlat"), "v_prior": Prior("HalfFlat"), } - def build_model(self): + def build_model(self) -> None: # type: ignore[override] + """Build the model.""" z_mean = pt.as_tensor_variable(self.data["monetary_value"]) x = pt.as_tensor_variable(self.data["frequency"]) @@ -425,6 +433,7 @@ class GammaGammaModelIndividual(BaseGammaGammaModel): Using iso-value curves for customer base analysis”, Journal of Marketing Research, 42 (November), 415-430. https://journals.sagepub.com/doi/pdf/10.1509/jmkr.2005.42.4.415 + """ _model_type = "Gamma-Gamma Model (Individual Transactions)" @@ -444,13 +453,15 @@ def __init__( @property def default_model_config(self) -> dict: + """Default model configuration.""" return { "p_prior": Prior("HalfFlat"), "q_prior": Prior("HalfFlat"), "v_prior": Prior("HalfFlat"), } - def build_model(self): + def build_model(self) -> None: # type: ignore[override] + """Build the model.""" z = self.data["individual_transaction_value"] coords = { diff --git a/pymc_marketing/clv/models/pareto_nbd.py b/pymc_marketing/clv/models/pareto_nbd.py index 81f8cac7..1a3aab1a 100644 --- a/pymc_marketing/clv/models/pareto_nbd.py +++ b/pymc_marketing/clv/models/pareto_nbd.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Pareto NBD Model.""" + import warnings from collections.abc import Sequence from typing import Literal, cast @@ -188,6 +191,7 @@ class ParetoNBDModel(CLVModel): .. [5] Fader, Peter & G. S. Hardie, Bruce (2007). "Incorporating Time-Invariant Covariates into the Pareto/NBD and BG/NBD Models". https://www.brucehardie.com/notes/019/time_invariant_covariates.pdf + """ _model_type = "Pareto/NBD" # Pareto Negative-Binomial Distribution @@ -224,6 +228,7 @@ def __init__( @property def default_model_config(self) -> ModelConfig: + """Default model configuration.""" return { "r_prior": Prior("Weibull", alpha=2, beta=1), "alpha_prior": Prior("Weibull", alpha=2, beta=10), @@ -236,6 +241,7 @@ def default_model_config(self) -> ModelConfig: } def build_model(self) -> None: # type: ignore[override] + """Build the model.""" coords = { "purchase_covariate": self.purchase_covariate_cols, "dropout_covariate": self.dropout_covariate_cols, @@ -331,8 +337,8 @@ def fit(self, fit_method: str = "map", **kwargs): # type: ignore kwargs : dict Other keyword arguments passed to the underlying PyMC routines - """ + """ mode = get_default_mode() if fit_method == "mcmc": # Include rewrite in mode @@ -360,7 +366,8 @@ def _logp( t_x: xarray.DataArray, T: xarray.DataArray, ) -> xarray.DataArray: - """ + """Log-likelihood of the Pareto/NBD model. + Utility function for using ParetoNBD log-likelihood in predictive methods. """ # Add one dummy dimension to the right of the scalar parameters, so they broadcast with the `T` vector @@ -386,7 +393,10 @@ def _extract_predictive_variables( data: pd.DataFrame, customer_varnames: Sequence[str] = (), ) -> xarray.Dataset: - """Utility function assigning default customer arguments + """ + Extract predictive variables from the data. + + Utility function assigning default customer arguments for predictive methods and converting to xarrays. """ self._validate_cols( @@ -463,6 +473,8 @@ def expected_purchases( future_t: int | np.ndarray | pd.Series | None = None, ) -> xarray.DataArray: """ + Compute expected number of future purchases. + Given *recency*, *frequency*, and *T* for an individual customer, this method predicts the expected number of future purchases across *future_t* time periods. @@ -492,6 +504,7 @@ def expected_purchases( .. [1] Fader, Peter & G. S. Hardie, Bruce (2005). "A Note on Deriving the Pareto/NBD Model and Related Expressions." http://brucehardie.com/notes/009/pareto_nbd_derivations_2005-11-05.pdf + """ if data is None: data = self.data @@ -539,8 +552,11 @@ def expected_probability_alive( future_t: int | np.ndarray | pd.Series | None = None, ) -> xarray.DataArray: """ + Compute expected probability of being alive. + Compute the probability that a customer with history *frequency*, *recency*, and *T* - is currently active. Can also estimate alive probability for *future_t* periods into the future. + is currently active. + Can also estimate alive probability for *future_t* periods into the future. Adapted from equation (18) in Bruce Hardie's notes [1]_. @@ -567,6 +583,7 @@ def expected_probability_alive( .. [1] Fader, Peter & G. S. Hardie, Bruce (2014). "Additional Results for the Pareto/NBD Model." https://www.brucehardie.com/notes/015/additional_pareto_nbd_results.pdf + """ if data is None: data = self.data @@ -610,6 +627,8 @@ def expected_purchase_probability( future_t: int | np.ndarray | pd.Series | None = None, ) -> xarray.DataArray: """ + Compute expected probability of *n_purchases* over *future_t* time periods. + Estimate probability of *n_purchases* over *future_t* time periods, given an individual customer's current *frequency*, *recency*, and *T*. @@ -648,6 +667,7 @@ def expected_purchase_probability( .. [1] Fader, Peter & G. S. Hardie, Bruce (2014). "Deriving the Conditional PMF of the Pareto/NBD Model." https://www.brucehardie.com/notes/028/pareto_nbd_conditional_pmf.pdf + """ if data is None: data = self.data @@ -789,8 +809,7 @@ def expected_purchases_new_customer( *, t: int | np.ndarray | pd.Series | None = None, ) -> xarray.DataArray: - """ - Expected number of purchases for a new customer across *t* time periods. + """Compute the expected number of purchases for a new customer across *t* time periods. In a model with covariates, if `data` is not specified, the dataset used for fitting will be used and a prediction will be computed for a *new customer* with each set of covariates. @@ -818,6 +837,7 @@ def expected_purchases_new_customer( .. [1] Fader, Peter & G. S. Hardie, Bruce (2005). "A Note on Deriving the Pareto/NBD Model and Related Expressions." http://brucehardie.com/notes/009/pareto_nbd_derivations_2005-11-05.pdf + """ if data is None: data = self.data @@ -853,8 +873,7 @@ def distribution_new_customer( "recency_frequency", ), ) -> xarray.Dataset: - """Utility function for posterior predictive sampling of dropout, purchase rate - and frequency/recency of new customers. + """Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers. In a model with covariates, if `data` is not specified, the dataset used for fitting will be used and a prediction will be computed for a *new customer* with each set of covariates. @@ -964,6 +983,7 @@ def distribution_new_customer_dropout( ------- ~xarray.Dataset Dataset containing the posterior samples for the population-level dropout rate. + """ return self.distribution_new_customer( data=data, @@ -998,6 +1018,7 @@ def distribution_new_customer_purchase_rate( ------- ~xarray.Dataset Dataset containing the posterior samples for the population-level purchase rate. + """ return self.distribution_new_customer( data=data, @@ -1036,6 +1057,7 @@ def distribution_new_customer_recency_frequency( ------- ~xarray.Dataset Dataset containing the posterior samples for the customer population. + """ return self.distribution_new_customer( data=data, diff --git a/pymc_marketing/clv/models/shifted_beta_geo.py b/pymc_marketing/clv/models/shifted_beta_geo.py index 28e97f0c..98df7851 100644 --- a/pymc_marketing/clv/models/shifted_beta_geo.py +++ b/pymc_marketing/clv/models/shifted_beta_geo.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Shifted Beta Geometric model.""" + from collections.abc import Sequence import numpy as np @@ -25,7 +27,7 @@ class ShiftedBetaGeoModelIndividual(CLVModel): - """Shifted Beta Geometric model + """Shifted Beta Geometric model. Model for customer behavior in a discrete contractual setting. It assumes that: * At the end of each period, a customer has a probability `theta` of renewing the contract @@ -100,6 +102,7 @@ class ShiftedBetaGeoModelIndividual(CLVModel): .. [1] Fader, P. S., & Hardie, B. G. (2007). How to project customer retention. Journal of Interactive Marketing, 21(1), 76-90. https://journals.sagepub.com/doi/pdf/10.1002/dir.20074 + """ _model_type = "Shifted-Beta-Geometric Model (Individual Customers)" @@ -131,12 +134,14 @@ def __init__( @property def default_model_config(self) -> dict: + """Default model configuration.""" return { "alpha_prior": Prior("HalfFlat"), "beta_prior": Prior("HalfFlat"), } - def build_model(self): + def build_model(self) -> None: # type: ignore[override] + """Build the model.""" coords = {"customer_id": self.data["customer_id"]} with pm.Model(coords=coords) as self.model: alpha = self.model_config["alpha_prior"].create_variable("alpha") @@ -164,7 +169,6 @@ def distribution_customer_churn_time( It ignores that some customers may have already cancelled. """ - coords = {"customer_id": customer_id} with pm.Model(coords=coords): alpha = pm.HalfFlat("alpha") diff --git a/pymc_marketing/clv/plotting.py b/pymc_marketing/clv/plotting.py index bb54b898..650af987 100644 --- a/pymc_marketing/clv/plotting.py +++ b/pymc_marketing/clv/plotting.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Plotting functions for the CLV module.""" + from collections.abc import Sequence import matplotlib.pyplot as plt @@ -181,9 +183,7 @@ def plot_frequency_recency_matrix( ax: plt.Axes | None = None, **kwargs, ) -> plt.Axes: - """ - Plot expected transactions in *future_t* time periods as a heatmap - based on customer population *frequency* and *recency*. + """Plot expected transactions in *future_t* time periods as a heatmap based on customer population *frequency* and *recency*. Parameters ---------- @@ -209,7 +209,8 @@ def plot_frequency_recency_matrix( Returns ------- axes: matplotlib.AxesSubplot - """ + + """ # noqa: E501 if max_frequency is None: max_frequency = int(model.data["frequency"].max()) @@ -277,8 +278,7 @@ def plot_probability_alive_matrix( ax: plt.Axes | None = None, **kwargs, ) -> plt.Axes: - """ - Plot probability alive matrix as a heatmap based on customer population *frequency* and *recency*. + """Plot probability alive matrix as a heatmap based on customer population *frequency* and *recency*. Parameters ---------- @@ -302,8 +302,8 @@ def plot_probability_alive_matrix( Returns ------- axes: matplotlib.AxesSubplot - """ + """ if max_frequency is None: max_frequency = int(model.data["frequency"].max()) diff --git a/pymc_marketing/clv/utils.py b/pymc_marketing/clv/utils.py index 5bcb6bd8..d870d807 100644 --- a/pymc_marketing/clv/utils.py +++ b/pymc_marketing/clv/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Utilities for the CLV module.""" + import warnings from datetime import date, datetime @@ -48,7 +50,9 @@ def customer_lifetime_value( time_unit: str = "D", ) -> xarray.DataArray: """ - Compute the average lifetime value for a group of one or more customers, + Compute customer lifetime value. + + Compute the average lifetime value for a group of one or more customers and apply a discount rate for net present value estimations. Note `future_t` is measured in months regardless of `time_unit` specified. @@ -81,13 +85,27 @@ def customer_lifetime_value( ------- xarray DataArray containing estimated customer lifetime values - """ + """ if "future_spend" not in data.columns: raise ValueError("Required column future_spend missing") def _squeeze_dims(x: xarray.DataArray): - """this utility is required for MAP-fitted model predictions to broadcast properly""" + """ + Squeeze dimensions for MAP-fitted model predictions. + + This utility is required for MAP-fitted model predictions to broadcast properly. + + Parameters + ---------- + x : xarray.DataArray + DataArray to squeeze dimensions for. + + Returns + ------- + xarray.DataArray + DataArray with squeezed dimensions. + """ dims_to_squeeze: tuple[str, ...] = () if "chain" in x.dims and len(x.chain) == 1: dims_to_squeeze += ("chain",) @@ -148,8 +166,7 @@ def _find_first_transactions( time_unit: str = "D", sort_transactions: bool | None = True, ) -> pandas.DataFrame: - """ - Return dataframe with first transactions. + """Return dataframe with first transactions. This takes a DataFrame of transaction data of the form: *customer_id, datetime [, monetary_value]* @@ -183,8 +200,8 @@ def _find_first_transactions( sort_transactions : bool, optional Default: True If raw data is already sorted in chronological order, set to `False` to improve computational efficiency. - """ + """ select_columns = [customer_id_col, datetime_col] if observation_period_end is None: @@ -264,8 +281,7 @@ def rfm_summary( include_first_transaction: bool | None = False, sort_transactions: bool | None = True, ) -> pandas.DataFrame: - """ - Summarize transaction data for use in CLV modeling or RFM segmentation. + """Summarize transaction data for use in CLV modeling or RFM segmentation. This transforms a DataFrame of transaction data of the form: *customer_id, datetime [, monetary_value]* @@ -321,8 +337,8 @@ def rfm_summary( DataFrame Dataframe containing summarized RFM data, and test columns for *frequency*, *T*, and *monetary_value* if specified - """ + """ if observation_period_end is None: observation_period_end_ts = ( pandas.to_datetime(transactions[datetime_col].max(), format=datetime_format) @@ -423,8 +439,8 @@ def rfm_train_test_split( include_first_transaction: bool | None = False, sort_transactions: bool | None = True, ) -> pandas.DataFrame: - """ - Summarize transaction data and split into training and tests datasets for CLV modeling. + """Summarize transaction data and split into training and tests datasets for CLV modeling. + This can also be used to evaluate the impact of a time-based intervention like a marketing campaign. This transforms a DataFrame of transaction data of the form: @@ -476,8 +492,8 @@ def rfm_train_test_split( DataFrame Dataframe containing summarized RFM data, and test columns for *frequency*, *T*, and *monetary_value* if specified - """ + """ if test_period_end is None: test_period_end = transactions[datetime_col].max() @@ -588,8 +604,7 @@ def rfm_segments( time_scaler: float | None = 1, sort_transactions: bool | None = True, ) -> pandas.DataFrame: - """ - Assign customers to segments based on spending behavior derived from RFM scores. + """Assign customers to segments based on spending behavior derived from RFM scores. This transforms a DataFrame of transaction data of the form: *customer_id, datetime, monetary_value* @@ -650,8 +665,8 @@ def rfm_segments( ------- DataFrame Dataframe containing summarized RFM data, RFM scores, and segment assignments - """ + """ rfm_data = rfm_summary( transactions, customer_id_col=customer_id_col, @@ -718,7 +733,23 @@ def rfm_segments( def _rfm_quartile_labels(column_name, max_label_range): - """called internally by rfm_segments to label quartiles for each variable""" + """ + Label quartiles for each variable. + + Called internally by rfm_segments to label quartiles for each variable. + + Parameters + ---------- + column_name : str + The name of the column to label. + max_label_range : int + The maximum range of labels to create. + + Returns + ------- + list[int] + A list of labels for the column. + """ # recency labels must be reversed because lower values are more desirable if column_name == "r_quartile": return list(range(max_label_range - 1, 0, -1)) diff --git a/pymc_marketing/constants.py b/pymc_marketing/constants.py index ce062bcc..1a6d6d54 100644 --- a/pymc_marketing/constants.py +++ b/pymc_marketing/constants.py @@ -11,5 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Constants for the pymc_marketing package.""" + DAYS_IN_YEAR: float = 365.25 DAYS_IN_MONTH: float = DAYS_IN_YEAR / 12 diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index da281a64..df9e8191 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -189,7 +189,6 @@ def log_data(model: Model, idata: az.InferenceData) -> None: The InferenceData object returned by the sampling method. """ - data_vars: list[TensorVariable] = ( _backwards_compatiable_data_vars(model) if not hasattr(model, "data_vars") @@ -527,7 +526,6 @@ def autolog( mlflow.log_figure(fig, "components.png") """ - arviz_summary_kwargs = arviz_summary_kwargs or {} def patch_sample(sample): diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index c636272a..78c6edcc 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Marketing Mix Models (MMM).""" + from pymc_marketing.mmm import base, delayed_saturated_mmm, preprocessing, validating from pymc_marketing.mmm.base import BaseValidateMMM, MMMModelBuilder from pymc_marketing.mmm.components.adstock import ( diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index d5f2497e..bcc59262 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -50,6 +50,8 @@ class MMMModelBuilder(ModelBuilder): + """Base class for Marketing Mix Models (MMM).""" + model: pm.Model _model_type = "BaseMMM" version = "0.0.2" @@ -81,6 +83,7 @@ def __init__( @property def methods(self) -> list[Any]: + """Get all methods of the object.""" maybe_methods = [getattr_static(self, attr) for attr in dir(self)] return [ method @@ -105,8 +108,7 @@ def validation_methods( Callable[["MMMModelBuilder", pd.DataFrame | pd.Series | np.ndarray], None] ], ]: - """ - A property that provides validation methods for features ("X") and the target variable ("y"). + """A property that provides validation methods for features ("X") and the target variable ("y"). This property scans the methods of the object and returns those marked for validation. The methods are marked by having a _tags dictionary attribute,with either "validation_X" or "validation_y" @@ -136,8 +138,7 @@ def validation_methods( def validate( self, target: str, data: pd.DataFrame | pd.Series | np.ndarray ) -> None: - """ - Validates the input data based on the specified target type. + """Validate the input data based on the specified target type. This function loops over the validation methods specified for the target type and applies them to the input data. @@ -154,6 +155,7 @@ def validate( ------ ValueError If the target type is not "X" or "y", a ValueError will be raised. + """ if target not in ["X", "y"]: raise ValueError("Target must be either 'X' or 'y'") @@ -182,8 +184,7 @@ def preprocessing_methods( ] ], ]: - """ - A property that provides preprocessing methods for features ("X") and the target variable ("y"). + """A property that provides preprocessing methods for features ("X") and the target variable ("y"). This property scans the methods of the object and returns those marked for preprocessing. The methods are marked by having a _tags dictionary attribute, with either "preprocessing_X" @@ -195,6 +196,7 @@ def preprocessing_methods( tuple of list of Callable[["MMMModelBuilder", pd.DataFrame], pd.DataFrame] A tuple where the first element is a list of methods for "X" preprocessing, and the second element is a list of methods for "y" preprocessing. + """ return ( [ @@ -212,8 +214,7 @@ def preprocessing_methods( def preprocess( self, target: str, data: pd.DataFrame | pd.Series | np.ndarray ) -> pd.DataFrame | pd.Series | np.ndarray: - """ - Preprocess the provided data according to the specified target. + """Preprocess the provided data according to the specified target. This method applies preprocessing methods to the data ("X" or "y"), which are specified in the preprocessing_methods property of this object. It iteratively applies each method in the appropriate @@ -241,6 +242,7 @@ def preprocess( ------- >>> data = pd.DataFrame({"x1": [1, 2, 3], "y": [4, 5, 6]}) >>> self.preprocess("X", data) + """ data_cp = data.copy() if target == "X": @@ -259,6 +261,7 @@ def get_target_transformer(self) -> Pipeline: Returns ------- Pipeline + """ try: return self.target_transformer # type: ignore @@ -268,6 +271,7 @@ def get_target_transformer(self) -> Pipeline: @property def prior(self) -> Dataset: + """Get the prior data.""" if self.idata is None or "prior" not in self.idata: raise RuntimeError( "The model hasn't been sampled yet, call .sample_prior_predictive() first" @@ -276,6 +280,7 @@ def prior(self) -> Dataset: @property def prior_predictive(self) -> Dataset: + """Get the prior predictive data.""" if self.idata is None or "prior_predictive" not in self.idata: raise RuntimeError( "The model hasn't been sampled yet, call .sample_prior_predictive() first" @@ -284,12 +289,14 @@ def prior_predictive(self) -> Dataset: @property def fit_result(self) -> Dataset: + """Get the posterior data.""" if self.idata is None or "posterior" not in self.idata: raise RuntimeError("The model hasn't been fit yet, call .fit() first") return self.idata["posterior"] @property def posterior_predictive(self) -> Dataset: + """Get the posterior predictive data.""" if self.idata is None or "posterior_predictive" not in self.idata: raise RuntimeError( "The model hasn't been fit yet, call .sample_posterior_predictive() first" @@ -297,6 +304,18 @@ def posterior_predictive(self) -> Dataset: return self.idata["posterior_predictive"] def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure: + """Plot the prior predictive data. + + Parameters + ---------- + **plt_kwargs + Keyword arguments passed to `plt.subplots`. + + Returns + ------- + plt.Figure + + """ prior_predictive_data: az.InferenceData = self.prior_predictive likelihood_hdi_94: DataArray = az.hdi(ary=prior_predictive_data, hdi_prob=0.94)[ @@ -357,6 +376,7 @@ def plot_posterior_predictive( Returns ------- plt.Figure + """ try: posterior_predictive_data: Dataset = self.posterior_predictive @@ -431,6 +451,7 @@ def get_errors(self, original_scale: bool = False) -> DataArray: Returns ------- DataArray + """ try: posterior_predictive_data: Dataset = self.posterior_predictive @@ -491,6 +512,7 @@ def plot_errors( Returns ------- plt.Figure + """ errors = self.get_errors(original_scale=original_scale) @@ -539,8 +561,10 @@ def _format_model_contributions(self, var_contribution: str) -> DataArray: return contributions.sum(contracted_dims) if contracted_dims else contributions def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure: - """Plot the target variable and the posterior predictive model components in - the scaled space. + """Plot the target variable and the posterior predictive model components. + + We can plot the target variable and the posterior predictive model components in + the scaled space or the original space. **plt_kwargs Additional keyword arguments to pass to `plt.subplots`. @@ -548,6 +572,7 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure: Returns ------- plt.Figure + """ channel_contributions = self._format_model_contributions( var_contribution="channel_contributions" @@ -648,6 +673,7 @@ def compute_channel_contribution_original_scale(self) -> DataArray: Returns ------- DataArray + """ channel_contribution = az.extract( data=self.fit_result, var_names=["channel_contributions"], combined=False @@ -684,6 +710,7 @@ def compute_mean_contributions_over_time( ------- pd.DataFrame A dataframe with the mean contributions of each channel and control variables over time. + """ contributions_channel_over_time = ( az.extract( @@ -801,8 +828,8 @@ def plot_grouped_contribution_breakdown_over_time( ------- plt.Figure Matplotlib figure with the plot. - """ + """ all_contributions_over_time = self.compute_mean_contributions_over_time( original_scale=original_scale ) @@ -850,6 +877,7 @@ def plot_channel_contribution_share_hdi( Returns ------- plt.Figure + """ channel_contributions_share: DataArray = ( self._get_channel_contributions_share_samples() @@ -867,11 +895,23 @@ def plot_channel_contribution_share_hdi( return fig def graphviz(self, **kwargs): + """Get the graphviz representation of the model. + + Parameters + ---------- + **kwargs + Keyword arguments for the `pm.model_to_graphviz` function + + Returns + ------- + graphviz.Digraph + + """ return pm.model_to_graphviz(self.model, **kwargs) def _process_decomposition_components(self, data: pd.DataFrame) -> pd.DataFrame: - """ - Process data to compute the sum of contributions by component and calculate their percentages. + """Process data to compute the sum of contributions by component and calculate their percentages. + The output dataframe will have columns for "component", "contribution", and "percentage". Parameters @@ -884,8 +924,8 @@ def _process_decomposition_components(self, data: pd.DataFrame) -> pd.DataFrame: pd.DataFrame A dataframe with contributions summed up by component, sorted by contribution in ascending order. With an additional column showing the percentage contribution of each component. - """ + """ dataframe = data.copy() stack_dataframe = dataframe.stack().reset_index() stack_dataframe.columns = pd.Index(["date", "component", "contribution"]) @@ -905,8 +945,9 @@ def plot_waterfall_components_decomposition( figsize: tuple[int, int] = (14, 7), **kwargs, ) -> plt.Figure: - """ - This function creates a waterfall plot. The plot shows the decomposition of the target into its components. + """Create a waterfall plot. + + The plot shows the decomposition of the target into its components. Parameters ---------- @@ -921,8 +962,8 @@ def plot_waterfall_components_decomposition( ------- fig : matplotlib.figure.Figure The matplotlib figure object. - """ + """ dataframe = self.compute_mean_contributions_over_time( original_scale=original_scale ) diff --git a/pymc_marketing/mmm/budget_optimizer.py b/pymc_marketing/mmm/budget_optimizer.py index 96434e97..d90bebe1 100644 --- a/pymc_marketing/mmm/budget_optimizer.py +++ b/pymc_marketing/mmm/budget_optimizer.py @@ -32,8 +32,7 @@ def __init__(self, message: str): class BudgetOptimizer(BaseModel): - """ - A class for optimizing budget allocation in a marketing mix model. + """A class for optimizing budget allocation in a marketing mix model. The goal of this optimization is to maximize the total expected response by allocating the given budget across different marketing channels. The @@ -57,6 +56,7 @@ class BudgetOptimizer(BaseModel): adstock_first : bool, optional Whether to apply adstock transformation first or saturation transformation first. Default is True. + """ adstock: AdstockTransformation = Field( @@ -83,9 +83,9 @@ class BudgetOptimizer(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) def objective(self, budgets: list[float]) -> float: - """ - Calculate the total response during a period of time given the budgets, - considering the saturation and adstock transformations. + """Calculate the total response during a period of time given the budgets. + + It considers the saturation and adstock transformations. Parameters ---------- @@ -96,6 +96,7 @@ def objective(self, budgets: list[float]) -> float: ------- float The negative total response value. + """ total_response = 0 first_transform, second_transform = ( @@ -131,8 +132,7 @@ def allocate_budget( custom_constraints: dict[Any, Any] | None = None, minimize_kwargs: dict[str, Any] | None = None, ) -> tuple[dict[str, float], float]: - """ - Allocate the budget based on the total budget, budget bounds, and custom constraints. + """Allocate the budget based on the total budget, budget bounds, and custom constraints. The default budget bounds are (0, total_budget) for each channel. @@ -167,6 +167,7 @@ def allocate_budget( ------ Exception If the optimization fails, an exception is raised with the reason for the failure. + """ if budget_bounds is None: budget_bounds = {channel: (0, total_budget) for channel in self.parameters} diff --git a/pymc_marketing/mmm/components/__init__.py b/pymc_marketing/mmm/components/__init__.py index 3bd01a8d..3bd2a04d 100644 --- a/pymc_marketing/mmm/components/__init__.py +++ b/pymc_marketing/mmm/components/__init__.py @@ -42,4 +42,5 @@ def function(self, x, b): adstock=adstock, adstock_first=True, ) + """ diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index 81756d1d..af1ae250 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -106,6 +106,7 @@ def __init__( super().__init__(priors=priors, prefix=prefix) def __repr__(self) -> str: + """Representation of the adstock transformation.""" return ( f"{self.__class__.__name__}(" f"prefix={self.prefix!r}, " @@ -146,7 +147,6 @@ def sample_curve( Adstocked version of the amount. """ - time_since = np.arange(0, self.l_max) coords = { "time since exposure": time_since, @@ -187,6 +187,7 @@ class GeometricAdstock(AdstockTransformation): lookup_name = "geometric" def function(self, x, alpha): + """Geometric adstock function.""" return geometric_adstock( x, alpha=alpha, l_max=self.l_max, normalize=self.normalize, mode=self.mode ) @@ -219,6 +220,7 @@ class DelayedAdstock(AdstockTransformation): lookup_name = "delayed" def function(self, x, alpha, theta): + """Delayed adstock function.""" return delayed_adstock( x, alpha=alpha, @@ -259,6 +261,7 @@ class WeibullPDFAdstock(AdstockTransformation): lookup_name = "weibull_pdf" def function(self, x, lam, k): + """Weibull adstock function.""" return weibull_adstock( x=x, lam=lam, @@ -300,6 +303,7 @@ class WeibullCDFAdstock(AdstockTransformation): lookup_name = "weibull_cdf" def function(self, x, lam, k): + """Weibull adstock function.""" return weibull_adstock( x=x, lam=lam, @@ -366,6 +370,7 @@ def __init__( ) def function(self, x, lam, k): + """Weibull adstock function.""" return weibull_adstock( x=x, lam=lam, @@ -414,7 +419,10 @@ def _get_adstock_function( function: str | AdstockTransformation, **kwargs, ) -> AdstockTransformation: - """Helper for use in the MMM to get an adstock function.""" + """Get an adstock function. + + Helper for use in the MMM to get an adstock function from the if registered. + """ if isinstance(function, AdstockTransformation): return function diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index cc23a6e2..74b8142b 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -120,6 +120,7 @@ def __init__( self.prefix = prefix or self.prefix def __repr__(self) -> str: + """Representation of the transformation.""" return ( f"{self.__class__.__name__}(" f"prefix={self.prefix!r}, " @@ -145,6 +146,7 @@ def to_dict(self) -> dict[str, Any]: } def __eq__(self, other: Any) -> bool: + """Check if two transformations are equal.""" if not isinstance(other, self.__class__): return False @@ -152,6 +154,7 @@ def __eq__(self, other: Any) -> bool: @property def function_priors(self) -> dict[str, Prior]: + """Get the priors for the function.""" return self._function_priors @function_priors.setter @@ -162,7 +165,7 @@ def function_priors(self, priors: dict[str, Any | Prior] | None) -> None: self._function_priors = {**deepcopy(self.default_priors), **priors} def update_priors(self, priors: dict[str, Prior]) -> None: - """Helper to update the priors for a function after initialization. + """Update the priors for a function after initialization. Uses {prefix}_{parameter_name} as the key for the priors instead of the parameter name in order to be used in the larger MMM. @@ -477,7 +480,7 @@ def plot_curve_hdi( ) def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable: - """Called within a model context. + """Call within a model context. Used internally of the MMM to apply the transformation to the data. diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index 767b4d64..2c8ba409 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -100,7 +100,7 @@ class SaturationTransformation(Transformation): By subclassing from this method, lift test integration will come for free! Examples - ---------- + -------- Make a non-saturating saturation transformation .. code-block:: python @@ -199,6 +199,7 @@ class LogisticSaturation(SaturationTransformation): lookup_name = "logistic" def function(self, x, lam, beta): + """Logistic saturation function.""" return beta * logistic_saturation(x, lam) default_priors = { @@ -232,6 +233,7 @@ class InverseScaledLogisticSaturation(SaturationTransformation): lookup_name = "inverse_scaled_logistic" def function(self, x, lam, beta): + """Inverse scaled logistic saturation function.""" return beta * inverse_scaled_logistic_saturation(x, lam) default_priors = { @@ -265,6 +267,7 @@ class TanhSaturation(SaturationTransformation): lookup_name = "tanh" def function(self, x, b, c, beta): + """Tanh saturation function.""" return beta * tanh_saturation(x, b, c) default_priors = { @@ -299,6 +302,7 @@ class TanhSaturationBaselined(SaturationTransformation): lookup_name = "tanh_baselined" def function(self, x, x0, gain, r, beta): + """Tanh saturation function.""" return beta * tanh_saturation_baselined(x, x0, gain, r) default_priors = { @@ -366,6 +370,7 @@ class HillSaturation(SaturationTransformation): lookup_name = "hill" def function(self, x, slope, kappa, beta): + """Hill saturation function.""" return beta * hill_function(x, slope, kappa) default_priors = { @@ -433,6 +438,7 @@ class RootSaturation(SaturationTransformation): lookup_name = "root" def function(self, x, alpha, beta): + """Root saturation function.""" return beta * root_saturation(x, alpha) default_priors = { @@ -457,11 +463,15 @@ def function(self, x, alpha, beta): def register_saturation_transformation(cls: type[SaturationTransformation]) -> None: - """Register a new saturation transformation.""" + """Register a new saturation transformation. + + Helper for use in the MMM to register a new saturation function. + """ SATURATION_TRANSFORMATIONS[cls.lookup_name] = cls def saturation_from_dict(data: dict) -> SaturationTransformation: + """Get a saturation function from a dictionary.""" data = data.copy() cls = SATURATION_TRANSFORMATIONS[data.pop("lookup_name")] @@ -475,7 +485,11 @@ def saturation_from_dict(data: dict) -> SaturationTransformation: def _get_saturation_function( function: str | SaturationTransformation, ) -> SaturationTransformation: - """Helper for use in the MMM to get a saturation function.""" + """ + Get a saturation function. + + Helper for use in the MMM to get a saturation function. + """ if isinstance(function, SaturationTransformation): return function diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index a3258ecf..76a66393 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -15,7 +15,7 @@ import json import warnings -from typing import Annotated, Any +from typing import Annotated, Any, Literal import arviz as az import matplotlib.pyplot as plt @@ -62,12 +62,12 @@ class BaseMMM(BaseValidateMMM): - """ - Base class for a media mix model using Delayed Adstock and Logistic Saturation (see [1]_). + """Base class for a media mix model using Delayed Adstock and Logistic Saturation (see [1]_). References ---------- .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). + """ _model_name: str = "BaseMMM" @@ -125,7 +125,7 @@ def __init__( True, description="Whether to apply adstock first." ), ) -> None: - """Constructor method. + """Define the constructor method. Parameter --------- @@ -217,17 +217,30 @@ def __init__( @property def default_sampler_config(self) -> dict: + """Default sampler configuration for the model. + + Returns + ------- + dict + Empty dictionary. + """ return {} @property - def output_var(self): - """Defines target variable for the model""" + def output_var(self) -> Literal["y"]: + """Define target variable for the model. + + Returns + ------- + str + The target variable for the model. + """ return "y" def _generate_and_preprocess_model_data( # type: ignore self, X: pd.DataFrame | pd.Series, y: pd.Series | np.ndarray ) -> None: - """Applies preprocessing to the data before fitting the model. + """Apply preprocessing to the data before fitting the model. If validate is True, it will check if the data is valid for the model. sets self.model_coords based on provided dataset @@ -254,6 +267,7 @@ def _generate_and_preprocess_model_data( # type: ignore The middle index of the date index. Used by TVP. _time_resolution: int The time resolution of the date index. Used by TVP. + """ try: date_data = pd.to_datetime(X[self.date_column]) @@ -299,6 +313,14 @@ def _generate_and_preprocess_model_data( # type: ignore ).days def create_idata_attrs(self) -> dict[str, str]: + """Create attributes for the inference data. + + Returns + ------- + dict[str, str] + The attributes for the inference data. + + """ attrs = super().create_idata_attrs() attrs["date_column"] = json.dumps(self.date_column) attrs["adstock"] = json.dumps(self.adstock.to_dict()) @@ -317,7 +339,7 @@ def create_idata_attrs(self) -> dict[str, str]: def forward_pass( self, x: pt.TensorVariable | npt.NDArray[np.float64] ) -> pt.TensorVariable: - """Transforms channel input into target contributions of each channel. + """Transform channel input into target contributions of each channel. This method handles the ordering of the adstock and saturation transformations. @@ -327,13 +349,14 @@ def forward_pass( associated with the number of columns of `x`. Parameters - ------------ + ---------- x : pt.TensorVariable | npt.NDArray[np.float64] The channel input which could be spends or impressions Returns - -------- + ------- The contributions associated with the channel input + """ first, second = ( (self.adstock, self.saturation) @@ -349,8 +372,7 @@ def build_model( y: pd.Series | np.ndarray, **kwargs, ) -> None: - """ - Builds a probabilistic model using PyMC for marketing mix modeling. + """Build a probabilistic model using PyMC for marketing mix modeling. The model incorporates channels, control variables, and Fourier components, applying adstock and saturation transformations to the channel data. The final model is @@ -411,7 +433,6 @@ def build_model( ) """ - self._generate_and_preprocess_model_data(X, y) with pm.Model( coords=self.model_coords, @@ -553,6 +574,7 @@ def create_deterministic(x: pt.TensorVariable) -> None: @property def default_model_config(self) -> dict: + """Define the default model configuration.""" base_config = { "intercept": Prior("Normal", mu=0, sigma=2), "likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=2)), @@ -604,6 +626,7 @@ def channel_contributions_forward_pass( ------- array-like Transformed channel data. + """ coords = { **self.model_coords, @@ -639,6 +662,14 @@ def ndarray_to_list(d: dict) -> dict: @classmethod def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: + """Convert attributes to initialization kwargs. + + Returns + ------- + dict[str, Any] + The initialization kwargs. + + """ return { "model_config": cls._model_config_formatting( json.loads(attrs["model_config"]) @@ -664,8 +695,7 @@ def _data_setter( X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series | None = None, ) -> None: - """ - Sets new data in the model. + """Set new data in the model. This function accepts data in various formats and sets them into the model using the PyMC's `set_data` method. The data corresponds to the @@ -693,6 +723,7 @@ def _data_setter( Returns ------- None + """ if not isinstance(X, pd.DataFrame): msg = "X must be a pandas DataFrame in order to access the columns" @@ -753,10 +784,22 @@ def identity(x): @classmethod def _model_config_formatting(cls, model_config: dict) -> dict: - """ + """Format the model configuration. + Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists. This function converts them back to tuples and numpy arrays to ensure correct id encoding. + + Parameters + ---------- + model_config : dict + The model configuration to format. + + Returns + ------- + dict + The formatted model configuration. + """ def format_nested_dict(d: dict) -> dict: @@ -781,8 +824,7 @@ class MMM( ValidateControlColumns, BaseMMM, ): - """ - Media Mix Model class, Delayed Adstock and logistic saturation as default initialization (see [1]_). + r"""Media Mix Model class, Delayed Adstock and logistic saturation as default initialization (see [1]_). Given a time series target variable :math:`y_{t}` (e.g. sales on conversions), media variables :math:`x_{m, t}` (e.g. impressions, clicks or costs) and a set of control covariates :math:`z_{c, t}` (e.g. holidays, special events) @@ -908,6 +950,7 @@ class MMM( ---------- .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). .. [2] Orduz, J. `"Media Effect Estimation with PyMC: Adstock, Saturation & Diminishing Returns" `_. + """ # noqa: E501 _model_type: str = "MMM" @@ -917,16 +960,19 @@ def channel_contributions_forward_pass( self, channel_data: npt.NDArray[np.float64] ) -> npt.NDArray[np.float64]: """Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass. + We return the contribution in the original scale of the target variable. Parameters ---------- channel_data : array-like Input channel data. Result of all the preprocessing steps. + Returns ------- array-like Transformed channel data. + """ channel_contribution_forward_pass = super().channel_contributions_forward_pass( channel_data=channel_data @@ -951,10 +997,12 @@ def get_channel_contributions_forward_pass_grid( End of the grid. It must be greater than start. num : int Number of points in the grid. + Returns ------- DataArray Grid of channel contributions. + """ if start < 0: raise ValueError("start must be greater than or equal to 0.") @@ -981,26 +1029,26 @@ def get_channel_contributions_forward_pass_grid( ) def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figure: - """ - Plot the posterior distribution of a specific parameter for each channel. + """Plot the posterior distribution of a specific parameter for each channel. - Parameters: + Parameters ---------- param_name : str The name of the parameter to plot. **plt_kwargs : Any Additional keyword arguments to pass to the `plt.subplots` function. - Returns: + Returns ------- plt.Figure The matplotlib Figure object containing the plot. - Raises: + Raises ------ ValueError If the specified parameter name is invalid or not found in the model saturation or adstock function. + """ saturation: SaturationTransformation = self.saturation adstock: AdstockTransformation = self.adstock @@ -1044,6 +1092,7 @@ def get_ts_contribution_posterior( ------- DataArray The posterior distribution of the time series contributions. + """ contributions = self._format_model_contributions( var_contribution=var_contribution @@ -1061,8 +1110,10 @@ def get_ts_contribution_posterior( def plot_components_contributions( self, original_scale: bool = False, **plt_kwargs: Any ) -> plt.Figure: - """Plot the target variable and the posterior predictive model components in - the scaled space. + """Plot the target variable and the posterior predictive model components. + + We can plot the target variable and the posterior predictive model components in + the scaled space or in the original space. Parameters ---------- @@ -1075,6 +1126,7 @@ def plot_components_contributions( Returns ------- plt.Figure + """ channel_contributions = self.get_ts_contribution_posterior( var_contribution="channel_contributions", original_scale=original_scale @@ -1197,7 +1249,7 @@ def plot_channel_contributions_grid( absolute_xrange: bool = False, **plt_kwargs: Any, ) -> plt.Figure: - """Plots a grid of scaled channel contributions for a given grid of share values. + """Plot a grid of scaled channel contributions for a given grid of share values. Parameters ---------- @@ -1217,6 +1269,7 @@ def plot_channel_contributions_grid( ------- plt.Figure Plot of grid of channel contributions. + """ share_grid = np.linspace(start=start, stop=stop, num=num) contributions = self.get_channel_contributions_forward_pass_grid( @@ -1336,6 +1389,7 @@ def new_spend_contributions( spend = np.ones(n_channels) spend_leading_up = np.ones(n_channels) new_spend_contributions = model.new_spend_contributions(spend=spend, spend_leading_up=spend_leading_up) + """ if spend is None: spend = self.X.loc[:, self.channel_columns].mean().to_numpy() # type: ignore @@ -1502,8 +1556,7 @@ def _channel_map_scales(self) -> dict: def format_recovered_transformation_parameters( self, quantile: float = 0.5 ) -> dict[str, dict[str, dict[str, float]]]: - """ - Format the recovered transformation parameters for each channel. + """Format the recovered transformation parameters for each channel. This function retrieves the quantile of the parameters for each channel and formats them into a dictionary containing the channel name, the saturation parameters, and the adstock parameters. @@ -1543,6 +1596,7 @@ def format_recovered_transformation_parameters( } } } + """ # Retrieve channel names channels = self.fit_result.channel.values @@ -1589,8 +1643,7 @@ def _plot_response_curve_fit( quantile_lower: float = 0.05, quantile_upper: float = 0.95, ) -> None: - """ - Plot the curve fit for the given channel based on the estimation of the parameters by the model. + """Plot the curve fit for the given channel based on the estimation of the parameters by the model. Parameters ---------- @@ -1613,8 +1666,8 @@ def _plot_response_curve_fit( ------- None The function modifies the given axes object in-place and doesn't return any object. - """ + """ if self.X is not None: x_mean = np.max(self.X[channel]) @@ -1682,10 +1735,10 @@ def plot_direct_contribution_curves( quantile_lower: float = 0.05, quantile_upper: float = 0.95, ) -> plt.Figure: - """ - Plots the direct contribution curves for each marketing channel. The term "direct" refers to the fact - we plot costs vs immediate returns and we do not take into account the lagged - effects of the channels e.g. adstock transformations. + """Plot the direct contribution curves for each marketing channel. + + The term "direct" refers to the fact that we plot costs vs immediate returns and + we do not take into account the lagged effects of the channels e.g. adstock transformations. Parameters ---------- @@ -1702,6 +1755,7 @@ def plot_direct_contribution_curves( ------- plt.Figure A matplotlib Figure object with the direct contribution curves. + """ channels_to_plot = self.channel_columns if channels is None else channels @@ -1795,9 +1849,8 @@ def sample_posterior_predictive( include_last_observations: bool = False, original_scale: bool = True, **sample_posterior_predictive_kwargs, - ): - """ - Sample from the model's posterior predictive distribution. + ) -> DataArray: + """Sample from the model's posterior predictive distribution. Parameters ---------- @@ -1821,6 +1874,7 @@ def sample_posterior_predictive( ------- posterior_predictive_samples : DataArray, shape (n_pred, samples) Posterior predictive samples for each input X_pred + """ if include_last_observations: X_pred = pd.concat( @@ -1992,8 +2046,7 @@ def _create_synth_dataset( lag: int, noise_level: float = 0.01, ) -> pd.DataFrame: - """ - Create a synthetic dataset based on the given allocation strategy (Budget) and time granularity. + """Create a synthetic dataset based on the given allocation strategy (Budget) and time granularity. Parameters ---------- @@ -2027,6 +2080,7 @@ def _create_synth_dataset( ------ ValueError If the time granularity is not supported. + """ time_offsets = { "daily": {"days": 1}, @@ -2089,8 +2143,7 @@ def allocate_budget_to_maximize_response( quantile: float = 0.5, noise_level: float = 0.01, ) -> az.InferenceData: - """ - Allocate the given budget to maximize the response over a specified time period. + """Allocate the given budget to maximize the response over a specified time period. This function optimizes the allocation of a given budget across different channels to maximize the response, considering adstock and saturation effects. It scales the @@ -2127,6 +2180,7 @@ def allocate_budget_to_maximize_response( ------ ValueError If the time granularity is not supported. + """ parameters_mid = self.format_recovered_transformation_parameters( quantile=quantile @@ -2176,8 +2230,7 @@ def plot_budget_allocation( ax: plt.Axes | None = None, original_scale: bool = True, ) -> tuple[plt.Figure, plt.Axes]: - """ - Plot the budget allocation and channel contributions. + """Plot the budget allocation and channel contributions. Parameters ---------- @@ -2195,8 +2248,8 @@ def plot_budget_allocation( ------- Tuple[plt.Figure, plt.Axes] The matplotlib figure object and axis containing the plot. - """ + """ if original_scale: channel_contributions = ( samples["channel_contributions"] @@ -2275,8 +2328,7 @@ def plot_allocated_contribution_by_channel( upper_quantile: float = 0.975, original_scale: bool = True, ) -> plt.Figure: - """ - Plot the allocated contribution by channel with uncertainty intervals. + """Plot the allocated contribution by channel with uncertainty intervals. This function visualizes the mean allocated contributions by channel along with the uncertainty intervals defined by the lower and upper quantiles. The contributions @@ -2297,6 +2349,7 @@ def plot_allocated_contribution_by_channel( ------- fig : matplotlib.figure.Figure The matplotlib figure object containing the plot. + """ if original_scale: channel_contributions = ( @@ -2324,6 +2377,8 @@ def plot_allocated_contribution_by_channel( class DelayedSaturatedMMM(MMM): + """Deprecated class for DelayedSaturatedMMM.""" + _model_type: str = "MMM" _model_name: str = "DelayedSaturatedMMM" version: str = "0.0.3" @@ -2344,6 +2399,8 @@ def __init__( adstock_first: bool = True, ) -> None: """ + Define constructor. + Wrapper function for DelayedSaturatedMMM class initializer. Warns that MMM class should be used instead and returns an instance of MMM with diff --git a/pymc_marketing/mmm/fourier.py b/pymc_marketing/mmm/fourier.py index c59ae7db..23bac226 100644 --- a/pymc_marketing/mmm/fourier.py +++ b/pymc_marketing/mmm/fourier.py @@ -287,6 +287,7 @@ class FourierBase(BaseModel): variable_name: str | None = Field(None) def model_post_init(self, __context: Any) -> None: + """Model post initialization for a Pydantic model.""" if self.variable_name is None: self.variable_name = f"{self.prefix}_beta" @@ -308,6 +309,19 @@ def _check_prior_has_right_dimensions(self) -> Self: @field_serializer("prior", when_used="json") def serialize_prior(prior: Prior) -> dict[str, Any]: + """Serialize the prior distribution. + + Parameters + ---------- + prior : Prior + The prior distribution to serialize. + + Returns + ------- + dict[str, Any] + The serialized prior distribution. + + """ return prior.to_json() @property diff --git a/pymc_marketing/mmm/lift_test.py b/pymc_marketing/mmm/lift_test.py index e42f877d..7a308eaa 100644 --- a/pymc_marketing/mmm/lift_test.py +++ b/pymc_marketing/mmm/lift_test.py @@ -74,7 +74,6 @@ def lift_test_indices(df_lift_test: pd.DataFrame, model: pm.Model) -> Indices: If some lift test values are not in the model. """ - columns = df_lift_test.columns.tolist() return { @@ -172,8 +171,8 @@ def indices_from_lift_tests( ------- dict[str, np.ndarray] Dictionary of indices for the lift test results in the model. - """ + """ named_vars_to_dims = { name: dims for name, dims in model.named_vars_to_dims.items() @@ -204,7 +203,7 @@ class NonMonotonicLiftError(Exception): def check_increasing_assumption(df_lift_tests: pd.DataFrame) -> None: - """Checks if the lift test results satisfy the increasing assumption. + """Check if the lift test results satisfy the increasing assumption. If delta_x is positive, delta_y must be positive, and vice versa. """ @@ -454,7 +453,6 @@ def scale_channel_lift_measurements( DataFrame with the scaled lift measurements. """ - # DataFrame with MultiIndex (RangeIndex, channel_col) # columns: x, delta_x df_original = df_lift_test.loc[:, [channel_col, "x", "delta_x"]].set_index( @@ -603,7 +601,10 @@ def add_lift_measurements_to_likelihood_from_saturation( dist: type[pm.Distribution] = pm.Gamma, name: str = "lift_measurements", ) -> None: - """Wrapper around :func:`add_lift_measurements_to_likelihood` to work with + """ + Add lift measurements to the likelihood from a saturation transformation. + + Wrapper around :func:`add_lift_measurements_to_likelihood` to work with SaturationTransformation instances and time-varying variables. Used internally of the :class:`MMM` class. @@ -628,7 +629,6 @@ def add_lift_measurements_to_likelihood_from_saturation( Name of the likelihood, by default "lift_measurements" """ - if time_varying_var_name: saturation_function, variable_mapping = create_time_varying_saturation( saturation=saturation, diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 0a4fd8ba..84f4eaef 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Plotting functions for the MMM.""" + import warnings from collections.abc import Generator, MutableMapping, Sequence from itertools import product @@ -29,11 +31,39 @@ def get_plot_coords(coords: Coords, non_grid_names: set[str]) -> Coords: + """Get the plot coordinates. + + Parameters + ---------- + coords : Coords + The coordinates to get the plot coordinates from. + non_grid_names : set[str] + The names to exclude from the grid. + + Returns + ------- + Coords + The plot coordinates. + + """ plot_coord_names = list(key for key in coords.keys() if key not in non_grid_names) return {name: np.array(coords[name]) for name in plot_coord_names} def get_total_coord_size(coords: Coords) -> int: + """Get the total size of the coordinates. + + Parameters + ---------- + coords : Coords + The coordinates to get the total size of. + + Returns + ------- + int + The total size of the coordinates. + + """ total_size: int = ( 1 if coords == {} else np.prod([len(values) for values in coords.values()]) # type: ignore ) @@ -47,6 +77,21 @@ def set_subplot_kwargs_defaults( subplot_kwargs: MutableMapping[str, Any], total_size: int, ) -> None: + """Set the defaults for the subplot kwargs. + + Parameters + ---------- + subplot_kwargs : MutableMapping[str, Any] + The subplot kwargs to set the defaults for. + total_size : int + The total size of the coordinates. + + Raises + ------ + ValueError + If both `ncols` and `nrows` are specified. + + """ if "ncols" in subplot_kwargs and "nrows" in subplot_kwargs: raise ValueError("Only specify one") @@ -62,7 +107,19 @@ def set_subplot_kwargs_defaults( def selections( coords: Coords, ) -> Generator[dict[str, Any], None, None]: - """Helper to create generator of selections.""" + """Create generator of selections. + + Parameters + ---------- + coords : Coords + The coordinates to create the selections from. + + Yields + ------ + dict[str, Any] + The selections. + + """ coord_names = coords.keys() for values in product(*coords.values()): yield {name: value for name, value in zip(coord_names, values, strict=True)} @@ -152,6 +209,25 @@ def random_samples( n_chains: int, n_draws: int, ) -> list[tuple[int, int]]: + """Generate random samples from the chains and draws. + + Parameters + ---------- + rng : np.random.Generator + Random number generator + n : int + Number of samples to generate + n_chains : int + Number of chains + n_draws : int + Number of draws + + Returns + ------- + list[tuple[int, int]] + The random samples + + """ combinations = list(product(range(n_chains), range(n_draws))) return [ diff --git a/pymc_marketing/mmm/preprocessing.py b/pymc_marketing/mmm/preprocessing.py index fb3cc956..b499c228 100644 --- a/pymc_marketing/mmm/preprocessing.py +++ b/pymc_marketing/mmm/preprocessing.py @@ -31,6 +31,21 @@ def preprocessing_method_X(method: Callable) -> Callable: + """Tag a method as a preprocessing method for the X data. + + Decorator to mark a method as a preprocessing method for the X data. + + Parameters + ---------- + method : Callable + The method to tag as a preprocessing method for the X data. + + Returns + ------- + Callable + The tagged method. + + """ if not hasattr(method, "_tags"): method._tags = {} # type: ignore method._tags["preprocessing_X"] = True # type: ignore @@ -38,6 +53,21 @@ def preprocessing_method_X(method: Callable) -> Callable: def preprocessing_method_y(method: Callable) -> Callable: + """Tag a method as a preprocessing method for the y data. + + Decorator to mark a method as a preprocessing method for the y data. + + Parameters + ---------- + method : Callable + The method to tag as a preprocessing method for the y data. + + Returns + ------- + Callable + The tagged method. + + """ if not hasattr(method, "_tags"): method._tags = {} # type: ignore method._tags["preprocessing_y"] = True # type: ignore @@ -45,12 +75,27 @@ def preprocessing_method_y(method: Callable) -> Callable: class MaxAbsScaleTarget: + """MaxAbsScaler for the target data.""" + target_transformer: Pipeline @preprocessing_method_y def max_abs_scale_target_data( self, data: pd.Series | np.ndarray ) -> np.ndarray | pd.Series: + """MaxAbsScaler for the target data. + + Parameters + ---------- + data : pd.Series | np.ndarray + The target data to scale. + + Returns + ------- + np.ndarray | pd.Series + The scaled target data. + + """ if isinstance(data, pd.Series): data = data.to_numpy() @@ -63,10 +108,25 @@ def max_abs_scale_target_data( class MaxAbsScaleChannels: + """MaxAbsScaler for the channel data.""" + channel_columns: list[str] | tuple[str] @preprocessing_method_X def max_abs_scale_channel_data(self, data: pd.DataFrame) -> pd.DataFrame: + """MaxAbsScaler for the channel data. + + Parameters + ---------- + data : pd.DataFrame + The channel data to scale. + + Returns + ------- + pd.DataFrame + The scaled channel data. + + """ data_cp = data.copy() channel_data: pd.DataFrame | pd.Series[Any] = data_cp[self.channel_columns] transformers = [("scaler", MaxAbsScaler())] @@ -79,10 +139,25 @@ def max_abs_scale_channel_data(self, data: pd.DataFrame) -> pd.DataFrame: class StandardizeControls: + """StandardScaler for the control data.""" + control_columns: list[str] # TODO: Handle Optional[List[str]] @preprocessing_method_X def standardize_control_data(self, data: pd.DataFrame) -> pd.DataFrame: + """StandardScaler for the control data. + + Parameters + ---------- + data : pd.DataFrame + The control data to scale. + + Returns + ------- + pd.DataFrame + The scaled control data. + + """ control_data: pd.DataFrame = data[self.control_columns] transformers = [("scaler", StandardScaler())] pipeline: Pipeline = Pipeline(steps=transformers) diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 2ae28d86..bb12b4b6 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -24,6 +24,8 @@ class ConvMode(str, Enum): + """Convolution mode for the convolution.""" + # TODO: use StrEnum when we upgrade to python 3.11 After = "After" Before = "Before" @@ -31,6 +33,8 @@ class ConvMode(str, Enum): class WeibullType(str, Enum): + """Weibull type for the Weibull adstock.""" + # TODO: use StrEnum when we upgrade to python 3.11 PDF = "PDF" CDF = "CDF" @@ -94,6 +98,7 @@ def batched_convolution( result will match the shape of ``x`` up to broadcasting with ``w``. The convolved axis will show the results of left padding zeros to ``x`` while applying the convolutions. + """ # We move the axis to the last dimension of the array so that it's easier to # reason about parameter broadcasting. We will move the axis back at the end @@ -228,8 +233,8 @@ def geometric_adstock( ---------- .. [1] Jin, Yuxue, et al. "Bayesian methods for media mix modeling with carryover and shape effects." (2017). - """ + """ w = pt.power(pt.as_tensor(alpha)[..., None], pt.arange(l_max, dtype=x.dtype)) w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w return batched_convolution(x, w, axis=axis, mode=mode) @@ -310,6 +315,7 @@ def delayed_adstock( ---------- .. [1] Jin, Yuxue, et al. "Bayesian methods for media mix modeling with carryover and shape effects." (2017). + """ w = pt.power( pt.as_tensor(alpha)[..., None], @@ -413,6 +419,7 @@ def weibull_adstock( ------- tensor Transformed tensor based on Weibull adstock transformation. + """ lam = pt.as_tensor(lam)[..., None] k = pt.as_tensor(k)[..., None] @@ -474,6 +481,7 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5): ------- tensor Transformed tensor. + """ return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x)) @@ -481,10 +489,11 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5): def inverse_scaled_logistic_saturation( x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3) ): - """Inverse scaled logistic saturation transformation. - It offers a more intuitive alternative to logistic_saturation, - allowing for lambda to be interpreted as the half saturation point - when using default value for eps. + r"""Inverse scaled logistic saturation transformation. + + It offers a more intuitive alternative to logistic_saturation, + allowing for lambda to be interpreted as the half saturation point + when using default value for eps. .. math:: f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}} @@ -523,7 +532,8 @@ def inverse_scaled_logistic_saturation( ------- tensor Transformed tensor. - """ # noqa: W605 + + """ return logistic_saturation(x, eps / lam) @@ -654,6 +664,7 @@ def tanh_saturation( References ---------- See https://www.pymc-labs.io/blog-posts/reducing-customer-acquisition-costs-how-we-helped-optimizing-hellofreshs-marketing-budget/ # noqa: E501 + """ # noqa: E501 return b * pt.tanh(x / (b * c)) @@ -664,8 +675,7 @@ def tanh_saturation_baselined( gain: pt.TensorLike = 0.5, r: pt.TensorLike = 0.5, ) -> pt.TensorVariable: - r""" - Baselined Tanh Saturation. + r"""Baselined Tanh Saturation. This parameterization that is easier than :func:`tanh_saturation` to use for industry applications where domain knowledge is an essence. @@ -806,6 +816,7 @@ def tanh_saturation_baselined( References ---------- Developed by Max Kochurov and Aziz Al-Maeeni doing innovative work in `PyMC Labs `_. + """ return gain * x0 * pt.tanh(x * pt.arctanh(r) / x0) / r @@ -815,8 +826,7 @@ def michaelis_menten( alpha: float | np.ndarray | npt.NDArray[np.float64], lam: float | np.ndarray | npt.NDArray[np.float64], ) -> float | Any: - r""" - Evaluate the Michaelis-Menten function for given values of x, alpha, and lambda. + r"""Evaluate the Michaelis-Menten function for given values of x, alpha, and lambda. The Michaelis-Menten function models enzyme kinetics and describes how the rate of a chemical reaction increases with substrate concentration until it reaches its @@ -894,15 +904,15 @@ def michaelis_menten( ------- float The value of the Michaelis-Menten function given the parameters. - """ + """ return alpha * x / (lam + x) def hill_function( x: pt.TensorLike, slope: pt.TensorLike, kappa: pt.TensorLike ) -> pt.TensorVariable: - r"""Hill Function + r"""Hill Function. .. math:: f(x) = 1 - \frac{\kappa^s}{\kappa^s + x^s} @@ -966,6 +976,7 @@ def hill_function( References ---------- .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). + """ # noqa: E501 return pt.as_tensor_variable( 1 - pt.power(kappa, slope) / (pt.power(kappa, slope) + pt.power(x, slope)) @@ -978,7 +989,7 @@ def hill_saturation_sigmoid( beta: pt.TensorLike, lam: pt.TensorLike, ) -> pt.TensorVariable: - r"""Hill Saturation Sigmoid Function + r"""Hill Saturation Sigmoid Function. .. math:: f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}} - \frac{\sigma}{1 + e^{\beta\lambda}} @@ -1062,6 +1073,7 @@ def hill_saturation_sigmoid( ------- float or array-like The value of the Hill saturation sigmoid function for each input value of x. + """ return sigma / (1 + pt.exp(-beta * (x - lam))) - sigma / (1 + pt.exp(beta * lam)) diff --git a/pymc_marketing/mmm/tvp.py b/pymc_marketing/mmm/tvp.py index db434252..d5bdb1c3 100644 --- a/pymc_marketing/mmm/tvp.py +++ b/pymc_marketing/mmm/tvp.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Time Varying Gaussian Process Multiplier for Marketing Mix Modeling (MMM). +"""Time Varying Gaussian Process Multiplier for Marketing Mix Modeling (MMM). + Designed to model time-varying effects in marketing mix models (MMM). This module provides a time-varying Gaussian Process (GP) multiplier, @@ -20,7 +20,6 @@ Examples -------- - Create a basic PyMC model using the time-varying GP multiplier: .. code-block:: python @@ -136,8 +135,8 @@ def time_varying_prior( - Solin, A., Sarkka, S. (2019) Hilbert Space Methods for Reduced-Rank Gaussian Process Regression. - """ + """ if hsgp_kwargs is None: hsgp_kwargs = HSGPKwargs() @@ -193,13 +192,14 @@ def create_time_varying_gp_multiplier( Midpoint of the time points. time_resolution : int Resolution of time points. - hsgp_kwargsg : HSGPKwargs + hsgp_kwargs : HSGPKwargs Keyword arguments for the Hilbert Space Gaussian Process (HSGP) component. Returns ------- pt.TensorVariable Time-varying Gaussian Process multiplier for a given variable. + """ if hsgp_kwargs.L is None: hsgp_kwargs.L = time_index_mid + DAYS_IN_YEAR / time_resolution diff --git a/pymc_marketing/mmm/utils.py b/pymc_marketing/mmm/utils.py index e79352b2..69920115 100644 --- a/pymc_marketing/mmm/utils.py +++ b/pymc_marketing/mmm/utils.py @@ -31,8 +31,7 @@ def estimate_menten_parameters( contributions: xr.DataArray | Any, **kwargs, ) -> list[float]: - """ - Estimate the parameters for the Michaelis-Menten function using curve fitting. + """Estimate the parameters for the Michaelis-Menten function using curve fitting. This function extracts the relevant data for the specified channel from both the original_dataframe and contributions DataArray resulting from the model. @@ -48,11 +47,14 @@ def estimate_menten_parameters( The original DataFrame containing the channel data. contributions : xr.DataArray An xarray DataArray containing the contributions data, indexed by channel. + **kwargs : dict + Additional keyword arguments to pass to the curve_fit function. Returns ------- List[float] The estimated parameters of the extended sigmoid function. + """ maxfev = kwargs.get("maxfev", 5000) lam_initial_estimate = kwargs.get("lam_initial_estimate", 0.001) @@ -65,7 +67,7 @@ def estimate_menten_parameters( # Initial guess for L and k initial_guess = [alpha_initial_estimate, lam_initial_estimate] # Curve fitting - popt, pcov = curve_fit(michaelis_menten, x, y, p0=initial_guess, maxfev=maxfev) + popt, _ = curve_fit(michaelis_menten, x, y, p0=initial_guess, maxfev=maxfev) # Save the parameters return popt @@ -77,8 +79,7 @@ def estimate_sigmoid_parameters( contributions: xr.DataArray | Any, **kwargs, ) -> list[float]: - """ - Estimate the parameters for the sigmoid function using curve fitting. + """Estimate the parameters for the sigmoid function using curve fitting. This function extracts the relevant data for the specified channel from both the original_dataframe and contributions DataArray resulting from the model. @@ -99,6 +100,7 @@ def estimate_sigmoid_parameters( ------- List[float] The estimated parameters of the extended sigmoid function. + """ maxfev = kwargs.get("maxfev", 5000) lam_initial_estimate = kwargs.get("lam_initial_estimate", 0.00001) @@ -126,8 +128,7 @@ def compute_sigmoid_second_derivative( alpha: float | np.ndarray | npt.NDArray[np.float64], lam: float | np.ndarray | npt.NDArray[np.float64], ) -> float | Any: - """ - Compute the second derivative of the extended sigmoid function. + """Compute the second derivative of the extended sigmoid function. The second derivative of a function gives us information about the curvature of the function. In the context of the sigmoid function, it helps us identify the inflection point, which is @@ -146,8 +147,8 @@ def compute_sigmoid_second_derivative( ------- float The second derivative of the sigmoid function at the input value. - """ + """ return ( -alpha * lam**2 @@ -161,8 +162,7 @@ def find_sigmoid_inflection_point( alpha: float | np.ndarray | npt.NDArray[np.float64], lam: float | np.ndarray | npt.NDArray[np.float64], ) -> tuple[Any, float]: - """ - Find the inflection point of the extended sigmoid function. + """Find the inflection point of the extended sigmoid function. The inflection point of a function is the point where the function changes its curvature, i.e., it changes from being concave up to concave down, or vice versa. For the sigmoid @@ -179,8 +179,8 @@ def find_sigmoid_inflection_point( ------- tuple The x and y coordinates of the inflection point. - """ + """ # Minimize the negative of the absolute value of the second derivative result = minimize_scalar( lambda x: -abs(compute_sigmoid_second_derivative(x, alpha, lam)) @@ -199,18 +199,25 @@ def apply_sklearn_transformer_across_dim( dim_name: str, combined: bool = False, ) -> xr.DataArray: - """Helper function in order to use scikit-learn functions with the xarray target. + """Apply a scikit-learn transformer across a dimension of an xarray DataArray. + + Helper function in order to use scikit-learn functions with the xarray target. Parameters ---------- - data : - func : scikit-learn method to apply to the data - dim_name : Name of the dimension to apply the function to - combined : Flag to indicate if the data coords have been combined or not + data : xr.DataArray + The input data to transform. + func : Callable[[np.ndarray], np.ndarray] + scikit-learn method to apply to the data + dim_name : str + Name of the dimension to apply the function to + combined : bool, default False + Flag to indicate if the data coords have been combined or not Returns ------- xr.DataArray + """ # These are lost during the ufunc attrs = data.attrs @@ -250,6 +257,7 @@ def transform_1d_array( ------- np.ndarray The transformed data. + """ return transform(np.array(y)[:, None]).flatten() @@ -259,16 +267,19 @@ def sigmoid_saturation( alpha: float | np.ndarray | npt.NDArray[np.float64], lam: float | np.ndarray | npt.NDArray[np.float64], ) -> float | Any: - """ + """Sigmoid saturation function. + Parameters ---------- - alpha + x : float or np.ndarray + The input value for which the function is to be computed. + alpha : float or np.ndarray α (alpha): Represent the Asymptotic Maximum or Ceiling Value. - lam + lam : float or np.ndarray λ (lambda): affects how quickly the function approaches its upper and lower asymptotes. A higher value of lam makes the curve steeper, while a lower value makes it more gradual. - """ + """ if alpha <= 0 or lam <= 0: raise ValueError("alpha and lam must be greater than 0") @@ -322,7 +333,7 @@ def create_new_spend_data( Parameters - --------- + ---------- spend : np.ndarray The spend data for the channels. adstock_max_lag : int @@ -336,6 +347,7 @@ def create_new_spend_data( ------- np.ndarray The new spend data for the channel forward pass. + """ n_channels = len(spend) @@ -364,8 +376,7 @@ def create_new_spend_data( def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray: - """ - Remove scalar coordinates from an xarray DataArray. + """Remove scalar coordinates from an xarray DataArray. This function identifies and removes scalar coordinates from the given DataArray. Scalar coordinates are those with a single value that are @@ -381,6 +392,7 @@ def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray: ------- xr.DataArray A new DataArray with the identified scalar coordinates removed. + """ scalar_coords_to_drop = [] for coord, values in curve.coords.items(): diff --git a/pymc_marketing/mmm/validating.py b/pymc_marketing/mmm/validating.py index 2691f554..8eeaf37d 100644 --- a/pymc_marketing/mmm/validating.py +++ b/pymc_marketing/mmm/validating.py @@ -28,6 +28,7 @@ def validation_method_y(method: Callable) -> Callable: + """Tag a method as a validation method for the target column.""" if not hasattr(method, "_tags"): method._tags = {} # type: ignore method._tags["validation_y"] = True # type: ignore @@ -35,6 +36,7 @@ def validation_method_y(method: Callable) -> Callable: def validation_method_X(method: Callable) -> Callable: + """Tag a method as a validation method for the predictor columns.""" if not hasattr(method, "_tags"): method._tags = {} # type: ignore method._tags["validation_X"] = True # type: ignore @@ -42,17 +44,43 @@ def validation_method_X(method: Callable) -> Callable: class ValidateTargetColumn: + """Validate the target column.""" + @validation_method_y def validate_target(self, data: pd.Series) -> None: + """Validate the target column. + + Parameters + ---------- + data : pd.Series + The data to validate. + + Raises + ------ + ValueError: If the target column is not valid. + """ if len(data) == 0: raise ValueError("y must have at least one element") class ValidateDateColumn: + """Validate the date column.""" + date_column: str @validation_method_X def validate_date_col(self, data: pd.DataFrame) -> None: + """Validate the date column. + + Parameters + ---------- + data : pd.DataFrame + The data to validate. + + Raises + ------ + ValueError: If the date column is not valid. + """ if self.date_column not in data.columns: raise ValueError(f"date_col {self.date_column} not in data") if not data[self.date_column].is_unique: @@ -60,10 +88,23 @@ def validate_date_col(self, data: pd.DataFrame) -> None: class ValidateChannelColumns: + """Validate the channel columns.""" + channel_columns: list[str] | tuple[str] @validation_method_X def validate_channel_columns(self, data: pd.DataFrame) -> None: + """Validate the channel columns. + + Parameters + ---------- + data : pd.DataFrame + The data to validate. + + Raises + ------ + ValueError: If the channel columns are not valid. + """ if not isinstance(self.channel_columns, list | tuple): raise ValueError("channel_columns must be a list or tuple") if len(self.channel_columns) == 0: @@ -81,10 +122,23 @@ def validate_channel_columns(self, data: pd.DataFrame) -> None: class ValidateControlColumns: + """Validate the control columns.""" + control_columns: list[str] | None @validation_method_X def validate_control_columns(self, data: pd.DataFrame) -> None: + """Validate the control columns. + + Parameters + ---------- + data : pd.DataFrame + The data to validate. + + Raises + ------ + ValueError: If the control columns are not valid. + """ if self.control_columns is None: return None if not isinstance(self.control_columns, list | tuple): diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 496d6321..d243b030 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -39,15 +39,18 @@ except ImportError: def check_X_y(X, y, **kwargs): + """Check if the input data is valid for the model.""" return X, y def check_array(X, **kwargs): + """Check if the input data is valid for the model.""" return X class ModelBuilder(ABC): - """ - ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models + """Base class for building models with PyMC Marketing. + + It provides an easy-to-use API (similar to scikit-learn) for models and help with deployment. """ @@ -62,8 +65,7 @@ def __init__( model_config: dict | None = None, sampler_config: dict | None = None, ): - """ - Initializes model configuration and sampler configuration for the model + """Initialize model configuration and sampler configuration for the model. Parameters ---------- @@ -79,6 +81,7 @@ def __init__( >>> class MyModel(ModelBuilder): >>> ... >>> model = MyModel(model_config, sampler_config) + """ if sampler_config is None: sampler_config = {} @@ -108,8 +111,7 @@ def _data_setter( X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series | None = None, ) -> None: - """ - Sets new data in the model. + """Set new data in the model. Parameters ---------- @@ -118,8 +120,8 @@ def _data_setter( y : array, shape (n_obs,) The target values (real numbers). - Returns: - ---------- + Returns + ------- None Examples @@ -137,20 +139,21 @@ def _data_setter( @property @abstractmethod def output_var(self) -> str: - """ - Returns the name of the output variable of the model. + """Returns the name of the output variable of the model. Returns ------- output_var : str Name of the output variable of the model. + """ @property @abstractmethod def default_model_config(self) -> dict: - """ - Returns a class default config dict for model builder if no model_config is provided on class initialization + """Return a class default configuration dictionary. + + For model builder if no model_config is provided on class initialization Useful for understanding structure of required model_config to allow its customization by users Examples @@ -173,14 +176,17 @@ def default_model_config(self) -> dict: ------- model_config : dict A set of default parameters for predictor distributions that allow to save and recreate the model. + """ @property @abstractmethod def default_sampler_config(self) -> dict: - """ - Returns a class default sampler dict for model builder if no sampler_config is provided on class initialization + """Return a class default sampler configuration dictionary. + + For model builder if no sampler_config is provided on class initialization Useful for understanding structure of required sampler_config to allow its customization by users + Examples -------- >>> @classmethod @@ -196,21 +202,23 @@ def default_sampler_config(self) -> dict: ------- sampler_config : dict A set of default settings for used by model in fit process. + """ @abstractmethod def _generate_and_preprocess_model_data( self, X: pd.DataFrame | pd.Series, y: np.ndarray ) -> None: - """ - Applies preprocessing to the data before fitting the model. + """Apply preprocessing to the data before fitting the model. + if validate is True, it will check if the data is valid for the model. sets self.model_coords based on provided dataset In case of optional parameters being passed into the model, this method should implement the conditional logic responsible for correct handling of the optional parameters, and including them into the dataset. - Parameters: + Parameters + ---------- X : array, shape (n_obs, n_features) y : array, shape (n_obs,) @@ -237,9 +245,9 @@ def build_model( y: pd.Series | np.ndarray, **kwargs, ) -> None: - """ - Creates an instance of pm.Model based on provided data and model_config, and - attaches it to self. + """Create an instance of `pm.Model` based on provided data and model_config. + + It attaches the model to self.model. Parameters ---------- @@ -263,9 +271,18 @@ def build_model( Returns ------- None + """ def create_idata_attrs(self) -> dict[str, str]: + """Create attributes for the inference data. + + Returns + ------- + dict[str, str] + A dictionary of attributes for the inference data. + """ + def default(x): if isinstance(x, Prior): return x.to_json() @@ -289,8 +306,7 @@ def default(x): def set_idata_attrs( self, idata: az.InferenceData | None = None ) -> az.InferenceData: - """ - Set attributes on an InferenceData object. + """Set attributes on an InferenceData object. Parameters ---------- @@ -355,8 +371,7 @@ def set_idata_attrs( return idata def save(self, fname: str) -> None: - """ - Save the model's inference data to a file. + """Save the model's inference data to a file. Parameters ---------- @@ -383,6 +398,7 @@ def save(self, fname: str) -> None: >>> model = MyModel() >>> model.fit(X,y) >>> model.save('model_results.nc') # This will call the overridden method in MyModel + """ if self.idata is not None and "posterior" in self.idata: file = Path(str(fname)) @@ -392,7 +408,8 @@ def save(self, fname: str) -> None: @classmethod def _model_config_formatting(cls, model_config: dict) -> dict: - """ + """Format the model configuration. + Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists. This function converts them back to tuples and numpy arrays to ensure correct id encoding. @@ -415,6 +432,7 @@ def _model_config_formatting(cls, model_config: dict) -> dict: @classmethod def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: + """Convert the model configuration and sampler configuration from the attributes to keyword arguments.""" return { "model_config": cls._model_config_formatting( json.loads(attrs["model_config"]) @@ -424,8 +442,7 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: @classmethod def load(cls, fname: str): - """ - Creates a ModelBuilder instance from a file, + """Create a ModelBuilder instance from a file. Loads inference data for the model. @@ -489,8 +506,8 @@ def fit( random_seed: RandomState | None = None, **kwargs: Any, ) -> az.InferenceData: - """ - Fit a model using the data passed as a parameter. + """Fit a model using the data passed as a parameter. + Sets attrs to inference data of the model. Parameters @@ -521,6 +538,7 @@ def fit( >>> idata = model.fit(X,y) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... + """ if predictor_names is None: predictor_names = [] @@ -573,9 +591,9 @@ def predict( extend_idata: bool = True, **kwargs, ) -> np.ndarray: - """ - Uses model to predict on unseen data and return point prediction of all the samples. The point prediction - for each input row is the expected output value, computed as the mean of MCMC samples. + """Use a model to predict on unseen data and return point prediction of all the samples. + + The point prediction for each input row is the expected output value, computed as the mean of MCMC samples. Parameters ---------- @@ -598,8 +616,8 @@ def predict( >>> x_pred = [] >>> prediction_data = pd.DataFrame({'input':x_pred}) >>> pred_mean = model.predict(prediction_data) - """ + """ posterior_predictive_samples = self.sample_posterior_predictive( X_pred, extend_idata, combined=False, **kwargs ) @@ -623,8 +641,7 @@ def sample_prior_predictive( combined: bool = True, **kwargs, ): - """ - Sample from the model's prior predictive distribution. + """Sample from the model's prior predictive distribution. Parameters ---------- @@ -645,6 +662,7 @@ def sample_prior_predictive( ------- prior_predictive_samples : DataArray, shape (n_pred, samples) Prior predictive samples for each input X_pred + """ if y_pred is None: y_pred = np.zeros(len(X_pred)) @@ -677,8 +695,7 @@ def sample_posterior_predictive( combined: bool = True, **sample_posterior_predictive_kwargs, ): - """ - Sample from the model's posterior predictive distribution. + """Sample from the model's posterior predictive distribution. Parameters ---------- @@ -696,6 +713,7 @@ def sample_posterior_predictive( ------- posterior_predictive_samples : DataArray, shape (n_pred, samples) Posterior predictive samples for each input X_pred + """ self._data_setter(X_pred) @@ -716,32 +734,29 @@ def sample_posterior_predictive( return az.extract(post_pred, variable_name, combined=combined) def get_params(self, deep=True): - """ - Get all the model parameters needed to instantiate a copy of the model, not including training data. - """ + """Get all the model parameters needed to instantiate a copy of the model, not including training data.""" return { "model_config": self.model_config, "sampler_config": self.sampler_config, } def set_params(self, **params): - """ - Set all the model parameters needed to instantiate the model, not including training data. - """ + """Set all the model parameters needed to instantiate the model, not including training data.""" self.model_config = params["model_config"] self.sampler_config = params["sampler_config"] @property @abstractmethod def _serializable_model_config(self) -> dict[str, int | float | dict]: - """ - Converts non-serializable values from model_config to their serializable reversable equivalent. + """Converts non-serializable values from model_config to their serializable reversable equivalent. + Data types like pandas DataFrame, Series or datetime aren't JSON serializable, so in order to save the model they need to be formatted. Returns ------- model_config: dict + """ def predict_proba( @@ -761,8 +776,7 @@ def predict_posterior( combined: bool = True, **kwargs, ) -> xr.DataArray: - """ - Generate posterior predictive samples on unseen data. + """Generate posterior predictive samples on unseen data. Parameters ---------- @@ -781,8 +795,8 @@ def predict_posterior( y_pred : DataArray Posterior predictive samples for each input X_pred. Shape is (n_pred, chains * draws) if combined is True, otherwise (chains, draws, n_pred). - """ + """ X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( X_pred, extend_idata, combined, **kwargs @@ -797,8 +811,7 @@ def predict_posterior( @property def id(self) -> str: - """ - Generate a unique hash value for the model. + """Generate a unique hash value for the model. The hash value is created using the last 16 characters of the SHA256 hash encoding, based on the model configuration, version, and model type. @@ -813,6 +826,7 @@ def id(self) -> str: >>> model = MyModel() >>> model.id '0123456789abcdef' + """ hasher = hashlib.sha256() hasher.update(str(self.model_config.values()).encode()) diff --git a/pymc_marketing/paths.py b/pymc_marketing/paths.py index 2b63b421..dcc8f22e 100644 --- a/pymc_marketing/paths.py +++ b/pymc_marketing/paths.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Paths for the project.""" + from pyprojroot import here root = here() diff --git a/pymc_marketing/prior.py b/pymc_marketing/prior.py index 18cd419f..b2ad0106 100644 --- a/pymc_marketing/prior.py +++ b/pymc_marketing/prior.py @@ -172,7 +172,7 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa def create_dim_handler(desired_dims: Dims) -> DimHandler: - """Wrapper to act like the previous `create_dim_handler` function.""" + """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function.""" def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable: return handle_dims(x, dims, desired_dims) @@ -312,6 +312,7 @@ def dims(self, dims) -> None: self._unique_dims() def __getitem__(self, key: str) -> Prior | Any: + """Return the parameter of the prior.""" return self.parameters[key] def _checks(self) -> None: @@ -395,6 +396,7 @@ def _param_dims_work(self) -> None: ) def __str__(self) -> str: + """Return a string representation of the prior.""" param_str = ", ".join( [f"{param}={value}" for param, value in self.parameters.items()] ) @@ -406,6 +408,7 @@ def __str__(self) -> str: return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})' def __repr__(self) -> str: + """Return a string representation of the prior.""" return f"{self}" def _create_parameter(self, param, value, name): @@ -706,7 +709,6 @@ def constrain(self, lower: float, upper: float, **kwargs) -> Prior: ).constrain(lower=0.5, upper=0.8) """ - if self.transform: raise ValueError("Can't constrain a transformed variable") @@ -727,6 +729,7 @@ def constrain(self, lower: float, upper: float, **kwargs) -> Prior: ) def __eq__(self, other) -> bool: + """Check if two priors are equal.""" if not isinstance(other, Prior): return False @@ -790,6 +793,7 @@ def sample_prior( return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior def __deepcopy__(self, memo) -> Prior: + """Return a deep copy of the prior.""" if id(self) in memo: return memo[id(self)] diff --git a/pymc_marketing/utils.py b/pymc_marketing/utils.py index 1efd5eff..de2c8805 100644 --- a/pymc_marketing/utils.py +++ b/pymc_marketing/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Utility functions for PyMC Marketing.""" + import warnings from pathlib import Path @@ -18,6 +20,18 @@ def from_netcdf(filepath: str | Path) -> az.InferenceData: + """Load inference data from a netcdf file. + + Parameters + ---------- + filepath : str or Path + The path to the netcdf file. + + Returns + ------- + az.InferenceData + The inference data. + """ with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/pymc_marketing/version.py b/pymc_marketing/version.py index dc954610..0da3b9e8 100644 --- a/pymc_marketing/version.py +++ b/pymc_marketing/version.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Version of the package.""" + import os here = os.path.dirname(os.path.realpath(__file__)) -def read_version(): +def read_version() -> str: + """Read the version from the version file.""" version_file = os.path.join(here, "version.txt") with open(version_file, encoding="utf-8") as buff: return buff.read().splitlines()[0] diff --git a/pyproject.toml b/pyproject.toml index 1c452ccc..5c8efd50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,8 @@ repository = "https://github.com/pymc-labs/pymc-marketing" #changelog = "" [tool.ruff.lint] -select = ["B", "E", "F", "I", "RUF", "S", "UP", "W"] +select = ["B", "D", "DOC", "E", "F", "I", "RUF", "S", "UP", "W"] +pydocstyle.convention = "numpy" ignore = [ "B008", # Do not perform calls in argument defaults (this is ok with Field from pydantic) "B904", # raise-without-from-inside-except @@ -110,11 +111,14 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "docs/source/notebooks/*" = [ "B018", # Checks for "useless" expressions. Not useful for notebooks. + "D103", # Missing docstring in public function. ] "tests/*" = [ "B018", # Checks for "useless" expressions. This is useful for tests. + "D", "S101", # Use of assert ] +"scripts/*" = ["D"] [tool.ruff.lint.pycodestyle] max-line-length = 120 diff --git a/streamlit/mmm-explainer/Visualise_Priors.py b/streamlit/mmm-explainer/Visualise_Priors.py index f3fa7c2d..0f85ab8b 100644 --- a/streamlit/mmm-explainer/Visualise_Priors.py +++ b/streamlit/mmm-explainer/Visualise_Priors.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Import custom functions +"""Streamlit page for visualising priors.""" + import prior_functions as pf import streamlit as st diff --git a/streamlit/mmm-explainer/pages/Adstock.py b/streamlit/mmm-explainer/pages/Adstock.py index 4a6597ca..24dbb9f2 100644 --- a/streamlit/mmm-explainer/pages/Adstock.py +++ b/streamlit/mmm-explainer/pages/Adstock.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Import custom functions +"""Streamlit page for adstock transformations.""" + import numpy as np import pandas as pd import plotly.express as px diff --git a/streamlit/mmm-explainer/pages/Saturation.py b/streamlit/mmm-explainer/pages/Saturation.py index 92193e65..80b66454 100644 --- a/streamlit/mmm-explainer/pages/Saturation.py +++ b/streamlit/mmm-explainer/pages/Saturation.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Import custom functions +"""Streamlit page for saturation curves.""" + import numpy as np import pandas as pd import plotly.graph_objects as go diff --git a/streamlit/mmm-explainer/prior_functions.py b/streamlit/mmm-explainer/prior_functions.py index f976b224..5778ae1d 100644 --- a/streamlit/mmm-explainer/prior_functions.py +++ b/streamlit/mmm-explainer/prior_functions.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Functions for plotting prior distributions.""" + # Imports import numpy as np import plotly.express as px @@ -23,15 +25,17 @@ @st.cache_data # 👈 Add the caching decorator, make app run faster def get_distribution(distribution_name=pz.distributions, **params): - """ - Retrieve and create a distribution instance from the PreliZ library. + """Retrieve and create a distribution instance from the PreliZ library. - Parameters: + Parameters + ---------- distribution_name (str): The name of the distribution to create. **params: Variable length dict of parameters and values required by the distribution. - Returns: + Returns + ------- object: An instance of the requested distribution. + """ try: # Get the distribution class from preliz @@ -49,15 +53,18 @@ def get_distribution(distribution_name=pz.distributions, **params): def plot_prior_distribution( draws, nbins=100, opacity=0.1, title="Prior Distribution - Visualised" ): - """ - Plots samples of a prior distribution as a histogram with a KDE (Kernel Density Estimate) overlay - and a violin plot along the top too with quartile values. + """Plot samples of a prior distribution as a histogram. + + It uses a KDE (Kernel Density Estimate) overlay and a violin plot along the top too + with quartile values. - Parameters: + Parameters + ---------- - draws: numpy array of samples from prior distribution. - nbins: int, the number of bins for the histogram. - opacity: float, the opacity level for the histogram bars. - title: str, the title of the plot. + """ # Create the histogram using Plotly Express fig = px.histogram( diff --git a/tests/clv/models/test_beta_geo.py b/tests/clv/models/test_beta_geo.py index 8aed0925..f643170e 100644 --- a/tests/clv/models/test_beta_geo.py +++ b/tests/clv/models/test_beta_geo.py @@ -184,9 +184,7 @@ def test_customer_id_duplicate(self): def test_numerically_stable_logp( self, frequency, recency, logp_value, model_config ): - """ - See Solution #2 on pages 3 and 4 of http://brucehardie.com/notes/027/bgnbd_num_error.pdf - """ + """See Solution #2 on pages 3 and 4 of http://brucehardie.com/notes/027/bgnbd_num_error.pdf""" model_config = { "a_prior": Prior("Flat"), "b_prior": Prior("Flat"), diff --git a/tests/conftest.py b/tests/conftest.py index 696f6235..0473e3cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,8 +56,7 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope="module") def cdnow_trans() -> pd.DataFrame: - """ - Load CDNOW_sample transaction data into a Pandas dataframe. + """Load CDNOW_sample transaction data into a Pandas dataframe. Data source: https://www.brucehardie.com/datasets/ """ diff --git a/tests/mmm/test_preprocessing.py b/tests/mmm/test_preprocessing.py index 47172918..42458c31 100644 --- a/tests/mmm/test_preprocessing.py +++ b/tests/mmm/test_preprocessing.py @@ -52,7 +52,7 @@ def test_preprocessing_method(): assert vf.__name__ == f.__name__ def f2(x): - """bla""" + """Bla""" return x vf = preprocessing_method_X(f2) @@ -63,14 +63,14 @@ def f2(x): class F: @preprocessing_method_X def f3(self, x): - """bla""" + """Bla""" return x vf = F().f3 assert getattr(vf, "_tags", {}).get("preprocessing_X", False) assert F.f3.__doc__ == vf.__doc__ assert F.f3.__name__ == vf.__name__ - assert vf.__doc__ == "bla" + assert vf.__doc__ == "Bla" assert vf.__name__ == "f3" diff --git a/tests/mmm/test_validating.py b/tests/mmm/test_validating.py index b3f5a44a..36b09036 100644 --- a/tests/mmm/test_validating.py +++ b/tests/mmm/test_validating.py @@ -54,7 +54,7 @@ def test_validation_method(): assert vf.__name__ == f.__name__ def f2(x): - """bla""" + """Bla""" return x vf = validation_method_X(f2) @@ -65,14 +65,14 @@ def f2(x): class F: @validation_method_X def f3(self, x): - """bla""" + """Bla""" return x vf = F().f3 assert getattr(vf, "_tags", {}).get("validation_X", False) assert F.f3.__doc__ == vf.__doc__ assert F.f3.__name__ == vf.__name__ - assert vf.__doc__ == "bla" + assert vf.__doc__ == "Bla" assert vf.__name__ == "f3"