diff --git a/.gitignore b/.gitignore index 3db98ce0..c50fd26c 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,8 @@ target/ #mlflow /.mlruns *.db +mlruns/ +mlartifacts/ # Jupyter Notebook .ipynb_checkpoints @@ -169,4 +171,7 @@ cython_debug/ # Mac related *.DS_Store +# vscode +.vscode/ + .python-version diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 43c5b519..62948928 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -18,15 +18,17 @@ import mlflow.pyfunc import mlflow.pytorch import mlflow.sklearn +import mlflow from mlflow.entities.model_registry import ModelVersion from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient +from sortedcontainers import SortedSet from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError -from numalogic.tools.types import artifact_t, KEYS, META_VT +from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT _LOGGER = logging.getLogger(__name__) @@ -187,6 +189,39 @@ def load( self._save_in_cache(model_key, artifact_data) return artifact_data + def load_multiple( + self, + skeys: KEYS, + dkeys_list: list[list[str]], + ) -> Optional[dict[str, ArtifactData]]: + """ + Load multiple artifacts from the registry for pyfunc models. + Args: + skeys (KEYS): The source keys of the artifacts to load. + dkeys_list (list[list[str]]): + A list of lists containing the dkeys of the artifacts to load. + + Returns + ------- + Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys + to the loaded artifacts, or None if no artifacts were found. + """ + dkeys = self.__get_sorted_unique_dkeys(dkeys_list) + loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") + if loaded_model is not None: + metadata = loaded_model.artifact.unwrap_python_model().metadata + dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts + artifacts_dict = {} + for artifact in dict_artifacts.values(): + artifact_data = ArtifactData( + artifact=artifact.artifact, metadata=metadata, extras=None + ) + dynamic_key = ":".join(artifact.dkeys) + artifacts_dict[dynamic_key] = artifact_data + else: + artifacts_dict = None + return artifacts_dict + @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST: @@ -225,7 +260,10 @@ def save( handler = self.handler_from_type(artifact_type) try: mlflow.start_run(run_id=run_id) - handler.log_model(artifact, "model", registered_model_name=model_key) + if artifact_type == "pyfunc": + handler.log_model("model", python_model=artifact, registered_model_name=model_key) + else: + handler.log_model(artifact, "model", registered_model_name=model_key) if metadata: mlflow.log_params(metadata) model_version = self.transition_stage(skeys=skeys, dkeys=dkeys) @@ -238,6 +276,37 @@ def save( finally: mlflow.end_run() + def save_multiple( + self, + skeys: KEYS, + dict_artifacts: dict[str, KeyedArtifact], + **metadata: META_VT, + ) -> Optional[ModelVersion]: + """ + Saves multiple artifacts into mlflow registry. The last save stores all the + artifact versions in the metadata. + + Args: + ---- + skeys: static key fields as list/tuple of strings + dict_artifacts: dict of artifacts to save + metadata: additional metadata surrounding the artifact that needs to be saved. + + Returns + ------- + mlflow ModelVersion instance + """ + multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) + dkeys_list = multiple_artifacts.get_dkeys_list() + sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list) + return self.save( + skeys=multiple_artifacts.skeys, + dkeys=sorted_dkeys, + artifact=multiple_artifacts, + artifact_type="pyfunc", + metadata=multiple_artifacts.metadata, + ) + @staticmethod def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: """Returns whether the given artifact is stale or not, i.e. if @@ -338,3 +407,63 @@ def __load_artifacts( version_info.version, ) return model, metadata + + def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]: + """ + Returns a unique sorted list of all dkeys in the stored artifacts. + + Args: + ---- + dkeys_list: A list of lists containing the destination keys of the artifacts. + + Returns + ------- + List[str] + A list of all unique dkeys in the stored artifacts, sorted in ascending order. + """ + return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys])) + + +class CompositeModels(mlflow.pyfunc.PythonModel): + """A composite model that represents multiple artifacts. + + This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load + multiple artifacts in the MLflow registry. It provides a convenient way to manage and + organize multiple artifacts associated with a single model. + + Args: + skeys (KEYS): The static keys of the artifacts. + dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to + `KeyedArtifact` objects. + **metadata (META_VT): Additional metadata associated with the artifacts. + + Methods + ------- + get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts. + + Attributes + ---------- + skeys (KEYS): The static keys of the artifacts. + dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to + `KeyedArtifact` objects. + metadata (META_VT): Additional metadata associated with the artifacts. + """ + + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT): + self.skeys = skeys + self.dict_artifacts = dict_artifacts + self.metadata = metadata + + def get_dkeys_list(self): + """ + Returns a list of all dynamic keys in the stored artifacts. + + Returns + ------- + list[list[str]]: A list of all dynamic keys in the stored artifacts. + """ + dkeys_list = [] + artifacts = self.dict_artifacts.values() + for artifact in artifacts: + dkeys_list.append(artifact.dkeys) + return dkeys_list diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 4b61eddc..1120c3b7 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -9,8 +9,12 @@ from mlflow.store.entities import PagedList from sklearn.preprocessing import StandardScaler from torch import tensor +from mlflow.models import Model +from numalogic.models.autoencoder.variants.vanilla import VanillaAE from numalogic.models.threshold import StdDevThreshold +from numalogic.registry.mlflow_registry import CompositeModels +from numalogic.tools.types import KeyedArtifact def create_model(): @@ -135,6 +139,71 @@ def mock_log_model_sklearn(*_, **__): ) +def mock_log_model_pyfunc(*_, **__): + return ModelInfo( + artifact_path="model", + flavors={ + "pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None}, + "python_function": { + "pickle_module_name": "mlflow.pyfunc.pickle_module", + "loader_module": "mlflow.pyfunc", + "python_version": "3.8.5", + "data": "data", + "env": "conda.yaml", + }, + }, + model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model", + model_uuid="a7c0b376530b40d7b23e6ce2081c899c", + run_id="a7c0b376530b40d7b23e6ce2081c899c", + saved_input_example_info=None, + signature_dict=None, + utc_time_created="2022-05-23 22:35:59.557372", + mlflow_version="2.0.1", + signature=None, + ) + + +def mock_load_model_pyfunc(*_, **__): + artifact_path = "model" + flavors = { + "python_function": { + "cloudpickle_version": "3.0.0", + "code": None, + "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}, + "loader_module": "mlflow.pyfunc.model", + "python_model": "python_model.pkl", + "python_version": "3.10.14", + "streamable": False, + } + } + model_size_bytes = 8912 + model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc" + run_id = "7e85a3fa46d44e668c840f3dddc909c3" + utc_time_created = "2024-09-18 17:12:41.501209" + model = Model( + artifact_path=artifact_path, + flavors=flavors, + model_size_bytes=model_size_bytes, + model_uuid=model_uuid, + run_id=run_id, + utc_time_created=utc_time_created, + mlflow_version="2.16.0", + ) + return mlflow.pyfunc.PyFuncModel( + model_meta=model, + model_impl=TestObject( + python_model=CompositeModels( + skeys=["model"], + dict_artifacts={ + "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), + "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + }, + **{"learning_rate": 0.01}, + ) + ), + ) + + def mock_transition_stage(*_, **__): return ModelVersion( creation_timestamp=1653402941169, @@ -303,6 +372,23 @@ def return_sklearn_rundata(): ) +def return_pyfunc_rundata(): + return Run( + run_info=RunInfo( + artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model", + end_time=None, + experiment_id="0", + lifecycle_stage="active", + run_id="a7c0b376530b40d7b23e6ce2081c899c", + run_uuid="a7c0b376530b40d7b23e6ce2081c899c", + start_time=1658788772612, + status="RUNNING", + user_id="lol", + ), + run_data=RunData(metrics={}, tags={}, params={}), + ) + + def return_pytorch_rundata_dict(): return Run( run_info=RunInfo( @@ -318,3 +404,8 @@ def return_pytorch_rundata_dict(): ), run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]), ) + + +class TestObject(mlflow.pyfunc.PythonModel): + def __init__(self, python_model): + self.python_model = python_model diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 8482c6e8..de2cbd44 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -15,7 +15,10 @@ from numalogic.registry.mlflow_registry import ModelStage +from numalogic.tools.types import KeyedArtifact from tests.registry._mlflow_utils import ( + mock_load_model_pyfunc, + mock_log_model_pyfunc, model_sklearn, create_model, mock_log_model_pytorch, @@ -23,6 +26,7 @@ mock_get_model_version, mock_transition_stage, mock_log_model_sklearn, + return_pyfunc_rundata, return_pytorch_rundata_dict, return_empty_rundata, mock_list_of_model_version, @@ -56,22 +60,68 @@ def test_construct_key(self): self.assertEqual("model_:nnet::error1", key) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) - @patch("mlflow.log_param", mock_log_state_dict) + @patch("mlflow.log_params", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) - @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) def test_save_model(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys status = ml.save( - skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch" + skeys=skeys, + dkeys=dkeys, + artifact=self.model, + run_id="1234", + artifact_type="pytorch", + **{"lr": 0.01}, ) mock_status = "READY" self.assertEqual(mock_status, status.status) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + def test_save_multiple_models_pyfunc(self): + ml = MLflowRegistry(TRACKING_URI) + status = ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), + "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + }, + **{"learning_rate": 0.01}, + ) + self.assertIsNotNone(status) + mock_status = "READY" + self.assertEqual(mock_status, status.status) + + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_models_when_pyfunc_model_exist(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys_list = [["AE", "infer"], ["scaler", "infer"]] + data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list) + self.assertIsNotNone(data["AE:infer"].metadata) + self.assertIsNotNone(data["scaler:infer"].metadata) + self.assertIsInstance(data, dict) + self.assertIsInstance(data["AE:infer"].artifact, VanillaAE) + self.assertIsInstance(data["scaler:infer"].artifact, StandardScaler) + @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))