From 568aa350ab340280b032c7e8ef04f6092e1aee98 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:52:08 +0800 Subject: [PATCH] Fix use_auth_token with ORTModel (#1740) fix use_auth_token --- optimum/exporters/tasks.py | 13 +++++++++++-- optimum/modeling_base.py | 4 +++- tests/onnxruntime/test_modeling.py | 7 +++++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index a92e190b99..8d8a7e82de 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1378,6 +1378,7 @@ def get_model_files( model_name_or_path: Union[str, Path], subfolder: str = "", cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE, + use_auth_token: Optional[str] = None, ): request_exception = None full_model_path = Path(model_name_or_path) / subfolder @@ -1391,7 +1392,9 @@ def get_model_files( try: if not isinstance(model_name_or_path, str): model_name_or_path = str(model_name_or_path) - all_files = huggingface_hub.list_repo_files(model_name_or_path, repo_type="model") + all_files = huggingface_hub.list_repo_files( + model_name_or_path, repo_type="model", token=use_auth_token + ) if subfolder != "": all_files = [file[len(subfolder) + 1 :] for file in all_files if file.startswith(subfolder)] except RequestsConnectionError as e: # Hub not accessible @@ -1672,6 +1675,7 @@ def infer_library_from_model( revision: Optional[str] = None, cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE, library_name: Optional[str] = None, + use_auth_token: Optional[str] = None, ): """ Infers the library from the model repo. @@ -1689,13 +1693,17 @@ def infer_library_from_model( Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. library_name (`Optional[str]`, *optional*): The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". + use_auth_token (`Optional[str]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. Returns: `str`: The library name automatically detected from the model repo. """ if library_name is not None: return library_name - all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir) + all_files, _ = TasksManager.get_model_files( + model_name_or_path, subfolder, cache_dir, use_auth_token=use_auth_token + ) if "model_index.json" in all_files: library_name = "diffusers" @@ -1710,6 +1718,7 @@ def infer_library_from_model( "subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, + "use_auth_token": use_auth_token, } config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs) model_config = PretrainedConfig.from_dict(config_dict, **kwargs) diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 10a291882d..e7254276c2 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -346,7 +346,9 @@ def from_pretrained( ) model_id, revision = model_id.split("@") - library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir) + library_name = TasksManager.infer_library_from_model( + model_id, subfolder, revision, cache_dir, use_auth_token=use_auth_token + ) if library_name == "timm": config = PretrainedConfig.from_pretrained(model_id, subfolder, revision) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4e384798f9..eba667cdb5 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -937,9 +937,12 @@ def test_stable_diffusion_model_on_rocm_ep_str(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "ROCMExecutionProvider") self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) - @require_hf_token def test_load_model_from_hub_private(self): - model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, use_auth_token=os.environ.get("HF_AUTH_TOKEN", None)) + subprocess.run("huggingface-cli logout", shell=True) + # Read token of fxmartyclone (dummy user). + token = "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" + + model = ORTModelForCustomTasks.from_pretrained("fxmartyclone/tiny-onnx-private-2", use_auth_token=token) self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig)