Skip to content

Commit

Permalink
fix: code patch test cases
Browse files Browse the repository at this point in the history
Signed-off-by: Leila Wang <[email protected]>
  • Loading branch information
yleilawang committed Sep 26, 2024
1 parent 69caf21 commit 049d686
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 4 deletions.
12 changes: 9 additions & 3 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def load_multiple(
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
if loaded_model is None:
return None
if loaded_model.artifact.loader_module != "mlflow.pyfunc.model":
raise TypeError("The loaded model is not a valid pyfunc Python model.")

try:
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
except mlflow.exceptions.MlflowException as e:
raise TypeError("The loaded model is not a valid pyfunc Python model.") from e
except AttributeError:
_LOGGER.exception("The loaded model does not have an unwrap_python_model method")
return None
Expand Down Expand Up @@ -448,4 +448,10 @@ def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadat
self.metadata = metadata

def predict(self):
raise NotImplementedError()
"""
Predict method is not implemented for our use case.
The CompositeModel class is designed to store and load multiple artifacts,
and the predict method is not required for this functionality.
"""
raise NotImplementedError("The predict method is not implemented for CompositeModel.")
31 changes: 31 additions & 0 deletions tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,37 @@ def mock_load_model_pyfunc(*_, **__):
)


def mock_load_model_pyfunc_type_error(*_, **__):
artifact_path = "model"
flavors = {
"python_function": {
"cloudpickle_version": "3.0.0",
"code": None,
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
"loader_module": "mlflow.pytorch.model",
"python_model": "python_model.pkl",
"python_version": "3.10.14",
"streamable": False,
}
}
model_size_bytes = 8912
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
utc_time_created = "2024-09-18 17:12:41.501209"
model = Model(
artifact_path=artifact_path,
flavors=flavors,
model_size_bytes=model_size_bytes,
model_uuid=model_uuid,
run_id=run_id,
utc_time_created=utc_time_created,
mlflow_version="2.16.0",
)
return mlflow.pyfunc.PyFuncModel(
model_meta=model, model_impl=mlflow.pytorch._PyTorchWrapper(VanillaAE(10), device="cpu")
)


def mock_transition_stage(*_, **__):
return ModelVersion(
creation_timestamp=1653402941169,
Expand Down
94 changes: 93 additions & 1 deletion tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.registry.mlflow_registry import CompositeModel, ModelStage
from tests.registry._mlflow_utils import (
mock_load_model_pyfunc,
mock_load_model_pyfunc_type_error,
mock_log_model_pyfunc,
model_sklearn,
create_model,
Expand Down Expand Up @@ -104,6 +105,25 @@ def test_save_multiple_models_pyfunc(self):
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
def test_save_multiple_models_when_only_one_model(self):
ml = MLflowRegistry(TRACKING_URI)
with self.assertLogs(level="WARNING"):
ml.save_multiple(
skeys=self.skeys,
dict_artifacts={
"inference": VanillaAE(10),
},
dkeys=["unique", "sorted"],
**{"learning_rate": 0.01},
)

@patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
Expand Down Expand Up @@ -134,6 +154,78 @@ def test_load_multiple_models_when_pyfunc_model_exist(self):
self.assertIsInstance(data.artifact["inference"], VanillaAE)
self.assertIsInstance(data.artifact["precrocessing"], StandardScaler)

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch(
"mlflow.tracking.MlflowClient.get_latest_versions",
Mock(return_value=PagedList(items=[], token=None)),
)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pyfunc.load_model", Mock(return_value=None))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_when_no_model_pyfunc(self):
fake_skeys = ["Fakemodel_"]
fake_dkeys = ["error"]
ml = MLflowRegistry(TRACKING_URI)
with self.assertLogs(level="ERROR") as log:
o = ml.load_multiple(skeys=fake_skeys, dkeys=fake_dkeys)
self.assertIsNone(o)
self.assertTrue(log.output)

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj)
@patch(
"mlflow.pyfunc.load_model",
Mock(
return_value=CompositeModel(
skeys=["error"],
dict_artifacts={
"inference": VanillaAE(10),
"precrocessing": StandardScaler(),
"threshold": StdDevThreshold(),
},
**{"learning_rate": 0.01},
)
),
)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata()))
def test_load_multiple_attribute_error(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = ["unique", "sorted"]
with self.assertLogs(level="ERROR") as log:
result = ml.load_multiple(skeys=skeys, dkeys=dkeys)
self.assertIsNone(result)
self.assertTrue(
any(
"The loaded model does not have an unwrap_python_model method" in message
for message in log.output
)
)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc_type_error)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata()))
def test_load_multiple_type_error(self):
ml = MLflowRegistry(TRACKING_URI)
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=self.model,
artifact_type="pytorch",
**{"lr": 0.01},
)
with self.assertRaises(TypeError):
ml.load_multiple(skeys=self.skeys, dkeys=self.dkeys)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
Expand Down

0 comments on commit 049d686

Please sign in to comment.