Skip to content

Commit

Permalink
use a module fixure for setup and tear down
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Aug 12, 2024
1 parent dd6562c commit 5548c7b
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@
import pytest
from mlflow.client import MlflowClient

from pymc_marketing.mlflow import autolog
from pymc_marketing.mlflow import autolog, log_model_graph
from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation

uri: str = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(uri=uri)

seed = sum(map(ord, "mlflow-with-pymc"))
rng = np.random.default_rng(seed)

autolog()

@pytest.fixture(scope="module", autouse=True)
def setup_module():
uri: str = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(uri=uri)
autolog()

yield

pm.sample = pm.sample.__wrapped__
MMM.fit = MMM.fit.__wrapped__


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -62,6 +69,20 @@ def get_run_data(run_id):
return inputs, data.params, data.metrics, tags, artifacts


def test_log_model_graph_no_graphviz(mocker, model) -> None:
mocker.patch(
"pymc.model_to_graphviz",
side_effect=ImportError("No module named 'graphviz'"),
)
with mlflow.start_run() as run:
log_model_graph(model, "model_graph")

run_id = run.info.run_id
artifacts = get_run_data(run_id)[-1]

assert artifacts == []


def metric_checks(metrics, nuts_sampler) -> None:
assert metrics["total_divergences"] >= 0.0
if nuts_sampler not in ["numpyro", "nutpie", "blackjax"]:
Expand Down

0 comments on commit 5548c7b

Please sign in to comment.