diff --git a/docs/source/notebooks/clv/dev/utilities_plotting.ipynb b/docs/source/notebooks/clv/dev/utilities_plotting.ipynb
new file mode 100644
index 00000000..a797237e
--- /dev/null
+++ b/docs/source/notebooks/clv/dev/utilities_plotting.ipynb
@@ -0,0 +1,500 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "435ed203-5c3c-4efc-93d1-abac66ce7187",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pymc_marketing.clv import utils\n",
+ "\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ce561a65-e600-42de-84b6-f3c683729fff",
+ "metadata": {},
+ "source": [
+ "Create a simple dataset for testing:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "7de7f396-1d5b-4457-916b-c29ed90aa132",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d = [\n",
+ " [1, \"2015-01-01\", 1],\n",
+ " [1, \"2015-02-06\", 2],\n",
+ " [2, \"2015-01-01\", 2],\n",
+ " [3, \"2015-01-01\", 3],\n",
+ " [3, \"2015-01-02\", 1],\n",
+ " [3, \"2015-01-05\", 5],\n",
+ " [4, \"2015-01-16\", 6],\n",
+ " [4, \"2015-02-02\", 3],\n",
+ " [4, \"2015-02-05\", 3],\n",
+ " [4, \"2015-02-05\", 2],\n",
+ " [5, \"2015-01-16\", 3],\n",
+ " [5, \"2015-01-17\", 1],\n",
+ " [5, \"2015-01-18\", 8],\n",
+ " [6, \"2015-02-02\", 5],\n",
+ "]\n",
+ "test_data = pd.DataFrame(d, columns=[\"id\", \"date\", \"monetary_value\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b089a2be-2c3e-4dd1-b96d-ee7c0bd02250",
+ "metadata": {},
+ "source": [
+ "Note customer 4 made two purchases on 2015-02-05. \n",
+ "\n",
+ "`_find_first_transactions` flags the first purchase each customer has made, which must be excluded for modeling. It is called internally by `rfm_summary`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "id": "932e8db6-78cf-49df-aa4a-83ee6584e5dd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " date | \n",
+ " first | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2015-01-01 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2015-02-06 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2015-01-01 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 2015-01-01 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 3 | \n",
+ " 2015-01-02 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 3 | \n",
+ " 2015-01-05 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 4 | \n",
+ " 2015-01-16 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 4 | \n",
+ " 2015-02-02 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 4 | \n",
+ " 2015-02-05 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 5 | \n",
+ " 2015-01-16 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 5 | \n",
+ " 2015-01-17 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 5 | \n",
+ " 2015-01-18 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 6 | \n",
+ " 2015-02-02 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id date first\n",
+ "0 1 2015-01-01 True\n",
+ "1 1 2015-02-06 False\n",
+ "2 2 2015-01-01 True\n",
+ "3 3 2015-01-01 True\n",
+ "4 3 2015-01-02 False\n",
+ "5 3 2015-01-05 False\n",
+ "6 4 2015-01-16 True\n",
+ "7 4 2015-02-02 False\n",
+ "8 4 2015-02-05 False\n",
+ "10 5 2015-01-16 True\n",
+ "11 5 2015-01-17 False\n",
+ "12 5 2015-01-18 False\n",
+ "13 6 2015-02-02 True"
+ ]
+ },
+ "execution_count": 70,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "utils._find_first_transactions(\n",
+ " transactions=test_data, \n",
+ " customer_id_col = \"id\", \n",
+ " datetime_col = \"date\",\n",
+ " #monetary_value_col = \"monetary_value\", \n",
+ " #datetime_format = \"%Y%m%d\",\n",
+ ").reindex()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cd77dcbe-6990-4784-9960-9fc2b52e90f0",
+ "metadata": {},
+ "source": [
+ "Notice how **9** is missing from the dataframe index. Multiple transactions in the same time period are treated as a single purchase, so the indices for those additional transactions are skipped. \n",
+ "\n",
+ "`rfm_summary` is the primary data preprocessing step for CLV modeling in the continuous, non-contractual domain:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "4c0a7de5-8825-40af-84e5-6cd0ad26a0e3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " customer_id | \n",
+ " frequency | \n",
+ " recency | \n",
+ " T | \n",
+ " monetary_value | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1.0 | \n",
+ " 5.0 | \n",
+ " 5.0 | \n",
+ " 2.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 5.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 1.0 | \n",
+ " 3.0 | \n",
+ " 3.0 | \n",
+ " 8.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 3.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " customer_id frequency recency T monetary_value\n",
+ "0 1 1.0 5.0 5.0 2.0\n",
+ "1 2 0.0 0.0 5.0 0.0\n",
+ "2 3 1.0 1.0 5.0 5.0\n",
+ "3 4 1.0 3.0 3.0 8.0\n",
+ "4 5 0.0 0.0 3.0 0.0"
+ ]
+ },
+ "execution_count": 74,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rfm_df = utils.rfm_summary(\n",
+ " test_data, \n",
+ " customer_id_col = \"id\", \n",
+ " datetime_col = \"date\", \n",
+ " monetary_value_col = \"monetary_value\",\n",
+ " observation_period_end = \"2015-02-06\",\n",
+ " datetime_format = \"%Y-%m-%d\",\n",
+ " time_unit = \"W\",\n",
+ " include_first_transaction=False,\n",
+ ")\n",
+ "\n",
+ "rfm_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "aa8a6479-04fd-48ec-a34a-817b6fdff93c",
+ "metadata": {},
+ "source": [
+ "For MAP fits and covariate models, `rfm_train_test_split` can be used to evaluate models on unseen data. It is also useful for identifying the impact of a time-based event like a marketing campaign."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "id": "761edfe9-1b69-4966-83bf-4f1242eda2d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " customer_id | \n",
+ " frequency | \n",
+ " recency | \n",
+ " T | \n",
+ " monetary_value | \n",
+ " test_frequency | \n",
+ " test_monetary_value | \n",
+ " test_T | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 31.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 2.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 31.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 2.0 | \n",
+ " 4.0 | \n",
+ " 31.0 | \n",
+ " 3.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 16.0 | \n",
+ " 0.0 | \n",
+ " 2.0 | \n",
+ " 4.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 2.0 | \n",
+ " 2.0 | \n",
+ " 16.0 | \n",
+ " 4.5 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " customer_id frequency recency T monetary_value test_frequency \\\n",
+ "0 1 0.0 0.0 31.0 0.0 1.0 \n",
+ "1 2 0.0 0.0 31.0 0.0 0.0 \n",
+ "2 3 2.0 4.0 31.0 3.0 0.0 \n",
+ "3 4 0.0 0.0 16.0 0.0 2.0 \n",
+ "4 5 2.0 2.0 16.0 4.5 0.0 \n",
+ "\n",
+ " test_monetary_value test_T \n",
+ "0 2.0 5.0 \n",
+ "1 0.0 5.0 \n",
+ "2 0.0 5.0 \n",
+ "3 4.0 5.0 \n",
+ "4 0.0 5.0 "
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_test = utils.rfm_train_test_split(\n",
+ " test_data, \n",
+ " customer_id_col = \"id\", \n",
+ " datetime_col = \"date\", \n",
+ " train_period_end = \"2015-02-01\",\n",
+ " monetary_value_col = \"monetary_value\",\n",
+ ")\n",
+ "\n",
+ "train_test.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c7b3f800-8dfb-4e5a-b939-5f908281563c",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pymc_marketing/clv/utils.py b/pymc_marketing/clv/utils.py
index 307b82d3..111e8a69 100644
--- a/pymc_marketing/clv/utils.py
+++ b/pymc_marketing/clv/utils.py
@@ -1,12 +1,18 @@
import warnings
-from datetime import datetime
+from datetime import date, datetime
from typing import Optional, Union
import numpy as np
import pandas as pd
import xarray
+from numpy import datetime64
-__all__ = ["to_xarray", "customer_lifetime_value", "rfm_summary"]
+__all__ = [
+ "to_xarray",
+ "customer_lifetime_value",
+ "rfm_summary",
+ "rfm_train_test_split",
+]
def to_xarray(customer_id, *arrays, dim: str = "customer_id"):
@@ -257,9 +263,9 @@ def _find_first_transactions(
period_transactions.loc[first_transactions, "first"] = True
select_columns.append("first")
# reset datetime_col to period
- period_transactions.loc[:, datetime_col] = period_transactions[
- datetime_col
- ].dt.to_period(time_unit)
+ period_transactions[datetime_col] = period_transactions[datetime_col].dt.to_period(
+ time_unit
+ )
return period_transactions[select_columns]
@@ -410,3 +416,174 @@ def rfm_summary(
)
return summary_df
+
+
+def rfm_train_test_split(
+ transactions: pd.DataFrame,
+ customer_id_col: str,
+ datetime_col: str,
+ train_period_end: Union[Union[float, str], datetime, datetime64, date],
+ test_period_end: Optional[
+ Union[Union[float, str], datetime, datetime64, date]
+ ] = None,
+ time_unit: str = "D",
+ time_scaler: Optional[float] = 1,
+ datetime_format: Optional[str] = None,
+ monetary_value_col: Optional[str] = None,
+ include_first_transaction: Optional[bool] = False,
+ sort_transactions: Optional[bool] = True,
+) -> pd.DataFrame:
+ """
+ Summarize transaction data and split into training and tests datasets for CLV modeling.
+ This can also be used to evaluate the impact of a time-based intervention like a marketing campaign.
+
+ This transforms a DataFrame of transaction data of the form:
+ customer_id, datetime [, monetary_value]
+ to a DataFrame of the form:
+ customer_id, frequency, recency, T [, monetary_value], test_frequency [, test_monetary_value], test_T
+
+ Note this function will exclude new customers whose first transactions occurred during the test period.
+
+ Adapted from lifetimes package
+ https://github.com/CamDavidsonPilon/lifetimes/blob/41e394923ad72b17b5da93e88cfabab43f51abe2/lifetimes/utils.py#L27
+
+ Parameters
+ ----------
+ transactions: :obj: DataFrame
+ A Pandas DataFrame that contains the customer_id col and the datetime col.
+ customer_id_col: string
+ Column in the transactions DataFrame that denotes the customer_id.
+ datetime_col: string
+ Column in the transactions DataFrame that denotes the datetime the purchase was made.
+ train_period_end: Union[str, pd.Period, datetime], optional
+ A string or datetime to denote the final time period for the training data.
+ Events after this time period are used for the test data.
+ test_period_end: Union[str, pd.Period, datetime], optional
+ A string or datetime to denote the final time period of the study.
+ Events after this date are truncated. If not given, defaults to the max of 'datetime_col'.
+ time_unit: string, optional
+ Time granularity for study.
+ Default: 'D' for days. Possible values listed here:
+ https://numpy.org/devdocs/reference/arrays.datetime.html#datetime-units
+ time_scaler: int, optional
+ Default: 1. Useful for scaling recency & T to a different time granularity. Example:
+ With freq='D' and freq_multiplier=1, we get recency=591 and T=632
+ With freq='h' and freq_multiplier=24, we get recency=590.125 and T=631.375
+ This is useful if predictions in months or years are desired,
+ and can also help with model convergence for study periods of many years.
+ datetime_format: string, optional
+ A string that represents the timestamp format. Useful if Pandas can't understand
+ the provided format.
+ monetary_value_col: string, optional
+ Column in the transactions DataFrame that denotes the monetary value of the transaction.
+ Optional; only needed for spend estimation models like the Gamma-Gamma model.
+ include_first_transaction: bool, optional
+ Default: False
+ For predictive CLV modeling, this should be False.
+ Set to True if performing RFM segmentation.
+ sort_transactions: bool, optional
+ Default: True
+ If raw data is already sorted in chronological order, set to `False` to improve computational efficiency.
+
+ Returns
+ -------
+ :obj: DataFrame:
+ customer_id, frequency, recency, T, test_frequency, test_T [, monetary_value, test_monetary_value]
+ """
+
+ if test_period_end is None:
+ test_period_end = transactions[datetime_col].max()
+
+ transaction_cols = [customer_id_col, datetime_col]
+ if monetary_value_col:
+ transaction_cols.append(monetary_value_col)
+ transactions = transactions[transaction_cols].copy()
+
+ transactions[datetime_col] = pd.to_datetime(
+ transactions[datetime_col], format=datetime_format
+ )
+ test_period_end = pd.to_datetime(test_period_end, format=datetime_format)
+ train_period_end = pd.to_datetime(train_period_end, format=datetime_format)
+
+ # create training dataset
+ training_transactions = transactions.loc[
+ transactions[datetime_col] <= train_period_end
+ ]
+
+ if training_transactions.empty:
+ raise ValueError(
+ "No data available. Check `test_transactions` and `train_period_end` and confirm values in `transactions` occur prior to those time periods."
+ )
+
+ training_rfm_data = rfm_summary(
+ training_transactions,
+ customer_id_col,
+ datetime_col,
+ monetary_value_col=monetary_value_col,
+ datetime_format=datetime_format,
+ observation_period_end=train_period_end,
+ time_unit=time_unit,
+ time_scaler=time_scaler,
+ include_first_transaction=include_first_transaction,
+ sort_transactions=sort_transactions,
+ )
+
+ # create test dataset
+ test_transactions = transactions.loc[
+ (test_period_end >= transactions[datetime_col])
+ & (transactions[datetime_col] > train_period_end)
+ ].copy()
+
+ if test_transactions.empty:
+ raise ValueError(
+ "No data available. Check `test_transactions` and `train_period_end` and confirm values in `transactions` occur prior to those time periods."
+ )
+
+ test_transactions[datetime_col] = test_transactions[datetime_col].dt.to_period(
+ time_unit
+ )
+ # create dataframe with customer_id and test_frequency columns
+ test_rfm_data = (
+ test_transactions.groupby([customer_id_col, datetime_col], sort=False)[
+ datetime_col
+ ]
+ .agg(lambda r: 1)
+ .groupby(level=customer_id_col)
+ .count()
+ ).reset_index()
+
+ test_rfm_data = test_rfm_data.rename(
+ columns={"id": "customer_id", "date": "test_frequency"}
+ )
+
+ if monetary_value_col:
+ test_monetary_value = (
+ test_transactions.groupby([customer_id_col, datetime_col])[
+ monetary_value_col
+ ]
+ .sum()
+ .groupby(customer_id_col)
+ .mean()
+ )
+
+ test_rfm_data = test_rfm_data.merge(
+ test_monetary_value,
+ left_on="customer_id",
+ right_on=customer_id_col,
+ how="inner",
+ )
+ test_rfm_data = test_rfm_data.rename(
+ columns={monetary_value_col: "test_monetary_value"}
+ )
+
+ train_test_rfm_data = training_rfm_data.merge(
+ test_rfm_data, on="customer_id", how="left"
+ )
+ train_test_rfm_data.fillna(0, inplace=True)
+
+ time_delta = (
+ test_period_end.to_period(time_unit) - train_period_end.to_period(time_unit)
+ ).n
+ train_test_rfm_data["test_T"] = time_delta / time_scaler # type: ignore
+
+ return train_test_rfm_data
diff --git a/tests/clv/test_utils.py b/tests/clv/test_utils.py
index ca008de8..91a631f3 100644
--- a/tests/clv/test_utils.py
+++ b/tests/clv/test_utils.py
@@ -13,6 +13,7 @@
clv_summary,
customer_lifetime_value,
rfm_summary,
+ rfm_train_test_split,
to_xarray,
)
from tests.clv.utils import set_model_fit
@@ -137,31 +138,11 @@ def fitted_gg(test_summary_data) -> GammaGammaModel:
return model
-@pytest.fixture(scope="module")
-def transaction_data() -> pd.DataFrame:
- d = [
- [1, "2015-01-01", 1],
- [1, "2015-02-06", 2],
- [2, "2015-01-01", 2],
- [3, "2015-01-01", 3],
- [3, "2015-01-02", 1],
- [3, "2015-01-05", 5],
- [4, "2015-01-16", 6],
- [4, "2015-02-02", 3],
- [4, "2015-02-05", 3],
- [5, "2015-01-16", 3],
- [5, "2015-01-17", 1],
- [5, "2015-01-18", 8],
- [6, "2015-02-02", 5],
- ]
- return pd.DataFrame(d, columns=["id", "date", "monetary_value"])
-
-
class TestCustomerLifetimeValue:
def test_customer_lifetime_value_bg_with_known_values(
self, test_summary_data, fitted_bg
):
- # Test borrowed from
+ # Test adapted from
# https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L527
t = test_summary_data.head()
@@ -345,387 +326,522 @@ def test_clv_after_thinning(
)
-def test_find_first_transactions_observation_period_end_none(transaction_data):
- max_date = transaction_data["date"].max()
- pd.testing.assert_frame_equal(
- left=_find_first_transactions(
- transactions=transaction_data,
- customer_id_col="id",
- datetime_col="date",
- observation_period_end=None,
- ),
- right=_find_first_transactions(
- transactions=transaction_data,
- customer_id_col="id",
- datetime_col="date",
- observation_period_end=max_date,
- ),
- )
-
-
-@pytest.mark.parametrize(
- argnames="today",
- argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
- ids=["string", "period", "datetime", "none"],
-)
-def test_find_first_transactions_returns_correct_results(transaction_data, today):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L137
+class TestRFM:
+ @pytest.fixture(scope="class")
+ def transaction_data(self) -> pd.DataFrame:
+ d = [
+ [1, "2015-01-01", 1],
+ [1, "2015-02-06", 2],
+ [2, "2015-01-01", 2],
+ [3, "2015-01-01", 3],
+ [3, "2015-01-02", 1],
+ [3, "2015-01-05", 5],
+ [4, "2015-01-16", 6],
+ [4, "2015-02-02", 3],
+ [4, "2015-02-05", 3],
+ [4, "2015-02-05", 6],
+ [5, "2015-01-16", 3],
+ [5, "2015-01-17", 1],
+ [5, "2015-01-18", 8],
+ [6, "2015-02-02", 5],
+ ]
+ return pd.DataFrame(d, columns=["id", "date", "monetary_value"])
+
+ def test_find_first_transactions_test_period_end_none(self, transaction_data):
+ max_date = transaction_data["date"].max()
+ pd.testing.assert_frame_equal(
+ left=_find_first_transactions(
+ transactions=transaction_data,
+ customer_id_col="id",
+ datetime_col="date",
+ observation_period_end=None,
+ ),
+ right=_find_first_transactions(
+ transactions=transaction_data,
+ customer_id_col="id",
+ datetime_col="date",
+ observation_period_end=max_date,
+ ),
+ )
- actual = _find_first_transactions(
- transaction_data,
- "id",
- "date",
- observation_period_end=today,
+ @pytest.mark.parametrize(
+ argnames="today",
+ argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
+ ids=["string", "period", "datetime", "none"],
)
- expected = pd.DataFrame(
- [
- [1, pd.Period("2015-01-01", "D"), True],
- [1, pd.Period("2015-02-06", "D"), False],
- [2, pd.Period("2015-01-01", "D"), True],
- [3, pd.Period("2015-01-01", "D"), True],
- [3, pd.Period("2015-01-02", "D"), False],
- [3, pd.Period("2015-01-05", "D"), False],
- [4, pd.Period("2015-01-16", "D"), True],
- [4, pd.Period("2015-02-02", "D"), False],
- [4, pd.Period("2015-02-05", "D"), False],
- [5, pd.Period("2015-01-16", "D"), True],
- [5, pd.Period("2015-01-17", "D"), False],
- [5, pd.Period("2015-01-18", "D"), False],
- [6, pd.Period("2015-02-02", "D"), True],
- ],
- columns=["id", "date", "first"],
+ def test_find_first_transactions_returns_correct_results(
+ self, transaction_data, today
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L137
+
+ actual = _find_first_transactions(
+ transaction_data,
+ "id",
+ "date",
+ observation_period_end=today,
+ )
+ expected = pd.DataFrame(
+ [
+ [1, pd.Period("2015-01-01", "D"), True],
+ [1, pd.Period("2015-02-06", "D"), False],
+ [2, pd.Period("2015-01-01", "D"), True],
+ [3, pd.Period("2015-01-01", "D"), True],
+ [3, pd.Period("2015-01-02", "D"), False],
+ [3, pd.Period("2015-01-05", "D"), False],
+ [4, pd.Period("2015-01-16", "D"), True],
+ [4, pd.Period("2015-02-02", "D"), False],
+ [4, pd.Period("2015-02-05", "D"), False],
+ [5, pd.Period("2015-01-16", "D"), True],
+ [5, pd.Period("2015-01-17", "D"), False],
+ [5, pd.Period("2015-01-18", "D"), False],
+ [6, pd.Period("2015-02-02", "D"), True],
+ ],
+ columns=["id", "date", "first"],
+ index=[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13],
+ ) # row indices are skipped for time periods with multiple transactions
+ assert_frame_equal(actual, expected)
+
+ @pytest.mark.parametrize(
+ argnames="today",
+ argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
+ ids=["string", "period", "datetime", "none"],
)
- assert_frame_equal(actual, expected)
-
-
-@pytest.mark.parametrize(
- argnames="today",
- argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
- ids=["string", "period", "datetime", "none"],
-)
-def test_find_first_transactions_with_specific_non_daily_frequency(
- transaction_data, today
-):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L161
-
- actual = _find_first_transactions(
- transaction_data,
- "id",
- "date",
- observation_period_end=today,
- time_unit="W",
+ def test_find_first_transactions_with_specific_non_daily_frequency(
+ self, transaction_data, today
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L161
+
+ actual = _find_first_transactions(
+ transaction_data,
+ "id",
+ "date",
+ observation_period_end=today,
+ time_unit="W",
+ )
+ expected = pd.DataFrame(
+ [
+ [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True],
+ [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False],
+ [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True],
+ [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True],
+ [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), False],
+ [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True],
+ [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False],
+ [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True],
+ [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), True],
+ ],
+ columns=["id", "date", "first"],
+ index=actual.index,
+ ) # we shouldn't really care about row ordering or indexing, but assert_frame_equals is strict about it
+ assert_frame_equal(actual, expected)
+
+ @pytest.mark.parametrize(
+ argnames="today",
+ argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
+ ids=["string", "period", "datetime", "none"],
)
- expected = pd.DataFrame(
- [
- [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True],
- [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False],
- [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True],
- [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True],
- [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), False],
- [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True],
- [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False],
- [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True],
- [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), True],
- ],
- columns=["id", "date", "first"],
- index=actual.index,
- ) # we shouldn't really care about row ordering or indexing, but assert_frame_equals is strict about it
- assert_frame_equal(actual, expected)
-
-
-@pytest.mark.parametrize(
- argnames="today",
- argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
- ids=["string", "period", "datetime", "none"],
-)
-def test_find_first_transactions_with_monetary_values(transaction_data, today):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L184
+ def test_find_first_transactions_with_monetary_values(
+ self, transaction_data, today
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L184
+
+ actual = _find_first_transactions(
+ transaction_data,
+ "id",
+ "date",
+ "monetary_value",
+ observation_period_end=today,
+ )
+ expected = pd.DataFrame(
+ [
+ [1, pd.Period("2015-01-01", "D"), 1, True],
+ [1, pd.Period("2015-02-06", "D"), 2, False],
+ [2, pd.Period("2015-01-01", "D"), 2, True],
+ [3, pd.Period("2015-01-01", "D"), 3, True],
+ [3, pd.Period("2015-01-02", "D"), 1, False],
+ [3, pd.Period("2015-01-05", "D"), 5, False],
+ [4, pd.Period("2015-01-16", "D"), 6, True],
+ [4, pd.Period("2015-02-02", "D"), 3, False],
+ [4, pd.Period("2015-02-05", "D"), 9, False],
+ [5, pd.Period("2015-01-16", "D"), 3, True],
+ [5, pd.Period("2015-01-17", "D"), 1, False],
+ [5, pd.Period("2015-01-18", "D"), 8, False],
+ [6, pd.Period("2015-02-02", "D"), 5, True],
+ ],
+ columns=["id", "date", "monetary_value", "first"],
+ )
+ assert_frame_equal(actual, expected)
- actual = _find_first_transactions(
- transaction_data,
- "id",
- "date",
- "monetary_value",
- observation_period_end=today,
- )
- expected = pd.DataFrame(
- [
- [1, pd.Period("2015-01-01", "D"), 1, True],
- [1, pd.Period("2015-02-06", "D"), 2, False],
- [2, pd.Period("2015-01-01", "D"), 2, True],
- [3, pd.Period("2015-01-01", "D"), 3, True],
- [3, pd.Period("2015-01-02", "D"), 1, False],
- [3, pd.Period("2015-01-05", "D"), 5, False],
- [4, pd.Period("2015-01-16", "D"), 6, True],
- [4, pd.Period("2015-02-02", "D"), 3, False],
- [4, pd.Period("2015-02-05", "D"), 3, False],
- [5, pd.Period("2015-01-16", "D"), 3, True],
- [5, pd.Period("2015-01-17", "D"), 1, False],
- [5, pd.Period("2015-01-18", "D"), 8, False],
- [6, pd.Period("2015-02-02", "D"), 5, True],
- ],
- columns=["id", "date", "monetary_value", "first"],
+ @pytest.mark.parametrize(
+ argnames="today",
+ argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
+ ids=["string", "period", "datetime", "none"],
)
- assert_frame_equal(actual, expected)
-
+ def test_find_first_transactions_with_monetary_values_with_specific_non_daily_frequency(
+ self, transaction_data, today
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L210
+
+ actual = _find_first_transactions(
+ transaction_data,
+ "id",
+ "date",
+ "monetary_value",
+ observation_period_end=today,
+ time_unit="W",
+ )
+ expected = pd.DataFrame(
+ [
+ [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 1, True],
+ [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 2, False],
+ [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 2, True],
+ [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 4, True],
+ [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), 5, False],
+ [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 6, True],
+ [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 12, False],
+ [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 12, True],
+ [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 5, True],
+ ],
+ columns=["id", "date", "monetary_value", "first"],
+ )
+ assert_frame_equal(actual, expected)
-@pytest.mark.parametrize(
- argnames="today",
- argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None],
- ids=["string", "period", "datetime", "none"],
-)
-def test_find_first_transactions_with_monetary_values_with_specific_non_daily_frequency(
- transaction_data, today
-):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L210
+ @pytest.mark.parametrize(
+ argnames="today",
+ argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7)],
+ ids=["string", "period", "datetime"],
+ )
+ def test_rfm_summary_returns_correct_results(self, transaction_data, today):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L239
- actual = _find_first_transactions(
+ actual = rfm_summary(
+ transaction_data, "id", "date", observation_period_end=today
+ )
+ expected = pd.DataFrame(
+ [
+ [1, 1.0, 36.0, 37.0],
+ [2, 0.0, 0.0, 37.0],
+ [3, 2.0, 4.0, 37.0],
+ [4, 2.0, 20.0, 22.0],
+ [5, 2.0, 2.0, 22.0],
+ [6, 0.0, 0.0, 5.0],
+ ],
+ columns=["customer_id", "frequency", "recency", "T"],
+ )
+ assert_frame_equal(actual, expected)
+
+ def test_rfm_summary_works_with_string_customer_ids(self):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L250
+
+ d = [
+ ["X", "2015-02-01"],
+ ["X", "2015-02-06"],
+ ["Y", "2015-01-01"],
+ ["Y", "2015-01-01"],
+ ["Y", "2015-01-02"],
+ ["Y", "2015-01-05"],
+ ]
+ df = pd.DataFrame(d, columns=["id", "date"])
+ rfm_summary(df, "id", "date")
+
+ def test_rfm_summary_works_with_int_customer_ids_and_doesnt_coerce_to_float(self):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L263
+
+ d = [
+ [1, "2015-02-01"],
+ [1, "2015-02-06"],
+ [1, "2015-01-01"],
+ [2, "2015-01-01"],
+ [2, "2015-01-02"],
+ [2, "2015-01-05"],
+ ]
+ df = pd.DataFrame(d, columns=["id", "date"])
+ actual = rfm_summary(df, "id", "date")
+ assert actual.index.dtype == "int64"
+
+ def test_rfm_summary_with_specific_datetime_format(
+ self,
transaction_data,
- "id",
- "date",
- "monetary_value",
- observation_period_end=today,
- time_unit="W",
- )
- expected = pd.DataFrame(
- [
- [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 1, True],
- [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 2, False],
- [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 2, True],
- [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 4, True],
- [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), 5, False],
- [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 6, True],
- [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 6, False],
- [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 12, True],
- [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 5, True],
- ],
- columns=["id", "date", "monetary_value", "first"],
- )
- assert_frame_equal(actual, expected)
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L279
+ transaction_data["date"] = transaction_data["date"].map(
+ lambda x: x.replace("-", "")
+ )
+ format = "%Y%m%d"
+ today = "20150207"
+ actual = rfm_summary(
+ transaction_data,
+ "id",
+ "date",
+ observation_period_end=today,
+ datetime_format=format,
+ sort_transactions=False,
+ )
+ expected = pd.DataFrame(
+ [
+ [1, 1.0, 36.0, 37.0],
+ [2, 0.0, 0.0, 37.0],
+ [3, 2.0, 4.0, 37.0],
+ [4, 2.0, 20.0, 22.0],
+ [5, 2.0, 2.0, 22.0],
+ [6, 0.0, 0.0, 5.0],
+ ],
+ columns=["customer_id", "frequency", "recency", "T"],
+ )
+ assert_frame_equal(actual, expected)
-@pytest.mark.parametrize(
- argnames="today",
- argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7)],
- ids=["string", "period", "datetime"],
-)
-def test_rfm_summary_returns_correct_results(transaction_data, today):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L239
-
- actual = rfm_summary(transaction_data, "id", "date", observation_period_end=today)
- expected = pd.DataFrame(
- [
- [1, 1.0, 36.0, 37.0],
- [2, 0.0, 0.0, 37.0],
- [3, 2.0, 4.0, 37.0],
- [4, 2.0, 20.0, 22.0],
- [5, 2.0, 2.0, 22.0],
- [6, 0.0, 0.0, 5.0],
- ],
- columns=["customer_id", "frequency", "recency", "T"],
- )
- assert_frame_equal(actual, expected)
-
-
-def test_rfm_summary_works_with_string_customer_ids():
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L250
-
- d = [
- ["X", "2015-02-01"],
- ["X", "2015-02-06"],
- ["Y", "2015-01-01"],
- ["Y", "2015-01-01"],
- ["Y", "2015-01-02"],
- ["Y", "2015-01-05"],
- ]
- df = pd.DataFrame(d, columns=["id", "date"])
- rfm_summary(df, "id", "date")
-
-
-def test_rfm_summary_works_with_int_customer_ids_and_doesnt_coerce_to_float():
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L263
-
- d = [
- [1, "2015-02-01"],
- [1, "2015-02-06"],
- [1, "2015-01-01"],
- [2, "2015-01-01"],
- [2, "2015-01-02"],
- [2, "2015-01-05"],
- ]
- df = pd.DataFrame(d, columns=["id", "date"])
- actual = rfm_summary(df, "id", "date")
- assert actual.index.dtype == "int64"
-
-
-def test_rfm_summary_with_specific_datetime_format(
- transaction_data,
-):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L279
-
- transaction_data["date"] = transaction_data["date"].map(
- lambda x: x.replace("-", "")
- )
- format = "%Y%m%d"
- today = "20150207"
- actual = rfm_summary(
+ def test_rfm_summary_non_daily_frequency(
+ self,
transaction_data,
- "id",
- "date",
- observation_period_end=today,
- datetime_format=format,
- sort_transactions=False,
- )
- expected = pd.DataFrame(
- [
- [1, 1.0, 36.0, 37.0],
- [2, 0.0, 0.0, 37.0],
- [3, 2.0, 4.0, 37.0],
- [4, 2.0, 20.0, 22.0],
- [5, 2.0, 2.0, 22.0],
- [6, 0.0, 0.0, 5.0],
- ],
- columns=["customer_id", "frequency", "recency", "T"],
- )
- assert_frame_equal(actual, expected)
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L292
+
+ today = "20150207"
+ actual = rfm_summary(
+ transaction_data,
+ "id",
+ "date",
+ observation_period_end=today,
+ time_unit="W",
+ )
+ expected = pd.DataFrame(
+ [
+ [1, 1.0, 5.0, 5.0],
+ [2, 0.0, 0.0, 5.0],
+ [3, 1.0, 1.0, 5.0],
+ [4, 1.0, 3.0, 3.0],
+ [5, 0.0, 0.0, 3.0],
+ [6, 0.0, 0.0, 0.0],
+ ],
+ columns=["customer_id", "frequency", "recency", "T"],
+ )
+ assert_frame_equal(actual, expected)
+ def test_rfm_summary_monetary_values_and_first_transactions(
+ self,
+ transaction_data,
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L311
+
+ today = "20150207"
+ actual = rfm_summary(
+ transaction_data,
+ "id",
+ "date",
+ monetary_value_col="monetary_value",
+ observation_period_end=today,
+ )
+ expected = pd.DataFrame(
+ [
+ [1, 1.0, 36.0, 37.0, 2],
+ [2, 0.0, 0.0, 37.0, 0],
+ [3, 2.0, 4.0, 37.0, 3],
+ [4, 2.0, 20.0, 22.0, 6],
+ [5, 2.0, 2.0, 22.0, 4.5],
+ [6, 0.0, 0.0, 5.0, 0],
+ ],
+ columns=["customer_id", "frequency", "recency", "T", "monetary_value"],
+ )
+ assert_frame_equal(actual, expected)
+
+ actual_first_trans = rfm_summary(
+ transaction_data,
+ "id",
+ "date",
+ monetary_value_col="monetary_value",
+ observation_period_end=today,
+ include_first_transaction=True,
+ )
+ expected_first_trans = pd.DataFrame(
+ [
+ [1, 2.0, 36.0, 37.0, 1.5],
+ [2, 1.0, 0.0, 37.0, 2],
+ [3, 3.0, 4.0, 37.0, 3],
+ [4, 3.0, 20.0, 22.0, 6],
+ [5, 3.0, 2.0, 22.0, 4],
+ [6, 1.0, 0.0, 5.0, 5],
+ ],
+ columns=["customer_id", "frequency", "recency", "T", "monetary_value"],
+ )
+ assert_frame_equal(actual_first_trans, expected_first_trans)
-def test_rfm_summary_non_daily_frequency(
- transaction_data,
-):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L292
+ def test_rfm_summary_will_choose_the_correct_first_order_to_drop_in_monetary_transactions(
+ self,
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L334
- today = "20150207"
- actual = rfm_summary(
- transaction_data,
- "id",
- "date",
- observation_period_end=today,
- time_unit="W",
- )
- expected = pd.DataFrame(
- [
- [1, 1.0, 5.0, 5.0],
- [2, 0.0, 0.0, 5.0],
- [3, 1.0, 1.0, 5.0],
- [4, 1.0, 3.0, 3.0],
- [5, 0.0, 0.0, 3.0],
- [6, 0.0, 0.0, 0.0],
- ],
- columns=["customer_id", "frequency", "recency", "T"],
- )
- assert_frame_equal(actual, expected)
+ cust = pd.Series([2, 2, 2])
+ dates_ordered = pd.to_datetime(
+ pd.Series(
+ ["2014-03-14 00:00:00", "2014-04-09 00:00:00", "2014-05-21 00:00:00"]
+ )
+ )
+ sales = pd.Series([10, 20, 25])
+ transaction_data = pd.DataFrame(
+ {"date": dates_ordered, "id": cust, "sales": sales}
+ )
+ summary_ordered_data = rfm_summary(transaction_data, "id", "date", "sales")
+ dates_unordered = pd.to_datetime(
+ pd.Series(
+ ["2014-04-09 00:00:00", "2014-03-14 00:00:00", "2014-05-21 00:00:00"]
+ )
+ )
+ sales = pd.Series([20, 10, 25])
+ transaction_data = pd.DataFrame(
+ {"date": dates_unordered, "id": cust, "sales": sales}
+ )
+ summary_unordered_data = rfm_summary(transaction_data, "id", "date", "sales")
-def test_rfm_summary_monetary_values_and_first_transactions(
- transaction_data,
-):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L311
+ assert_frame_equal(summary_ordered_data, summary_unordered_data)
+ assert summary_ordered_data["monetary_value"].loc[0] == 22.5
- today = "20150207"
- actual = rfm_summary(
- transaction_data,
- "id",
- "date",
- monetary_value_col="monetary_value",
- observation_period_end=today,
- )
- expected = pd.DataFrame(
- [
- [1, 1.0, 36.0, 37.0, 2],
- [2, 0.0, 0.0, 37.0, 0],
- [3, 2.0, 4.0, 37.0, 3],
- [4, 2.0, 20.0, 22.0, 3],
- [5, 2.0, 2.0, 22.0, 4.5],
- [6, 0.0, 0.0, 5.0, 0],
- ],
- columns=["customer_id", "frequency", "recency", "T", "monetary_value"],
- )
- assert_frame_equal(actual, expected)
+ def test_rfm_summary_statistics_identical_to_hardie_paper(
+ self,
+ cdnow_trans,
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L353
+
+ # see http://brucehardie.com/papers/rfm_clv_2005-02-16.pdf
+ # RFM and CLV: Using Iso-value Curves for Customer Base Analysis
+ summary = rfm_summary(
+ cdnow_trans,
+ "id",
+ "date",
+ "spent",
+ observation_period_end="19971001",
+ datetime_format="%Y%m%d",
+ )
+ results = summary[summary["frequency"] > 0]["monetary_value"].describe()
- actual_first_trans = rfm_summary(
- transaction_data,
- "id",
- "date",
- monetary_value_col="monetary_value",
- observation_period_end=today,
- include_first_transaction=True,
- )
- expected_first_trans = pd.DataFrame(
- [
- [1, 2.0, 36.0, 37.0, 1.5],
- [2, 1.0, 0.0, 37.0, 2],
- [3, 3.0, 4.0, 37.0, 3],
- [4, 3.0, 20.0, 22.0, 4],
- [5, 3.0, 2.0, 22.0, 4],
- [6, 1.0, 0.0, 5.0, 5],
- ],
- columns=["customer_id", "frequency", "recency", "T", "monetary_value"],
- )
- assert_frame_equal(actual_first_trans, expected_first_trans)
+ assert np.round(results.loc["mean"]) == 35
+ assert np.round(results.loc["std"]) == 30
+ assert np.round(results.loc["min"]) == 3
+ assert np.round(results.loc["50%"]) == 27
+ assert np.round(results.loc["max"]) == 300
+ assert np.round(results.loc["count"]) == 946
+ def test_rfm_summary_squashes_period_purchases_to_one_purchase(self):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L472
-def test_rfm_summary_will_choose_the_correct_first_order_to_drop_in_monetary_transactions():
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L334
+ transactions = pd.DataFrame(
+ [[1, "2015-01-01"], [1, "2015-01-01"]], columns=["id", "t"]
+ )
+ actual = rfm_summary(transactions, "id", "t", time_unit="W")
+ assert actual.loc[0]["frequency"] == 1.0 - 1.0
- cust = pd.Series([2, 2, 2])
- dates_ordered = pd.to_datetime(
- pd.Series(["2014-03-14 00:00:00", "2014-04-09 00:00:00", "2014-05-21 00:00:00"])
- )
- sales = pd.Series([10, 20, 25])
- transaction_data = pd.DataFrame({"date": dates_ordered, "id": cust, "sales": sales})
- summary_ordered_data = rfm_summary(transaction_data, "id", "date", "sales")
+ def test_clv_summary_warning(self, transaction_data):
+ with pytest.warns(UserWarning, match="clv_summary was renamed to rfm_summary"):
+ clv_summary(transaction_data, "id", "date")
- dates_unordered = pd.to_datetime(
- pd.Series(["2014-04-09 00:00:00", "2014-03-14 00:00:00", "2014-05-21 00:00:00"])
- )
- sales = pd.Series([20, 10, 25])
- transaction_data = pd.DataFrame(
- {"date": dates_unordered, "id": cust, "sales": sales}
- )
- summary_unordered_data = rfm_summary(transaction_data, "id", "date", "sales")
+ def test_rfm_train_test_split(self, transaction_data):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L374
- assert_frame_equal(summary_ordered_data, summary_unordered_data)
- assert summary_ordered_data["monetary_value"].loc[0] == 22.5
+ train_end = "2015-02-01"
+ actual = rfm_train_test_split(transaction_data, "id", "date", train_end)
+ assert actual.loc[0]["test_frequency"] == 1
+ assert actual.loc[1]["test_frequency"] == 0
+ with pytest.raises(KeyError):
+ actual.loc[6]
-def test_rfm_summary_statistics_identical_to_hardie_paper(
- cdnow_trans,
-):
- # Test borrowed from
- # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L353
+ @pytest.mark.parametrize("train_end", ("2014-02-07", "2015-02-08"))
+ def test_rfm_train_test_split_throws_better_error_if_test_period_end_is_too_early(
+ self,
+ train_end,
+ transaction_data,
+ ):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L387
- # see http://brucehardie.com/papers/rfm_clv_2005-02-16.pdf
- # RFM and CLV: Using Iso-value Curves for Customer Base Analysis
- summary = rfm_summary(
- cdnow_trans,
- "id",
- "date",
- "spent",
- observation_period_end="19971001",
- datetime_format="%Y%m%d",
- )
- results = summary[summary["frequency"] > 0]["monetary_value"].describe()
+ test_end = "2014-02-07"
- assert np.round(results.loc["mean"]) == 35
- assert np.round(results.loc["std"]) == 30
- assert np.round(results.loc["min"]) == 3
- assert np.round(results.loc["50%"]) == 27
- assert np.round(results.loc["max"]) == 300
- assert np.round(results.loc["count"]) == 946
+ with pytest.raises(
+ ValueError,
+ match="No data available. Check `test_transactions` and `train_period_end` and confirm values in `transactions` occur prior to those time periods.",
+ ):
+ rfm_train_test_split(
+ transaction_data, "id", "date", train_end, test_period_end=test_end
+ )
+ def test_rfm_train_test_split_works_with_specific_frequency(self, transaction_data):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L412
+
+ test_end = "2015-02-07"
+ train_end = "2015-02-01"
+ actual = rfm_train_test_split(
+ transaction_data,
+ "id",
+ "date",
+ train_end,
+ test_period_end=test_end,
+ time_unit="W",
+ )
+ expected_cols = [
+ "customer_id",
+ "frequency",
+ "recency",
+ "T",
+ "test_frequency",
+ "test_T",
+ ]
+ expected = pd.DataFrame(
+ [
+ [1, 0.0, 0.0, 4.0, 1, 1],
+ [2, 0.0, 0.0, 4.0, 0, 1],
+ [3, 1.0, 1.0, 4.0, 0, 1],
+ [4, 0.0, 0.0, 2.0, 1, 1],
+ [5, 0.0, 0.0, 2.0, 0, 1],
+ ],
+ columns=expected_cols,
+ )
+ assert_frame_equal(actual, expected, check_dtype=False)
+
+ def test_rfm_train_test_split_gives_correct_date_boundaries(self, transaction_data):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L432
+
+ actual = rfm_train_test_split(
+ transaction_data,
+ "id",
+ "date",
+ train_period_end="2015-02-01",
+ test_period_end="2015-02-04",
+ )
+ assert actual["test_frequency"].loc[1] == 0
+ assert actual["test_frequency"].loc[3] == 1
+
+ def test_rfm_train_test_split_with_monetary_value(self, transaction_data):
+ # Test adapted from
+ # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L457
+
+ test_end = "2015-02-07"
+ train_end = "2015-02-01"
+ actual = rfm_train_test_split(
+ transaction_data,
+ "id",
+ "date",
+ train_end,
+ test_period_end=test_end,
+ monetary_value_col="monetary_value",
+ )
+ assert (actual["monetary_value"] == [0, 0, 3, 0, 4.5]).all()
+ assert (actual["test_monetary_value"] == [2, 0, 0, 6, 0]).all()
-def test_clv_summary_warning(transaction_data):
- with pytest.warns(UserWarning, match="clv_summary was renamed to rfm_summary"):
- clv_summary(transaction_data, "id", "date")
+ # check test_monetary_value is being aggregated correctly for time periods with multiple purchases