diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index c066ae797..d4d4779e3 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -74,7 +74,8 @@ is_torch_xla_available, patch_within_function, ) -from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path +from .utils.cache_utils import get_hf_hub_cache_repos, get_neuron_cache_path, set_neuron_cache_path +from .utils.hub_neuronx_cache import hub_neuronx_cache, synchronize_hub_cache from .utils.require_utils import requires_neuronx_distributed from .utils.training_utils import ( TRANSFORMERS_MIN_VERSION_USE_ACCELERATE, @@ -125,25 +126,25 @@ import torch_xla.distributed.xla_backend as xbn if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - _ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path() - - # _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set. - if _ORIGINAL_NEURON_CACHE_PATH is not None: - if is_precompilation(): - # During precompilation, we make sure to set the cache path to the defined compile cache path by the - # user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler: - # /var/tmp/neuron-compile-cache - set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH) - else: - if os.environ["RANK"] == "0": - _TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path()) - store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True) - store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name) - _TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name) - else: - store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False) - _TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8")) - set_neuron_cache_path(_TMP_NEURON_CACHE_PATH) + # _ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path() + + # # _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set. + # if _ORIGINAL_NEURON_CACHE_PATH is not None: + # if is_precompilation(): + # # During precompilation, we make sure to set the cache path to the defined compile cache path by the + # # user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler: + # # /var/tmp/neuron-compile-cache + # set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH) + # else: + # if os.environ["RANK"] == "0": + # _TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path()) + # store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True) + # store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name) + # _TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name) + # else: + # store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False) + # _TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8")) + # set_neuron_cache_path(_TMP_NEURON_CACHE_PATH) torch.distributed.init_process_group(backend="xla") if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): @@ -197,15 +198,15 @@ def __init__(self, *args, **kwargs): push = self.args.local_rank <= 0 and not is_precompilation() and not self.args.skip_cache_push fetch = self.args.local_rank <= 0 or self.args.mp_plugin.should_parallelize - callback = NeuronCacheCallback( - tmp_neuron_cache=_TMP_NEURON_CACHE_PATH, - original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH, - fetch=fetch, - push=push, - wait_for_everyone_on_fetch=True, - wait_for_everyone_on_push=True, - ) - self.add_callback(callback) + # callback = NeuronCacheCallback( + # tmp_neuron_cache=_TMP_NEURON_CACHE_PATH, + # original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH, + # fetch=fetch, + # push=push, + # wait_for_everyone_on_fetch=True, + # wait_for_everyone_on_push=True, + # ) + # self.add_callback(callback) # Make the model Neuron-compatible for generation. patch_generation_mixin_to_neuron_generation_mixin(self.model) @@ -624,6 +625,18 @@ def _load_optimizer_and_scheduler(self, checkpoint): else: return super()._load_optimizer_and_scheduler(checkpoint) + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + with hub_neuronx_cache(): + result = super().train(resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, **kwargs) + synchronize_hub_cache(get_hf_hub_cache_repos()[0]) + return result + @requires_neuronx_distributed def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None