Skip to content

Commit

Permalink
Integrate new cache system for training
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 9, 2024
1 parent ab582ce commit e35e453
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
is_torch_xla_available,
patch_within_function,
)
from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path
from .utils.cache_utils import get_hf_hub_cache_repos, get_neuron_cache_path, set_neuron_cache_path
from .utils.hub_neuronx_cache import hub_neuronx_cache, synchronize_hub_cache
from .utils.require_utils import requires_neuronx_distributed
from .utils.training_utils import (
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
Expand Down Expand Up @@ -125,25 +126,25 @@
import torch_xla.distributed.xla_backend as xbn

if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
_ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()

# _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set.
if _ORIGINAL_NEURON_CACHE_PATH is not None:
if is_precompilation():
# During precompilation, we make sure to set the cache path to the defined compile cache path by the
# user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler:
# /var/tmp/neuron-compile-cache
set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH)
else:
if os.environ["RANK"] == "0":
_TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True)
store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name)
_TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name)
else:
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False)
_TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8"))
set_neuron_cache_path(_TMP_NEURON_CACHE_PATH)
# _ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()

# # _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set.
# if _ORIGINAL_NEURON_CACHE_PATH is not None:
# if is_precompilation():
# # During precompilation, we make sure to set the cache path to the defined compile cache path by the
# # user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler:
# # /var/tmp/neuron-compile-cache
# set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH)
# else:
# if os.environ["RANK"] == "0":
# _TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
# store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True)
# store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name)
# _TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name)
# else:
# store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False)
# _TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8"))
# set_neuron_cache_path(_TMP_NEURON_CACHE_PATH)

torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
Expand Down Expand Up @@ -197,15 +198,15 @@ def __init__(self, *args, **kwargs):
push = self.args.local_rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
fetch = self.args.local_rank <= 0 or self.args.mp_plugin.should_parallelize

callback = NeuronCacheCallback(
tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
fetch=fetch,
push=push,
wait_for_everyone_on_fetch=True,
wait_for_everyone_on_push=True,
)
self.add_callback(callback)
# callback = NeuronCacheCallback(
# tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
# original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
# fetch=fetch,
# push=push,
# wait_for_everyone_on_fetch=True,
# wait_for_everyone_on_push=True,
# )
# self.add_callback(callback)

# Make the model Neuron-compatible for generation.
patch_generation_mixin_to_neuron_generation_mixin(self.model)
Expand Down Expand Up @@ -624,6 +625,18 @@ def _load_optimizer_and_scheduler(self, checkpoint):
else:
return super()._load_optimizer_and_scheduler(checkpoint)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
with hub_neuronx_cache():
result = super().train(resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, **kwargs)
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
return result

@requires_neuronx_distributed
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
Expand Down

0 comments on commit e35e453

Please sign in to comment.