Skip to content

Commit

Permalink
fix: test cases mlflow registry (#415)
Browse files Browse the repository at this point in the history
There existed errors in test logs while test cases still passed. One
example is the ConnectionError due to not mocking mlflow function calls
properly.

Fixed through modifying a few syntax for patch mock calls. The test
cases should pass with no logging errors now.

---------

Signed-off-by: Leila <[email protected]>
  • Loading branch information
yleilawang authored Sep 19, 2024
1 parent ed81ef0 commit b18b2d2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
1 change: 1 addition & 0 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.13.1"
version = "0.13.2"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
52 changes: 27 additions & 25 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
import unittest
from contextlib import contextmanager
from unittest.mock import patch, Mock
Expand All @@ -11,8 +12,9 @@

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.exceptions import ModelVersionError
from tests.registry._mlflow_utils import (
model_sklearn,
create_model,
Expand Down Expand Up @@ -71,26 +73,29 @@ def test_save_model(self):
self.assertEqual(mock_status, status.status)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_sklearn_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version)
def test_save_model_sklearn(self):
model = self.model_sklearn
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn")

mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_load_model_when_pytorch_model_exist1(self):
Expand All @@ -103,11 +108,12 @@ def test_load_model_when_pytorch_model_exist1(self):
self.assertIsNotNone(data.metadata)
self.assertIsInstance(data.artifact, VanillaAE)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_when_pytorch_model_exist2(self):
Expand Down Expand Up @@ -147,12 +153,13 @@ def test_load_model_when_sklearn_model_exist(self):
self.assertIsInstance(data.artifact, StandardScaler)
self.assertEqual(data.metadata, {})

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_empty_rundata()))
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_with_version(self):
Expand All @@ -177,12 +184,16 @@ def test_staging_model_load_error(self):
ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE)
skeys = self.skeys
dkeys = self.dkeys
ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertRaises(ModelVersionError)
with self.assertLogs(level="ERROR") as log:
result = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNone(result) # Ensure the result is None
self.assertTrue(
any("No Model found" in message for message in log.output)
) # Check that the expected log was made

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version())
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_both_version_latest_model_with_version(self):
Expand Down Expand Up @@ -254,7 +265,7 @@ def test_no_implementation(self):
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.delete_model_version", None)
@patch("mlflow.tracking.MlflowClient.delete_model_version", Mock(return_value=None))
@patch("mlflow.pytorch.load_model", Mock(side_effect=RuntimeError))
def test_delete_model_when_model_exist(self):
model = self.model
Expand Down Expand Up @@ -321,12 +332,13 @@ def test_load_other_mlflow_err(self):
dkeys = self.dkeys
self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_true(self):
Expand All @@ -342,12 +354,13 @@ def test_is_model_stale_true(self):
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertTrue(ml.is_artifact_stale(data, 12))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_false(self):
Expand Down Expand Up @@ -381,29 +394,18 @@ def test_cache(self):
self.assertIsNotNone(registry._load_from_cache("key"))
self.assertIsNotNone(registry._clear_cache("key"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_cache_loading(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=self.model,
**{"lr": 0.01},
artifact_type="pytorch",
)
ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
self.assertIsNotNone(ml._load_from_cache(key))
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertIsNotNone(data)


if __name__ == "__main__":
Expand Down

0 comments on commit b18b2d2

Please sign in to comment.