Skip to content

Commit

Permalink
expose arviz summary kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Aug 14, 2024
1 parent 8003ddf commit df13f6e
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def log_arviz_summary(
idata: az.InferenceData,
path: str | Path,
var_names: list[str] | None = None,
**summary_kwargs,
) -> None:
"""Log the ArviZ summary as an artifact on MLflow.
Expand All @@ -152,9 +153,11 @@ def log_arviz_summary(
var_names : list[str], optional
The names of the variables to include in the summary. Default is
all the variables in the InferenceData object.
summary_kwargs : dict
Additional keyword arguments to pass to `az.summary`.
"""
df_summary = az.summary(idata, var_names=var_names)
df_summary = az.summary(idata, var_names=var_names, **summary_kwargs)
df_summary.to_html(path)
mlflow.log_artifact(str(path))
os.remove(path)
Expand Down Expand Up @@ -394,6 +397,7 @@ def autolog(
log_datasets: bool = True,
log_model_info: bool = True,
summary_var_names: list[str] | None = None,
arviz_summary_kwargs: dict | None = None,
log_mmm: bool = True,
disable: bool = False,
silent: bool = False,
Expand All @@ -417,6 +421,8 @@ def autolog(
summary_var_names : list[str], optional
The names of the variables to include in the ArviZ summary. Default is
all the variables in the InferenceData object.
arviz_summary_kwargs : dict, optional
Additional keyword arguments to pass to `az.summary`.
log_mmm : bool, optional
Whether to log PyMC-Marketing MMM models. Default is True.
disable : bool, optional
Expand Down Expand Up @@ -500,6 +506,8 @@ def autolog(
"""

arviz_summary_kwargs = arviz_summary_kwargs or {}

def patch_sample(sample):
@wraps(sample)
def new_sample(*args, **kwargs):
Expand All @@ -510,7 +518,12 @@ def new_sample(*args, **kwargs):

if log_sampler_info:
log_sample_diagnostics(idata)
log_arviz_summary(idata, "summary.html", var_names=summary_var_names)
log_arviz_summary(
idata,
"summary.html",
var_names=summary_var_names,
**arviz_summary_kwargs,
)

model = pm.modelcontext(kwargs.get("model"))
if log_model_info:
Expand Down

0 comments on commit df13f6e

Please sign in to comment.