diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index 9626d035..0d3dc33c 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -100,6 +100,7 @@ """ import json +import logging import os from functools import wraps from pathlib import Path @@ -214,12 +215,20 @@ def log_model_graph(model: Model, path: str | Path) -> None: """ try: graph = pm.model_to_graphviz(model) - except ImportError: + except ImportError as e: + msg = ( + "Unable to render the model graph. Please install the graphviz package. " + f"{e}" + ) + logging.info(msg) + return None try: saved_path = graph.render(path) - except Exception: + except Exception as e: + msg = f"Unable to render the model graph. {e}" + logging.info(msg) return None else: mlflow.log_artifact(saved_path) diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index 929b50b2..1bf97b86 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import logging import arviz as az import mlflow @@ -137,19 +138,35 @@ def test_multi_likelihood_type(multi_likelihood_model) -> None: @pytest.mark.parametrize( - "to_patch, side_effect", + "to_patch, side_effect, expected_info_message", [ - ("pymc.model_to_graphviz", ImportError("No module named 'graphviz'")), - ("graphviz.graphs.Digraph.render", Exception("Unknown error occurred")), + ( + "pymc.model_to_graphviz", + ImportError("No module named 'graphviz'"), + "Unable to render the model graph. Please install the graphviz package. No module named 'graphviz'", + ), + ( + "graphviz.graphs.Digraph.render", + Exception("Unknown error occurred"), + "Unable to render the model graph. Unknown error occurred", + ), ], + ids=["no_graphviz", "render_error"], ) -def test_log_model_graph_no_graphviz(mocker, model, to_patch, side_effect) -> None: +def test_log_model_graph_no_graphviz( + caplog, mocker, model, to_patch, side_effect, expected_info_message +) -> None: mocker.patch( to_patch, side_effect=side_effect, ) with mlflow.start_run() as run: - log_model_graph(model, "model_graph") + with caplog.at_level(logging.INFO): + log_model_graph(model, "model_graph") + + assert caplog.messages == [ + expected_info_message, + ] run_id = run.info.run_id artifacts = get_run_data(run_id)[-1]