From 104bd645aefdf3646f24484992f799e3f0bcf916 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 15 Jan 2024 17:17:08 +0100 Subject: [PATCH] Skip pushing if the user does not have write access to the cache repo (#405) * Skip pushing if the user does not have write access to the cache repo * Fix tests * Fix tests --- optimum/neuron/trainer_callback.py | 25 +++++++++++++++++++++++-- tests/test_cache_utils.py | 3 +-- tests/utils.py | 7 +++++-- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/optimum/neuron/trainer_callback.py b/optimum/neuron/trainer_callback.py index 17ab7025a..b2442efa1 100644 --- a/optimum/neuron/trainer_callback.py +++ b/optimum/neuron/trainer_callback.py @@ -37,9 +37,11 @@ NeuronHash, create_or_append_to_neuron_parallel_compile_report, download_cached_model_from_hub, + get_hf_hub_cache_repos, get_neuron_cache_path, get_neuron_compiler_version_dir_name, get_neuron_parallel_compile_report, + has_write_access_to_repo, list_files_in_neuron_cache, path_after_folder, push_to_cache_on_hub, @@ -94,6 +96,19 @@ def __init__( self.wait_for_everyone_on_fetch = is_torch_xla_available() and wait_for_everyone_on_fetch self.wait_for_everyone_on_push = is_torch_xla_available() and wait_for_everyone_on_push + cache_repo_ids = get_hf_hub_cache_repos() + if cache_repo_ids: + self.cache_repo_id = cache_repo_ids[0] + has_write_access = has_write_access_to_repo(self.cache_repo_id) + if self.push and not has_write_access: + logger.warning( + f"Pushing to the remote cache repo {self.cache_repo_id} is disabled because you do not have write " + "access to it." + ) + self.push = False + else: + self.cache_repo_id = None + # Real Neuron compile cache if it exists. if original_neuron_cache_path is None: self.neuron_cache_path = get_neuron_cache_path() @@ -293,7 +308,9 @@ def synchronize_temporary_neuron_cache_state(self) -> List[Path]: def synchronize_temporary_neuron_cache(self): for neuron_hash, files in self.neuron_hash_to_files.items(): for path in files: - push_to_cache_on_hub(neuron_hash, path, local_path_to_path_in_repo="default") + push_to_cache_on_hub( + neuron_hash, path, cache_repo_id=self.cache_repo_id, local_path_to_path_in_repo="default" + ) if self.use_neuron_cache: path_in_cache = self.full_path_to_path_in_temporary_cache(path) target_file = self.neuron_cache_path / path_in_cache @@ -364,7 +381,11 @@ def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control for path in filenames: try: push_to_cache_on_hub( - neuron_hash, path, local_path_to_path_in_repo="default", fail_when_could_not_push=True + neuron_hash, + path, + cache_repo_id=self.cache_repo_id, + local_path_to_path_in_repo="default", + fail_when_could_not_push=True, ) except HfHubHTTPError: # It means that we could not push, so we do not remove this entry from the report. diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 6d00cba9a..f92dba1d1 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -279,14 +279,13 @@ def remove_repo(): HfFolder.save_token(orig_token) def test_has_write_access_to_repo(self): - orig_token = HfFolder.get_token() wrong_token = "random_string" HfFolder.save_token(wrong_token) self.assertFalse(has_write_access_to_repo(self.CUSTOM_CACHE_REPO)) self.assertFalse(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO)) - HfFolder.save_token(orig_token) + HfFolder.save_token(self._staging_token) self.assertTrue(has_write_access_to_repo(self.CUSTOM_CACHE_REPO)) self.assertTrue(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO)) diff --git a/tests/utils.py b/tests/utils.py index be069ddf1..2b6caf8e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,7 +24,7 @@ import torch from datasets import Dataset, DatasetDict -from huggingface_hub import CommitOperationDelete, HfApi, HfFolder, create_repo, delete_repo +from huggingface_hub import CommitOperationDelete, HfApi, HfFolder, create_repo, delete_repo, logout from huggingface_hub.utils import RepositoryNotFoundError from transformers import PretrainedConfig, PreTrainedModel from transformers.testing_utils import ENDPOINT_STAGING @@ -163,7 +163,10 @@ class StagingTestMixin: @classmethod def set_hf_hub_token(cls, token: str) -> str: orig_token = HfFolder.get_token() - HfFolder.save_token(token) + if token is not None: + HfFolder.save_token(token) + else: + logout() cls._env = dict(os.environ, HF_ENDPOINT=ENDPOINT_STAGING) return orig_token