diff --git a/docs/source/notebooks/clv/dev/utilities_plotting.ipynb b/docs/source/notebooks/clv/dev/utilities_plotting.ipynb new file mode 100644 index 00000000..a797237e --- /dev/null +++ b/docs/source/notebooks/clv/dev/utilities_plotting.ipynb @@ -0,0 +1,500 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "435ed203-5c3c-4efc-93d1-abac66ce7187", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + } + ], + "source": [ + "from pymc_marketing.clv import utils\n", + "\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "id": "ce561a65-e600-42de-84b6-f3c683729fff", + "metadata": {}, + "source": [ + "Create a simple dataset for testing:" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "7de7f396-1d5b-4457-916b-c29ed90aa132", + "metadata": {}, + "outputs": [], + "source": [ + "d = [\n", + " [1, \"2015-01-01\", 1],\n", + " [1, \"2015-02-06\", 2],\n", + " [2, \"2015-01-01\", 2],\n", + " [3, \"2015-01-01\", 3],\n", + " [3, \"2015-01-02\", 1],\n", + " [3, \"2015-01-05\", 5],\n", + " [4, \"2015-01-16\", 6],\n", + " [4, \"2015-02-02\", 3],\n", + " [4, \"2015-02-05\", 3],\n", + " [4, \"2015-02-05\", 2],\n", + " [5, \"2015-01-16\", 3],\n", + " [5, \"2015-01-17\", 1],\n", + " [5, \"2015-01-18\", 8],\n", + " [6, \"2015-02-02\", 5],\n", + "]\n", + "test_data = pd.DataFrame(d, columns=[\"id\", \"date\", \"monetary_value\"])" + ] + }, + { + "cell_type": "markdown", + "id": "b089a2be-2c3e-4dd1-b96d-ee7c0bd02250", + "metadata": {}, + "source": [ + "Note customer 4 made two purchases on 2015-02-05. \n", + "\n", + "`_find_first_transactions` flags the first purchase each customer has made, which must be excluded for modeling. It is called internally by `rfm_summary`." + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "932e8db6-78cf-49df-aa4a-83ee6584e5dd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iddatefirst
012015-01-01True
112015-02-06False
222015-01-01True
332015-01-01True
432015-01-02False
532015-01-05False
642015-01-16True
742015-02-02False
842015-02-05False
1052015-01-16True
1152015-01-17False
1252015-01-18False
1362015-02-02True
\n", + "
" + ], + "text/plain": [ + " id date first\n", + "0 1 2015-01-01 True\n", + "1 1 2015-02-06 False\n", + "2 2 2015-01-01 True\n", + "3 3 2015-01-01 True\n", + "4 3 2015-01-02 False\n", + "5 3 2015-01-05 False\n", + "6 4 2015-01-16 True\n", + "7 4 2015-02-02 False\n", + "8 4 2015-02-05 False\n", + "10 5 2015-01-16 True\n", + "11 5 2015-01-17 False\n", + "12 5 2015-01-18 False\n", + "13 6 2015-02-02 True" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "utils._find_first_transactions(\n", + " transactions=test_data, \n", + " customer_id_col = \"id\", \n", + " datetime_col = \"date\",\n", + " #monetary_value_col = \"monetary_value\", \n", + " #datetime_format = \"%Y%m%d\",\n", + ").reindex()" + ] + }, + { + "cell_type": "markdown", + "id": "cd77dcbe-6990-4784-9960-9fc2b52e90f0", + "metadata": {}, + "source": [ + "Notice how **9** is missing from the dataframe index. Multiple transactions in the same time period are treated as a single purchase, so the indices for those additional transactions are skipped. \n", + "\n", + "`rfm_summary` is the primary data preprocessing step for CLV modeling in the continuous, non-contractual domain:" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "4c0a7de5-8825-40af-84e5-6cd0ad26a0e3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idfrequencyrecencyTmonetary_value
011.05.05.02.0
120.00.05.00.0
231.01.05.05.0
341.03.03.08.0
450.00.03.00.0
\n", + "
" + ], + "text/plain": [ + " customer_id frequency recency T monetary_value\n", + "0 1 1.0 5.0 5.0 2.0\n", + "1 2 0.0 0.0 5.0 0.0\n", + "2 3 1.0 1.0 5.0 5.0\n", + "3 4 1.0 3.0 3.0 8.0\n", + "4 5 0.0 0.0 3.0 0.0" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rfm_df = utils.rfm_summary(\n", + " test_data, \n", + " customer_id_col = \"id\", \n", + " datetime_col = \"date\", \n", + " monetary_value_col = \"monetary_value\",\n", + " observation_period_end = \"2015-02-06\",\n", + " datetime_format = \"%Y-%m-%d\",\n", + " time_unit = \"W\",\n", + " include_first_transaction=False,\n", + ")\n", + "\n", + "rfm_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "aa8a6479-04fd-48ec-a34a-817b6fdff93c", + "metadata": {}, + "source": [ + "For MAP fits and covariate models, `rfm_train_test_split` can be used to evaluate models on unseen data. It is also useful for identifying the impact of a time-based event like a marketing campaign." + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "761edfe9-1b69-4966-83bf-4f1242eda2d5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idfrequencyrecencyTmonetary_valuetest_frequencytest_monetary_valuetest_T
010.00.031.00.01.02.05.0
120.00.031.00.00.00.05.0
232.04.031.03.00.00.05.0
340.00.016.00.02.04.05.0
452.02.016.04.50.00.05.0
\n", + "
" + ], + "text/plain": [ + " customer_id frequency recency T monetary_value test_frequency \\\n", + "0 1 0.0 0.0 31.0 0.0 1.0 \n", + "1 2 0.0 0.0 31.0 0.0 0.0 \n", + "2 3 2.0 4.0 31.0 3.0 0.0 \n", + "3 4 0.0 0.0 16.0 0.0 2.0 \n", + "4 5 2.0 2.0 16.0 4.5 0.0 \n", + "\n", + " test_monetary_value test_T \n", + "0 2.0 5.0 \n", + "1 0.0 5.0 \n", + "2 0.0 5.0 \n", + "3 4.0 5.0 \n", + "4 0.0 5.0 " + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_test = utils.rfm_train_test_split(\n", + " test_data, \n", + " customer_id_col = \"id\", \n", + " datetime_col = \"date\", \n", + " train_period_end = \"2015-02-01\",\n", + " monetary_value_col = \"monetary_value\",\n", + ")\n", + "\n", + "train_test.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7b3f800-8dfb-4e5a-b939-5f908281563c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_marketing/clv/utils.py b/pymc_marketing/clv/utils.py index 307b82d3..111e8a69 100644 --- a/pymc_marketing/clv/utils.py +++ b/pymc_marketing/clv/utils.py @@ -1,12 +1,18 @@ import warnings -from datetime import datetime +from datetime import date, datetime from typing import Optional, Union import numpy as np import pandas as pd import xarray +from numpy import datetime64 -__all__ = ["to_xarray", "customer_lifetime_value", "rfm_summary"] +__all__ = [ + "to_xarray", + "customer_lifetime_value", + "rfm_summary", + "rfm_train_test_split", +] def to_xarray(customer_id, *arrays, dim: str = "customer_id"): @@ -257,9 +263,9 @@ def _find_first_transactions( period_transactions.loc[first_transactions, "first"] = True select_columns.append("first") # reset datetime_col to period - period_transactions.loc[:, datetime_col] = period_transactions[ - datetime_col - ].dt.to_period(time_unit) + period_transactions[datetime_col] = period_transactions[datetime_col].dt.to_period( + time_unit + ) return period_transactions[select_columns] @@ -410,3 +416,174 @@ def rfm_summary( ) return summary_df + + +def rfm_train_test_split( + transactions: pd.DataFrame, + customer_id_col: str, + datetime_col: str, + train_period_end: Union[Union[float, str], datetime, datetime64, date], + test_period_end: Optional[ + Union[Union[float, str], datetime, datetime64, date] + ] = None, + time_unit: str = "D", + time_scaler: Optional[float] = 1, + datetime_format: Optional[str] = None, + monetary_value_col: Optional[str] = None, + include_first_transaction: Optional[bool] = False, + sort_transactions: Optional[bool] = True, +) -> pd.DataFrame: + """ + Summarize transaction data and split into training and tests datasets for CLV modeling. + This can also be used to evaluate the impact of a time-based intervention like a marketing campaign. + + This transforms a DataFrame of transaction data of the form: + customer_id, datetime [, monetary_value] + to a DataFrame of the form: + customer_id, frequency, recency, T [, monetary_value], test_frequency [, test_monetary_value], test_T + + Note this function will exclude new customers whose first transactions occurred during the test period. + + Adapted from lifetimes package + https://github.com/CamDavidsonPilon/lifetimes/blob/41e394923ad72b17b5da93e88cfabab43f51abe2/lifetimes/utils.py#L27 + + Parameters + ---------- + transactions: :obj: DataFrame + A Pandas DataFrame that contains the customer_id col and the datetime col. + customer_id_col: string + Column in the transactions DataFrame that denotes the customer_id. + datetime_col: string + Column in the transactions DataFrame that denotes the datetime the purchase was made. + train_period_end: Union[str, pd.Period, datetime], optional + A string or datetime to denote the final time period for the training data. + Events after this time period are used for the test data. + test_period_end: Union[str, pd.Period, datetime], optional + A string or datetime to denote the final time period of the study. + Events after this date are truncated. If not given, defaults to the max of 'datetime_col'. + time_unit: string, optional + Time granularity for study. + Default: 'D' for days. Possible values listed here: + https://numpy.org/devdocs/reference/arrays.datetime.html#datetime-units + time_scaler: int, optional + Default: 1. Useful for scaling recency & T to a different time granularity. Example: + With freq='D' and freq_multiplier=1, we get recency=591 and T=632 + With freq='h' and freq_multiplier=24, we get recency=590.125 and T=631.375 + This is useful if predictions in months or years are desired, + and can also help with model convergence for study periods of many years. + datetime_format: string, optional + A string that represents the timestamp format. Useful if Pandas can't understand + the provided format. + monetary_value_col: string, optional + Column in the transactions DataFrame that denotes the monetary value of the transaction. + Optional; only needed for spend estimation models like the Gamma-Gamma model. + include_first_transaction: bool, optional + Default: False + For predictive CLV modeling, this should be False. + Set to True if performing RFM segmentation. + sort_transactions: bool, optional + Default: True + If raw data is already sorted in chronological order, set to `False` to improve computational efficiency. + + Returns + ------- + :obj: DataFrame: + customer_id, frequency, recency, T, test_frequency, test_T [, monetary_value, test_monetary_value] + """ + + if test_period_end is None: + test_period_end = transactions[datetime_col].max() + + transaction_cols = [customer_id_col, datetime_col] + if monetary_value_col: + transaction_cols.append(monetary_value_col) + transactions = transactions[transaction_cols].copy() + + transactions[datetime_col] = pd.to_datetime( + transactions[datetime_col], format=datetime_format + ) + test_period_end = pd.to_datetime(test_period_end, format=datetime_format) + train_period_end = pd.to_datetime(train_period_end, format=datetime_format) + + # create training dataset + training_transactions = transactions.loc[ + transactions[datetime_col] <= train_period_end + ] + + if training_transactions.empty: + raise ValueError( + "No data available. Check `test_transactions` and `train_period_end` and confirm values in `transactions` occur prior to those time periods." + ) + + training_rfm_data = rfm_summary( + training_transactions, + customer_id_col, + datetime_col, + monetary_value_col=monetary_value_col, + datetime_format=datetime_format, + observation_period_end=train_period_end, + time_unit=time_unit, + time_scaler=time_scaler, + include_first_transaction=include_first_transaction, + sort_transactions=sort_transactions, + ) + + # create test dataset + test_transactions = transactions.loc[ + (test_period_end >= transactions[datetime_col]) + & (transactions[datetime_col] > train_period_end) + ].copy() + + if test_transactions.empty: + raise ValueError( + "No data available. Check `test_transactions` and `train_period_end` and confirm values in `transactions` occur prior to those time periods." + ) + + test_transactions[datetime_col] = test_transactions[datetime_col].dt.to_period( + time_unit + ) + # create dataframe with customer_id and test_frequency columns + test_rfm_data = ( + test_transactions.groupby([customer_id_col, datetime_col], sort=False)[ + datetime_col + ] + .agg(lambda r: 1) + .groupby(level=customer_id_col) + .count() + ).reset_index() + + test_rfm_data = test_rfm_data.rename( + columns={"id": "customer_id", "date": "test_frequency"} + ) + + if monetary_value_col: + test_monetary_value = ( + test_transactions.groupby([customer_id_col, datetime_col])[ + monetary_value_col + ] + .sum() + .groupby(customer_id_col) + .mean() + ) + + test_rfm_data = test_rfm_data.merge( + test_monetary_value, + left_on="customer_id", + right_on=customer_id_col, + how="inner", + ) + test_rfm_data = test_rfm_data.rename( + columns={monetary_value_col: "test_monetary_value"} + ) + + train_test_rfm_data = training_rfm_data.merge( + test_rfm_data, on="customer_id", how="left" + ) + train_test_rfm_data.fillna(0, inplace=True) + + time_delta = ( + test_period_end.to_period(time_unit) - train_period_end.to_period(time_unit) + ).n + train_test_rfm_data["test_T"] = time_delta / time_scaler # type: ignore + + return train_test_rfm_data diff --git a/tests/clv/test_utils.py b/tests/clv/test_utils.py index ca008de8..91a631f3 100644 --- a/tests/clv/test_utils.py +++ b/tests/clv/test_utils.py @@ -13,6 +13,7 @@ clv_summary, customer_lifetime_value, rfm_summary, + rfm_train_test_split, to_xarray, ) from tests.clv.utils import set_model_fit @@ -137,31 +138,11 @@ def fitted_gg(test_summary_data) -> GammaGammaModel: return model -@pytest.fixture(scope="module") -def transaction_data() -> pd.DataFrame: - d = [ - [1, "2015-01-01", 1], - [1, "2015-02-06", 2], - [2, "2015-01-01", 2], - [3, "2015-01-01", 3], - [3, "2015-01-02", 1], - [3, "2015-01-05", 5], - [4, "2015-01-16", 6], - [4, "2015-02-02", 3], - [4, "2015-02-05", 3], - [5, "2015-01-16", 3], - [5, "2015-01-17", 1], - [5, "2015-01-18", 8], - [6, "2015-02-02", 5], - ] - return pd.DataFrame(d, columns=["id", "date", "monetary_value"]) - - class TestCustomerLifetimeValue: def test_customer_lifetime_value_bg_with_known_values( self, test_summary_data, fitted_bg ): - # Test borrowed from + # Test adapted from # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L527 t = test_summary_data.head() @@ -345,387 +326,522 @@ def test_clv_after_thinning( ) -def test_find_first_transactions_observation_period_end_none(transaction_data): - max_date = transaction_data["date"].max() - pd.testing.assert_frame_equal( - left=_find_first_transactions( - transactions=transaction_data, - customer_id_col="id", - datetime_col="date", - observation_period_end=None, - ), - right=_find_first_transactions( - transactions=transaction_data, - customer_id_col="id", - datetime_col="date", - observation_period_end=max_date, - ), - ) - - -@pytest.mark.parametrize( - argnames="today", - argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], - ids=["string", "period", "datetime", "none"], -) -def test_find_first_transactions_returns_correct_results(transaction_data, today): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L137 +class TestRFM: + @pytest.fixture(scope="class") + def transaction_data(self) -> pd.DataFrame: + d = [ + [1, "2015-01-01", 1], + [1, "2015-02-06", 2], + [2, "2015-01-01", 2], + [3, "2015-01-01", 3], + [3, "2015-01-02", 1], + [3, "2015-01-05", 5], + [4, "2015-01-16", 6], + [4, "2015-02-02", 3], + [4, "2015-02-05", 3], + [4, "2015-02-05", 6], + [5, "2015-01-16", 3], + [5, "2015-01-17", 1], + [5, "2015-01-18", 8], + [6, "2015-02-02", 5], + ] + return pd.DataFrame(d, columns=["id", "date", "monetary_value"]) + + def test_find_first_transactions_test_period_end_none(self, transaction_data): + max_date = transaction_data["date"].max() + pd.testing.assert_frame_equal( + left=_find_first_transactions( + transactions=transaction_data, + customer_id_col="id", + datetime_col="date", + observation_period_end=None, + ), + right=_find_first_transactions( + transactions=transaction_data, + customer_id_col="id", + datetime_col="date", + observation_period_end=max_date, + ), + ) - actual = _find_first_transactions( - transaction_data, - "id", - "date", - observation_period_end=today, + @pytest.mark.parametrize( + argnames="today", + argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], + ids=["string", "period", "datetime", "none"], ) - expected = pd.DataFrame( - [ - [1, pd.Period("2015-01-01", "D"), True], - [1, pd.Period("2015-02-06", "D"), False], - [2, pd.Period("2015-01-01", "D"), True], - [3, pd.Period("2015-01-01", "D"), True], - [3, pd.Period("2015-01-02", "D"), False], - [3, pd.Period("2015-01-05", "D"), False], - [4, pd.Period("2015-01-16", "D"), True], - [4, pd.Period("2015-02-02", "D"), False], - [4, pd.Period("2015-02-05", "D"), False], - [5, pd.Period("2015-01-16", "D"), True], - [5, pd.Period("2015-01-17", "D"), False], - [5, pd.Period("2015-01-18", "D"), False], - [6, pd.Period("2015-02-02", "D"), True], - ], - columns=["id", "date", "first"], + def test_find_first_transactions_returns_correct_results( + self, transaction_data, today + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L137 + + actual = _find_first_transactions( + transaction_data, + "id", + "date", + observation_period_end=today, + ) + expected = pd.DataFrame( + [ + [1, pd.Period("2015-01-01", "D"), True], + [1, pd.Period("2015-02-06", "D"), False], + [2, pd.Period("2015-01-01", "D"), True], + [3, pd.Period("2015-01-01", "D"), True], + [3, pd.Period("2015-01-02", "D"), False], + [3, pd.Period("2015-01-05", "D"), False], + [4, pd.Period("2015-01-16", "D"), True], + [4, pd.Period("2015-02-02", "D"), False], + [4, pd.Period("2015-02-05", "D"), False], + [5, pd.Period("2015-01-16", "D"), True], + [5, pd.Period("2015-01-17", "D"), False], + [5, pd.Period("2015-01-18", "D"), False], + [6, pd.Period("2015-02-02", "D"), True], + ], + columns=["id", "date", "first"], + index=[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13], + ) # row indices are skipped for time periods with multiple transactions + assert_frame_equal(actual, expected) + + @pytest.mark.parametrize( + argnames="today", + argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], + ids=["string", "period", "datetime", "none"], ) - assert_frame_equal(actual, expected) - - -@pytest.mark.parametrize( - argnames="today", - argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], - ids=["string", "period", "datetime", "none"], -) -def test_find_first_transactions_with_specific_non_daily_frequency( - transaction_data, today -): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L161 - - actual = _find_first_transactions( - transaction_data, - "id", - "date", - observation_period_end=today, - time_unit="W", + def test_find_first_transactions_with_specific_non_daily_frequency( + self, transaction_data, today + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L161 + + actual = _find_first_transactions( + transaction_data, + "id", + "date", + observation_period_end=today, + time_unit="W", + ) + expected = pd.DataFrame( + [ + [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True], + [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False], + [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True], + [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True], + [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), False], + [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True], + [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False], + [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True], + [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), True], + ], + columns=["id", "date", "first"], + index=actual.index, + ) # we shouldn't really care about row ordering or indexing, but assert_frame_equals is strict about it + assert_frame_equal(actual, expected) + + @pytest.mark.parametrize( + argnames="today", + argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], + ids=["string", "period", "datetime", "none"], ) - expected = pd.DataFrame( - [ - [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True], - [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False], - [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True], - [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), True], - [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), False], - [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True], - [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), False], - [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), True], - [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), True], - ], - columns=["id", "date", "first"], - index=actual.index, - ) # we shouldn't really care about row ordering or indexing, but assert_frame_equals is strict about it - assert_frame_equal(actual, expected) - - -@pytest.mark.parametrize( - argnames="today", - argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], - ids=["string", "period", "datetime", "none"], -) -def test_find_first_transactions_with_monetary_values(transaction_data, today): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L184 + def test_find_first_transactions_with_monetary_values( + self, transaction_data, today + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L184 + + actual = _find_first_transactions( + transaction_data, + "id", + "date", + "monetary_value", + observation_period_end=today, + ) + expected = pd.DataFrame( + [ + [1, pd.Period("2015-01-01", "D"), 1, True], + [1, pd.Period("2015-02-06", "D"), 2, False], + [2, pd.Period("2015-01-01", "D"), 2, True], + [3, pd.Period("2015-01-01", "D"), 3, True], + [3, pd.Period("2015-01-02", "D"), 1, False], + [3, pd.Period("2015-01-05", "D"), 5, False], + [4, pd.Period("2015-01-16", "D"), 6, True], + [4, pd.Period("2015-02-02", "D"), 3, False], + [4, pd.Period("2015-02-05", "D"), 9, False], + [5, pd.Period("2015-01-16", "D"), 3, True], + [5, pd.Period("2015-01-17", "D"), 1, False], + [5, pd.Period("2015-01-18", "D"), 8, False], + [6, pd.Period("2015-02-02", "D"), 5, True], + ], + columns=["id", "date", "monetary_value", "first"], + ) + assert_frame_equal(actual, expected) - actual = _find_first_transactions( - transaction_data, - "id", - "date", - "monetary_value", - observation_period_end=today, - ) - expected = pd.DataFrame( - [ - [1, pd.Period("2015-01-01", "D"), 1, True], - [1, pd.Period("2015-02-06", "D"), 2, False], - [2, pd.Period("2015-01-01", "D"), 2, True], - [3, pd.Period("2015-01-01", "D"), 3, True], - [3, pd.Period("2015-01-02", "D"), 1, False], - [3, pd.Period("2015-01-05", "D"), 5, False], - [4, pd.Period("2015-01-16", "D"), 6, True], - [4, pd.Period("2015-02-02", "D"), 3, False], - [4, pd.Period("2015-02-05", "D"), 3, False], - [5, pd.Period("2015-01-16", "D"), 3, True], - [5, pd.Period("2015-01-17", "D"), 1, False], - [5, pd.Period("2015-01-18", "D"), 8, False], - [6, pd.Period("2015-02-02", "D"), 5, True], - ], - columns=["id", "date", "monetary_value", "first"], + @pytest.mark.parametrize( + argnames="today", + argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], + ids=["string", "period", "datetime", "none"], ) - assert_frame_equal(actual, expected) - + def test_find_first_transactions_with_monetary_values_with_specific_non_daily_frequency( + self, transaction_data, today + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L210 + + actual = _find_first_transactions( + transaction_data, + "id", + "date", + "monetary_value", + observation_period_end=today, + time_unit="W", + ) + expected = pd.DataFrame( + [ + [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 1, True], + [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 2, False], + [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 2, True], + [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 4, True], + [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), 5, False], + [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 6, True], + [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 12, False], + [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 12, True], + [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 5, True], + ], + columns=["id", "date", "monetary_value", "first"], + ) + assert_frame_equal(actual, expected) -@pytest.mark.parametrize( - argnames="today", - argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7), None], - ids=["string", "period", "datetime", "none"], -) -def test_find_first_transactions_with_monetary_values_with_specific_non_daily_frequency( - transaction_data, today -): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L210 + @pytest.mark.parametrize( + argnames="today", + argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7)], + ids=["string", "period", "datetime"], + ) + def test_rfm_summary_returns_correct_results(self, transaction_data, today): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L239 - actual = _find_first_transactions( + actual = rfm_summary( + transaction_data, "id", "date", observation_period_end=today + ) + expected = pd.DataFrame( + [ + [1, 1.0, 36.0, 37.0], + [2, 0.0, 0.0, 37.0], + [3, 2.0, 4.0, 37.0], + [4, 2.0, 20.0, 22.0], + [5, 2.0, 2.0, 22.0], + [6, 0.0, 0.0, 5.0], + ], + columns=["customer_id", "frequency", "recency", "T"], + ) + assert_frame_equal(actual, expected) + + def test_rfm_summary_works_with_string_customer_ids(self): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L250 + + d = [ + ["X", "2015-02-01"], + ["X", "2015-02-06"], + ["Y", "2015-01-01"], + ["Y", "2015-01-01"], + ["Y", "2015-01-02"], + ["Y", "2015-01-05"], + ] + df = pd.DataFrame(d, columns=["id", "date"]) + rfm_summary(df, "id", "date") + + def test_rfm_summary_works_with_int_customer_ids_and_doesnt_coerce_to_float(self): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L263 + + d = [ + [1, "2015-02-01"], + [1, "2015-02-06"], + [1, "2015-01-01"], + [2, "2015-01-01"], + [2, "2015-01-02"], + [2, "2015-01-05"], + ] + df = pd.DataFrame(d, columns=["id", "date"]) + actual = rfm_summary(df, "id", "date") + assert actual.index.dtype == "int64" + + def test_rfm_summary_with_specific_datetime_format( + self, transaction_data, - "id", - "date", - "monetary_value", - observation_period_end=today, - time_unit="W", - ) - expected = pd.DataFrame( - [ - [1, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 1, True], - [1, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 2, False], - [2, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 2, True], - [3, pd.Period("2014-12-29/2015-01-04", "W-SUN"), 4, True], - [3, pd.Period("2015-01-05/2015-01-11", "W-SUN"), 5, False], - [4, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 6, True], - [4, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 6, False], - [5, pd.Period("2015-01-12/2015-01-18", "W-SUN"), 12, True], - [6, pd.Period("2015-02-02/2015-02-08", "W-SUN"), 5, True], - ], - columns=["id", "date", "monetary_value", "first"], - ) - assert_frame_equal(actual, expected) + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L279 + transaction_data["date"] = transaction_data["date"].map( + lambda x: x.replace("-", "") + ) + format = "%Y%m%d" + today = "20150207" + actual = rfm_summary( + transaction_data, + "id", + "date", + observation_period_end=today, + datetime_format=format, + sort_transactions=False, + ) + expected = pd.DataFrame( + [ + [1, 1.0, 36.0, 37.0], + [2, 0.0, 0.0, 37.0], + [3, 2.0, 4.0, 37.0], + [4, 2.0, 20.0, 22.0], + [5, 2.0, 2.0, 22.0], + [6, 0.0, 0.0, 5.0], + ], + columns=["customer_id", "frequency", "recency", "T"], + ) + assert_frame_equal(actual, expected) -@pytest.mark.parametrize( - argnames="today", - argvalues=["2015-02-07", pd.Period("2015-02-07"), datetime(2015, 2, 7)], - ids=["string", "period", "datetime"], -) -def test_rfm_summary_returns_correct_results(transaction_data, today): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L239 - - actual = rfm_summary(transaction_data, "id", "date", observation_period_end=today) - expected = pd.DataFrame( - [ - [1, 1.0, 36.0, 37.0], - [2, 0.0, 0.0, 37.0], - [3, 2.0, 4.0, 37.0], - [4, 2.0, 20.0, 22.0], - [5, 2.0, 2.0, 22.0], - [6, 0.0, 0.0, 5.0], - ], - columns=["customer_id", "frequency", "recency", "T"], - ) - assert_frame_equal(actual, expected) - - -def test_rfm_summary_works_with_string_customer_ids(): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L250 - - d = [ - ["X", "2015-02-01"], - ["X", "2015-02-06"], - ["Y", "2015-01-01"], - ["Y", "2015-01-01"], - ["Y", "2015-01-02"], - ["Y", "2015-01-05"], - ] - df = pd.DataFrame(d, columns=["id", "date"]) - rfm_summary(df, "id", "date") - - -def test_rfm_summary_works_with_int_customer_ids_and_doesnt_coerce_to_float(): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L263 - - d = [ - [1, "2015-02-01"], - [1, "2015-02-06"], - [1, "2015-01-01"], - [2, "2015-01-01"], - [2, "2015-01-02"], - [2, "2015-01-05"], - ] - df = pd.DataFrame(d, columns=["id", "date"]) - actual = rfm_summary(df, "id", "date") - assert actual.index.dtype == "int64" - - -def test_rfm_summary_with_specific_datetime_format( - transaction_data, -): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L279 - - transaction_data["date"] = transaction_data["date"].map( - lambda x: x.replace("-", "") - ) - format = "%Y%m%d" - today = "20150207" - actual = rfm_summary( + def test_rfm_summary_non_daily_frequency( + self, transaction_data, - "id", - "date", - observation_period_end=today, - datetime_format=format, - sort_transactions=False, - ) - expected = pd.DataFrame( - [ - [1, 1.0, 36.0, 37.0], - [2, 0.0, 0.0, 37.0], - [3, 2.0, 4.0, 37.0], - [4, 2.0, 20.0, 22.0], - [5, 2.0, 2.0, 22.0], - [6, 0.0, 0.0, 5.0], - ], - columns=["customer_id", "frequency", "recency", "T"], - ) - assert_frame_equal(actual, expected) + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L292 + + today = "20150207" + actual = rfm_summary( + transaction_data, + "id", + "date", + observation_period_end=today, + time_unit="W", + ) + expected = pd.DataFrame( + [ + [1, 1.0, 5.0, 5.0], + [2, 0.0, 0.0, 5.0], + [3, 1.0, 1.0, 5.0], + [4, 1.0, 3.0, 3.0], + [5, 0.0, 0.0, 3.0], + [6, 0.0, 0.0, 0.0], + ], + columns=["customer_id", "frequency", "recency", "T"], + ) + assert_frame_equal(actual, expected) + def test_rfm_summary_monetary_values_and_first_transactions( + self, + transaction_data, + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L311 + + today = "20150207" + actual = rfm_summary( + transaction_data, + "id", + "date", + monetary_value_col="monetary_value", + observation_period_end=today, + ) + expected = pd.DataFrame( + [ + [1, 1.0, 36.0, 37.0, 2], + [2, 0.0, 0.0, 37.0, 0], + [3, 2.0, 4.0, 37.0, 3], + [4, 2.0, 20.0, 22.0, 6], + [5, 2.0, 2.0, 22.0, 4.5], + [6, 0.0, 0.0, 5.0, 0], + ], + columns=["customer_id", "frequency", "recency", "T", "monetary_value"], + ) + assert_frame_equal(actual, expected) + + actual_first_trans = rfm_summary( + transaction_data, + "id", + "date", + monetary_value_col="monetary_value", + observation_period_end=today, + include_first_transaction=True, + ) + expected_first_trans = pd.DataFrame( + [ + [1, 2.0, 36.0, 37.0, 1.5], + [2, 1.0, 0.0, 37.0, 2], + [3, 3.0, 4.0, 37.0, 3], + [4, 3.0, 20.0, 22.0, 6], + [5, 3.0, 2.0, 22.0, 4], + [6, 1.0, 0.0, 5.0, 5], + ], + columns=["customer_id", "frequency", "recency", "T", "monetary_value"], + ) + assert_frame_equal(actual_first_trans, expected_first_trans) -def test_rfm_summary_non_daily_frequency( - transaction_data, -): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L292 + def test_rfm_summary_will_choose_the_correct_first_order_to_drop_in_monetary_transactions( + self, + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L334 - today = "20150207" - actual = rfm_summary( - transaction_data, - "id", - "date", - observation_period_end=today, - time_unit="W", - ) - expected = pd.DataFrame( - [ - [1, 1.0, 5.0, 5.0], - [2, 0.0, 0.0, 5.0], - [3, 1.0, 1.0, 5.0], - [4, 1.0, 3.0, 3.0], - [5, 0.0, 0.0, 3.0], - [6, 0.0, 0.0, 0.0], - ], - columns=["customer_id", "frequency", "recency", "T"], - ) - assert_frame_equal(actual, expected) + cust = pd.Series([2, 2, 2]) + dates_ordered = pd.to_datetime( + pd.Series( + ["2014-03-14 00:00:00", "2014-04-09 00:00:00", "2014-05-21 00:00:00"] + ) + ) + sales = pd.Series([10, 20, 25]) + transaction_data = pd.DataFrame( + {"date": dates_ordered, "id": cust, "sales": sales} + ) + summary_ordered_data = rfm_summary(transaction_data, "id", "date", "sales") + dates_unordered = pd.to_datetime( + pd.Series( + ["2014-04-09 00:00:00", "2014-03-14 00:00:00", "2014-05-21 00:00:00"] + ) + ) + sales = pd.Series([20, 10, 25]) + transaction_data = pd.DataFrame( + {"date": dates_unordered, "id": cust, "sales": sales} + ) + summary_unordered_data = rfm_summary(transaction_data, "id", "date", "sales") -def test_rfm_summary_monetary_values_and_first_transactions( - transaction_data, -): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L311 + assert_frame_equal(summary_ordered_data, summary_unordered_data) + assert summary_ordered_data["monetary_value"].loc[0] == 22.5 - today = "20150207" - actual = rfm_summary( - transaction_data, - "id", - "date", - monetary_value_col="monetary_value", - observation_period_end=today, - ) - expected = pd.DataFrame( - [ - [1, 1.0, 36.0, 37.0, 2], - [2, 0.0, 0.0, 37.0, 0], - [3, 2.0, 4.0, 37.0, 3], - [4, 2.0, 20.0, 22.0, 3], - [5, 2.0, 2.0, 22.0, 4.5], - [6, 0.0, 0.0, 5.0, 0], - ], - columns=["customer_id", "frequency", "recency", "T", "monetary_value"], - ) - assert_frame_equal(actual, expected) + def test_rfm_summary_statistics_identical_to_hardie_paper( + self, + cdnow_trans, + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L353 + + # see http://brucehardie.com/papers/rfm_clv_2005-02-16.pdf + # RFM and CLV: Using Iso-value Curves for Customer Base Analysis + summary = rfm_summary( + cdnow_trans, + "id", + "date", + "spent", + observation_period_end="19971001", + datetime_format="%Y%m%d", + ) + results = summary[summary["frequency"] > 0]["monetary_value"].describe() - actual_first_trans = rfm_summary( - transaction_data, - "id", - "date", - monetary_value_col="monetary_value", - observation_period_end=today, - include_first_transaction=True, - ) - expected_first_trans = pd.DataFrame( - [ - [1, 2.0, 36.0, 37.0, 1.5], - [2, 1.0, 0.0, 37.0, 2], - [3, 3.0, 4.0, 37.0, 3], - [4, 3.0, 20.0, 22.0, 4], - [5, 3.0, 2.0, 22.0, 4], - [6, 1.0, 0.0, 5.0, 5], - ], - columns=["customer_id", "frequency", "recency", "T", "monetary_value"], - ) - assert_frame_equal(actual_first_trans, expected_first_trans) + assert np.round(results.loc["mean"]) == 35 + assert np.round(results.loc["std"]) == 30 + assert np.round(results.loc["min"]) == 3 + assert np.round(results.loc["50%"]) == 27 + assert np.round(results.loc["max"]) == 300 + assert np.round(results.loc["count"]) == 946 + def test_rfm_summary_squashes_period_purchases_to_one_purchase(self): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L472 -def test_rfm_summary_will_choose_the_correct_first_order_to_drop_in_monetary_transactions(): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L334 + transactions = pd.DataFrame( + [[1, "2015-01-01"], [1, "2015-01-01"]], columns=["id", "t"] + ) + actual = rfm_summary(transactions, "id", "t", time_unit="W") + assert actual.loc[0]["frequency"] == 1.0 - 1.0 - cust = pd.Series([2, 2, 2]) - dates_ordered = pd.to_datetime( - pd.Series(["2014-03-14 00:00:00", "2014-04-09 00:00:00", "2014-05-21 00:00:00"]) - ) - sales = pd.Series([10, 20, 25]) - transaction_data = pd.DataFrame({"date": dates_ordered, "id": cust, "sales": sales}) - summary_ordered_data = rfm_summary(transaction_data, "id", "date", "sales") + def test_clv_summary_warning(self, transaction_data): + with pytest.warns(UserWarning, match="clv_summary was renamed to rfm_summary"): + clv_summary(transaction_data, "id", "date") - dates_unordered = pd.to_datetime( - pd.Series(["2014-04-09 00:00:00", "2014-03-14 00:00:00", "2014-05-21 00:00:00"]) - ) - sales = pd.Series([20, 10, 25]) - transaction_data = pd.DataFrame( - {"date": dates_unordered, "id": cust, "sales": sales} - ) - summary_unordered_data = rfm_summary(transaction_data, "id", "date", "sales") + def test_rfm_train_test_split(self, transaction_data): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L374 - assert_frame_equal(summary_ordered_data, summary_unordered_data) - assert summary_ordered_data["monetary_value"].loc[0] == 22.5 + train_end = "2015-02-01" + actual = rfm_train_test_split(transaction_data, "id", "date", train_end) + assert actual.loc[0]["test_frequency"] == 1 + assert actual.loc[1]["test_frequency"] == 0 + with pytest.raises(KeyError): + actual.loc[6] -def test_rfm_summary_statistics_identical_to_hardie_paper( - cdnow_trans, -): - # Test borrowed from - # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L353 + @pytest.mark.parametrize("train_end", ("2014-02-07", "2015-02-08")) + def test_rfm_train_test_split_throws_better_error_if_test_period_end_is_too_early( + self, + train_end, + transaction_data, + ): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L387 - # see http://brucehardie.com/papers/rfm_clv_2005-02-16.pdf - # RFM and CLV: Using Iso-value Curves for Customer Base Analysis - summary = rfm_summary( - cdnow_trans, - "id", - "date", - "spent", - observation_period_end="19971001", - datetime_format="%Y%m%d", - ) - results = summary[summary["frequency"] > 0]["monetary_value"].describe() + test_end = "2014-02-07" - assert np.round(results.loc["mean"]) == 35 - assert np.round(results.loc["std"]) == 30 - assert np.round(results.loc["min"]) == 3 - assert np.round(results.loc["50%"]) == 27 - assert np.round(results.loc["max"]) == 300 - assert np.round(results.loc["count"]) == 946 + with pytest.raises( + ValueError, + match="No data available. Check `test_transactions` and `train_period_end` and confirm values in `transactions` occur prior to those time periods.", + ): + rfm_train_test_split( + transaction_data, "id", "date", train_end, test_period_end=test_end + ) + def test_rfm_train_test_split_works_with_specific_frequency(self, transaction_data): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L412 + + test_end = "2015-02-07" + train_end = "2015-02-01" + actual = rfm_train_test_split( + transaction_data, + "id", + "date", + train_end, + test_period_end=test_end, + time_unit="W", + ) + expected_cols = [ + "customer_id", + "frequency", + "recency", + "T", + "test_frequency", + "test_T", + ] + expected = pd.DataFrame( + [ + [1, 0.0, 0.0, 4.0, 1, 1], + [2, 0.0, 0.0, 4.0, 0, 1], + [3, 1.0, 1.0, 4.0, 0, 1], + [4, 0.0, 0.0, 2.0, 1, 1], + [5, 0.0, 0.0, 2.0, 0, 1], + ], + columns=expected_cols, + ) + assert_frame_equal(actual, expected, check_dtype=False) + + def test_rfm_train_test_split_gives_correct_date_boundaries(self, transaction_data): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L432 + + actual = rfm_train_test_split( + transaction_data, + "id", + "date", + train_period_end="2015-02-01", + test_period_end="2015-02-04", + ) + assert actual["test_frequency"].loc[1] == 0 + assert actual["test_frequency"].loc[3] == 1 + + def test_rfm_train_test_split_with_monetary_value(self, transaction_data): + # Test adapted from + # https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/tests/test_utils.py#L457 + + test_end = "2015-02-07" + train_end = "2015-02-01" + actual = rfm_train_test_split( + transaction_data, + "id", + "date", + train_end, + test_period_end=test_end, + monetary_value_col="monetary_value", + ) + assert (actual["monetary_value"] == [0, 0, 3, 0, 4.5]).all() + assert (actual["test_monetary_value"] == [2, 0, 0, 6, 0]).all() -def test_clv_summary_warning(transaction_data): - with pytest.warns(UserWarning, match="clv_summary was renamed to rfm_summary"): - clv_summary(transaction_data, "id", "date") + # check test_monetary_value is being aggregated correctly for time periods with multiple purchases