From cec2cc2482cbae57341ec18064fb035ca997eb75 Mon Sep 17 00:00:00 2001 From: Colt Allen <10178857+ColtAllen@users.noreply.github.com> Date: Thu, 27 Jun 2024 02:02:23 -0600 Subject: [PATCH] `GammaGammaModel` API Improvements (#758) * utils.customer_lifetime_value * expected_customer_lifetime_value * WIP clv.models.gamma_gamma.py * gamma_gamma API * fixed circular import * gamma_gamma tests * delete tests/datasets/test_summary.csv * clv test_utils.py * remove expected_purchases(future_t=0) * remove monetary_value arg * WIP docstrings * notebooks * docstrings * Revert "notebooks" This reverts commit a3154d9cd84fb795ae45053d3c38e45165746404. * gamma-gamma notebook * docstrings --- docs/source/notebooks/clv/gamma_gamma.ipynb | 962 ++-- pymc_marketing/clv/models/gamma_gamma.py | 468 +- pymc_marketing/clv/utils.py | 186 +- tests/clv/datasets/test_summary_data.csv | 5001 ------------------- tests/clv/models/test_gamma_gamma.py | 67 +- tests/clv/test_utils.py | 174 +- tests/conftest.py | 6 +- 7 files changed, 966 insertions(+), 5898 deletions(-) delete mode 100644 tests/clv/datasets/test_summary_data.csv diff --git a/docs/source/notebooks/clv/gamma_gamma.ipynb b/docs/source/notebooks/clv/gamma_gamma.ipynb index ec02ccc6..c8185ec7 100644 --- a/docs/source/notebooks/clv/gamma_gamma.ipynb +++ b/docs/source/notebooks/clv/gamma_gamma.ipynb @@ -23,10 +23,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "813aa3e6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "import arviz as az\n", "import matplotlib.pyplot as plt\n", @@ -58,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 21, "id": "4039ce96", "metadata": {}, "outputs": [ @@ -87,6 +96,7 @@ "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "idata_map = model.fit(fit_method=\"map\").posterior.to_dataframe()" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 31, "id": "b8f11643", "metadata": {}, "outputs": [ @@ -705,7 +761,7 @@ "0 0 6.248787 3.744591 15.447813" ] }, - "execution_count": 12, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -734,10 +790,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "id": "ed88b572", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [p, q, v]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a9f2a1d48dd74e0394cb66ec19a33b8f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 9 seconds.\n" + ] + } + ], "source": [ "sampler_kwargs = {\n", " \"draws\": 2_000,\n", @@ -751,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 33, "id": "52c3b00e", "metadata": {}, "outputs": [ @@ -766,8 +877,8 @@ "
<xarray.Dataset> Size: 30kB\n", - "Dimensions: (index: 946)\n", + "<xarray.Dataset> Size: 45kB\n", + "Dimensions: (index: 946)\n", "Coordinates:\n", - " * index (index) int64 8kB 0 1 5 6 8 ... 2348 2349 2353 2355\n", + " * index (index) int64 8kB 0 1 5 6 8 10 ... 2347 2348 2349 2353 2355\n", "Data variables:\n", - " customer_id (index) int64 8kB 0 1 5 6 8 ... 2348 2349 2353 2355\n", - " mean_transaction_value (index) float64 8kB 22.35 11.77 ... 44.93 33.32\n", - " frequency (index) int64 8kB 2 1 7 1 2 5 10 1 ... 1 2 7 1 2 5 4