From 5548c7b95e3ebe86426922ba63f14aff63ca9b9b Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 12 Aug 2024 14:54:02 +0200 Subject: [PATCH] use a module fixure for setup and tear down --- tests/test_mlflow.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index 431d2662..277ece34 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -20,16 +20,23 @@ import pytest from mlflow.client import MlflowClient -from pymc_marketing.mlflow import autolog +from pymc_marketing.mlflow import autolog, log_model_graph from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation -uri: str = "sqlite:///mlruns.db" -mlflow.set_tracking_uri(uri=uri) - seed = sum(map(ord, "mlflow-with-pymc")) rng = np.random.default_rng(seed) -autolog() + +@pytest.fixture(scope="module", autouse=True) +def setup_module(): + uri: str = "sqlite:///mlruns.db" + mlflow.set_tracking_uri(uri=uri) + autolog() + + yield + + pm.sample = pm.sample.__wrapped__ + MMM.fit = MMM.fit.__wrapped__ @pytest.fixture(scope="module") @@ -62,6 +69,20 @@ def get_run_data(run_id): return inputs, data.params, data.metrics, tags, artifacts +def test_log_model_graph_no_graphviz(mocker, model) -> None: + mocker.patch( + "pymc.model_to_graphviz", + side_effect=ImportError("No module named 'graphviz'"), + ) + with mlflow.start_run() as run: + log_model_graph(model, "model_graph") + + run_id = run.info.run_id + artifacts = get_run_data(run_id)[-1] + + assert artifacts == [] + + def metric_checks(metrics, nuts_sampler) -> None: assert metrics["total_divergences"] >= 0.0 if nuts_sampler not in ["numpyro", "nutpie", "blackjax"]: