From 1f049e15fd52bdeb683da464125138ad122fe5b9 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 8 Apr 2024 16:10:49 +0200 Subject: [PATCH] Cache utils related cleanup (#553) --- optimum/neuron/accelerate/state.py | 4 +- optimum/neuron/trainer_callback.py | 433 ---------- optimum/neuron/trainers.py | 31 +- optimum/neuron/utils/cache_utils.py | 808 +----------------- optimum/neuron/utils/hub_neuronx_cache.py | 36 +- .../torch_xla_and_neuronx_initialization.py | 5 +- tests/test_cache_utils.py | 617 +------------ tests/test_trainer_callback.py | 210 ----- tests/test_trainers.py | 35 +- tests/utils.py | 55 +- 10 files changed, 74 insertions(+), 2160 deletions(-) delete mode 100644 optimum/neuron/trainer_callback.py delete mode 100644 tests/test_trainer_callback.py diff --git a/optimum/neuron/accelerate/state.py b/optimum/neuron/accelerate/state.py index 0da4ae002..a03a53707 100644 --- a/optimum/neuron/accelerate/state.py +++ b/optimum/neuron/accelerate/state.py @@ -38,7 +38,7 @@ from ..utils import is_neuronx_distributed_available, is_torch_xla_available from ..utils.torch_xla_and_neuronx_initialization import ( init_process_group, - set_common_neuron_cc_flags, + set_common_flags, set_neuron_cc_flags_for_torch_amp, ) from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin @@ -91,7 +91,7 @@ def __init__(self, cpu: bool = False, **kwargs): torch.cuda.set_device(self.device) elif is_torch_xla_available() and not cpu: # It is important to set the environment variables before initializing the process group otherwise they will be ignored by the Neuron compiler. - set_common_neuron_cc_flags() + set_common_flags() if os.environ.get("ACCELERATE_USE_AMP", "false") == "true": set_neuron_cc_flags_for_torch_amp() init_process_group() diff --git a/optimum/neuron/trainer_callback.py b/optimum/neuron/trainer_callback.py deleted file mode 100644 index b2442efa1..000000000 --- a/optimum/neuron/trainer_callback.py +++ /dev/null @@ -1,433 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines custom Trainer callbacks specific to AWS Neuron instances.""" - -import inspect -import json -import os -import shutil -import subprocess -from collections import defaultdict -from dataclasses import asdict, dataclass -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -import torch -from huggingface_hub.utils import HfHubHTTPError -from packaging import version -from transformers import TrainerCallback, TrainerState - -from ..utils import logging -from .distributed.utils import TENSOR_PARALLEL_SHARDS_DIR_NAME -from .utils import is_torch_xla_available -from .utils.cache_utils import ( - NeuronHash, - create_or_append_to_neuron_parallel_compile_report, - download_cached_model_from_hub, - get_hf_hub_cache_repos, - get_neuron_cache_path, - get_neuron_compiler_version_dir_name, - get_neuron_parallel_compile_report, - has_write_access_to_repo, - list_files_in_neuron_cache, - path_after_folder, - push_to_cache_on_hub, - remove_entries_in_neuron_parallel_compile_report, - set_neuron_cache_path, -) -from .utils.training_utils import is_precompilation -from .version import __version__ - - -if TYPE_CHECKING: - from transformers import PreTrainedModel, TrainerControl, TrainingArguments - - from .training_args import NeuronTrainingArguments - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - -logger = logging.get_logger(__name__) - - -@dataclass -class NeuronTrainerState(TrainerState): - last_inputs: Optional[Dict[str, Any]] = None - - def __post_init__(self): - super().__post_init__() - if self.last_inputs is None: - self.last_inputs = {} - - @classmethod - def from_trainer_state(cls, state: TrainerState) -> "NeuronTrainerState": - neuron_trainer_state = cls(asdict(state)) - neuron_trainer_state.last_inputs = getattr(state, "last_inputs", {}) - return neuron_trainer_state - - -class NeuronCacheCallback(TrainerCallback): - def __init__( - self, - tmp_neuron_cache: Optional[Path] = None, - original_neuron_cache_path: Optional[Path] = None, - fetch: bool = True, - push: bool = True, - wait_for_everyone_on_fetch: bool = True, - wait_for_everyone_on_push: bool = True, - ): - super().__init__() - self.fetch = fetch - self.push = push - self.wait_for_everyone_on_fetch = is_torch_xla_available() and wait_for_everyone_on_fetch - self.wait_for_everyone_on_push = is_torch_xla_available() and wait_for_everyone_on_push - - cache_repo_ids = get_hf_hub_cache_repos() - if cache_repo_ids: - self.cache_repo_id = cache_repo_ids[0] - has_write_access = has_write_access_to_repo(self.cache_repo_id) - if self.push and not has_write_access: - logger.warning( - f"Pushing to the remote cache repo {self.cache_repo_id} is disabled because you do not have write " - "access to it." - ) - self.push = False - else: - self.cache_repo_id = None - - # Real Neuron compile cache if it exists. - if original_neuron_cache_path is None: - self.neuron_cache_path = get_neuron_cache_path() - else: - self.neuron_cache_path = original_neuron_cache_path - self.use_neuron_cache = self.neuron_cache_path is not None - self.neuron_cache_path.mkdir(parents=True, exist_ok=True) - - # Temporary Neuron compile cache. - if is_precompilation(): - # When doing precompilation, the graph will be compiled after than the script is done. - # By setting `self.tmp_neuron_cache` to `self.neuron_cache_path`, `neuron_parallel_compile` will extract - # the very same graphs than the one created during real training, while not doing any synchronization - # during training since the compiled files will not be there yet. - self.tmp_neuron_cache_path = self.neuron_cache_path - elif tmp_neuron_cache is None: - # To keep an instance of the TemporaryDirectory as long as the callback lives. - self._tmp_neuron_cache = self.create_temporary_neuron_cache(self.neuron_cache_path) - self.tmp_neuron_cache_path = Path(self._tmp_neuron_cache.name) - else: - self.tmp_neuron_cache_path = tmp_neuron_cache - - self.tmp_neuron_cache_state = list_files_in_neuron_cache(self.tmp_neuron_cache_path, only_relevant_files=True) - self.fetch_files = set() - - # Keys are of format: - # (model, input_shapes, data_type, tensor_parallel_size) - self.neuron_hashes: Dict[ - Tuple["PreTrainedModel", Tuple[Tuple[str, Tuple[int]], ...], torch.dtype, int], NeuronHash - ] = {} - self.neuron_hash_to_files: Dict[NeuronHash, List[Path]] = defaultdict(list) - - def prepare_state(self, state: TrainerState): - if isinstance(state, NeuronTrainerState): - return state - return NeuronTrainerState.from_trainer_state(state) - - @staticmethod - def get_dir_size(path: Path) -> int: - if not path.is_dir(): - raise ValueError(f"{path} is not a directory.") - proc = subprocess.Popen(["du", "-s", path.as_posix()], stdout=subprocess.PIPE) - stdout, _ = proc.communicate() - stdout = stdout.decode("utf-8") - return int(stdout.split()[0]) - - @classmethod - def _load_cache_stats(cls, neuron_cache_path: Path) -> Dict[str, Dict[str, Any]]: - cache_stats_path = neuron_cache_path / "cache_stats.json" - if cache_stats_path.exists(): - with open(neuron_cache_path / "cache_stats.json", "r") as fp: - cache_stats = json.load(fp) - else: - cache_stats = {} - return cache_stats - - @classmethod - def _insert_in_cache_stats(cls, cache_stats: Dict[str, Dict[str, Any]], full_path: Path, path_in_cache: Path): - cache_key = path_in_cache.parts[0] - item = cache_stats.get(cache_key, {}) - if full_path.parent.as_posix() in item: - return - item[full_path.parent.as_posix()] = {"used_time": 1, "size": cls.get_dir_size(full_path.parent)} - cache_stats[cache_key] = item - - @classmethod - def _update_cache_stats(cls, neuron_cache_path: Path): - cache_stats = cls._load_cache_stats(neuron_cache_path) - for path in list_files_in_neuron_cache(neuron_cache_path): - cls._insert_in_cache_stats(cache_stats, path, neuron_cache_path) - with open(neuron_cache_path / "cache_stats.json", "w") as fp: - json.dump(cache_stats, fp) - - @classmethod - def create_temporary_neuron_cache(cls, neuron_cache_path: Optional[Path]) -> TemporaryDirectory: - tmp_neuron_cache = TemporaryDirectory() - tmp_neuron_cache_path = Path(tmp_neuron_cache.name) - if neuron_cache_path is not None: - neuron_cache_files = list_files_in_neuron_cache(neuron_cache_path) - else: - neuron_cache_files = [] - - # Setting the Neuron compilation cache to be the temporary Neuron compilation cache. - set_neuron_cache_path(tmp_neuron_cache_path) - - cache_stats_exists = False - if neuron_cache_path is not None: - cache_stats = cls._load_cache_stats(neuron_cache_path) - else: - cache_stats = {} - - for cache_file in neuron_cache_files: - if cache_file.name == "cache_stats.json": - continue - try: - path_in_neuron_cache = path_after_folder( - cache_file, - get_neuron_compiler_version_dir_name(), - include_folder=True, - fail_when_folder_not_found=True, - ) - except Exception: - # Here only when the folder `get_neuron_compiler_version_dir_name()` was not in the path of - # `cache_file`. In this case, no symlink is created because it is interpreted as not being a - # compilation file. - continue - tmp_cache_file = tmp_neuron_cache_path / path_in_neuron_cache - tmp_cache_file.parent.mkdir(parents=True, exist_ok=True) - # TODO: investigate why it is needed. Minor issue. - if not tmp_cache_file.exists(): - tmp_cache_file.symlink_to(cache_file) - - cls._insert_in_cache_stats(cache_stats, cache_file, path_in_neuron_cache) - - if not cache_stats_exists: - with open(tmp_neuron_cache_path / "cache_stats.json", "w") as fp: - json.dump(cache_stats, fp) - - return tmp_neuron_cache - - def neuron_hash_for_model( - self, - args: "NeuronTrainingArguments", - model: "PreTrainedModel", - inputs: Dict[str, Any], - try_to_fetch_cached_model: bool = False, - ) -> NeuronHash: - input_names = inspect.signature(model.forward).parameters.keys() - input_shapes = tuple( - (input_name, tuple(input_.shape)) for input_name, input_ in inputs.items() if input_name in input_names - ) - - # For backward compatibility, to not break the cache for users for now. - if version.parse(__version__) <= version.parse("0.0.14"): - use_bf16 = args.bf16 - else: - use_bf16 = ( - args.bf16 - or os.environ.get("XLA_USE_BF16", "0") == "1" - or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1" - ) - if args.fp16: - data_type = torch.float16 - elif use_bf16: - data_type = torch.bfloat16 - else: - data_type = torch.float32 - - key_args = (model, input_shapes, data_type) - key_kwargs = {"tensor_parallel_size": args.tensor_parallel_size} - key = key_args + tuple(key_kwargs.values()) - neuron_hash = self.neuron_hashes.get(key, None) - if neuron_hash is None: - neuron_hash = NeuronHash(*key_args, **key_kwargs) - self.neuron_hashes[key] = neuron_hash - if try_to_fetch_cached_model: - self.try_to_fetch_cached_model(neuron_hash) - return neuron_hash - - def full_path_to_path_in_temporary_cache(self, path: Path): - return path_after_folder(path, self.tmp_neuron_cache_path.name) - - def try_to_fetch_cached_model(self, neuron_hash: NeuronHash) -> bool: - # TODO: needs to be called ONLY when absolutely needed. - files_before_fetching = list_files_in_neuron_cache(self.tmp_neuron_cache_path, only_relevant_files=True) - - found_in_cache = download_cached_model_from_hub( - neuron_hash, - target_directory=self.tmp_neuron_cache_path, - path_in_repo_to_path_in_target_directory="default", - ) - - if found_in_cache: - files_after_fetching = list_files_in_neuron_cache(self.tmp_neuron_cache_path, only_relevant_files=True) - diff = [f for f in files_after_fetching if f not in files_before_fetching] - # The fetched files should not be synchronized with the Hub. - self.tmp_neuron_cache_state += diff - if self.use_neuron_cache: - for path in diff: - path_in_cache = self.full_path_to_path_in_temporary_cache(path) - path_in_original_cache = self.neuron_cache_path / path_in_cache - path_in_original_cache.parent.mkdir(parents=True, exist_ok=True) - if path_in_original_cache.exists(): - continue - shutil.copy(path, path_in_original_cache) - - return found_in_cache - - def synchronize_temporary_neuron_cache_state(self) -> List[Path]: - current_files_in_neuron_cache = list_files_in_neuron_cache( - self.tmp_neuron_cache_path, only_relevant_files=True - ) - diff = [p for p in current_files_in_neuron_cache if p not in self.tmp_neuron_cache_state] - self.tmp_neuron_cache_state = current_files_in_neuron_cache - return diff - - def synchronize_temporary_neuron_cache(self): - for neuron_hash, files in self.neuron_hash_to_files.items(): - for path in files: - push_to_cache_on_hub( - neuron_hash, path, cache_repo_id=self.cache_repo_id, local_path_to_path_in_repo="default" - ) - if self.use_neuron_cache: - path_in_cache = self.full_path_to_path_in_temporary_cache(path) - target_file = self.neuron_cache_path / path_in_cache - target_file.parent.mkdir(parents=True, exist_ok=True) - shutil.copy(path, self.neuron_cache_path / path_in_cache) - - if self.use_neuron_cache: - self._update_cache_stats(self.neuron_cache_path) - - for neuron_hash in self.neuron_hash_to_files: - self.neuron_hash_to_files[neuron_hash] = [] - - def on_step_middle(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs): - if self.fetch: - model = kwargs["model"] - self.neuron_hash_for_model(args, model, state.last_inputs, try_to_fetch_cached_model=True) - if self.wait_for_everyone_on_fetch: - xm.rendezvous("wait for everyone after fetching") - - def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): - """ - Event called at the end of a training step. If using gradient accumulation, one training step might take - several inputs. - """ - - if self.push or (xm.get_local_ordinal() == 0 and is_precompilation()): - model = kwargs["model"] - state = self.prepare_state(state) - neuron_hash = self.neuron_hash_for_model(args, model, state.last_inputs, try_to_fetch_cached_model=True) - diff = self.synchronize_temporary_neuron_cache_state() - self.neuron_hash_to_files[neuron_hash].extend(diff) - - def on_prediction_step(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs): - """ - Event called after a prediction step. - """ - self.on_step_end(args, state, control, **kwargs) - - def on_save(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs): - """ - Event called after a checkpoint save. - """ - if xm.get_local_ordinal() == 0 and is_precompilation() and self.tmp_neuron_cache_path is not None: - create_or_append_to_neuron_parallel_compile_report(self.tmp_neuron_cache_path, self.neuron_hash_to_files) - for neuron_hash in self.neuron_hash_to_files: - self.neuron_hash_to_files[neuron_hash] = [] - if self.push: - self.synchronize_temporary_neuron_cache() - if self.wait_for_everyone_on_push: - xm.rendezvous("wait for everyone after pushing") - - def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs): - """ - Event called at the beginning of training. - """ - if is_precompilation() or self.neuron_cache_path is None: - return - if self.push: - neuron_parallel_compile_report = get_neuron_parallel_compile_report( - self.neuron_cache_path, as_neuron_hash=True - ) - entries_to_remove = [] - for entry in neuron_parallel_compile_report: - neuron_hash = entry["neuron_hash"] - path = entry["directory"] - filenames = list_files_in_neuron_cache(path, only_relevant_files=True) - success = True - for path in filenames: - try: - push_to_cache_on_hub( - neuron_hash, - path, - cache_repo_id=self.cache_repo_id, - local_path_to_path_in_repo="default", - fail_when_could_not_push=True, - ) - except HfHubHTTPError: - # It means that we could not push, so we do not remove this entry from the report. - success = False - if success: - entries_to_remove.append(entry) - - # Removing the entries that were uploaded. - remove_entries_in_neuron_parallel_compile_report(self.neuron_cache_path, entries_to_remove) - if self.wait_for_everyone_on_push: - xm.rendezvous("wait for everyone after pushing") - - def on_train_end(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs): - """ - Event called at the end of training. - """ - self.on_save(args, state, control, **kwargs) - if is_precompilation(): - if xm.get_local_ordinal() == 0: - output_dir = Path(args.output_dir) - for file_or_dir in output_dir.glob("**/*"): - if file_or_dir.is_file(): - continue - if ( - file_or_dir.name.startswith("checkpoint-") - or file_or_dir.name == TENSOR_PARALLEL_SHARDS_DIR_NAME - ): - logger.info( - f"Removing {file_or_dir} since the weights were produced by `neuron_parallel_compile`, " - "thus cannot be used." - ) - shutil.rmtree(file_or_dir, ignore_errors=True) - xm.rendezvous("wait for everyone after end of training cleanup during precompilation") - - def on_evaluate(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs): - """ - Event called after an evaluation phase. - """ - self.on_save(args, state, control, **kwargs) - - def on_predict(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", metrics, **kwargs): - """ - Event called after a successful prediction. - """ - self.on_save(args, state, control, **kwargs) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 73e05065b..014e229ad 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -77,7 +77,6 @@ get_hf_hub_cache_repos, get_model_name_or_path, get_neuron_cache_path, - get_neuronxcc_version, get_num_neuron_cores_used, has_write_access_to_repo, ) @@ -96,6 +95,7 @@ skip_first_batches, torch_xla_safe_save_file, ) +from .utils.version_utils import get_neuronxcc_version if is_apex_available(): @@ -1362,14 +1362,13 @@ def train( ignore_keys_for_eval: Optional[List[str]] = None, **kwargs, ): - with patch_neuron_cc_wrapper(): - with hub_neuronx_cache("training", entry=self.model_cache_entry): - result = super().train( - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - **kwargs, - ) + with hub_neuronx_cache("training", entry=self.model_cache_entry): + result = super().train( + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + **kwargs, + ) if not is_precompilation(): self.synchronize_hub_cache() return result @@ -1380,11 +1379,10 @@ def evaluate( ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: - with patch_neuron_cc_wrapper(): - with hub_neuronx_cache("training", entry=self.model_cache_entry): - result = super().evaluate( - eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix - ) + with hub_neuronx_cache("training", entry=self.model_cache_entry): + result = super().evaluate( + eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) if not is_precompilation(): self.synchronize_hub_cache() return result @@ -1392,9 +1390,8 @@ def evaluate( def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" ) -> PredictionOutput: - with patch_neuron_cc_wrapper(): - with hub_neuronx_cache("training", entry=self.model_cache_entry): - result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + with hub_neuronx_cache("training", entry=self.model_cache_entry): + result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) if not is_precompilation(): self.synchronize_hub_cache() return result diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py index d8ced265a..e87ed63e5 100644 --- a/optimum/neuron/utils/cache_utils.py +++ b/optimum/neuron/utils/cache_utils.py @@ -1,5 +1,3 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,39 +12,24 @@ # limitations under the License. """Utilities for caching.""" -import functools -import hashlib -import io -import json import os import re -import shutil -import tempfile -from dataclasses import InitVar, asdict, dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import List, Optional, Union -import huggingface_hub -import numpy as np -import torch from huggingface_hub import ( - CommitOperationAdd, HfApi, RepoUrl, create_repo, get_token, - hf_hub_download, whoami, ) -from huggingface_hub.utils import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError -from packaging import version -from transformers import PretrainedConfig, PreTrainedModel +from huggingface_hub.utils import RepositoryNotFoundError +from transformers import PretrainedConfig from ...utils import logging from ...utils.logging import warn_once from .misc import is_main_worker, string_to_bool -from .require_utils import requires_neuronx_distributed -from .version_utils import get_neuronxcc_version logger = logging.get_logger() @@ -67,16 +50,6 @@ else: HF_HUB_CACHE_REPOS = [f"aws-neuron/{CACHE_REPO_NAME}"] -HASH_FILENAME = "pytorch_model.bin" -REGISTRY_FILENAME = "registry.json" -NEURON_PARALLEL_COMPILE_REPORT_FILENAME = "neuron_parallel_compile_report.json" - -_IP_PATTERN = re.compile(r"ip-([0-9]{1,3}-){4}") -_HF_HUB_HTTP_ERROR_REQUEST_ID_PATTERN = re.compile(r"\(Request ID: Root=[\w-]+\)") - -_REGISTRY_FILE_EXISTS: Dict[str, bool] = {} -_ADDED_IN_REGISTRY: Dict[Tuple[str, "NeuronHash"], bool] = {} - # For testing purposes. _DISABLE_IS_PRIVATE_REPO_CHECK: bool = string_to_bool( os.environ.get("OPTIMUM_NEURON_DISABLE_IS_PRIVATE_REPO_CHECK", "false") @@ -130,7 +103,6 @@ def delete_custom_cache_repo_name_from_hf_home(hf_home_cache_repo_file: str = HF def create_custom_cache_repo(repo_id: str = CACHE_REPO_NAME, private: bool = True) -> RepoUrl: repo_url = create_repo(repo_id, private=private, repo_type="model") - create_registry_file_if_does_not_exist(repo_url.repo_id) set_custom_cache_repo_name_in_hf_home(repo_url.repo_id) return repo_url @@ -187,7 +159,14 @@ def has_write_access_to_repo(repo_id: str) -> bool: return has_write_access_in_org -def get_hf_hub_cache_repos(): +def get_hf_hub_cache_repos(log_warnings: bool = False) -> List[str]: + """ + Retrieves the name of the Hugging Face Hub model repo to use as remote cache. + Priority: + - If a repo is provided via the `CUSTOM_CACHE_REPO` environment variable, it will be used, + - Else, if a custom cache repo has been set locally, it will be used, + - Otherwise, it uses the default cache repo (on which most people do not have write access) + """ # Default hub repos. hf_hub_repos = HF_HUB_CACHE_REPOS @@ -201,7 +180,7 @@ def get_hf_hub_cache_repos(): if custom_cache_repo is not None and custom_cache_repo not in hf_hub_repos: hf_hub_repos = [custom_cache_repo] + hf_hub_repos - if is_main_worker() and saved_custom_cache_repo is None and custom_cache_repo is None: + if log_warnings and is_main_worker() and saved_custom_cache_repo is None and custom_cache_repo is None: warn_once( logger, "No Neuron cache name is saved locally. This means that only the official Neuron cache will be used. You " @@ -210,7 +189,7 @@ def get_hf_hub_cache_repos(): "set -n [name]`.", ) - if is_main_worker() and hf_hub_repos and not has_write_access_to_repo(hf_hub_repos[0]): + if log_warnings and is_main_worker() and hf_hub_repos and not has_write_access_to_repo(hf_hub_repos[0]): warn_once( logger, f"You do not have write access to {hf_hub_repos[0]} so you will not be able to push any cached compilation " @@ -219,6 +198,10 @@ def get_hf_hub_cache_repos(): return hf_hub_repos +def get_hf_hub_cache_repo(log_warnings: bool = False) -> str: + return get_hf_hub_cache_repos(log_warnings=log_warnings)[0] + + def get_neuron_cache_path() -> Optional[Path]: # NEURON_CC_FLAGS is the environment variable read by the neuron compiler. # Among other things, this is where the cache directory is specified. @@ -269,13 +252,7 @@ def get_num_neuron_cores() -> int: def get_num_neuron_cores_used() -> int: - return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) - - -def get_neuron_compiler_version_dir_name(neuron_compiler_version: Optional[str] = None) -> str: - if neuron_compiler_version is None: - neuron_compiler_version = get_neuronxcc_version() - return f"neuronxcc-{neuron_compiler_version}" + return int(os.environ.get("WORLD_SIZE", "1")) def list_files_in_neuron_cache(neuron_cache_path: Union[str, Path], only_relevant_files: bool = False) -> List[Path]: @@ -287,31 +264,6 @@ def list_files_in_neuron_cache(neuron_cache_path: Union[str, Path], only_relevan return files -def path_after_folder( - path: Path, folder: Union[str, Path], include_folder: bool = False, fail_when_folder_not_found: bool = False -) -> Path: - if isinstance(folder, Path): - folder = folder.name - try: - index = path.parts.index(folder) - except ValueError as e: - if fail_when_folder_not_found: - raise e - index = len(path.parts) - index = index + 1 if not include_folder else index - return Path("").joinpath(*path.parts[index:]) - - -def path_after_neuron_compiler_version_dir( - path: Path, neuron_compiler_version: str, include_folder: bool = False -) -> Path: - return path_after_folder(path, f"neuronxcc-{neuron_compiler_version}", include_folder=include_folder) - - -def remove_ip_adress_from_path(path: Path) -> Path: - return Path().joinpath(*(re.sub(_IP_PATTERN, "", part) for part in path.parts)) - - def get_model_name_or_path(config: "PretrainedConfig") -> Optional[str]: attribute_names_to_try = ["_model_name_or_path", "_name_or_path"] model_name_or_path = None @@ -323,727 +275,3 @@ def get_model_name_or_path(config: "PretrainedConfig") -> Optional[str]: if model_name_or_path == "": model_name_or_path = None return model_name_or_path - - -def get_neuron_parallel_compile_report( - neuron_cache_path: Union[str, Path], as_neuron_hash: bool = False -) -> List[Dict[str, Any]]: - report_file = Path(neuron_cache_path) / NEURON_PARALLEL_COMPILE_REPORT_FILENAME - report_content = [] - if report_file.is_file(): - try: - with open(report_file) as fp: - report_content = json.load(fp) - except json.JSONDecodeError: - pass - if as_neuron_hash: - for entry in report_content: - entry["neuron_hash"] = NeuronHash.from_neuron_compile_report(entry.pop("neuron_hash")) - return report_content - - -def create_or_append_to_neuron_parallel_compile_report( - neuron_cache_path: Union[str, Path], neuron_hash_to_files: Dict["NeuronHash", List[Path]] -): - report_content = get_neuron_parallel_compile_report(neuron_cache_path) - inserted = set() - for neuron_hash, filenames in neuron_hash_to_files.items(): - for filename in filenames: - directory = filename.parent - if directory in inserted: - continue - report_content.append( - {"neuron_hash": neuron_hash.to_dict_for_neuron_compile_report(), "directory": directory.as_posix()} - ) - inserted.add(directory) - - report_file = Path(neuron_cache_path) / NEURON_PARALLEL_COMPILE_REPORT_FILENAME - with open(report_file, "w") as fp: - json.dump(report_content, fp) - - -def remove_entries_in_neuron_parallel_compile_report( - neuron_cache_path: Union[str, Path], entries_to_remove: List[Dict[str, Any]] -): - report = get_neuron_parallel_compile_report(neuron_cache_path, as_neuron_hash=False) - new_report = [] - for entry in report: - entry_neuron_hash = entry["neuron_hash"] - entry_directory = entry["directory"] - should_keep = True - for entry_to_remove in entries_to_remove: - neuron_hash = entry_to_remove["neuron_hash"] - if isinstance(neuron_hash, NeuronHash): - overall_hash = neuron_hash._hash.overall_hash - else: - overall_hash = neuron_hash["overall_hash"] - directory = entry_to_remove["directory"] - if entry_neuron_hash["overall_hash"] == overall_hash and entry_directory == directory: - should_keep = False - if should_keep: - new_report.append(entry) - - report_file = Path(neuron_cache_path) / NEURON_PARALLEL_COMPILE_REPORT_FILENAME - with open(report_file, "w") as fp: - json.dump(new_report, fp) - - -def create_registry_file_if_does_not_exist(repo_id: str): - was_created = _REGISTRY_FILE_EXISTS.get(repo_id, False) - if was_created: - return - file_exists = True - try: - hf_hub_download(repo_id, REGISTRY_FILENAME, force_download=True) - except EntryNotFoundError: - file_exists = False - if file_exists: - return - with tempfile.NamedTemporaryFile() as tmpfile: - with open(tmpfile.name, "w") as fp: - json.dump({}, fp) - tmpfilename = Path(tmpfile.name) - add_registry_file = CommitOperationAdd(REGISTRY_FILENAME, tmpfilename.as_posix()) - HfApi().create_commit(repo_id, operations=[add_registry_file], commit_message="Create cache registry file") - - _REGISTRY_FILE_EXISTS[repo_id] = True - - -def add_in_registry(repo_id: str, neuron_hash: "NeuronHash"): - was_added = _ADDED_IN_REGISTRY.get((repo_id, neuron_hash), False) - if was_added: - return - model_name_or_path = neuron_hash._model_name_or_path - if model_name_or_path is None: - model_name_or_path = "null" - - model_hash, overall_hash = neuron_hash.compute_hash() - - with tempfile.TemporaryDirectory() as tmpdirname: - keep_going = True - while keep_going: - tmpdirpath = Path(tmpdirname) - head = HfApi().model_info(repo_id).sha - hf_hub_download( - repo_id, - REGISTRY_FILENAME, - revision=head, - local_dir=tmpdirpath, - local_dir_use_symlinks=False, - ) - registry_path = tmpdirpath / REGISTRY_FILENAME - with open(registry_path, "r") as fp: - registry = json.load(fp) - - orig_registry = registry - if neuron_hash.neuron_compiler_version not in registry: - registry[neuron_hash.neuron_compiler_version] = {} - registry = registry[neuron_hash.neuron_compiler_version] - - key = model_name_or_path if model_name_or_path != "null" else model_hash - if model_name_or_path not in registry: - registry[key] = {"model_name_or_path": model_name_or_path, "model_hash": model_hash} - registry = registry[key] - - if "features" not in registry: - registry["features"] = [] - - exists_already = False - for feature in registry["features"]: - if feature["neuron_hash"] == overall_hash: - exists_already = True - - if not exists_already: - data = { - "input_shapes": neuron_hash.input_shapes, - "precision": str(neuron_hash.data_type), - "num_neuron_cores": neuron_hash.num_neuron_cores, - "neuron_hash": overall_hash, - } - registry["features"].append(data) - - with open(registry_path, "w") as fp: - json.dump(orig_registry, fp) - - add_model_in_registry = CommitOperationAdd(REGISTRY_FILENAME, registry_path.as_posix()) - try: - HfApi().create_commit( - repo_id, - operations=[add_model_in_registry], - commit_message=f"Add {model_name_or_path} in registry for NeuronHash {overall_hash}", - parent_commit=head, - ) - except Exception as e: - if "A commit has happened since" in str(e): - if is_main_worker(): - logger.info( - "A commit has happened in cache repository since we tried to update the registry, starting " - "again..." - ) - else: - raise e - else: - keep_going = False - - _ADDED_IN_REGISTRY[(repo_id, neuron_hash)] = True - - -def _list_in_registry_dict( - registry: Dict[str, Any], - model_name_or_path_or_hash: Optional[str] = None, - neuron_compiler_version: Optional[str] = None, -) -> List[str]: - entries = [] - if neuron_compiler_version is not None: - registry = registry.get(neuron_compiler_version, {}) - else: - for version_ in registry: - entries += _list_in_registry_dict( - registry, model_name_or_path_or_hash=model_name_or_path_or_hash, neuron_compiler_version=version_ - ) - return entries - - def validate_features_input_shapes(input_shapes: Tuple[Tuple[str, Tuple[int, ...]], ...]) -> bool: - return len(input_shapes) > 0 and all(len(entry) == 2 for entry in input_shapes) - - # model_key is either a model name or path or a model hash. - for model_key in registry: - data = registry[model_key] - if model_name_or_path_or_hash is not None and not ( - data["model_name_or_path"].startswith(model_name_or_path_or_hash) - or data["model_hash"].startswith(model_name_or_path_or_hash) - ): - continue - - for features in data["features"]: - if not validate_features_input_shapes(features["input_shapes"]): - continue - if len(features["input_shapes"]) > 1: - inputs = "\n\t- ".join(f"{x[0]} => {x[1]}" for x in features["input_shapes"]) - inputs = f"\t- {inputs}" - else: - x = features["input_shapes"][0] - inputs = f"\t- {x[0]} => {x[1]}" - information = [ - f"Model name:\t{data['model_name_or_path']}", - f"Model hash:\t{data['model_hash']}", - f"Global hash:\t{features['neuron_hash']}", - f"Precision:\t{features['precision']}", - f"Neuron X Compiler version:\t{neuron_compiler_version}", - f"Num of neuron cores:\t{features['num_neuron_cores']}", - f"Input shapes:\n{inputs}", - ] - entries.append("\n".join(information)) - return entries - - -def list_in_registry( - repo_id: str, model_name_or_path_or_hash: Optional[str] = None, neuron_compiler_version: Optional[str] = None -): - with tempfile.TemporaryDirectory() as tmpdirname: - hf_hub_download(repo_id, REGISTRY_FILENAME, local_dir=tmpdirname, local_dir_use_symlinks=False) - registry_filename = Path(tmpdirname) / REGISTRY_FILENAME - with open(registry_filename, "r") as fp: - registry = json.load(fp) - - return _list_in_registry_dict( - registry, - model_name_or_path_or_hash=model_name_or_path_or_hash, - neuron_compiler_version=neuron_compiler_version, - ) - - -class StaticTemporaryDirectory: - def __init__(self, dirname: Union[str, Path]): - if isinstance(dirname, str): - dirname = Path(dirname) - if dirname.exists(): - raise FileExistsError( - f"{dirname} already exists, cannot create a static temporary directory with this name." - ) - self.dirname = dirname - - def __enter__(self): - self.dirname.mkdir(parents=True) - return self.dirname - - def __exit__(self, *exc): - shutil.rmtree(self.dirname) - - -@dataclass -class _MutableHashAttribute: - model_hash: str = "" - overall_hash: str = "" - - @property - def is_empty(self): - return (not self.model_hash) or (not self.overall_hash) - - def __hash__(self): - return hash(f"{self.model_hash}_{self.overall_hash}") - - -@dataclass(frozen=True) -class _UnspecifiedHashAttribute: - min_optimum_neuron_version: Optional[str] = None - min_neuron_compiler_version: Optional[str] = None - default: Optional[Any] = None - - @classmethod - def with_args( - cls, - min_optimum_neuron_version: Optional[str] = None, - min_neuron_compiler_version: Optional[str] = None, - default: Optional[Any] = None, - ) -> Callable[[], "_UnspecifiedHashAttribute"]: - def constructor(): - return cls( - min_optimum_neuron_version=min_optimum_neuron_version, - min_neuron_compiler_version=min_neuron_compiler_version, - default=default, - ) - - return constructor - - def check_requirements_are_met(self, neuron_compiler_version: str): - if self.should_be_inserted_in_hash_dict(neuron_compiler_version) and self.default is None: - raise ValueError("A default value must be specified.") - # from ..version import __version__ - - # optimum_neuron_requirement = True - # if self.min_optimum_neuron_version is not None: - # if version.parse(__version__) >= version.parse(self.min_optimum_neuron_version): - # optimum_neuron_requirement = self.default is not None - - # neuron_compiler_requirement = True - # if self.min_neuron_compiler_version is not None: - # if version.parse(neuron_compiler_version) >= version.parse(self.min_neuron_compiler_version): - # neuron_compiler_requirement = self.default is not None - - # if not optimum_neuron_requirement or not neuron_compiler_requirement: - # raise ValueError("A default value must be specified.") - - def should_be_inserted_in_hash_dict(self, neuron_compiler_version: str) -> bool: - from ..version import __version__ - - optimum_neuron_requirement = False - if self.min_optimum_neuron_version is not None: - optimum_neuron_requirement = version.parse(__version__) >= version.parse(self.min_optimum_neuron_version) - - neuron_compiler_requirement = False - if self.min_neuron_compiler_version is not None: - neuron_compiler_requirement = version.parse(neuron_compiler_version) >= version.parse( - self.min_neuron_compiler_version - ) - - return optimum_neuron_requirement or neuron_compiler_requirement - - -@dataclass(frozen=True) -class NeuronHash: - model: InitVar["PreTrainedModel"] - input_shapes: Tuple[Tuple[str, Tuple[int, ...]], ...] - data_type: torch.dtype - num_neuron_cores: int = field(default_factory=get_num_neuron_cores_used) - neuron_compiler_version: str = field(default_factory=get_neuronxcc_version) - fsdp: Union[int, _UnspecifiedHashAttribute] = field( - default_factory=_UnspecifiedHashAttribute.with_args(min_optimum_neuron_version="0.0.8", default=False) - ) - tensor_parallel_size: Union[int, _UnspecifiedHashAttribute] = field( - default_factory=_UnspecifiedHashAttribute.with_args(min_optimum_neuron_version="0.0.8", default=1) - ) - pipeline_parallel_size: Union[int, _UnspecifiedHashAttribute] = field( - default_factory=_UnspecifiedHashAttribute.with_args(min_optimum_neuron_version="0.0.17", default=1) - ) - _model_name_or_path: Optional[str] = None - _is_private: Optional[bool] = None - _model_type: Optional[str] = None - _hash: _MutableHashAttribute = field(default_factory=_MutableHashAttribute) - - def __post_init__(self, model: "PreTrainedModel"): - for attr in self.__dict__.values(): - if isinstance(attr, _UnspecifiedHashAttribute): - attr.check_requirements_are_met(self.neuron_compiler_version) - - # Checking whether the model is private or not. - is_private = None - model_name_or_path = get_model_name_or_path(model.config) - if model_name_or_path is None: - is_private = True - elif Path(model_name_or_path).exists(): - is_private = True - else: - is_private = is_private_repo(model_name_or_path) - - # Using object.__setattr__ to change the field value because NeuronHash is supposed to be frozen. - # Not very clean, but it should work here. - super().__setattr__("_model_name_or_path", model_name_or_path) - super().__setattr__("_is_private", is_private) - super().__setattr__("_model_type", model.config.model_type) - - self.compute_hash(model) - - def to_dict_for_neuron_compile_report(self) -> Dict[str, Any]: - return { - "model_hash": self._hash.model_hash, - "overall_hash": self._hash.overall_hash, - "neuron_compiler_version": self.neuron_compiler_version, - "model_name_or_path": self._model_name_or_path, - "is_private": self._is_private, - "model_type": self._model_type, - } - - @classmethod - def from_neuron_compile_report(cls, data: Dict[str, Any]) -> "NeuronHash": - # Creating a dummy neuron hash. - neuron_hash = cls(PreTrainedModel(PretrainedConfig()), (), torch.float32) - # Populate it with data. - super(cls, neuron_hash).__setattr__( - "_hash", _MutableHashAttribute(model_hash=data["model_hash"], overall_hash=data["overall_hash"]) - ) - super(cls, neuron_hash).__setattr__("neuron_compiler_version", data["neuron_compiler_version"]) - super(cls, neuron_hash).__setattr__("_model_name_or_path", data["model_name_or_path"]) - super(cls, neuron_hash).__setattr__("_is_private", data["is_private"]) - super(cls, neuron_hash).__setattr__("_model_type", data["model_type"]) - return neuron_hash - - def _insert_potential_unspecified_hash_attribute( - self, attribute_name: str, attribute: Any, hash_dict: Dict[str, Any] - ): - """ - Inserts `attribute` in `hash_dict` only if it is a specified attribute or if it has a default value. - """ - if isinstance(attribute, _UnspecifiedHashAttribute) and attribute.should_be_inserted_in_hash_dict: - hash_dict[attribute_name] = attribute.default - else: - hash_dict[attribute_name] = attribute - - def state_dict_to_bytes(self, state_dict: Dict[str, torch.Tensor]) -> bytes: - cast_to_mapping = { - torch.bfloat16: torch.float16, - } - bytes_to_join = [] - for name, tensor in state_dict.items(): - memfile = io.BytesIO() - # It is actually important to first move the tensor to CPU then cast, because all XLA tensor operations, - # and in particular `to()` behave differently when doing `neuron_parallel_compile`. - np.save(memfile, tensor.cpu().to(cast_to_mapping.get(tensor.dtype, tensor.dtype)).numpy()) - bytes_to_join.append(name.encode("utf-8")) - bytes_to_join.append(memfile.getvalue()) - return b"".join(bytes_to_join) - - def compute_sha512_hash(self, *buffers: bytes) -> str: - hash_ = hashlib.sha512() - for buffer in buffers: - hash_.update(buffer) - return hash_.hexdigest() - - @requires_neuronx_distributed - def compute_hash(self, model: Optional["PreTrainedModel"] = None) -> Tuple[str, str]: - if self._hash.is_empty: - if model is None: - raise ValueError("A model must be specified the first time the hash is computed.") - - from neuronx_distributed.pipeline import NxDPPModel - - if isinstance(model, NxDPPModel): - state_dict = model.local_state_dict() - else: - state_dict = model.state_dict() - model_hash = self.compute_sha512_hash(self.state_dict_to_bytes(state_dict)) - - hash_dict = asdict(self) - hash_dict["model"] = model_hash - hash_dict["_model_class"] = model.__class__ - hash_dict["_is_model_training"] = model.training - hash_dict.pop("_is_private") - hash_dict.pop("_model_type") - hash_dict.pop("_hash") - - self._insert_potential_unspecified_hash_attribute( - "tensor_parallel_size", self.tensor_parallel_size, hash_dict - ) - self._insert_potential_unspecified_hash_attribute( - "pipeline_parallel_size", self.tensor_parallel_size, hash_dict - ) - self._insert_potential_unspecified_hash_attribute("fsdp", self.fsdp, hash_dict) - - hash_dict["data_type"] = str(hash_dict["data_type"]).split(".")[1] - - buffers = [name.encode("utf-8") + str(value).encode("utf-8") for name, value in hash_dict.items()] - - overal_hash = self.compute_sha512_hash(*buffers) - self._hash.model_hash = model_hash - self._hash.overall_hash = overal_hash - - return self._hash.model_hash, self._hash.overall_hash - - @property - def folders(self) -> List[str]: - if self._model_type is None: - raise ValueError("Model type was not set.") - model_hash, overall_hash = self.compute_hash() - return [ - self.neuron_compiler_version, - self._model_type, - model_hash, - overall_hash, - ] - - @property - def cache_path(self) -> Path: - return Path().joinpath(*self.folders) - - @property - def neuron_compiler_version_dir_name(self): - return get_neuron_compiler_version_dir_name(self.neuron_compiler_version) - - @property - def is_private(self): - return self._is_private - - -@dataclass -class CachedModelOnTheHub: - repo_id: str - folder: Union[str, Path] - revision: str = "main" - files_on_the_hub: List[str] = field(default_factory=list) - - def __post_init__(self): - if isinstance(self.folder, Path): - self.folder = self.folder.as_posix() - - -def get_cached_model_on_the_hub(neuron_hash: NeuronHash) -> Optional[CachedModelOnTheHub]: - target_directory = neuron_hash.cache_path - - cache_repo_id = None - cache_revision = None - - for repo_id in get_hf_hub_cache_repos(): - if isinstance(repo_id, tuple): - repo_id, revision = repo_id - else: - revision = "main" - try: - repo_filenames = HfApi().list_repo_files(repo_id, revision=revision, token=get_token()) - except Exception: - continue - model_files_on_the_hub = [] - was_found_in_repo = False - for repo_filename in repo_filenames: - if repo_filename.startswith(target_directory.as_posix()): - if cache_repo_id is None: - cache_repo_id = repo_id - cache_revision = revision - was_found_in_repo = True - model_files_on_the_hub.append(repo_filename) - if was_found_in_repo: - break - - if cache_repo_id is None: - cached_model = None - else: - cached_model = CachedModelOnTheHub( - cache_repo_id, target_directory, revision=cache_revision, files_on_the_hub=model_files_on_the_hub - ) - - return cached_model - - -def default_path_in_repo_to_path_in_target_directory(path: Path, neuron_hash: NeuronHash): - cache_path = neuron_hash.cache_path - # The last part of cache_path is the overall hash. - return Path(neuron_hash.neuron_compiler_version_dir_name) / path_after_folder(path, cache_path.name) - - -def default_local_path_to_path_in_repo(path: Path, neuron_hash: NeuronHash): - return path_after_neuron_compiler_version_dir(path, neuron_hash.neuron_compiler_version) - - -def download_cached_model_from_hub( - neuron_hash: NeuronHash, - target_directory: Optional[Union[str, Path]] = None, - path_in_repo_to_path_in_target_directory: Optional[Union[Literal["default"], Callable[[Path], Path]]] = None, -) -> bool: - if target_directory is None: - target_directory = get_neuron_cache_path() - if target_directory is None: - raise ValueError("A target directory must be specified when no caching directory is used.") - elif isinstance(target_directory, str): - target_directory = Path(target_directory) - - if path_in_repo_to_path_in_target_directory == "default": - path_in_repo_to_path_in_target_directory = functools.partial( - default_path_in_repo_to_path_in_target_directory, neuron_hash=neuron_hash - ) - - if path_in_repo_to_path_in_target_directory is None: - - def path_in_repo_to_path_in_target_directory(x): - return x - - cached_model = get_cached_model_on_the_hub(neuron_hash) - if cached_model is not None: - folder = cached_model.folder - - ignore_patterns = [] - for filename in cached_model.files_on_the_hub: - path_in_repo = Path(filename) - if path_in_repo_to_path_in_target_directory is not None: - potential_local_path = target_directory / path_in_repo_to_path_in_target_directory(path_in_repo) - else: - potential_local_path = target_directory / path_in_repo - - potential_local_path = remove_ip_adress_from_path(potential_local_path) - - if potential_local_path.exists(): - ignore_patterns.append(filename) - - needs_to_download = cached_model.files_on_the_hub and len(ignore_patterns) != len( - cached_model.files_on_the_hub - ) - - if needs_to_download: - files_before_downloading = [f for f in (target_directory / folder).glob("**/*") if f.is_file()] - huggingface_hub.snapshot_download( - repo_id=cached_model.repo_id, - revision=cached_model.revision, - repo_type="model", - local_dir=target_directory, - local_dir_use_symlinks=False, - allow_patterns=f"{folder}/**", - ignore_patterns=ignore_patterns, - tqdm_class=None, - ) - - local_folder = target_directory / folder - for path in local_folder.glob("**/*"): - if path.is_dir(): - continue - if path in files_before_downloading: - continue - target_path = target_directory / path_in_repo_to_path_in_target_directory(path) - target_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(path, target_path) - # TODO: remove old directories. - - return cached_model is not None - - -def push_to_cache_on_hub( - neuron_hash: NeuronHash, - local_cache_dir_or_file: Path, - cache_repo_id: Optional[str] = None, - overwrite_existing: bool = False, - local_path_to_path_in_repo: Optional[Union[Literal["default"], Callable[[Path], Path]]] = None, - fail_when_could_not_push: bool = False, -) -> Optional[CachedModelOnTheHub]: - if cache_repo_id is None: - cache_repo_id = get_hf_hub_cache_repos()[0] - - if not has_write_access_to_repo(cache_repo_id): - error_message = ( - f"Could not push the cached model to {cache_repo_id} because you do not have write access to this repo." - ) - if fail_when_could_not_push: - raise ValueError(error_message) - if is_main_worker(): - logger.warning(error_message) - return - - try: - create_registry_file_if_does_not_exist(cache_repo_id) - _REGISTRY_FILE_EXISTS[cache_repo_id] = True - except HfHubHTTPError: - pass - - is_cache_repo_private = is_private_repo(cache_repo_id) - if neuron_hash.is_private and not is_cache_repo_private: - error_message = ( - f"Could not push the cached model to {cache_repo_id} because this repo is not private but the original " - "model is coming from private repo." - ) - if fail_when_could_not_push: - raise ValueError(error_message) - if is_main_worker(): - logger.warning(error_message) - return - - if local_path_to_path_in_repo == "default": - local_path_to_path_in_repo = functools.partial(default_local_path_to_path_in_repo, neuron_hash=neuron_hash) - - if local_path_to_path_in_repo is not None: - path_in_repo = local_path_to_path_in_repo(local_cache_dir_or_file) - else: - path_in_repo = local_cache_dir_or_file - - # Joining a path to a absolute path ignores the original path, so we remove the root directory "/" in this case. - if path_in_repo.is_absolute(): - path_in_repo = Path().joinpath(*path_in_repo.parts[1:]) - path_in_repo = neuron_hash.cache_path / path_in_repo - - repo_filenames = HfApi().list_repo_files(cache_repo_id, token=get_token()) - path_in_repo_str = path_in_repo.as_posix() - if local_cache_dir_or_file.is_dir(): - exists = any(filename.startswith(path_in_repo_str) for filename in repo_filenames) - else: - exists = any(filename == path_in_repo_str for filename in repo_filenames) - if is_main_worker() and exists: - if not overwrite_existing: - logger.info( - f"Did not push the cached model located at {local_cache_dir_or_file} to the repo named {cache_repo_id} " - "because it already exists there. Use overwrite_existing=True if you want to overwrite the cache on the " - "Hub." - ) - else: - logger.warning( - "Overwriting the already existing cached model on the Hub by the one located at " - f"{local_cache_dir_or_file}" - ) - - could_not_push_message = "Could not push the cached model to the repo {cache_repo_id}. Error message:\n{error}." - success = True - if local_cache_dir_or_file.is_dir(): - try: - HfApi().upload_folder( - folder_path=local_cache_dir_or_file.as_posix(), - path_in_repo=path_in_repo.as_posix(), - repo_id=cache_repo_id, - repo_type="model", - ) - except HfHubHTTPError as e: - if fail_when_could_not_push: - raise e - msg = could_not_push_message.format(cache_repo_id=cache_repo_id, error=e) - msg = re.sub(_HF_HUB_HTTP_ERROR_REQUEST_ID_PATTERN, "", msg) - if is_main_worker(): - warn_once(logger, msg) - success = False - else: - try: - HfApi().upload_file( - path_or_fileobj=local_cache_dir_or_file.as_posix(), - path_in_repo=path_in_repo.as_posix(), - repo_id=cache_repo_id, - repo_type="model", - ) - except HfHubHTTPError as e: - if fail_when_could_not_push: - raise e - msg = could_not_push_message.format(cache_repo_id=cache_repo_id, error=e) - msg = re.sub(_HF_HUB_HTTP_ERROR_REQUEST_ID_PATTERN, "", msg) - if is_main_worker(): - warn_once(logger, msg) - success = False - - # Adding the model to the registry if the upload was successful. - if success: - try: - add_in_registry(cache_repo_id, neuron_hash) - except HfHubHTTPError: - pass - - return CachedModelOnTheHub(cache_repo_id, path_in_repo) diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index c578cca4d..4ea89f490 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -18,7 +18,7 @@ import logging import os import shutil -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from enum import Enum from pathlib import Path from tempfile import TemporaryDirectory @@ -28,10 +28,10 @@ from transformers import AutoConfig, PretrainedConfig from ..version import __version__ -from .cache_utils import get_neuron_cache_path, load_custom_cache_repo_name_from_hf_home +from .cache_utils import get_hf_hub_cache_repo, get_neuron_cache_path from .import_utils import is_neuronx_available from .patching import patch_everywhere -from .require_utils import requires_torch_neuronx, requires_torch_xla +from .require_utils import requires_torch_neuronx if is_neuronx_available(): @@ -78,6 +78,8 @@ def create_compile_cache(): ] NEURON_CONFIG_WHITE_LIST = ["input_names", "output_names", "model_type"] +DEFAULT_PATH_FOR_NEURON_CC_WRAPPER = Path(__file__).parent.as_posix() + class CompileCacheHfProxy(CompileCache): """A HuggingFace Hub proxy cache implementing the CompileCache API. @@ -238,15 +240,6 @@ def download_file_to_string(self, filename: str, limit: int = None): return s -def get_hub_cache(): - HUB_CACHE = "aws-neuron/optimum-neuron-cache" - custom_hub_cache = load_custom_cache_repo_name_from_hf_home() - if custom_hub_cache is not None and len(custom_hub_cache) > 0: - return custom_hub_cache - else: - return os.getenv("CUSTOM_CACHE_REPO", HUB_CACHE) - - def create_hub_compile_cache_proxy( cache_url: Optional[CacheUrl] = None, cache_repo_id: Optional[str] = None, @@ -254,7 +247,7 @@ def create_hub_compile_cache_proxy( if cache_url is None: cache_url = CacheUrl.get_cache_url() if cache_repo_id is None: - cache_repo_id = get_hub_cache() + cache_repo_id = get_hf_hub_cache_repo() default_cache = CompileCacheS3(cache_url) if cache_url.is_s3() else CompileCacheFs(cache_url) # Reevaluate endpoint and token (needed for tests altering the environment) endpoint = os.getenv("HF_ENDPOINT") @@ -366,21 +359,23 @@ def hf_create_compile_cache(cache_url): patch_everywhere("create_compile_cache", create_compile_cache, "libneuronxla") -@requires_torch_neuronx -@requires_torch_xla @contextmanager -def patch_neuron_cc_wrapper(): +def patch_neuron_cc_wrapper( + directory: Optional[Union[str, Path]] = DEFAULT_PATH_FOR_NEURON_CC_WRAPPER, restore_path: bool = True +): """ Patches the `neuron_cc_wrapper` file to force it use our own version of it which essentially makes sure that it uses our caching system. """ + context_manager = TemporaryDirectory() if directory is None else nullcontext(enter_result=directory) tmpdirname = "" try: - with TemporaryDirectory() as dirname: + with context_manager as dirname: tmpdirname = dirname src = Path(__file__).parent / "neuron_cc_wrapper" dst = Path(tmpdirname) / "neuron_cc_wrapper" - shutil.copy(src, dst) + if src != dst: + shutil.copy(src, dst) path = os.environ["PATH"] os.environ["PATH"] = f"{tmpdirname}:{path}" @@ -389,7 +384,8 @@ def patch_neuron_cc_wrapper(): except Exception as e: raise e finally: - os.environ["PATH"] = os.environ["PATH"].replace(f"{tmpdirname}:", "") + if restore_path: + os.environ["PATH"] = os.environ["PATH"].replace(f"{tmpdirname}:", "") @requires_torch_neuronx @@ -418,7 +414,7 @@ def get_hub_cached_entries( model_id: str, mode: Union[Literal["training"], Literal["inference"], Mode], cache_repo_id: Optional[str] = None ): if cache_repo_id is None: - cache_repo_id = get_hub_cache() + cache_repo_id = get_hf_hub_cache_repo() # Allocate a Hub API with refreshed information (required for tests altering the env) endpoint = os.getenv("HF_ENDPOINT") token = get_token() diff --git a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py index ea0a34660..8100d5421 100644 --- a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py +++ b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py @@ -21,6 +21,7 @@ import torch from ...utils import logging +from .hub_neuronx_cache import patch_neuron_cc_wrapper from .misc import is_main_worker from .require_utils import requires_torch_xla @@ -42,7 +43,7 @@ def init_process_group(): raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") -def set_common_neuron_cc_flags(): +def set_common_flags(): """ Sets environment variables for transformer-based models training with AWS Neuron. """ @@ -52,6 +53,8 @@ def set_common_neuron_cc_flags(): # checkpointing. More information here: # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc os.environ["MALLOC_ARENA_MAX"] = "64" + # Setting the path to use our patched version of the `neuron_cc_wrapper`. + patch_neuron_cc_wrapper(restore_path=False).__enter__() def set_neuron_cc_flags_for_torch_amp(): diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 0c83e97ff..9e7d45370 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -14,50 +14,32 @@ # limitations under the License. """Tests for the cache utilities.""" -import json import logging import os import random -from dataclasses import FrozenInstanceError from pathlib import Path from tempfile import TemporaryDirectory from typing import List from unittest import TestCase import huggingface_hub -import pytest -import torch -from huggingface_hub import HfApi, create_repo, delete_repo, get_token, hf_hub_download, login -from transformers import BertConfig, BertModel, set_seed -from transformers.testing_utils import TOKEN as TRANSFORMERS_TOKEN -from transformers.testing_utils import USER as TRANSFORMERS_USER +from huggingface_hub import create_repo, delete_repo, get_token, login from transformers.testing_utils import is_staging_test from optimum.neuron.utils.cache_utils import ( CACHE_REPO_FILENAME, - REGISTRY_FILENAME, - NeuronHash, - _list_in_registry_dict, - add_in_registry, - create_registry_file_if_does_not_exist, - download_cached_model_from_hub, - get_cached_model_on_the_hub, get_neuron_cache_path, get_num_neuron_cores_used, has_write_access_to_repo, list_files_in_neuron_cache, - list_in_registry, load_custom_cache_repo_name_from_hf_home, - path_after_folder, - push_to_cache_on_hub, - remove_ip_adress_from_path, set_custom_cache_repo_name_in_hf_home, set_neuron_cache_path, ) from optimum.neuron.utils.testing_utils import is_trainium_test from optimum.utils.testing_utils import TOKEN, USER -from .utils import MyTinyModel, StagingTestMixin, TrainiumTestMixin, get_random_string +from .utils import StagingTestMixin, TrainiumTestMixin, get_random_string DUMMY_COMPILER_VERSION = "1.2.3" @@ -116,7 +98,7 @@ def test_get_num_neuron_cores_used(self): self.assertEqual(get_num_neuron_cores_used(), 1) randon_num_cores = random.randint(1, 32) - os.environ["LOCAL_WORLD_SIZE"] = str(randon_num_cores) + os.environ["WORLD_SIZE"] = str(randon_num_cores) self.assertEqual(get_num_neuron_cores_used(), randon_num_cores) def _create_random_neuron_cache( @@ -160,90 +142,6 @@ def test_list_files_in_neuron_cache(self): filenames = self._create_random_neuron_cache(Path(tmpdirname), return_only_relevant_files=True) self.assertSetEqual(set(filenames), set(list_files_in_neuron_cache(tmpdirname, only_relevant_files=True))) - def test_list_in_registry_dict(self): - registry = { - "2.1.0": { - "model_1": { - "model_name_or_path": "model_1", - "model_hash": "my model hash", - "features": [ - { - "input_shapes": [["x", [1, 2]], ["y", [2, 3]]], - "precision": "torch.float32", - "num_neuron_cores": 16, - "neuron_hash": "neuron hash 1", - }, - { - "input_shapes": [["x", [3, 2]], ["y", [7, 3]]], - "precision": "torch.float32", - "num_neuron_cores": 8, - "neuron_hash": "neuron hash 2", - }, - ], - }, - "model_2": { - "model_name_or_path": "null", - "model_hash": "my model hash 2", - "features": [ - { - "input_shapes": [["x", [1, 2]], ["y", [2, 3]]], - "precision": "torch.float16", - "num_neuron_cores": 16, - "neuron_hash": "neuron hash 3", - }, - { - "input_shapes": [["x", [3, 2]], ["y", [7, 3]]], - "precision": "torch.float32", - "num_neuron_cores": 8, - "neuron_hash": "neuron hash 4", - }, - ], - }, - }, - "2.5.0": { - "model_1": { - "model_name_or_path": "model_1", - "model_hash": "my model hash", - "features": [ - { - "input_shapes": [["x", [1, 2]], ["y", [2, 3]]], - "precision": "torch.float32", - "num_neuron_cores": 16, - "neuron_hash": "neuron hash 5", - }, - { - "input_shapes": [["x", [3, 2]], ["y", [7, 3]]], - "precision": "torch.float32", - "num_neuron_cores": 8, - "neuron_hash": "neuron hash 6", - }, - ], - }, - }, - } - - result = _list_in_registry_dict(registry) - self.assertEqual(len(result), 6) - self.assertTrue(result[-1].startswith("Model name:\tmodel_1")) - - result = _list_in_registry_dict(registry, model_name_or_path_or_hash="model_1") - self.assertEqual(len(result), 4) - self.assertTrue(result[0].startswith("Model name:\tmodel_1")) - - result = _list_in_registry_dict(registry, model_name_or_path_or_hash="my model hash 2") - self.assertEqual(len(result), 2) - self.assertTrue(result[0].startswith("Model name:\tnull")) - - result = _list_in_registry_dict(registry, neuron_compiler_version="2.5.0") - self.assertEqual(len(result), 2) - self.assertTrue(result[0].startswith("Model name:\tmodel_1")) - - result = _list_in_registry_dict(registry, model_name_or_path_or_hash="random bad string") - self.assertEqual(len(result), 0) - - result = _list_in_registry_dict(registry, neuron_compiler_version="-1.2") - self.assertEqual(len(result), 0) - @is_staging_test class StagingNeuronUtilsTestCase(StagingTestMixin, TestCase): @@ -295,512 +193,3 @@ def test_has_write_access_to_repo(self): self.assertTrue(has_write_access_to_repo(self.CUSTOM_CACHE_REPO)) self.assertTrue(has_write_access_to_repo(self.CUSTOM_PRIVATE_CACHE_REPO)) - - @is_trainium_test - def test_list_in_registry(self): - def _test_list_in_registry(use_private_cache_repo: bool): - if use_private_cache_repo: - cache_repo = self.CUSTOM_PRIVATE_CACHE_REPO - else: - cache_repo = self.CUSTOM_CACHE_REPO - create_registry_file_if_does_not_exist(cache_repo) - entries = list_in_registry(cache_repo) - self.assertEqual(len(entries), 0) - - bert_model = BertModel(BertConfig()) - neuron_hash = NeuronHash( - bert_model, - (("x", (4, 12)), ("y", (4, 12))), - torch.float32, - 2, - neuron_compiler_version=DUMMY_COMPILER_VERSION, - ) - add_in_registry(cache_repo, neuron_hash) - entries = list_in_registry(cache_repo) - self.assertEqual(len(entries), 1) - - bert_model = BertModel(BertConfig()) - neuron_hash = NeuronHash( - bert_model, - (("x", (4, 8)), ("y", (4, 12))), - torch.float32, - 2, - neuron_compiler_version=DUMMY_COMPILER_VERSION, - ) - add_in_registry(cache_repo, neuron_hash) - entries = list_in_registry(cache_repo) - self.assertEqual(len(entries), 2) - - model_hash = neuron_hash.compute_hash()[0] - entries = list_in_registry(cache_repo, model_name_or_path_or_hash=model_hash) - self.assertEqual(len(entries), 1) - - entries = list_in_registry(cache_repo, model_name_or_path_or_hash="dummy hash") - self.assertEqual(len(entries), 0) - - entries = list_in_registry(cache_repo, neuron_compiler_version=DUMMY_COMPILER_VERSION) - self.assertEqual(len(entries), 2) - - entries = list_in_registry(cache_repo, neuron_compiler_version="Bad version") - self.assertEqual(len(entries), 0) - - _test_list_in_registry(False) - _test_list_in_registry(True) - - -@is_trainium_test -class NeuronHashTestCase(TestCase): - def test_neuron_hash_is_not_mutable(self): - bert_model = BertModel(BertConfig()) - neuron_hash = NeuronHash( - bert_model, - (("x", (4, 12)), ("y", (4, 12))), - torch.float32, - 2, - neuron_compiler_version=DUMMY_COMPILER_VERSION, - ) - - with self.assertRaises(FrozenInstanceError): - neuron_hash.model = bert_model - - with self.assertRaises(FrozenInstanceError): - neuron_hash.input_shapes = (("x", (2, 32)), ("y", (2, 32))) - - with self.assertRaises(FrozenInstanceError): - neuron_hash.num_neuron_cores = 32 - - def _test_neuron_hash( - self, - model_a, - input_shapes_a, - dtype_a, - num_neuron_cores_a, - model_b, - input_shapes_b, - dtype_b, - num_neuron_cores_b, - should_be_equal, - ): - neuron_hash_a = NeuronHash( - model_a, - input_shapes_a, - dtype_a, - num_neuron_cores=num_neuron_cores_a, - neuron_compiler_version=DUMMY_COMPILER_VERSION, - ) - neuron_hash_b = NeuronHash( - model_b, - input_shapes_b, - dtype_b, - num_neuron_cores=num_neuron_cores_b, - neuron_compiler_version=DUMMY_COMPILER_VERSION, - ) - if should_be_equal: - self.assertEqual(neuron_hash_a.compute_hash(), neuron_hash_b.compute_hash()) - else: - self.assertNotEqual(neuron_hash_a.compute_hash(), neuron_hash_b.compute_hash()) - - def test_computed_hash_is_same_for_same_models(self): - set_seed(42) - bert_model = BertModel(BertConfig()) - set_seed(42) - same_bert_model = BertModel(BertConfig()) - - return self._test_neuron_hash( - bert_model, - ((1, 2), (2, 3)), - torch.bfloat16, - 19, - same_bert_model, - ((1, 2), (2, 3)), - torch.bfloat16, - 19, - True, - ) - - def test_computed_hash_is_different_for_different_models(self): - set_seed(42) - bert_model = BertModel(BertConfig()) - set_seed(38) - different_bert_model = BertModel(BertConfig()) - - return self._test_neuron_hash( - bert_model, - ((1, 2), (2, 3)), - torch.bfloat16, - 19, - different_bert_model, - ((1, 2), (2, 3)), - torch.bfloat16, - 19, - False, - ) - - def test_computed_hash_is_different_for_different_parameters_but_same_model(self): - bert_model = BertModel(BertConfig()) - parameters = [[((1, 2), (2, 3)), ((2, 3), (3, 4))], [torch.float32, torch.float16], [32, 2]] - params_a = [p[0] for p in parameters] - for i in range(len(parameters)): - params_b = [p[int(i == j)] for j, p in enumerate(parameters)] - self._test_neuron_hash(bert_model, *params_a, bert_model, *params_b, False) - - def test_neuron_hash_folders(self): - bert_model = BertModel(BertConfig()) - input_shapes = (("x", (1, 2)), ("y", (2, 3))) - data_type = torch.float32 - num_neuron_cores = 32 - - neuron_hash = NeuronHash( - bert_model, - input_shapes, - data_type, - num_neuron_cores=num_neuron_cores, - neuron_compiler_version=DUMMY_COMPILER_VERSION, - ) - hashes = neuron_hash.compute_hash() - expected_folders = [DUMMY_COMPILER_VERSION, "bert"] + list(hashes) - self.assertListEqual(neuron_hash.folders, expected_folders) - - def test_neuron_hash_is_private(self): - input_shapes = (("x", (1, 2)), ("y", (2, 3))) - data_type = torch.float32 - - bert_model = BertModel(BertConfig()) - neuron_hash = NeuronHash(bert_model, input_shapes, data_type, neuron_compiler_version=DUMMY_COMPILER_VERSION) - self.assertTrue(neuron_hash.is_private) - - bert_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") - neuron_hash = NeuronHash(bert_model, input_shapes, data_type, neuron_compiler_version=DUMMY_COMPILER_VERSION) - self.assertFalse(neuron_hash.is_private) - - with TemporaryDirectory() as tmpdirname: - bert_model.save_pretrained(tmpdirname) - local_bert_model = BertModel.from_pretrained(tmpdirname) - neuron_hash = NeuronHash( - local_bert_model, input_shapes, data_type, neuron_compiler_version=DUMMY_COMPILER_VERSION - ) - self.assertTrue(neuron_hash.is_private) - - -@is_trainium_test -@is_staging_test -@pytest.mark.skip("This is not needed anymore and will be removed.") -class CachedModelOnTheHubTestCase(StagingTestMixin, TestCase): - def test_push_to_hub_fails_with_private_model_and_public_repo(self): - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - neuron_hash = NeuronHash(tiny_model, input_shapes, data_type) - - cached_files = list_files_in_neuron_cache(tmpdirname) - - # The model being loaded locally is assumed to be private, push to hub should prevent from pushing to a - # public repo. - with self.assertRaisesRegex(ValueError, "Could not push the cached model"): - push_to_cache_on_hub( - neuron_hash, cached_files[0], self.CUSTOM_CACHE_REPO, fail_when_could_not_push=True - ) - - # It should work when using a private repo. - cached_model_on_the_hub = push_to_cache_on_hub( - neuron_hash, cached_files[0], self.CUSTOM_PRIVATE_CACHE_REPO - ) - self.assertIsNotNone(cached_model_on_the_hub) - - def test_push_to_hub_without_specifying_a_cache_repo_id(self): - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - neuron_hash = NeuronHash(tiny_model, input_shapes, data_type) - - cached_files = list_files_in_neuron_cache(tmpdirname) - - set_custom_cache_repo_name_in_hf_home(self.CUSTOM_PRIVATE_CACHE_REPO) - push_to_cache_on_hub(neuron_hash, cached_files[0]) - - def test_push_to_hub_overwrite_existing(self): - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - neuron_hash = NeuronHash(tiny_model, input_shapes, data_type) - - cache_dir = Path(tmpdirname) - cached_files = list_files_in_neuron_cache(cache_dir) - - push_to_cache_on_hub(neuron_hash, cached_files[0], self.CUSTOM_PRIVATE_CACHE_REPO) - - # With a file - with self.assertLogs("optimum", level="INFO") as cm: - push_to_cache_on_hub(neuron_hash, cached_files[0], self.CUSTOM_PRIVATE_CACHE_REPO) - self.assertIn("Did not push the cached model located at", cm.output[0]) - - with self.assertLogs("optimum", level="WARNING") as cm: - push_to_cache_on_hub( - neuron_hash, cached_files[0], self.CUSTOM_PRIVATE_CACHE_REPO, overwrite_existing=True - ) - self.assertIn( - "Overwriting the already existing cached model on the Hub by the one located at", cm.output[0] - ) - - # With a directory - with self.assertLogs("optimum", level="INFO") as cm: - push_to_cache_on_hub(neuron_hash, cache_dir, self.CUSTOM_PRIVATE_CACHE_REPO) - self.assertIn("Did not push the cached model located at", cm.output[0]) - - with self.assertLogs("optimum", level="WARNING") as cm: - push_to_cache_on_hub(neuron_hash, cache_dir, self.CUSTOM_PRIVATE_CACHE_REPO, overwrite_existing=True) - self.assertIn( - "Overwriting the already existing cached model on the Hub by the one located at", cm.output[0] - ) - - def test_push_to_hub_local_path_in_repo(self): - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - neuron_hash = NeuronHash(tiny_model, input_shapes, data_type) - - cache_dir = Path(tmpdirname) - cached_files = list_files_in_neuron_cache(cache_dir) - - def local_path_to_path_in_repo(path): - return Path("my/awesome/new/path") / path.name - - cached_file = cached_files[0] - - # With a file - push_to_cache_on_hub( - neuron_hash, - cached_file, - self.CUSTOM_PRIVATE_CACHE_REPO, - local_path_to_path_in_repo=local_path_to_path_in_repo, - ) - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - anonymous_cached_file = remove_ip_adress_from_path(cached_file) - path_in_repo = f"{neuron_hash.cache_path}/my/awesome/new/path/{anonymous_cached_file.name}" - self.assertIn(path_in_repo, files_in_repo) - - def another_local_path_to_path_in_repo(path): - return Path("my/another/awesome/new/path") / path.name - - # With a directory - push_to_cache_on_hub( - neuron_hash, - cache_dir, - self.CUSTOM_PRIVATE_CACHE_REPO, - local_path_to_path_in_repo=another_local_path_to_path_in_repo, - ) - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - for filename in cache_dir.glob("**/*"): - if filename.is_file(): - path_in_cache_dir = path_after_folder(filename, cache_dir, include_folder=True) - anonymous_path_in_cache_dir = remove_ip_adress_from_path(path_in_cache_dir) - path_in_repo = ( - f"{neuron_hash.cache_path}/my/another/awesome/new/path/{anonymous_path_in_cache_dir}" - ) - self.assertIn(path_in_repo, files_in_repo) - - def test_push_to_hub_without_writing_rights(self): - with TemporaryDirectory() as tmpdirname: - import torch_xla.core.xla_model as xm - - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - tiny_model.push_to_hub(f"tiny-public-model-{self.seed}") - public_tiny_model = MyTinyModel.from_pretrained(f"{USER}/tiny-public-model-{self.seed}") - neuron_hash = NeuronHash(public_tiny_model, input_shapes, data_type) - - public_tiny_model = public_tiny_model.to("xla") - input_ = torch.rand((32, 1)).to("xla") - public_tiny_model(input_) - xm.mark_step() - - # This should work because we do have writing access to this repo. - set_custom_cache_repo_name_in_hf_home(self.CUSTOM_CACHE_REPO) - push_to_cache_on_hub(neuron_hash, get_neuron_cache_path()) - - # Creating a repo under the Transformers user. - orig_token = self.set_hf_hub_token(TRANSFORMERS_TOKEN) - repo_name = f"optimum-neuron-cache-{self.seed}" - create_repo(repo_name, repo_type="model", exist_ok=True) - self.set_hf_hub_token(orig_token) - - set_custom_cache_repo_name_in_hf_home(f"{TRANSFORMERS_USER}/{repo_name}") - with self.assertLogs("optimum", "WARNING") as cm: - push_to_cache_on_hub(neuron_hash, get_neuron_cache_path()) - self.assertTrue(any("Could not push the cached model to" in output for output in cm.output)) - - self.set_hf_hub_token(TRANSFORMERS_TOKEN) - delete_repo(repo_name, repo_type="model") - self.set_hf_hub_token(orig_token) - - def _test_push_to_hub_create_and_add_registry(self, with_model_name_or_path: bool): - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - model_name = f"dummy_model-{self.seed}" - if with_model_name_or_path: - tiny_model.push_to_hub(model_name) - model_name = f"{USER}/{model_name}" - tiny_model.config._model_name_or_path = model_name - neuron_hash = NeuronHash(tiny_model, input_shapes, data_type) - - set_custom_cache_repo_name_in_hf_home(self.CUSTOM_PRIVATE_CACHE_REPO) - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - files_in_repo = [filename for filename in files_in_repo if not filename.startswith(".")] - self.assertListEqual(files_in_repo, [], "Repo should be empty") - - cached_files = list_files_in_neuron_cache(tmpdirname) - push_to_cache_on_hub(neuron_hash, cached_files[0]) - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - - self.assertIn(REGISTRY_FILENAME, files_in_repo) - hf_hub_download( - self.CUSTOM_PRIVATE_CACHE_REPO, - REGISTRY_FILENAME, - force_download=True, - local_dir=tmpdirname, - local_dir_use_symlinks=False, - ) - with open(Path(tmpdirname) / REGISTRY_FILENAME, "r") as fp: - registry = json.load(fp) - - neuron_compiler_version = list(registry.keys())[0] - model_key = list(registry[neuron_compiler_version].keys())[0] - expected_value = model_name if with_model_name_or_path else neuron_hash.compute_hash()[0] - self.assertEqual(model_key, expected_value) - - def test_push_to_hub_create_and_add_registry_without_model_name_or_path(self): - return self._test_push_to_hub_create_and_add_registry(False) - - def test_push_to_hub_create_and_add_registry_with_model_name_or_path(self): - return self._test_push_to_hub_create_and_add_registry(True) - - def test_download_cached_model_from_hub(self): - set_custom_cache_repo_name_in_hf_home(self.CUSTOM_PRIVATE_CACHE_REPO) - neuron_hash = self.push_tiny_pretrained_model_cache_to_hub(self.CUSTOM_PRIVATE_CACHE_REPO) - - neuron_cc_flags = os.environ["NEURON_CC_FLAGS"] - - with self.assertRaisesRegex( - ValueError, "A target directory must be specified when no caching directory is used" - ): - os.environ["NEURON_CC_FLAGS"] = "--no-cache" - self.assertTrue(download_cached_model_from_hub(neuron_hash)) - - os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags - self.assertTrue(download_cached_model_from_hub(neuron_hash)) - - def test_download_cached_model_from_hub_with_target_directory(self): - set_custom_cache_repo_name_in_hf_home(self.CUSTOM_PRIVATE_CACHE_REPO) - neuron_hash = self.push_tiny_pretrained_model_cache_to_hub(self.CUSTOM_PRIVATE_CACHE_REPO) - - cached_model_on_the_hub = get_cached_model_on_the_hub(neuron_hash) - if cached_model_on_the_hub is None: - self.fail("Could not find the model on the Hub, but it should be there.") - - repo_files = set(cached_model_on_the_hub.files_on_the_hub) - - if len(repo_files) == 0: - self.fail("Could not find any file in the Hub.") - - # With a target directory specified as a string. - with TemporaryDirectory() as tmpdirname: - success = download_cached_model_from_hub(neuron_hash, target_directory=tmpdirname) - self.assertTrue(success) - - tmpdir = Path(tmpdirname) - target_directory_files = {str(path_after_folder(f, tmpdir)) for f in tmpdir.glob("**/*") if f.is_file()} - self.assertSetEqual(target_directory_files, repo_files) - - # With a target directory specified as a Path. - with TemporaryDirectory() as tmpdirname: - tmpdir = Path(tmpdirname) - success = download_cached_model_from_hub(neuron_hash, target_directory=tmpdir) - self.assertTrue(success) - - target_directory_files = {str(path_after_folder(f, tmpdir)) for f in tmpdir.glob("**/*") if f.is_file()} - self.assertSetEqual(target_directory_files, repo_files) - - def test_download_cached_model_from_hub_with_path_in_repo_to_path_in_target_directory(self): - set_custom_cache_repo_name_in_hf_home(self.CUSTOM_PRIVATE_CACHE_REPO) - neuron_hash = self.push_tiny_pretrained_model_cache_to_hub(self.CUSTOM_PRIVATE_CACHE_REPO) - - cached_model_on_the_hub = get_cached_model_on_the_hub(neuron_hash) - if cached_model_on_the_hub is None: - self.fail("Could not find the model on the Hub, but it should be there.") - - def path_in_repo_to_path_in_target_directory(path): - return Path("custom_folder") / path.name - - repo_files = { - path_in_repo_to_path_in_target_directory(Path(f)) for f in cached_model_on_the_hub.files_on_the_hub - } - - if len(repo_files) == 0: - self.fail("Could not find any file in the Hub.") - - # With a target directory specified as a string. - with TemporaryDirectory() as tmpdirname: - success = download_cached_model_from_hub( - neuron_hash, - target_directory=tmpdirname, - path_in_repo_to_path_in_target_directory=path_in_repo_to_path_in_target_directory, - ) - self.assertTrue(success) - - tmpdir = Path(tmpdirname) - target_directory_files = {Path("custom_folder") / f.name for f in tmpdir.glob("**/*") if f.is_file()} - self.assertSetEqual(target_directory_files, repo_files) - - # Check the the original download directories do not exist since we specified a - # path_in_repo_to_path_in_target_directory function. - # self.assertListEqual([f.name for f in tmpdir.iterdir()], ["custom_folder"]) - - # TODO: not passing yet, to fix ASAP. - # def test_download_cached_model_from_hub_needs_to_download(self): - # os.environ["CUSTOM_CACHE_REPO"] = self.CUSTOM_PRIVATE_CACHE_REPO - - # with TemporaryDirectory() as tmpdirname: - # neuron_hash = self._push_tiny_pretrained_model_cache_to_hub(self.CUSTOM_PRIVATE_CACHE_REPO, cache_dir=tmpdirname) - - # with patch("huggingface_hub.snapshot_download") as mock_snapshot_download: - # # All the files are already there, should not download anything. - # download_cached_model_from_hub(neuron_hash, target_directory=tmpdirname) - # self.assertFalse(mock_snapshot_download.called, "No downloading should be peformed since all the files are already in the cache.") - # mock_snapshot_download.reset_mock() - # - # # All the files but one are there, should trigger downloading. - # for path in Path(tmpdirname).glob("**/*"): - # if path.is_file(): - # if path.suffix in [".json", ".txt"]: - # continue - # path.unlink() - # break - - # download_cached_model_from_hub(neuron_hash, target_directory=tmpdirname) - # self.assertTrue(mock_snapshot_download.called, "Downloading should be peformed since one file is missing in the cache.") - # mock_snapshot_download.reset_mock() - - # # No file at all, should download. - # with TemporaryDirectory() as another_tmpdirname: - # download_cached_model_from_hub(neuron_hash, target_directory=another_tmpdirname) - # self.assertTrue(mock_snapshot_download.called, "Downloading should be peformed since no file is in the cache.") diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py deleted file mode 100644 index 1bd9996dd..000000000 --- a/tests/test_trainer_callback.py +++ /dev/null @@ -1,210 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from pathlib import Path -from tempfile import TemporaryDirectory -from unittest import TestCase - -import pytest -import torch -from huggingface_hub import HfApi -from transformers.testing_utils import is_staging_test - -from optimum.neuron.trainer_callback import NeuronCacheCallback -from optimum.neuron.training_args import NeuronTrainingArguments -from optimum.neuron.utils.cache_utils import ( - NeuronHash, - list_files_in_neuron_cache, - push_to_cache_on_hub, - set_neuron_cache_path, -) -from optimum.neuron.utils.testing_utils import is_trainium_test - -from .utils import StagingTestMixin - - -@is_trainium_test -@is_staging_test -@pytest.mark.skip("Not used anymore, will be removed in cleaning PR.") -class NeuronCacheCallbackTestCase(StagingTestMixin, TestCase): - def test_neuron_hash_for_model(self): - with TemporaryDirectory() as tmpdirname: - args = NeuronTrainingArguments(tmpdirname) - model = self.create_tiny_pretrained_model(random_num_linears=True) - inputs = { - "x": torch.rand((1,)), - } - - callback = NeuronCacheCallback() - - # We first check that no hashes is in the hash cache already. - self.assertFalse(callback.neuron_hashes) - - callback.neuron_hash_for_model(args, model, inputs) - neuron_hash = callback.neuron_hashes[(model, (("x", tuple(inputs["x"].shape)),), torch.float32, 1)] - - same_neuron_hash = callback.neuron_hash_for_model(args, model, inputs) - - self.assertEqual(neuron_hash, same_neuron_hash, "Neuron hashes should be equal") - self.assertEqual(len(callback.neuron_hashes.keys()), 1, "There should be only one entry in neuron_hashes.") - - def test_try_to_fetch_cached_model(self): - import torch_xla.core.xla_model as xm - - os.environ["CUSTOM_CACHE_REPO"] = self.CUSTOM_PRIVATE_CACHE_REPO - model = self.create_tiny_pretrained_model(random_num_linears=True).to("xla") - - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - args = NeuronTrainingArguments(tmpdirname) - inputs = {"x": torch.rand((8, 1)).to("xla")} - output = model(**inputs) - xm.mark_step() - print(output) - neuron_hash = NeuronHash(model, (("x", (8, 1)),), torch.float32) - push_to_cache_on_hub(neuron_hash, Path(tmpdirname) / neuron_hash.neuron_compiler_version_dir_name) - - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - callback = NeuronCacheCallback() - args = NeuronTrainingArguments(tmpdirname) - inputs = {"x": torch.rand((24, 1))} - neuron_hash = callback.neuron_hash_for_model(args, model, inputs) - - found_in_cache = callback.try_to_fetch_cached_model(neuron_hash) - self.assertFalse(found_in_cache, "No model should have been fetched.") - - inputs = {"x": torch.rand((8, 1))} - neuron_hash = callback.neuron_hash_for_model(args, model, inputs) - - files_before_fetching = list_files_in_neuron_cache( - callback.tmp_neuron_cache_path, only_relevant_files=True - ) - tmp_neuron_cache_state = list(callback.tmp_neuron_cache_state) - neuron_cache_state = list_files_in_neuron_cache(Path(tmpdirname), only_relevant_files=True) - - found_in_cache = callback.try_to_fetch_cached_model(neuron_hash) - self.assertTrue(found_in_cache, "A model should have been fetched.") - - files_after_fetching = list_files_in_neuron_cache(callback.tmp_neuron_cache_path, only_relevant_files=True) - new_tmp_neuron_cache_state = list(callback.tmp_neuron_cache_state) - new_neuron_cache_state = list_files_in_neuron_cache(Path(tmpdirname), only_relevant_files=True) - - files_diff = [f for f in files_after_fetching if f not in files_before_fetching] - state_diff = [f for f in new_tmp_neuron_cache_state if f not in tmp_neuron_cache_state] - neuron_cache_files_diff = [f for f in new_neuron_cache_state if f not in neuron_cache_state] - - self.assertNotEqual(files_diff, []) - self.assertListEqual(files_diff, state_diff) - self.assertEqual(len(files_diff), len(neuron_cache_files_diff)) - - def test_synchronize_temporary_neuron_cache_state(self): - import torch_xla.core.xla_model as xm - - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - callback = NeuronCacheCallback() - - diff = callback.synchronize_temporary_neuron_cache_state() - self.assertListEqual(diff, [], "The diff should be empty.") - - model = self.create_tiny_pretrained_model(random_num_linears=True).to("xla") - inputs = {"x": torch.rand((8, 1)).to("xla")} - output = model(**inputs) - xm.mark_step() - print(output) - diff = callback.synchronize_temporary_neuron_cache_state() - self.assertNotEqual(diff, [], "The diff should not be empty.") - - diff = callback.synchronize_temporary_neuron_cache_state() - self.assertListEqual( - diff, [], "The diff should be empty because nothing happened since last synchronization" - ) - - def test_synchronize_temporary_neuron_cache(self): - import torch_xla.core.xla_model as xm - - os.environ["CUSTOM_CACHE_REPO"] = self.CUSTOM_PRIVATE_CACHE_REPO - model = self.create_tiny_pretrained_model(random_num_linears=True).to("xla") - - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - args = NeuronTrainingArguments(tmpdirname) - callback = NeuronCacheCallback() - - callback.synchronize_temporary_neuron_cache() - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - files_in_repo = [f for f in files_in_repo if not f.startswith(".")] - files_in_cache = list_files_in_neuron_cache(callback.neuron_cache_path, only_relevant_files=True) - self.assertListEqual(files_in_repo, [], "Repo should be empty.") - self.assertListEqual(files_in_cache, [], "Cache should be empty.") - - # Running some compilation. - for _ in range(3): - inputs = {"x": torch.rand((8, 1)).to("xla")} - output = model(**inputs) - xm.mark_step() - - xm.mark_step() - print(output) - - neuron_hash = callback.neuron_hash_for_model(args, model, inputs) - diff = callback.synchronize_temporary_neuron_cache_state() - callback.neuron_hash_to_files[neuron_hash].extend(diff) - - callback.synchronize_temporary_neuron_cache() - - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - files_in_repo = [f for f in files_in_repo if not f.startswith(".")] - files_in_cache = list_files_in_neuron_cache(callback.neuron_cache_path, only_relevant_files=True) - self.assertNotEqual(files_in_repo, [], "Repo should not be empty.") - self.assertNotEqual(files_in_cache, [], "Cache should not be empty.") - - # Using the same inputs, nothing should be uploaded. - inputs = {"x": torch.rand((8, 1)).to("xla")} - output = model(**inputs) - xm.mark_step() - print(output) - - neuron_hash = callback.neuron_hash_for_model(args, model, inputs) - diff = callback.synchronize_temporary_neuron_cache_state() - callback.neuron_hash_to_files[neuron_hash].extend(diff) - - callback.synchronize_temporary_neuron_cache() - - new_files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - new_files_in_repo = [f for f in new_files_in_repo if not f.startswith(".")] - new_files_in_cache = list_files_in_neuron_cache(callback.neuron_cache_path, only_relevant_files=True) - self.assertListEqual(files_in_repo, new_files_in_repo, "No new file should be in the Hub.") - self.assertListEqual(files_in_cache, new_files_in_cache, "No new file should be in the cache.") - - # New shape, should upload. - inputs = {"x": torch.rand((24, 1)).to("xla")} - output = model(**inputs) - xm.mark_step() - print(output) - - neuron_hash = callback.neuron_hash_for_model(args, model, inputs) - diff = callback.synchronize_temporary_neuron_cache_state() - callback.neuron_hash_to_files[neuron_hash].extend(diff) - - callback.synchronize_temporary_neuron_cache() - - files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) - files_in_repo = [f for f in files_in_repo if not f.startswith(".")] - files_in_cache = list_files_in_neuron_cache(callback.neuron_cache_path, only_relevant_files=True) - self.assertNotEqual(files_in_repo, new_files_in_repo, "New files should be in the Hub.") - self.assertNotEqual(files_in_cache, new_files_in_cache, "New files should be in the cache.") diff --git a/tests/test_trainers.py b/tests/test_trainers.py index 09a5e1671..d863e8db8 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -35,7 +35,6 @@ from optimum.neuron.utils.cache_utils import ( get_neuron_cache_path, list_files_in_neuron_cache, - remove_ip_adress_from_path, set_neuron_cache_path, ) from optimum.neuron.utils.testing_utils import is_trainium_test @@ -140,16 +139,15 @@ def test_train_and_eval(self): last_files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) last_files_in_repo = [f for f in last_files_in_repo if not f.startswith(".")] last_files_in_cache = list_files_in_neuron_cache(get_neuron_cache_path(), only_relevant_files=True) - last_files_in_cache = [remove_ip_adress_from_path(p) for p in last_files_in_cache] # TODO: investigate that, not urgent. - # self.assertListEqual( - # files_in_repo, last_files_in_repo, "No file should have been added to the Hub after first training." - # ) - # self.assertListEqual( - # files_in_cache, - # last_files_in_cache, - # "No file should have been added to the cache after first training.", - # ) + self.assertListEqual( + files_in_repo, last_files_in_repo, "No file should have been added to the Hub after first training." + ) + self.assertListEqual( + files_in_cache, + last_files_in_cache, + "No file should have been added to the cache after first training.", + ) self.assertTrue( second_training_duration < first_training_duration, @@ -295,16 +293,15 @@ def test_train_and_eval_multiple_workers(self): last_files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO) last_files_in_repo = [f for f in last_files_in_repo if not f.startswith(".")] last_files_in_cache = list_files_in_neuron_cache(get_neuron_cache_path(), only_relevant_files=True) - last_files_in_cache = [remove_ip_adress_from_path(p) for p in last_files_in_cache] # TODO: investigate that, not urgent. - # self.assertListEqual( - # files_in_repo, last_files_in_repo, "No file should have been added to the Hub after first training." - # ) - # self.assertListEqual( - # files_in_cache, - # last_files_in_cache, - # "No file should have been added to the cache after first training.", - # ) + self.assertListEqual( + files_in_repo, last_files_in_repo, "No file should have been added to the Hub after first training." + ) + self.assertListEqual( + files_in_cache, + last_files_in_cache, + "No file should have been added to the cache after first training.", + ) self.assertTrue( second_training_duration < first_training_duration, diff --git a/tests/utils.py b/tests/utils.py index f4b584e8c..1d5a7387c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,11 +16,8 @@ import os import random -import shutil import string -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Set, Tuple import torch from datasets import Dataset, DatasetDict @@ -30,15 +27,9 @@ from transformers.testing_utils import ENDPOINT_STAGING from optimum.neuron.utils.cache_utils import ( - _ADDED_IN_REGISTRY, - _REGISTRY_FILE_EXISTS, - NeuronHash, delete_custom_cache_repo_name_from_hf_home, load_custom_cache_repo_name_from_hf_home, - path_after_folder, - push_to_cache_on_hub, set_custom_cache_repo_name_in_hf_home, - set_neuron_cache_path, ) from optimum.utils import logging from optimum.utils.testing_utils import TOKEN, USER @@ -220,14 +211,6 @@ def tearDown(self): self.remove_all_files_in_repo(self.CUSTOM_CACHE_REPO) self.remove_all_files_in_repo(self.CUSTOM_PRIVATE_CACHE_REPO) - keys = list(_REGISTRY_FILE_EXISTS.keys()) - for key in keys: - _REGISTRY_FILE_EXISTS.pop(key) - - keys = list(_ADDED_IN_REGISTRY.keys()) - for key in keys: - _ADDED_IN_REGISTRY.pop(key) - def create_tiny_pretrained_model(self, num_linears: int = 1, random_num_linears: bool = False): return create_tiny_pretrained_model( num_linears=num_linears, @@ -241,39 +224,3 @@ def create_and_run_tiny_pretrained_model(self, num_linears: int = 1, random_num_ random_input = torch.rand(1, device="xla") print(tiny_model(random_input)) return tiny_model - - def push_tiny_pretrained_model_cache_to_hub( - self, repo_id: str, cache_dir: Optional[Union[str, Path]] = None - ) -> NeuronHash: - neuron_hash = None - orig_repo_id = load_custom_cache_repo_name_from_hf_home() - set_custom_cache_repo_name_in_hf_home(repo_id) - with TemporaryDirectory() as tmpdirname: - set_neuron_cache_path(tmpdirname) - - input_shapes = (("x", (1,)),) - data_type = torch.float32 - tiny_model = self.create_and_run_tiny_pretrained_model(random_num_linears=True) - neuron_hash = NeuronHash(tiny_model, input_shapes, data_type) - - tmp_cache_dir = Path(tmpdirname) / neuron_hash.neuron_compiler_version_dir_name - push_to_cache_on_hub( - neuron_hash, - tmp_cache_dir, - fail_when_could_not_push=True, - ) - if cache_dir is not None: - for file_or_dir in tmp_cache_dir.iterdir(): - if file_or_dir.is_file(): - shutil.copy( - file_or_dir, - cache_dir / path_after_folder(file_or_dir, neuron_hash.neuron_compiler_version_dir_name), - ) - else: - shutil.copytree( - file_or_dir, - cache_dir / path_after_folder(file_or_dir, neuron_hash.neuron_compiler_version_dir_name), - ) - if orig_repo_id is not None: - set_custom_cache_repo_name_in_hf_home(orig_repo_id) - return neuron_hash