Skip to content

Commit

Permalink
Cache utils related cleanup (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Apr 8, 2024
1 parent 09ddd67 commit 1f049e1
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 2,160 deletions.
4 changes: 2 additions & 2 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from ..utils.torch_xla_and_neuronx_initialization import (
init_process_group,
set_common_neuron_cc_flags,
set_common_flags,
set_neuron_cc_flags_for_torch_amp,
)
from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self, cpu: bool = False, **kwargs):
torch.cuda.set_device(self.device)
elif is_torch_xla_available() and not cpu:
# It is important to set the environment variables before initializing the process group otherwise they will be ignored by the Neuron compiler.
set_common_neuron_cc_flags()
set_common_flags()
if os.environ.get("ACCELERATE_USE_AMP", "false") == "true":
set_neuron_cc_flags_for_torch_amp()
init_process_group()
Expand Down
433 changes: 0 additions & 433 deletions optimum/neuron/trainer_callback.py

This file was deleted.

31 changes: 14 additions & 17 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
get_hf_hub_cache_repos,
get_model_name_or_path,
get_neuron_cache_path,
get_neuronxcc_version,
get_num_neuron_cores_used,
has_write_access_to_repo,
)
Expand All @@ -96,6 +95,7 @@
skip_first_batches,
torch_xla_safe_save_file,
)
from .utils.version_utils import get_neuronxcc_version


if is_apex_available():
Expand Down Expand Up @@ -1362,14 +1362,13 @@ def train(
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
with patch_neuron_cc_wrapper():
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
if not is_precompilation():
self.synchronize_hub_cache()
return result
Expand All @@ -1380,21 +1379,19 @@ def evaluate(
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
if not is_precompilation():
self.synchronize_hub_cache()
return result

def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
with hub_neuronx_cache("training", entry=self.model_cache_entry):
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
if not is_precompilation():
self.synchronize_hub_cache()
return result
Expand Down
Loading

0 comments on commit 1f049e1

Please sign in to comment.