diff --git a/pymc_marketing/clv/plotting.py b/pymc_marketing/clv/plotting.py index 2834ffda..35277efb 100644 --- a/pymc_marketing/clv/plotting.py +++ b/pymc_marketing/clv/plotting.py @@ -1,3 +1,5 @@ +from typing import Optional + import matplotlib.pyplot as plt import numpy as np @@ -10,11 +12,12 @@ def plot_frequency_recency_matrix( model, t=1, - max_frequency=None, - max_recency=None, - title=None, - xlabel="Historical Frequency", - ylabel="Recency", + max_frequency: Optional[int] = None, + max_recency: Optional[int] = None, + title: Optional[str] = None, + xlabel: str = "Customer's Historical Frequency", + ylabel: str = "Customer's Recency", + ax: Optional[plt.Axes] = None, **kwargs, ) -> plt.Axes: """ @@ -64,11 +67,10 @@ def plot_frequency_recency_matrix( .mean(("draw", "chain")) .values.reshape(mesh_recency.shape) ) + if ax is None: + ax = plt.subplot(111) - ax = plt.subplot(111) pcm = ax.imshow(Z, **kwargs) - plt.xlabel(xlabel) - plt.ylabel(ylabel) if title is None: title = ( "Expected Number of Future Purchases for {} Unit{} of Time,".format( @@ -76,7 +78,12 @@ def plot_frequency_recency_matrix( ) + "\nby Frequency and Recency of a Customer" ) - plt.title(title) + + ax.set( + xlabel=xlabel, + ylabel=ylabel, + title=title, + ) force_aspect(ax) @@ -88,11 +95,12 @@ def plot_frequency_recency_matrix( def plot_probability_alive_matrix( model, - max_frequency=None, - max_recency=None, - title="Probability Customer is Alive,\nby Frequency and Recency of a Customer", - xlabel="Customer's Historical Frequency", - ylabel="Customer's Recency", + max_frequency: Optional[int] = None, + max_recency: Optional[int] = None, + title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer", + xlabel: str = "Customer's Historical Frequency", + ylabel: str = "Customer's Recency", + ax: Optional[plt.Axes] = None, **kwargs, ) -> plt.Axes: """ @@ -143,12 +151,16 @@ def plot_probability_alive_matrix( interpolation = kwargs.pop("interpolation", "none") - ax = plt.subplot(111) + if ax is None: + ax = plt.subplot(111) + pcm = ax.imshow(Z, interpolation=interpolation, **kwargs) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.title(title) + ax.set( + xlabel=xlabel, + ylabel=ylabel, + title=title, + ) force_aspect(ax) # plot colorbar beside matrix