Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Aug 20, 2024
1 parent e57073f commit 60ddf3e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
40 changes: 37 additions & 3 deletions pymc_marketing/clv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def customer_lifetime_value(
discount_rate: float = 0.00,
time_unit: str = "D",
) -> xarray.DataArray:
"""Compute the average lifetime value for a group of one or more customers,
"""
Compute customer lifetime value.
Compute the average lifetime value for a group of one or more customers
and apply a discount rate for net present value estimations.
Note `future_t` is measured in months regardless of `time_unit` specified.
Expand Down Expand Up @@ -86,7 +89,21 @@ def customer_lifetime_value(
raise ValueError("Required column future_spend missing")

def _squeeze_dims(x: xarray.DataArray):
"""This utility is required for MAP-fitted model predictions to broadcast properly"""
"""
Squeeze dimensions for MAP-fitted model predictions.
This utility is required for MAP-fitted model predictions to broadcast properly.
Parameters
----------
x : xarray.DataArray
DataArray to squeeze dimensions for.
Returns
-------
xarray.DataArray
DataArray with squeezed dimensions.
"""
dims_to_squeeze: tuple[str, ...] = ()
if "chain" in x.dims and len(x.chain) == 1:
dims_to_squeeze += ("chain",)
Expand Down Expand Up @@ -421,6 +438,7 @@ def rfm_train_test_split(
sort_transactions: bool | None = True,
) -> pandas.DataFrame:
"""Summarize transaction data and split into training and tests datasets for CLV modeling.
This can also be used to evaluate the impact of a time-based intervention like a marketing campaign.
This transforms a DataFrame of transaction data of the form:
Expand Down Expand Up @@ -713,7 +731,23 @@ def rfm_segments(


def _rfm_quartile_labels(column_name, max_label_range):
"""Called internally by rfm_segments to label quartiles for each variable"""
"""
Label quartiles for each variable.
Called internally by rfm_segments to label quartiles for each variable.
Parameters
----------
column_name : str
The name of the column to label.
max_label_range : int
The maximum range of labels to create.
Returns
-------
list[int]
A list of labels for the column.
"""
# recency labels must be reversed because lower values are more desirable
if column_name == "r_quartile":
return list(range(max_label_range - 1, 0, -1))
Expand Down
12 changes: 12 additions & 0 deletions pymc_marketing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@


def from_netcdf(filepath: str | Path) -> az.InferenceData:
"""Load inference data from a netcdf file.
Parameters
----------
filepath : str or Path
The path to the netcdf file.
Returns
-------
az.InferenceData
The inference data.
"""
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down

0 comments on commit 60ddf3e

Please sign in to comment.