diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index 8fe0ffb6..60c7ffcc 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -138,6 +138,7 @@ def log_arviz_summary( idata: az.InferenceData, path: str | Path, var_names: list[str] | None = None, + **summary_kwargs, ) -> None: """Log the ArviZ summary as an artifact on MLflow. @@ -152,9 +153,11 @@ def log_arviz_summary( var_names : list[str], optional The names of the variables to include in the summary. Default is all the variables in the InferenceData object. + summary_kwargs : dict + Additional keyword arguments to pass to `az.summary`. """ - df_summary = az.summary(idata, var_names=var_names) + df_summary = az.summary(idata, var_names=var_names, **summary_kwargs) df_summary.to_html(path) mlflow.log_artifact(str(path)) os.remove(path) @@ -394,6 +397,7 @@ def autolog( log_datasets: bool = True, log_model_info: bool = True, summary_var_names: list[str] | None = None, + arviz_summary_kwargs: dict | None = None, log_mmm: bool = True, disable: bool = False, silent: bool = False, @@ -417,6 +421,8 @@ def autolog( summary_var_names : list[str], optional The names of the variables to include in the ArviZ summary. Default is all the variables in the InferenceData object. + arviz_summary_kwargs : dict, optional + Additional keyword arguments to pass to `az.summary`. log_mmm : bool, optional Whether to log PyMC-Marketing MMM models. Default is True. disable : bool, optional @@ -500,6 +506,8 @@ def autolog( """ + arviz_summary_kwargs = arviz_summary_kwargs or {} + def patch_sample(sample): @wraps(sample) def new_sample(*args, **kwargs): @@ -510,7 +518,12 @@ def new_sample(*args, **kwargs): if log_sampler_info: log_sample_diagnostics(idata) - log_arviz_summary(idata, "summary.html", var_names=summary_var_names) + log_arviz_summary( + idata, + "summary.html", + var_names=summary_var_names, + **arviz_summary_kwargs, + ) model = pm.modelcontext(kwargs.get("model")) if log_model_info: