From 9784c5f8ba79c37bf2ddefb4bd48e22a3c87dc49 Mon Sep 17 00:00:00 2001 From: Louis Magowan <59659198+louismagowan@users.noreply.github.com> Date: Tue, 21 May 2024 11:34:11 +0100 Subject: [PATCH] v0 Streamlit MMM Explainer App (#614) * feat(streamlit_explainer): Pushing files for Streamlit explainer app, to illustrate saturation, adstock and prior concepts in an intuitive, visual way to stakeholders and new MMMers * chore(readme): Adding a readme for the app * fix(env): Updating dependencies to include those needed for the Streamlit app * Drop python 3.9 support (#615) * drop python 3.9 * try python 3.12 * undo try python 3.12 * add lift tests check * Add more content to the Gamma-Gamma Notebook (#573) * improve nb * rm warnings and add link to lifetimes quickstart * address comments * feedback part 3 * remove warnings manually * Add more content to the BG/NBD Notebook (#571) * add more info to the notebook * hide plots code * fix plot y labels * fix plot outputs and remove model build * improve final note probability plots * address comments * use quickstart dataset * feedback part 3 * remowe warnings manually * feedback part 4 * Improve MMM Docs (#612) * improve mmm docs init * add more code examples to docstrings * minor improvemeents * typo * better phrasing * add thomas suggestion * Fix `clv` plotting bugs and edits to Quickstart (#601) * move fixtures to conftest * docstrings and moved set_model_fit to conftest * fixed pandas quickstart warnings * revert to MockModel and add ParetoNBD support * quickstart edit for issue 609 * notebook edit * [pre-commit.ci] pre-commit autoupdate (#616) * improve coords matching (#623) * python 3.12 attempt (#618) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(saturation): Using pymc-marketing saturation functions rather than coding my own: Removing tanh, logistic and michaelis menten * refactor(saturation): Remove Hill and Root saturations, as they aren't supported by pymc-marketing currently * refactor(geometric_adstock): Removing custom adstock and using pymc-marketing adstock function to demo decay. Also updating latex to align with pymc-marketing, where decay factor is represented by alpha rather than beta * refactor(delayed_adstock): Using pymc-marketing delayed geometric function rather than custom one * fix(requirements): Adding pymc-marketing to Streamlit requirements for deployment * Added Dev Container Folder * refactor(weibull_cdf): Using pymc-marketing function for Weibull CDF * fix(weibull_cdf): Fixing incorrect dataframe var name for CDF plotting df * refactor(weibull_pdf): Using pymc-marketing function for WeibullPDF * refactor(custom_functions): Removing adstock_saturation_functions.py file now that it is no longer required * chore: Removing devcontainer created by Streamlit * fix(requirements): Adding preliz to requirements * refactor(prior_viz): Reworking the prior visualisation to use Preliz instead of custom function, as well as remove the tab-design. Prior distributions can now be specified programmatically. * refactor(prior_functions.py): Deleting the draw_samples function and replacing it with a programmatic PreliZ function, such that the distribution object is returned when the user passes in the name of a distribution * fix(requirements): Delete obsolete pymc requirement, which should fix deployment dependency conflicts * chore(readme): Updating with guidelines on how to add additional distributions or transformation functions to the app * refactor(plot_config): Moving height and width specifications into constants at top of Adstock and Saturation files, so the plot sizes are set programmatically --------- Co-authored-by: Juan Orduz Co-authored-by: Colt Allen <10178857+ColtAllen@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> --- environment.yml | 3 + streamlit/mmm-explainer/README.md | 78 +++ streamlit/mmm-explainer/Visualise_Priors.py | 146 +++++ streamlit/mmm-explainer/config.toml | 3 + streamlit/mmm-explainer/pages/Adstock.py | 658 ++++++++++++++++++++ streamlit/mmm-explainer/pages/Saturation.py | 262 ++++++++ streamlit/mmm-explainer/prior_functions.py | 96 +++ streamlit/mmm-explainer/requirements.txt | 9 + 8 files changed, 1255 insertions(+) create mode 100644 streamlit/mmm-explainer/README.md create mode 100644 streamlit/mmm-explainer/Visualise_Priors.py create mode 100644 streamlit/mmm-explainer/config.toml create mode 100644 streamlit/mmm-explainer/pages/Adstock.py create mode 100644 streamlit/mmm-explainer/pages/Saturation.py create mode 100644 streamlit/mmm-explainer/prior_functions.py create mode 100644 streamlit/mmm-explainer/requirements.txt diff --git a/environment.yml b/environment.yml index ecdbacb4..638087bc 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,9 @@ dependencies: - arviz>=0.13.0 - matplotlib>=3.5.1 - numpy>=1.17 +- scipy>=1.11 - pandas +- streamlit>=1.25.0 - pip # NOTE: Keep minimum pymc version in sync with ci.yml `OLDEST_PYMC_VERSION` - pymc>=5.12.0 @@ -29,6 +31,7 @@ dependencies: - sphinx-notfound-page - sphinx-design - watermark +- typing # lint - mypy - pandas-stubs diff --git a/streamlit/mmm-explainer/README.md b/streamlit/mmm-explainer/README.md new file mode 100644 index 00000000..26b14920 --- /dev/null +++ b/streamlit/mmm-explainer/README.md @@ -0,0 +1,78 @@ +# MMM Visualization App with Streamlit + +## Overview + +This Streamlit application is designed to provide a dynamic and interactive visualization of key Marketing Mix Modeling (MMM) concepts, including adstock, saturation, and the use of Bayesian priors. It aims to help marketers, data scientists, and anyone interested in understanding MMM more deeply. Through this application, users can explore how different parameters affect adstock, saturation, and Bayesian priors. +You may wish to run the app locally too - rather than relying on the [deployment](https://pymc-marketing-app.streamlit.app/). +In this case, you would just need to install the requirements.txt within the streamlit folder and do `streamlit run Visualise_Priors.py` + +## Features + +- **Adstock Transformation Visualization**: Interactive charts that demonstrate how the adstock effect changes with different decay rates and lengths of advertising impact. Users can input their parameters to see how adstocked values are calculated over time. + +- **Saturation Curve Exploration**: Interactive charts that demonstrate saturation curves, which represents the diminishing returns of marketing spend as it increases. Users can adjust parameters and choose from a variety of saturation transformations. + +- **Bayesian Priors**: Interactive charts that demonstrate Bayesian prior distributions, designed to showcase the power of Bayesian methods in handling uncertainty and incorporating prior knowledge into MMM. + +- **Customizable Parameters**: All sections of the app include options to customize parameters, allowing users to experiment with different scenarios and understand their impacts on MMM. + +## Getting Started + +### Deployment + +-Paste link to deployment here once reviewed by PyMC devs, hosted deployment is easy to implement- + + +## Contributing & Adding New Functions + +We welcome contributions from the community! Whether it's adding new features, improving documentation, or reporting bugs, please feel free to make a pull request or open an issue. +It's a good idea to always develop and test your changes to the app by running it locally, before submitting a PR. + +### Adding New Adstock / Saturation Transformers from pymc-marketing + +New transformation functions may be added to pymc-marketing which you may want to have visualised in the app. +To do so, you would just need to add them in the import statements at the top of either `Saturation.py` or `Adstock.py`. +e.g. +``` +from pymc_marketing.mmm.transformers import ( + logistic_saturation, + michaelis_menten, + tanh_saturation, + my_new_saturation_function +) +``` + +Then you would have to create a new Streamlit tab +``` +# Create tabs for plots +tab1, tab2, tab3, tab4 = st.tabs(["Logistic", "Tanh", "Michaelis-Menten", My New Saturation]) +``` + +And then add whatever plotting code you want for your new function! + +### Adding Additional Distributions from PreLiz + +PreliZ contains many, many distributions - not all of which are currently visualised. +Adding new distributions is quite simple. +You would need to firstly modify the dictionary of distributions and the parameters you want the user to be able to play around with. +``` +# Specify the possible distributions and their paramaters you want to visualise +DISTRIBUTIONS_DICT = { + "Beta": ["alpha", "beta"], + "Bernoulli": ["p"], + "Exponential": ["lam"], + "Gamma": ["alpha", "beta"], + "HalfNormal": ["sigma"], + "LogNormal": ["mu", "sigma"], + "Normal": ["mu", "sigma"], + "Poisson": ["mu"], + "StudentT": ["nu", "mu", "sigma"], + "TruncatedNormal": ["mu", "sigma", "lower", "upper"], + "Uniform": ["lower", "upper"], + "Weibull": ["alpha", "beta"], + "MY_NEW_DIST": ["something", "something_else"], +} +``` + +And then create new Streamlit input buttons for your new parameters (unless they are covered by existing parameters in the `for param in params.keys():` block) by adding another `elif` line. +Watch out - certain distributions may share parameters of the same name, but that have different accepted ranges. For example, the `mu` parameter in Poisson has to be >0, whereas for a Normal it can be whatever you want. You may need an additional `elif` block in these edge cases. diff --git a/streamlit/mmm-explainer/Visualise_Priors.py b/streamlit/mmm-explainer/Visualise_Priors.py new file mode 100644 index 00000000..f3fa7c2d --- /dev/null +++ b/streamlit/mmm-explainer/Visualise_Priors.py @@ -0,0 +1,146 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Import custom functions +import prior_functions as pf + +import streamlit as st + +# Constants +SEED = 42 +N_DRAWS = 50_000 +# Specify the possible distributions and their paramaters you want to visualise +DISTRIBUTIONS_DICT = { + "Beta": ["alpha", "beta"], + "Bernoulli": ["p"], + "Exponential": ["lam"], + "Gamma": ["alpha", "beta"], + "HalfNormal": ["sigma"], + "LogNormal": ["mu", "sigma"], + "Normal": ["mu", "sigma"], + "Poisson": ["mu"], + "StudentT": ["nu", "mu", "sigma"], + "TruncatedNormal": ["mu", "sigma", "lower", "upper"], + "Uniform": ["lower", "upper"], + "Weibull": ["alpha", "beta"], +} +PLOT_HEIGHT = 500 +PLOT_WIDTH = 1000 + +# -------------------------- TOP OF PAGE INFORMATION ------------------------- + +# Set browser / tab config +st.set_page_config( + page_title="MMM App - Prior Distributions Transformations", + page_icon="πŸ’Ž", +) + +# Give some context for what the page displays +st.title("Bayesian Prior Distribution Demonstrator") + +# -------------------------- VISUALISE PRIOR ------------------------- + +# Select the distribution to visualise +dist_name = st.selectbox( + "Please select the distribution you would like to visualise:", + options=DISTRIBUTIONS_DICT.keys(), +) +st.header(f":blue[{dist_name} Distribution]") # header + +# Variables need to be instantiated to avoid error where upper < lower +lower = None +upper = None + +# Initialize parameters with None +params = {param: None for param in DISTRIBUTIONS_DICT[dist_name]} + +# User inputs for distribution parameters +for param in params.keys(): + if param == "lower": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", key=param, value=0.0 + ) + elif param == "upper": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", key=param, value=2.0 + ) + elif param == "alpha": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=1.0, + min_value=0.01, + ) + elif param == "beta": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=1.0, + min_value=0.01, + ) + elif param == "sigma": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=1.0, + min_value=0.01, + ) + # Poisson mu must be > 0 + elif param == "mu" and dist_name == "Poisson": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=1.0, + min_value=0.01, + ) + elif param == "mu": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", key=param, value=0.0 + ) + elif param == "p": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=0.5, + min_value=0.0, + max_value=1.0, + ) + elif param == "lam": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=1.0, + min_value=0.01, + ) + elif param == "nu": + params[param] = st.number_input( + f"Please enter the value for {param.title()}:", + key=param, + value=10.0, + min_value=0.01, + ) + + +# Check to ensure lower < upper +if lower and lower >= upper: + st.error("Error: Lower bound must be less than upper bound.") + +## Create the selected distribution and sample from it +dist = pf.get_distribution(dist_name, **params) +draws = dist.rvs(N_DRAWS, random_state=SEED) + + +# Plot distribution +fig_root = pf.plot_prior_distribution(draws, title=f"{dist_name} Distribution Samples") +fig_root.update_layout(height=PLOT_HEIGHT, width=PLOT_WIDTH) +st.plotly_chart(fig_root, use_container_width=True) diff --git a/streamlit/mmm-explainer/config.toml b/streamlit/mmm-explainer/config.toml new file mode 100644 index 00000000..3d04bb48 --- /dev/null +++ b/streamlit/mmm-explainer/config.toml @@ -0,0 +1,3 @@ +[theme] +base="light" +primaryColor="#d7abf3" diff --git a/streamlit/mmm-explainer/pages/Adstock.py b/streamlit/mmm-explainer/pages/Adstock.py new file mode 100644 index 00000000..4a6597ca --- /dev/null +++ b/streamlit/mmm-explainer/pages/Adstock.py @@ -0,0 +1,658 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Import custom functions +import numpy as np +import pandas as pd +import plotly.express as px + +import streamlit as st +from pymc_marketing.mmm.transformers import ( + delayed_adstock, + geometric_adstock, + weibull_adstock, +) + +# Constants +PLOT_HEIGHT = 600 +PLOT_WIDTH = 1000 + + +# -------------------------- TOP OF PAGE INFORMATION ------------------------- + +# Set browser / tab config +st.set_page_config( + page_title="MMM App - Adstock Transformations", + page_icon="🧊", +) + +# Give some context for what the page displays +st.title("Adstock Transformations") +st.markdown( + """This page demonstrates the effect of various adstock \ + transformations on a variable. \nFor these examples, let's imagine \ + that we have _some variable that represents a quantity of a particlar_ \ + _advertising channel_. \n\nFor example, this could be the number of impressions\ + we get from Facebook. For an online channel such as this, we might expect the impact of these ads to be immediate: \ + \n ___We see an ad on Facebook - we either click on it, or we don't.___ \n\ + \n :blue[So, at the start of our example (_Week 1_), \ + we could have the impact of **100 impressions from Facebook**.] \ + \n\n Alternatively, for a channel like TV, we may not expect the impact of those ads \n \ + to come through immediately - there may be some delay. \ + \n\n :green[So, at the start of our example (_Week 1_), we may have the impact of **0 Gross Rating Points (TV viewership metric)**, \ + but 7 weeks later those TV ads might reach their full impact of **100 Gross Rating Points**.]\ + \n\n**_:violet[We will use this starting value of 100 for all of our adstock examples]_**. \ + """ # noqa: E501 +) + +st.markdown( + "**Reminder:** \n \ +- Geometric adstock transformations have **_:red[fixed decay]_** \n\ +- Weibull adstock transformations have **_:red[flexible decay]_**" +) + +# Starting value for adstock +initial_impact = 100 + +# Separate the adstock transformations into 3 tabs +tab1, tab2, tab3, tab4 = st.tabs( + ["Geometric", "Delayed Geometric", "Weibull CDF", "Weibull PDF"] +) + +# -------------------------- GEOMETRIC ADSTOCK DISPLAY ------------------------- +with tab1: + st.header(":blue[Geometric Adstock Transformation]") + st.divider() + st.markdown( + """___Geometric adstock is the simplest adstock function, it depends on a single parameter $\\alpha > 0$ which represents the fixed-rate decay.___ \n \ + \n __The geometric adstock function takes the following form :__""" # noqa: E501 + ) + st.latex(r""" + x_t^{\textrm{transf}} = x_t + \alpha x_{t-1}^{\textrm{transf}} + """) + st.divider() + st.markdown( + "**Typical values for geometric adstock:** \n \ +- TV: **:blue[0.3 - 0.8]** - _decays slowly_ \n \ +- OOH/Print/Radio: **:blue[0.1 - 0.4]** - _decays moderately_ \n \ +- Digital: **:blue[0.0 - 0.3]** - _decays quickly_ \n" + ) + st.caption( + ":link: [Values taken from Meta's Analyst's Guide to MMM](https://facebookexperimental.github.io/Robyn/docs/analysts-guide-to-MMM/#feature-engineering)" + ) + + # User inputs + st.subheader(":blue[User Inputs]") + num_periods = st.slider( + "Number of weeks after impressions first received :alarm_clock:", + 1, + 100, + 20, + key="Geometric", + ) + # Set l_max to same length of periods for demo purposes + l_max = num_periods + # Make array zeroes with only the first value as 100 + # to demo the decay purely + inputs = np.zeros(num_periods) + inputs[0] = 100 + + # Let user choose decay rates to plot with + decay_rate_1 = st.slider(":blue[Alpha 1 : ]", 0.0, 1.0, 0.3) + # Add up to 2 more lines if the user wants it + st.markdown("**Would you like to show multiple (3) decay lines on the plot**") + multi_plot = st.checkbox("Okay! :grin:") + # Create a list of decay rates + if multi_plot: + # Let user choose additional decay rates to plot with + decay_rate_2 = st.slider(":red[Alpha 2 : ]", 0.0, 1.0, 0.6) + decay_rate_3 = st.slider(":green[Alpha 3: ]", 0.0, 1.0, 0.9) + decay_rates = [decay_rate_1, decay_rate_2, decay_rate_3] + else: + decay_rates = [decay_rate_1] + + # Create df to store each adstock in + all_adstocks = pd.DataFrame() + # Iterate through decay rates and generate df of values to plot + for i, alpha in enumerate(decay_rates): + # Get geometric adstock values, decayed over time + adstock_df = pd.DataFrame( + { + "Week": range(1, (num_periods + 1)), + ## Calculate adstock values + "Adstock": geometric_adstock( + x=inputs, alpha=alpha, l_max=num_periods, normalize=False + ).eval(), + ## Format adstock labels for neater plotting + "Adstock Labels": [ + f"{x:,.0f}" + for x in geometric_adstock( + x=inputs, alpha=alpha, l_max=num_periods, normalize=False + ).eval() + ], + ## Create column to label each adstock + "Alpha": f"Alpha {i + 1}", + } + ) + + all_adstocks = pd.concat([all_adstocks, adstock_df]) + + # Plot adstock values + # Annotate the plot if user wants it + st.markdown("**Would you like to show the adstock values directly on the plot?**") + annotate = st.checkbox("Yes please! :pray:", key="Geometric Annotate") + if annotate: + fig = px.line( + all_adstocks, + x="Week", + y="Adstock", + text="Adstock Labels", + markers=True, + color="Alpha", + # Replaces default color mapping by value + color_discrete_map={ + "Alpha 1": "#636EFA", + "Alpha 2": "#EF553B", + "Alpha 3": "#00CC96", + }, + ) + fig.update_traces(textposition="bottom left") + else: + fig = px.line( + all_adstocks, + x="Week", + y="Adstock", + markers=True, + color="Alpha", + # Replaces default color mapping by value + color_discrete_map={ + "Alpha 1": "#636EFA", + "Alpha 2": "#EF553B", + "Alpha 3": "#00CC96", + }, + ) + # Format plot + fig.layout.height = PLOT_HEIGHT + fig.layout.width = PLOT_WIDTH + fig.update_layout( + title_text="Geometric Adstock Decayed Over Weeks", title_font=dict(size=30) + ) + st.plotly_chart(fig, theme="streamlit", use_container_width=False) + +# -------------------------- DELAYED GEOMETRIC ADSTOCK DISPLAY ------------------------- +with tab2: + st.header(":red[Delayed Geometric Adstock Transformation]") + st.divider() + st.markdown( + """___Delayed geometric adstock builds on geometric adstock___ \ + ___by adding in a delay $\\theta$ before the maximum adstock is observed (this happens at week 0 for the plain geometric decay).___ \ + \n ___It also adds a maximum duration for the carryover/adstock $L_{max}$, such that adstock after this point is 0.___ \n \ + \n __The delayed geometric adstock function takes the following form :__""" # noqa: E501 + ) + st.latex(r""" + x_t^{\textrm{transf}} = \sum_{i=0}^{L_{\max}-1} \left( \alpha^{|i-\theta|} \cdot x_{t-i} \right) \\""") + st.markdown( + "- $x_t^{\\textrm{transf}}$ refers to the transformed value at time $t$ after applying the delayed adstock transformation" # noqa: E501 + ) + st.markdown("- $\\alpha$ is the retention rate of the ad effect") + st.markdown("- $\\theta$ represents the delay before the peak effect occurs") + st.markdown("- $L_{max}$ is the maximum duration of the carryover effect") + st.divider() + st.markdown( + "**Typical values for geometric adstock:** \n \ +- TV: **:blue[0.3 - 0.8]** - _decays slowly_ \n \ +- OOH/Print/Radio: **:blue[0.1 - 0.4]** - _decays moderately_ \n \ +- Digital: **:blue[0.0 - 0.3]** - _decays quickly_ \n" + ) + st.caption( + ":link: [Values taken from Meta's Analyst's Guide to MMM](https://facebookexperimental.github.io/Robyn/docs/analysts-guide-to-MMM/#feature-engineering)" + ) + + # User inputs + st.subheader(":red[User Inputs]") + max_lag = st.slider( + "Number of weeks after impressions first received :alarm_clock: : ", + 1, + 100, + 30, + key="Delayed Geometric", + ) + max_peak = st.slider( + ":red[Number of weeks after impressions first received that max impact occurs :thermometer: : ]", + 0, + 100, + 10, + key="delayed_geom_L", + ) + # Let user choose decay rates to plot with + decay_rate_1 = st.slider(":red[Alpha 1: ]", 0.0, 1.0, 0.5, key="delay_decay") + + # Add up to 2 more lines if the user wants it + st.markdown("**Would you like to show multiple (3) decay lines on the plot**") + multi_plot = st.checkbox("Okay! :grin:", key="Delay Geom Multi") + # Create a list of decay rates + if multi_plot: + # Let user choose additional decay rates, lags and peaks to plot with + decay_rate_2 = st.slider(":blue[Alpha 2: ]", 0.0, 1.0, 0.6, key="delay_decay2") + max_peak_2 = st.slider( + ":blue[Number of weeks after impressions first received that max impact occurs :thermometer: :]", + 1, + 100, + 5, + key="delayed_geom_L 2", + ) + max_lag_2 = st.slider( + ":blue[Number of weeks after impressions first received :alarm_clock: : ]", + 1, + 100, + 20, + key="Delayed Geometric 2 ", + ) + decay_rate_3 = st.slider(":green[Alpha 3: ]", 0.0, 1.0, 0.9, key="delay_decay3") + max_lag_3 = st.slider( + ":green[Number of weeks after impressions first received :alarm_clock: : ]", + 1, + 100, + 20, + key="Delayed Geometric 3 ", + ) + max_peak_3 = st.slider( + ":green[Number of weeks after impressions first received that max impact occurs :thermometer: :]", + 1, + 100, + 5, + key="delayed_geom_L 3", + ) + + # Put in lists to iterate through later + decay_rates = [decay_rate_1, decay_rate_2, decay_rate_3] + lags = [max_lag, max_lag_2, max_lag_3] + peaks = [max_peak, max_peak_2, max_peak_3] + + else: + decay_rates = [decay_rate_1] + lags = [max_lag] + peaks = [max_peak] + + # Create df to store each adstock in + all_adstocks = pd.DataFrame() + # Iterate through decay rates and generate df of values to plot + for i, alpha in enumerate(decay_rates): + # Make array zeroes with only the max lagged value as 100 + # to demo the decay purely + inputs = np.zeros(lags[i]) + inputs[peaks[i]] = 100 + + # Get geometric adstock values, decayed over time + adstock_df = pd.DataFrame( + { + "Week": range(1, (lags[i] + 1)), + ## Calculate adstock values + "Adstock": delayed_adstock( + x=inputs, alpha=alpha, theta=peaks[i], l_max=lags[i] + ).eval(), + ## Format adstock labels for neater plotting + "Adstock Labels": [ + f"{x:,.0f}" + for x in delayed_adstock( + x=inputs, alpha=alpha, theta=peaks[i], l_max=lags[i] + ).eval() + ], + ## Create column to label each adstock + "Alpha": f"Alpha {i + 1}", + } + ) + + all_adstocks = pd.concat([all_adstocks, adstock_df]) + + # Plot adstock values + # Annotate the plot if user wants it + st.markdown("**Would you like to show the adstock values directly on the plot?**") + annotate = st.checkbox("Yes please! :pray:", key="Delayed Geometric Annotate") + if annotate: + fig = px.line( + all_adstocks, + x="Week", + y="Adstock", + text="Adstock Labels", + markers=True, + color="Alpha", + # Replaces default color mapping by value + color_discrete_map={ + "Alpha 1": "#636EFA", + "Alpha 2": "#EF553B", + "Alpha 3": "#00CC96", + }, + ) + fig.update_traces(textposition="bottom left") + else: + fig = px.line( + all_adstocks, + x="Week", + y="Adstock", + markers=True, + color="Alpha", + # Replaces default color mapping by value + color_discrete_map={ + "Alpha 1": "#636EFA", + "Alpha 2": "#EF553B", + "Alpha 3": "#00CC96", + }, + ) + # Format plot + fig.layout.height = PLOT_HEIGHT + fig.layout.width = PLOT_WIDTH + fig.update_layout( + title_text="Geometric Adstock Decayed Over Weeks", title_font=dict(size=30) + ) + st.plotly_chart(fig, theme="streamlit", use_container_width=False) + +# -------------------------- WEIBULL CDF ADSTOCK DISPLAY ------------------------- +with tab3: + st.header(":green[Weibull CDF Adstock Transformation]") + st.divider() + st.markdown( + """___The Weibull CDF is a function depending on two variables, $k$ (known as the **shape**) and $\\lambda$ (known as the **scale**)___. \n \ + The idea is closely related to geometric adstock but with one important difference : the rate of decay (what we called $\\alpha$ in the geometric adstock equation) \ + is no longer fixed. Instead it’s **time-dependent**. \ + \n \n **The Weibull CDF adstock function therefore takes the form :**""" # noqa: E501 + ) + st.latex(r""" + x_t^{\textrm{transf}} = x_t + \alpha_t x_{t-1}^{\textrm{transf}}""") + st.markdown("- where $\\alpha_t$ is now a function of time $t$") + st.markdown( + "**The Weibull CDF is actually used to build the $\\alpha_t$’s, and it takes the form :**" + ) + st.latex(r""" + F_{k, \lambda}(t) = 1 - e^{-(\frac{t}{\lambda})^k}""") + st.markdown("Then, $\\alpha_t$ is computed as : ") + st.latex(r""" + \alpha_t = 1 - F_{k,\lambda}(t)""") + st.divider() + # User inputs + st.subheader(":green[User Inputs]") + num_periods_2 = st.slider( + "Number of weeks after impressions first received :alarm_clock: :", + 1, + 100, + 20, + key="Weibull CDF Periods", + ) + # Let user choose shape and scale parameters to compare two Weibull PDF decay curves simultaneously + # Params for Line A + shape_parameter_A = st.slider( + ":triangular_ruler: :green[Shape $k$ of Line A]:", + 0.1, + 10.0, + 0.1, + key="Weibull CDF Shape A", + ) + scale_parameter_A = st.slider( + r":green[Scale $\lambda$ of Line A]:", 0.1, 50.0, 0.1, key="Weibull CDF Scale A" + ) + # Make array zeroes with only the first value as 100 + # to demo the decay purely + inputs = np.zeros(num_periods_2) + inputs[0] = 100 + + # Calculate weibull pdf adstock values, decayed over time for both sets of params + adstock_series_A = weibull_adstock( + x=inputs, + lam=scale_parameter_A, + k=shape_parameter_A, + l_max=num_periods_2, + type="CDF", + ).eval() + + # Create df of adstock values, to plot with + adstock_df_A = pd.DataFrame( + { + "Week": range(1, (num_periods_2 + 1)), + "Adstock": adstock_series_A, + "Line": "Line A", + } + ) + + # Create plotting df + weibull_cdf_df = adstock_df_A.copy() + + # Plot 2nd line if user desires values + st.markdown("**Would you like to add a second line to the plot?**") + second_cdf = st.checkbox("Okay! :grin:", key="Add 2nd Weibull CDF") + + if second_cdf: + # Params for Line B + shape_parameter_B = st.slider( + ":triangular_ruler: :red[Shape $k$ of Line B : ]", + 0.1, + 10.0, + 9.0, + key="Weibull CDF Shape B", + ) + scale_parameter_B = st.slider( + r":red[Scale $\lambda$ of Line B : ]", + 0.1, + 50.0, + 0.5, + key="Weibull CDF Scale B", + ) + # Calculate weibull pdf adstock values, decayed over time for both sets of params + adstock_series_B = weibull_adstock( + x=inputs, + lam=scale_parameter_B, + k=shape_parameter_B, + l_max=num_periods_2, + type="CDF", + ).eval() + + # Create df of adstock values, to plot with + adstock_df_B = pd.DataFrame( + { + "Week": range(1, (num_periods_2 + 1)), + "Adstock": adstock_series_B, + "Line": "Line B", + } + ) + # Create plotting df + weibull_cdf_df = pd.concat([adstock_df_A, adstock_df_B]) + + # Multiply by 100 to get back to scale of initial impact (100 FB impressions) + weibull_cdf_df.Adstock = weibull_cdf_df.Adstock + # Format adstock labels for neater plotting + weibull_cdf_df["Adstock Labels"] = weibull_cdf_df.Adstock.map("{:,.0f}".format) + + # Plot adstock values + # Annotate the plot if user wants it + st.markdown("**Would you like to show the adstock values directly on the plot?**") + annotate = st.checkbox("Yes please! :pray:", key="Weibull CDF Annotate") + if annotate: + fig = px.line( + weibull_cdf_df, + x="Week", + y="Adstock", + text="Adstock Labels", + markers=True, + color="Line", + # Replaces default color mapping by value + color_discrete_map={"Line A": "#636EFA", "Line B": "#EF553B"}, + ) + fig.update_traces(textposition="bottom left") + else: + fig = px.line( + weibull_cdf_df, + x="Week", + y="Adstock", + markers=True, + color="Line", + # Replaces default color mapping by value + color_discrete_map={"Line A": "#636EFA", "Line B": "#EF553B"}, + ) + # Format plot + fig.layout.height = PLOT_HEIGHT + fig.layout.width = PLOT_WIDTH + fig.update_layout( + title_text="Weibull CDF Adstock Decayed Over Weeks", title_font=dict(size=30) + ) + st.plotly_chart(fig, theme="streamlit", use_container_width=False) + +# -------------------------- WEIBULL PDF ADSTOCK DISPLAY ------------------------- +with tab4: + st.header(":violet[Weibull PDF Adstock Transformation]") + st.divider() + st.markdown( + """___The Weibull PDF is also a function depending on two variables, $k$ (shape) and $\\lambda$ (scale) \ + and the same remarks for Weibull CDF apply to Weibull PDF.___ \ + \n The key difference is that Weibull PDF \ + allows for lagged effects to be taken into account - the **time delay effect**. \ + \n \n **The Weibull PDF adstock function therefore takes the form :**""" + ) + st.latex(r""" + x_t^{\textrm{transf}} = x_t + \alpha_t x_{t-1}^{\textrm{transf}}""") + st.markdown("- where $\\alpha_t$ is now a function of time $t$") + st.markdown( + "**The Weibull PDF is actually used to build the $\\alpha_t$’s, and it takes the form :**" + ) + st.latex(r""" + G_{k,\lambda}(t) = \frac{k}{\lambda}\Big(\frac{t}{\lambda} \Big)^{k-1}e^{-(\frac{t}{\lambda})^k}""") + st.divider() + + # User inputs + st.subheader(":violet[User Inputs]") + num_periods_3 = st.slider( + "Number of weeks after impressions first received :alarm_clock: : ", + 1, + 100, + 20, + key="Weibull PDF Periods", + ) + # Let user choose shape and scale parameters to compare two Weibull PDF decay curves simultaneously + # Params for Line A + shape_parameter_A = st.slider( + ":triangular_ruler: :blue[Shape $k$ of Line A : ]", + 0.1, + 10.0, + 2.0, + key="Weibull PDF Shape A", + ) + scale_parameter_A = st.slider( + r":blue[Scale $\lambda$ of Line A : ]", + 0.1, + 50.0, + 0.5, + key="Weibull PDF Scale A", + ) + # Make array zeroes with only the first value as 100 + # to demo the decay purely + inputs = np.zeros(num_periods_3) + inputs[0] = 100 + + # Calculate weibull pdf adstock values, decayed over time for both sets of params + adstock_series_A = weibull_adstock( + x=inputs, + lam=scale_parameter_A, + k=shape_parameter_A, + l_max=num_periods_3, + type="PDF", + ).eval() + + # Create df of adstock values, to plot with + adstock_df_A = pd.DataFrame( + { + "Week": range(1, (num_periods_3 + 1)), + "Adstock": adstock_series_A, + "Line": "Line A", + } + ) + + # Create plotting df + weibull_pdf_df = adstock_df_A.copy() + + # Plot 2nd line if user desires values + st.markdown("**Would you like to add a second line to the plot?**") + second_pdf = st.checkbox("Okay! :grin:", key="Add 2nd Weibull PDF") + + if second_pdf: + # Params for Line B + shape_parameter_B = st.slider( + ":triangular_ruler: :red[Shape $k$ of Line B : ]", + 0.1, + 10.0, + 0.5, + key="Weibull PDF Shape B", + ) + scale_parameter_B = st.slider( + r":red[Scale $\lambda$ of Line B : ]", + 0.1, + 50.0, + 0.1, + key="Weibull PDF Scale B", + ) + + # Calculate weibull pdf adstock values, decayed over time for both sets of params + adstock_series_B = weibull_adstock( + x=inputs, + lam=scale_parameter_B, + k=shape_parameter_B, + l_max=num_periods_3, + type="PDF", + ).eval() + + # Create df of adstock values, to plot with + adstock_df_B = pd.DataFrame( + { + "Week": range(1, (num_periods_3 + 1)), + "Adstock": adstock_series_B, + "Line": "Line B", + } + ) + # Create plotting df + weibull_pdf_df = pd.concat([adstock_df_A, adstock_df_B]) + + # Multiply by 100 to get back to scale of initial impact (100 FB impressions) + weibull_pdf_df.Adstock = weibull_pdf_df.Adstock + # Format adstock labels for neater plotting + weibull_pdf_df["Adstock Labels"] = weibull_pdf_df.Adstock.map("{:,.0f}".format) + + # Plot adstock values + # Annotate the plot if user wants it + st.markdown("**Would you like to show the adstock values directly on the plot?**") + annotate = st.checkbox("Yes please! :pray:", key="Weibull PDF Annotate") + if annotate: + fig = px.line( + weibull_pdf_df, + x="Week", + y="Adstock", + text="Adstock Labels", + markers=True, + color="Line", + # Replaces default color mapping by value + color_discrete_map={"Line A": "#636EFA", "Line B": "#EF553B"}, + ) + fig.update_traces(textposition="bottom left") + else: + fig = px.line( + weibull_pdf_df, + x="Week", + y="Adstock", + markers=True, + color="Line", + # Replaces default color mapping by value + color_discrete_map={"Line A": "#636EFA", "Line B": "#EF553B"}, + ) + # Format plot + fig.layout.height = PLOT_HEIGHT + fig.layout.width = PLOT_WIDTH + fig.update_layout( + title_text="Weibull PDF Adstock Decayed Over Weeks", title_font=dict(size=30) + ) + st.plotly_chart(fig, theme="streamlit", use_container_width=False) diff --git a/streamlit/mmm-explainer/pages/Saturation.py b/streamlit/mmm-explainer/pages/Saturation.py new file mode 100644 index 00000000..92193e65 --- /dev/null +++ b/streamlit/mmm-explainer/pages/Saturation.py @@ -0,0 +1,262 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Import custom functions +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +import streamlit as st +from pymc_marketing.mmm.transformers import ( + logistic_saturation, + michaelis_menten, + tanh_saturation, +) + +# Constants +PLOT_HEIGHT = 500 +PLOT_WIDTH = 1000 + +# -------------------------- TOP OF PAGE INFORMATION ------------------------- + +# Set browser / tab config +st.set_page_config( + page_title="MMM App - Saturation Curves", + page_icon="🧊", +) + +# Give some context for what the page displays +st.title("Saturation Curves") +st.markdown( + "This page demonstrates the forms and shapes of saturation curves for MMM.\ + These curves try to model the relationship between weekly marketing \ + spends for a given channel (holding other channels constant) and \ + the conversions that result from that spend.\ + \n It doesn't need to be conversions, it could be sales or customers acquired\ + - whatever target metric you are interested in.\ + " +) + +st.markdown( + "**Reminder:** \n \ +- Certain saturation functions have **_:red[concave or convex shapes]_** \n\ +- Certain saturation functions have **_:red[S-shapes]_**" +) + +# -------------------------- SATURATION PLOTS ------------------------- + +# Generate simulated marketing data +np.random.seed(42) +num_points = 500 +media_spending = np.linspace(0, 1000, num_points) # x-axis + + +# Generate simulated datasets with noise +dummy_logistic = logistic_saturation( + media_spending, lam=0.01 +).eval() + np.random.normal(0, 0.1, num_points) +dummy_tanh = tanh_saturation(media_spending, b=10, c=20).eval() + np.random.normal( + 0, 0.75, num_points +) +dummy_m_m = michaelis_menten(media_spending, alpha=20, lam=200) + np.random.normal( + 0, 2, num_points +) + +# Create tabs for plots +tab1, tab2, tab3 = st.tabs(["Logistic", "Tanh", "Michaelis-Menten"]) + +# -------------------------- LOGISTIC CURVE ------------------------- +with tab1: + st.subheader(":green[Logistic Curve Saturation]") + st.markdown("___The Logistic function takes the form:___") + st.latex(r""" + x_t^{\textrm{transf}} = \frac{1 - e^{-\lambda x_t}}{1 + e^{-\lambda x_t}} + """) + st.divider() + # User inputs + st.subheader(":green[User Inputs]") + st.markdown("**Try to fit a saturation curve to the generated data!**") + + # User input for Modified Logistic Curve + logistic_lam = st.slider( + ":green[Logistic Curve $\\lambda$ (scaled value):]", + 0, + 1000, + 500, + step=1, + key="logistic_lam", + ) + logistic_lam = logistic_lam / 10000 + + # Calculate the user created response curve + user_logistic = logistic_saturation(media_spending, lam=logistic_lam).eval() + + # Tidy the simulated dataset for plotting + plot_data = pd.DataFrame( + {"Media Spending": np.round(media_spending), "Conversions": dummy_logistic} + ) + # Drop rows with negative conversions, generated by the noise + plot_data = plot_data[plot_data.Conversions >= 0] + + # Plot + fig_root = go.Figure() + # Plot weekly spend and response data, every 5th to make the plot less crowded + fig_root.add_trace( + go.Scatter( + x=plot_data["Media Spending"][::5], + y=plot_data["Conversions"][::5], + mode="markers", + name="Weekly Data", + marker=dict(color="#AB63FA"), + ) + ) + # Plot user-defined curve to match that data + fig_root.add_trace( + go.Scatter( + x=media_spending, + y=user_logistic, + mode="lines", + name="Saturation Curve", + line=dict(color="blue", dash="solid"), + ) + ) + + fig_root.update_layout( + title_text="Logistic Saturation Curve", + xaxis_title="Media Spend", + yaxis_title="Conversions", + height=PLOT_HEIGHT, + width=PLOT_WIDTH, + ) + + st.plotly_chart(fig_root, use_container_width=True) + + +# -------------------------- TANH CURVE ------------------------- +with tab2: + st.subheader(":orange[Tanh Curve Saturation]") + st.markdown("___The Tanh saturation function takes the form:___") + st.latex(r""" + x_t^{\textrm{transf}} = b \tanh \left( \frac{x_t}{bc} \right) + """) + st.divider() + # User inputs + st.subheader(":orange[User Inputs]") + st.markdown("**Try to fit a saturation curve to the generated data!**") + + # User input for Tanh Curve + tanh_b = st.slider(":orange[Tanh Curve $\\text{b}$]:", 0, 20, 5) + tanh_c = st.slider(":orange[Tanh Curve $\\text{c}$]:", 0, 100, 50) + + # Calculate the user created response curve + user_tanh = tanh_saturation(media_spending, b=tanh_b, c=tanh_c).eval() + + # Tidy the simulated dataset for plotting + plot_data = pd.DataFrame( + {"Media Spending": np.round(media_spending), "Conversions": dummy_tanh} + ) + # Drop rows with negative conversions, generated by the noise + plot_data = plot_data[plot_data.Conversions >= 0] + + # Plot + fig_root = go.Figure() + # Plot weekly spend and response data, every 5th to make the plot less crowded + fig_root.add_trace( + go.Scatter( + x=plot_data["Media Spending"][::5], + y=plot_data["Conversions"][::5], + mode="markers", + name="Weekly Data", + marker=dict(color="#AB63FA"), + ) + ) + # Plot user-defined curve to match that data + fig_root.add_trace( + go.Scatter( + x=media_spending, + y=user_tanh, + mode="lines", + name="Saturation Curve", + line=dict(color="blue", dash="solid"), + ) + ) + + fig_root.update_layout( + title_text="Tanh Saturation Curve", + xaxis_title="Media Spend", + yaxis_title="Conversions", + height=PLOT_HEIGHT, + width=PLOT_WIDTH, + ) + + st.plotly_chart(fig_root, use_container_width=True) + + +# -------------------------- MICHAELIS-MENTEN CURVE ------------------------- +with tab3: + st.subheader(":violet[Michaelis-Menten Curve Saturation]") + st.markdown("___The Michaelis-Menten saturation function takes the form:___") + st.latex(r""" + x_t^{\textrm{transf}} = \frac{\alpha \cdot x_t}{\lambda + x_t} + """) + st.divider() + # User inputs + st.subheader(":violet[User Inputs]") + st.markdown("**Try to fit a saturation curve to the generated data!**") + + # User input for Tanh Curve + m_m_alpha = st.slider(":violet[Michaelis-Menten Curve $\\alpha$:]", 0, 50, 25) + m_m_lambda = st.slider(":violet[Michaelis-Menten Curve $\\lambda$:]", 0, 500, 50) + + # Calculate the user created response curve + user_m_m = michaelis_menten(media_spending, alpha=m_m_alpha, lam=m_m_lambda) + + # Tidy the simulated dataset for plotting + plot_data = pd.DataFrame( + {"Media Spending": np.round(media_spending), "Conversions": dummy_m_m} + ) + # Drop rows with negative conversions, generated by the noise + plot_data = plot_data[plot_data.Conversions >= 0] + + # Plot + fig_root = go.Figure() + # Plot weekly spend and response data, every 5th to make the plot less crowded + fig_root.add_trace( + go.Scatter( + x=plot_data["Media Spending"][::5], + y=plot_data["Conversions"][::5], + mode="markers", + name="Weekly Data", + marker=dict(color="#AB63FA"), + ) + ) + # Plot user-defined curve to match that data + fig_root.add_trace( + go.Scatter( + x=media_spending, + y=user_m_m, + mode="lines", + name="Saturation Curve", + line=dict(color="blue", dash="solid"), + ) + ) + + fig_root.update_layout( + title_text="Michaelis-Menten Saturation Curve", + xaxis_title="Media Spend", + yaxis_title="Conversions", + height=PLOT_HEIGHT, + width=PLOT_WIDTH, + ) + + st.plotly_chart(fig_root, use_container_width=True) diff --git a/streamlit/mmm-explainer/prior_functions.py b/streamlit/mmm-explainer/prior_functions.py new file mode 100644 index 00000000..f976b224 --- /dev/null +++ b/streamlit/mmm-explainer/prior_functions.py @@ -0,0 +1,96 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Imports +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +import preliz as pz +from scipy.stats import gaussian_kde + +import streamlit as st + + +@st.cache_data # πŸ‘ˆ Add the caching decorator, make app run faster +def get_distribution(distribution_name=pz.distributions, **params): + """ + Retrieve and create a distribution instance from the PreliZ library. + + Parameters: + distribution_name (str): The name of the distribution to create. + **params: Variable length dict of parameters and values required by the distribution. + + Returns: + object: An instance of the requested distribution. + """ + try: + # Get the distribution class from preliz + dist_class = getattr(pz, distribution_name) + # Create an instance of the distribution with the provided parameters + return dist_class(**params) + except AttributeError: + raise ValueError(f"Distribution '{distribution_name}' is not found in preliz.") + except TypeError: + raise ValueError( + f"Incorrect parameters for the distribution '{distribution_name}'." + ) + + +def plot_prior_distribution( + draws, nbins=100, opacity=0.1, title="Prior Distribution - Visualised" +): + """ + Plots samples of a prior distribution as a histogram with a KDE (Kernel Density Estimate) overlay + and a violin plot along the top too with quartile values. + + Parameters: + - draws: numpy array of samples from prior distribution. + - nbins: int, the number of bins for the histogram. + - opacity: float, the opacity level for the histogram bars. + - title: str, the title of the plot. + """ + # Create the histogram using Plotly Express + fig = px.histogram( + draws, + x=draws, + nbins=nbins, + title=title, + labels={"x": "Value"}, + histnorm="probability density", + opacity=opacity, + marginal="violin", + color_discrete_sequence=["#0047AB"], + ) + + # Compute the KDE + kde = gaussian_kde(draws) + x_range = np.linspace(min(draws), max(draws), 500) + kde_values = kde(x_range) + + # Add the KDE plot to the histogram figure + fig.add_trace( + go.Scatter( + x=x_range, + y=kde_values, + mode="lines", + name="KDE", + line_color="#DA70D6", + opacity=0.8, + ) + ) + + # Customize the layout + fig.update_layout(xaxis_title="Value of Prior", yaxis_title="Density") + + # Return the plot + return fig diff --git a/streamlit/mmm-explainer/requirements.txt b/streamlit/mmm-explainer/requirements.txt new file mode 100644 index 00000000..96a5cb55 --- /dev/null +++ b/streamlit/mmm-explainer/requirements.txt @@ -0,0 +1,9 @@ +numpy==1.24.3 +pandas==2.0.2 +streamlit==1.25.0 +plotly==5.13.1 +scikit-learn==1.2.2 +scipy==1.11.0 +preliz==0.6.3 +pymc-marketing==0.6.0 +typing==3.7.4.3