diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 65c6812c..1eb07939 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -207,11 +207,11 @@ def load_multiple( loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") if loaded_model is None: return None - if loaded_model.artifact.loader_module != "mlflow.pyfunc.model": - raise TypeError("The loaded model is not a valid pyfunc Python model.") try: unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except mlflow.exceptions.MlflowException as e: + raise TypeError("The loaded model is not a valid pyfunc Python model.") from e except AttributeError: _LOGGER.exception("The loaded model does not have an unwrap_python_model method") return None @@ -448,4 +448,10 @@ def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadat self.metadata = metadata def predict(self): - raise NotImplementedError() + """ + Predict method is not implemented for our use case. + + The CompositeModel class is designed to store and load multiple artifacts, + and the predict method is not required for this functionality. + """ + raise NotImplementedError("The predict method is not implemented for CompositeModel.") diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index ca82f246..3afbc96a 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -204,6 +204,37 @@ def mock_load_model_pyfunc(*_, **__): ) +def mock_load_model_pyfunc_type_error(*_, **__): + artifact_path = "model" + flavors = { + "python_function": { + "cloudpickle_version": "3.0.0", + "code": None, + "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}, + "loader_module": "mlflow.pytorch.model", + "python_model": "python_model.pkl", + "python_version": "3.10.14", + "streamable": False, + } + } + model_size_bytes = 8912 + model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc" + run_id = "7e85a3fa46d44e668c840f3dddc909c3" + utc_time_created = "2024-09-18 17:12:41.501209" + model = Model( + artifact_path=artifact_path, + flavors=flavors, + model_size_bytes=model_size_bytes, + model_uuid=model_uuid, + run_id=run_id, + utc_time_created=utc_time_created, + mlflow_version="2.16.0", + ) + return mlflow.pyfunc.PyFuncModel( + model_meta=model, model_impl=mlflow.pytorch._PyTorchWrapper(VanillaAE(10), device="cpu") + ) + + def mock_transition_stage(*_, **__): return ModelVersion( creation_timestamp=1653402941169, diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index e0fa686d..13345853 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -15,9 +15,10 @@ from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache -from numalogic.registry.mlflow_registry import ModelStage +from numalogic.registry.mlflow_registry import CompositeModel, ModelStage from tests.registry._mlflow_utils import ( mock_load_model_pyfunc, + mock_load_model_pyfunc_type_error, mock_log_model_pyfunc, model_sklearn, create_model, @@ -104,6 +105,25 @@ def test_save_multiple_models_pyfunc(self): mock_status = "READY" self.assertEqual(mock_status, status.status) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + def test_save_multiple_models_when_only_one_model(self): + ml = MLflowRegistry(TRACKING_URI) + with self.assertLogs(level="WARNING"): + ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, + ) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) @patch("mlflow.log_params", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) @@ -134,6 +154,78 @@ def test_load_multiple_models_when_pyfunc_model_exist(self): self.assertIsInstance(data.artifact["inference"], VanillaAE) self.assertIsInstance(data.artifact["precrocessing"], StandardScaler) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch( + "mlflow.tracking.MlflowClient.get_latest_versions", + Mock(return_value=PagedList(items=[], token=None)), + ) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", Mock(return_value=None)) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) + def test_load_model_when_no_model_pyfunc(self): + fake_skeys = ["Fakemodel_"] + fake_dkeys = ["error"] + ml = MLflowRegistry(TRACKING_URI) + with self.assertLogs(level="ERROR") as log: + o = ml.load_multiple(skeys=fake_skeys, dkeys=fake_dkeys) + self.assertIsNone(o) + self.assertTrue(log.output) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj) + @patch( + "mlflow.pyfunc.load_model", + Mock( + return_value=CompositeModel( + skeys=["error"], + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + **{"learning_rate": 0.01}, + ) + ), + ) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_attribute_error(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys = ["unique", "sorted"] + with self.assertLogs(level="ERROR") as log: + result = ml.load_multiple(skeys=skeys, dkeys=dkeys) + self.assertIsNone(result) + self.assertTrue( + any( + "The loaded model does not have an unwrap_python_model method" in message + for message in log.output + ) + ) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc_type_error) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_type_error(self): + ml = MLflowRegistry(TRACKING_URI) + ml.save( + skeys=self.skeys, + dkeys=self.dkeys, + artifact=self.model, + artifact_type="pytorch", + **{"lr": 0.01}, + ) + with self.assertRaises(TypeError): + ml.load_multiple(skeys=self.skeys, dkeys=self.dkeys) + @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))