Skip to content

Commit

Permalink
Skip pushing if the user does not have write access to the cache repo (
Browse files Browse the repository at this point in the history
…#405)

* Skip pushing if the user does not have write access to the cache repo

* Fix tests

* Fix tests
  • Loading branch information
michaelbenayoun authored Jan 15, 2024
1 parent 923398e commit 104bd64
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
25 changes: 23 additions & 2 deletions optimum/neuron/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
NeuronHash,
create_or_append_to_neuron_parallel_compile_report,
download_cached_model_from_hub,
get_hf_hub_cache_repos,
get_neuron_cache_path,
get_neuron_compiler_version_dir_name,
get_neuron_parallel_compile_report,
has_write_access_to_repo,
list_files_in_neuron_cache,
path_after_folder,
push_to_cache_on_hub,
Expand Down Expand Up @@ -94,6 +96,19 @@ def __init__(
self.wait_for_everyone_on_fetch = is_torch_xla_available() and wait_for_everyone_on_fetch
self.wait_for_everyone_on_push = is_torch_xla_available() and wait_for_everyone_on_push

cache_repo_ids = get_hf_hub_cache_repos()
if cache_repo_ids:
self.cache_repo_id = cache_repo_ids[0]
has_write_access = has_write_access_to_repo(self.cache_repo_id)
if self.push and not has_write_access:
logger.warning(
f"Pushing to the remote cache repo {self.cache_repo_id} is disabled because you do not have write "
"access to it."
)
self.push = False
else:
self.cache_repo_id = None

# Real Neuron compile cache if it exists.
if original_neuron_cache_path is None:
self.neuron_cache_path = get_neuron_cache_path()
Expand Down Expand Up @@ -293,7 +308,9 @@ def synchronize_temporary_neuron_cache_state(self) -> List[Path]:
def synchronize_temporary_neuron_cache(self):
for neuron_hash, files in self.neuron_hash_to_files.items():
for path in files:
push_to_cache_on_hub(neuron_hash, path, local_path_to_path_in_repo="default")
push_to_cache_on_hub(
neuron_hash, path, cache_repo_id=self.cache_repo_id, local_path_to_path_in_repo="default"
)
if self.use_neuron_cache:
path_in_cache = self.full_path_to_path_in_temporary_cache(path)
target_file = self.neuron_cache_path / path_in_cache
Expand Down Expand Up @@ -364,7 +381,11 @@ def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control
for path in filenames:
try:
push_to_cache_on_hub(
neuron_hash, path, local_path_to_path_in_repo="default", fail_when_could_not_push=True
neuron_hash,
path,
cache_repo_id=self.cache_repo_id,
local_path_to_path_in_repo="default",
fail_when_could_not_push=True,
)
except HfHubHTTPError:
# It means that we could not push, so we do not remove this entry from the report.
Expand Down
3 changes: 1 addition & 2 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,13 @@ def remove_repo():
HfFolder.save_token(orig_token)

def test_has_write_access_to_repo(self):
orig_token = HfFolder.get_token()
wrong_token = "random_string"
HfFolder.save_token(wrong_token)

self.assertFalse(has_write_access_to_repo(self.CUSTOM_CACHE_REPO))
self.assertFalse(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO))

HfFolder.save_token(orig_token)
HfFolder.save_token(self._staging_token)

self.assertTrue(has_write_access_to_repo(self.CUSTOM_CACHE_REPO))
self.assertTrue(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO))
Expand Down
7 changes: 5 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import torch
from datasets import Dataset, DatasetDict
from huggingface_hub import CommitOperationDelete, HfApi, HfFolder, create_repo, delete_repo
from huggingface_hub import CommitOperationDelete, HfApi, HfFolder, create_repo, delete_repo, logout
from huggingface_hub.utils import RepositoryNotFoundError
from transformers import PretrainedConfig, PreTrainedModel
from transformers.testing_utils import ENDPOINT_STAGING
Expand Down Expand Up @@ -163,7 +163,10 @@ class StagingTestMixin:
@classmethod
def set_hf_hub_token(cls, token: str) -> str:
orig_token = HfFolder.get_token()
HfFolder.save_token(token)
if token is not None:
HfFolder.save_token(token)
else:
logout()
cls._env = dict(os.environ, HF_ENDPOINT=ENDPOINT_STAGING)
return orig_token

Expand Down

0 comments on commit 104bd64

Please sign in to comment.