diff --git a/pymc_marketing/clv/utils.py b/pymc_marketing/clv/utils.py index 5c60564e..9633a16d 100644 --- a/pymc_marketing/clv/utils.py +++ b/pymc_marketing/clv/utils.py @@ -47,7 +47,10 @@ def customer_lifetime_value( discount_rate: float = 0.00, time_unit: str = "D", ) -> xarray.DataArray: - """Compute the average lifetime value for a group of one or more customers, + """ + Compute customer lifetime value. + + Compute the average lifetime value for a group of one or more customers and apply a discount rate for net present value estimations. Note `future_t` is measured in months regardless of `time_unit` specified. @@ -86,7 +89,21 @@ def customer_lifetime_value( raise ValueError("Required column future_spend missing") def _squeeze_dims(x: xarray.DataArray): - """This utility is required for MAP-fitted model predictions to broadcast properly""" + """ + Squeeze dimensions for MAP-fitted model predictions. + + This utility is required for MAP-fitted model predictions to broadcast properly. + + Parameters + ---------- + x : xarray.DataArray + DataArray to squeeze dimensions for. + + Returns + ------- + xarray.DataArray + DataArray with squeezed dimensions. + """ dims_to_squeeze: tuple[str, ...] = () if "chain" in x.dims and len(x.chain) == 1: dims_to_squeeze += ("chain",) @@ -421,6 +438,7 @@ def rfm_train_test_split( sort_transactions: bool | None = True, ) -> pandas.DataFrame: """Summarize transaction data and split into training and tests datasets for CLV modeling. + This can also be used to evaluate the impact of a time-based intervention like a marketing campaign. This transforms a DataFrame of transaction data of the form: @@ -713,7 +731,23 @@ def rfm_segments( def _rfm_quartile_labels(column_name, max_label_range): - """Called internally by rfm_segments to label quartiles for each variable""" + """ + Label quartiles for each variable. + + Called internally by rfm_segments to label quartiles for each variable. + + Parameters + ---------- + column_name : str + The name of the column to label. + max_label_range : int + The maximum range of labels to create. + + Returns + ------- + list[int] + A list of labels for the column. + """ # recency labels must be reversed because lower values are more desirable if column_name == "r_quartile": return list(range(max_label_range - 1, 0, -1)) diff --git a/pymc_marketing/utils.py b/pymc_marketing/utils.py index 1efd5eff..744fdb0b 100644 --- a/pymc_marketing/utils.py +++ b/pymc_marketing/utils.py @@ -18,6 +18,18 @@ def from_netcdf(filepath: str | Path) -> az.InferenceData: + """Load inference data from a netcdf file. + + Parameters + ---------- + filepath : str or Path + The path to the netcdf file. + + Returns + ------- + az.InferenceData + The inference data. + """ with warnings.catch_warnings(): warnings.filterwarnings( "ignore",