From 3005c77071bd9a132c1509355bec34546daedede Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 3 Apr 2024 15:02:06 +0200 Subject: [PATCH] Mixed-precision training with both `torch_xla` or `torch.autocast` (#523) --- examples/language-modeling/run_clm.py | 6 +- optimum/neuron/accelerate/accelerator.py | 127 +++++++++++------- optimum/neuron/accelerate/state.py | 75 ++++++++--- optimum/neuron/accelerate/utils/__init__.py | 7 +- .../neuron/accelerate/utils/dataclasses.py | 9 ++ optimum/neuron/accelerate/utils/misc.py | 26 ++++ optimum/neuron/trainers.py | 67 +++++---- optimum/neuron/training_args.py | 55 +++----- optimum/neuron/utils/__init__.py | 2 - optimum/neuron/utils/hub_neuronx_cache.py | 4 +- .../torch_xla_and_neuronx_initialization.py | 94 +++++++++++++ optimum/neuron/utils/training_utils.py | 56 -------- .../distributed/test_model_parallelization.py | 23 ++-- tests/test_trainer_callback.py | 2 + 14 files changed, 335 insertions(+), 218 deletions(-) create mode 100644 optimum/neuron/utils/torch_xla_and_neuronx_initialization.py diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index a31b2456a..bedf48ec9 100755 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -466,9 +466,9 @@ def main(): # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. - embedding_size = model.get_input_embeddings().weight.shape[0] - if len(tokenizer) > embedding_size: - model.resize_token_embeddings(len(tokenizer)) + # embedding_size = model.get_input_embeddings().weight.shape[0] + # if len(tokenizer) > embedding_size: + # model.resize_token_embeddings(len(tokenizer)) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 2330aa5f3..a25a23a26 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -20,6 +20,8 @@ import os import re import shutil +import sys +import warnings from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union @@ -27,7 +29,7 @@ import torch from accelerate import Accelerator from accelerate.checkpointing import save_accelerator_state, save_custom_state -from accelerate.utils import DistributedType +from accelerate.utils import AutocastKwargs, DistributedType from accelerate.utils.operations import gather_object, recursively_apply from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -41,14 +43,15 @@ is_neuronx_distributed_available, is_torch_xla_available, patch_within_function, - patched_finfo, ) from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla +from ..utils.torch_xla_and_neuronx_initialization import check_neuron_cc_flags_for_model from .optimizer import NeuronAcceleratedOptimizer from .scheduler import NeuronAcceleratedScheduler from .state import NeuronAcceleratorState from .utils import ( + AutocastBackend, ModelParallelismPlugin, NeuronDistributedType, NeuronFullyShardedDataParallelPlugin, @@ -56,6 +59,7 @@ patch_accelerate_is_tpu_available, tie_parameters, ) +from .utils.misc import create_patched_finfo from .utils.operations import _xla_gather @@ -83,22 +87,20 @@ MODEL_PATCHING_SPECS = [ ("config.layerdrop", 0), ("no_sync", lambda: contextlib.nullcontext()), - ( - "forward", - DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))), - ), ] -NxDPPMODEL_PATCHING_SPECS = [ - ( - "forward", - DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))), - ), -] +NxDPPMODEL_PATCHING_SPECS = [] class NeuronAccelerator(Accelerator): - def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, zero_1: bool = False, **kwargs): + def __init__( + self, + *args, + mp_plugin: Optional[ModelParallelismPlugin] = None, + zero_1: bool = False, + autocast_backend: Union[str, AutocastBackend] = "xla", + **kwargs, + ): # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` patch_accelerate_is_tpu_available() @@ -132,34 +134,23 @@ def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, ze ) self.fsdp_plugin = fsdp_plugin - use_neuronx_distributed_tp = os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false") - use_neuronx_distributed_pp = os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_PP", "false") - if mp_plugin is None: - if use_neuronx_distributed_tp == "false": - tp_size = 1 - else: - tp_size = int(use_neuronx_distributed_tp) - if use_neuronx_distributed_pp == "false": - pp_size = 1 - else: - pp_size = int(use_neuronx_distributed_pp) - mp_plugin = ModelParallelismPlugin( - tensor_parallel_size=tp_size, parallelize_embeddings=True, pipeline_parallel_size=pp_size - ) self._model_cpu_parameters_to_xla = {} - if mp_plugin.tensor_parallel_size > 1: - os.environ["ACCELERATE_USE_NEURONX_DISTRIBUTED_TP"] = "true" + if not isinstance(autocast_backend, AutocastBackend): + autocast_backend = AutocastBackend(autocast_backend) - if mp_plugin.pipeline_parallel_size > 1: - os.environ["ACCELERATE_USE_NEURONX_DISTRIBUTED_PP"] = "true" - - patched_accelerator_state = partial(NeuronAcceleratorState, mp_plugin=mp_plugin) + patched_accelerator_state = partial( + NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend + ) with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]): super().__init__(**full_kwargs) self.zero_1 = zero_1 + if self.autocast_handler is None: + enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP + self.autocast_handler = AutocastKwargs(enabled=enabled) + if self.fsdp_plugin is not None and self.zero_1: raise ValueError("Either enable XLA ZeRO Stage 1 or XLA FSDP but not both.") @@ -244,6 +235,7 @@ def _prepare_optimizer_for_mp(self, optimizer: torch.optim.Optimizer, device_pla optimizer = Parallelizer.optimizer_for_mp(optimizer, cpu_parameters_to_xla) else: xla_parameters, _ = Parallelizer.optimizer_cpu_params_to_xla_params(optimizer, cpu_parameters_to_xla) + if hasattr(optimizer, "_args_to_recreate"): args, kwargs = optimizer._args_to_recreate args = (xla_parameters,) + args[1:] @@ -325,12 +317,30 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: def prepare_scheduler(self, scheduler: "LRScheduler"): return super().prepare_scheduler(scheduler) - @staticmethod def patch_model_for_neuron( - model: "torch.nn.Module", patching_specs: Optional[List[Tuple[str, Any]]] = None + self, + model: "torch.nn.Module", + patching_specs: Optional[List[Tuple[str, Any]]] = None, ) -> "torch.nn.Module": if patching_specs is None: patching_specs = MODEL_PATCHING_SPECS + + # Working on a copy for safety. + patching_specs = list(patching_specs) + + mixed_precision_is_bf16 = self.state.mixed_precision == "bf16" + patched_finfo = create_patched_finfo( + xla_downcast_bf16=mixed_precision_is_bf16 and self.state.downcast_bfloat, + use_amp=mixed_precision_is_bf16 and self.state.autocast_backend is AutocastBackend.AMP, + xla_use_bf16=mixed_precision_is_bf16 and not self.state.downcast_bfloat, + ) + patching_specs.append( + ( + "forward", + DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))), + ), + ) + prepared_patching_specs = [] for spec in patching_specs: prepared_patching_specs.append((model,) + spec) @@ -420,6 +430,7 @@ def _prepare_model_for_mp( return model cpu_ids = {name: id(param) for name, param in model.named_parameters()} + tied_parameters_dict = get_tied_parameters_dict(model) model_main_input_name = getattr(model, "main_input_name", None) model = self.state.mp_plugin.parallelize_model(model, device=self.device) @@ -431,21 +442,12 @@ def _prepare_model_for_mp( model.local_module = self.patch_model_for_neuron( model.local_module, patching_specs=NxDPPMODEL_PATCHING_SPECS ) - model_to_cast = model.local_module - else: - model_to_cast = model # Update CPU ids original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"] for key in list(cpu_ids.keys()): cpu_ids[original_parameter_names_to_gqa_qkv_names.get(key, key)] = cpu_ids.pop(key) - model_to_cast = model.local_module if isinstance(model, NxDPPModel) else model - if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1": - model_to_cast.to(torch.bfloat16) - else: - model_to_cast.to(torch.float32) - def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): """Tie or clone module weights depending of whether we are using TorchScript or not""" output_embeddings.weight = input_embeddings.weight @@ -453,17 +455,15 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): output_embeddings.out_features = input_embeddings.num_embeddings if isinstance(model, NxDPPModel): - with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_mp)]): - model.move_model_to_device() - tie_parameters(model, tied_parameters_dict) + model.move_model_to_device() + tie_parameters(model, tied_parameters_dict) xla_params = dict(model.local_named_parameters()) self._model_cpu_parameters_to_xla[id(model)] = { cpu_ids[name]: xla_params[name] for name, _ in model.local_named_parameters() } else: - with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_mp)]): - move_model_to_device(model, self.device) - tie_parameters(model, tied_parameters_dict) + move_model_to_device(model, self.device) + tie_parameters(model, tied_parameters_dict) xla_params = dict(model.named_parameters()) symmetric_diff = set(cpu_ids.keys()).symmetric_difference((xla_params.keys())) @@ -490,6 +490,10 @@ def prepare_model( if model in self._models: return model + # Since it is not possible to set the best compiler flags for a given model because XLA is initialized before + # we get access to the model, we simply check if the flags are the best and notify the user otherwise. + check_neuron_cc_flags_for_model(model) + model = self.patch_model_for_neuron(model) # We do not want to use the cache, or output unused tensors as it would imply more communication that we do not @@ -533,6 +537,29 @@ def clip_grad_norm_for_xla_fsdp(self, parameters, max_norm, norm_type: int = 2): if parameters == list(model.parameters()): return model.clip_grad_norm_(max_norm, norm_type) + @contextlib.contextmanager + def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None): + if cache_enabled: + warnings.warn( + "Passing `cache_enabled=True` to `accelerator.autocast` is deprecated and will be removed in v0.23.0. " + "Please use the `AutocastKwargs` class instead and pass it to the `Accelerator` as a `kwarg_handler`.", + FutureWarning, + ) + if self.autocast_handler is not None: + self.autocast_handler.cache_enabled = True + else: + self.autocast_handler = AutocastKwargs(cache_enabled=True) + if autocast_handler is None: + # By default `self.autocast_handler` enables autocast if: + # - `self.state.mixed_precision == "bf16"` + # - `self.state.autocast_backend is AutocastBackend.AMP` + autocast_handler = self.autocast_handler + autocast_kwargs = autocast_handler.to_kwargs() + autocast_context = torch.autocast(dtype=torch.bfloat16, device_type="cuda", **autocast_kwargs) + autocast_context.__enter__() + yield + autocast_context.__exit__(*sys.exc_info()) + @requires_neuronx_distributed def _prepare_clip_grad_norm(self, parameters, max_norm, norm_type: int = 2): from neuronx_distributed.pipeline import NxDPPModel diff --git a/optimum/neuron/accelerate/state.py b/optimum/neuron/accelerate/state.py index 6ba710445..6a2b98ec5 100644 --- a/optimum/neuron/accelerate/state.py +++ b/optimum/neuron/accelerate/state.py @@ -15,6 +15,7 @@ """Custom PartialState and AcceleratorState for Neuron.""" import os +from typing import Optional, Union import torch from accelerate.state import AcceleratorState, PartialState, ThreadLocalSharedDict @@ -35,8 +36,13 @@ from ...utils import logging from ..utils import is_neuronx_distributed_available, is_torch_xla_available +from ..utils.torch_xla_and_neuronx_initialization import ( + init_process_group, + set_common_neuron_cc_flags, + set_neuron_cc_flags_for_torch_amp, +) from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin -from .utils.dataclasses import ModelParallelismPlugin +from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin if is_torch_xla_available(): @@ -84,6 +90,11 @@ def __init__(self, cpu: bool = False, **kwargs): self.device = torch.device("cuda", self.local_process_index) torch.cuda.set_device(self.device) elif is_torch_xla_available() and not cpu: + # It is important to set the environment variables before initializing the process group otherwise they will be ignored by the Neuron compiler. + set_common_neuron_cc_flags() + if os.environ.get("ACCELERATE_USE_AMP", "false") == "true": + set_neuron_cc_flags_for_torch_amp() + init_process_group() self.distributed_type = DistributedType.TPU self.num_processes = xm.xrt_world_size() self.process_index = xm.get_ordinal() @@ -224,17 +235,26 @@ def __init__( deepspeed_plugin=None, fsdp_plugin=None, megatron_lm_plugin=None, - mp_plugin=None, + mp_plugin: Optional[ModelParallelismPlugin] = None, + autocast_backend: Optional[Union[str, AutocastBackend]] = None, _from_accelerator: bool = False, **kwargs, ): self.__dict__ = self._shared_state if parse_flag_from_env("ACCELERATE_USE_CPU"): cpu = True + + if autocast_backend is None: + autocast_backend = AutocastBackend.XLA + elif not isinstance(autocast_backend, AutocastBackend): + autocast_backend = AutocastBackend(autocast_backend) + if NeuronPartialState._shared_state == {}: + if autocast_backend is AutocastBackend.AMP: + os.environ["ACCELERATE_USE_AMP"] = "true" NeuronPartialState(cpu, **kwargs) self.__dict__.update(NeuronPartialState._shared_state) - self._check_initialized(mixed_precision, cpu) + self._check_initialized(mixed_precision, cpu, autocast_backend) if not self.initialized: self.deepspeed_plugin = None self.ipex_plugin = None @@ -253,9 +273,14 @@ def __init__( ) # deepspeed handles mixed_precision using deepspeed_config self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision + + self._autocast_backend = autocast_backend + if self.distributed_type == DistributedType.TPU: if mixed_precision == "bf16": - if os.environ.get("ACCELERATE_DOWNCAST_BF16"): + if autocast_backend is AutocastBackend.AMP: + self.downcast_bfloat = True + elif os.environ.get("ACCELERATE_DOWNCAST_BF16"): os.environ["XLA_USE_BF16"] = str(0) os.environ["XLA_DOWNCAST_BF16"] = str(1) self.downcast_bfloat = True @@ -263,24 +288,15 @@ def __init__( os.environ["XLA_USE_BF16"] = str(1) os.environ["XLA_DOWNCAST_BF16"] = str(0) self.downcast_bfloat = False - if ( - os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false") == "true" - or os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_PP", "false") == "true" - ): - if mp_plugin is None: - raise ValueError( - "Could not initialize model parallelism because no `ModelParallelismPlugin` was provided." - ) - if mp_plugin.should_parallelize: - self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM - else: - logger.warning( - "Model parallelism is requested but nothing is done because the tensor parallel size and " - "the pipeline parallel size are set to 1." - ) - self.mp_plugin = mp_plugin - else: - self.mp_plugin = ModelParallelismPlugin() + + if mp_plugin is None: + mp_plugin = ModelParallelismPlugin() + + if mp_plugin.should_parallelize: + self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM + + self.mp_plugin = mp_plugin + print("MP PLUGIN", self.mp_plugin) if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized(): parallel_state.initialize_model_parallel( @@ -323,3 +339,18 @@ def __init__( ): torch.backends.cuda.matmul.allow_tf32 = True PartialState._shared_state["distributed_type"] = self.distributed_type + + def _check_initialized(self, mixed_precision=None, cpu=None, autocast_backend=None): + "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" + super()._check_initialized(mixed_precision=mixed_precision, cpu=cpu) + err = ( + "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and " + "pass `{flag}` to `Accelerator()`." + ) + if self.initialized: + if autocast_backend is not None and autocast_backend != self.autocast_backend: + raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'")) + + @property + def autocast_backend(self): + return self._autocast_backend diff --git a/optimum/neuron/accelerate/utils/__init__.py b/optimum/neuron/accelerate/utils/__init__.py index 211d33cf0..49cea8cf6 100644 --- a/optimum/neuron/accelerate/utils/__init__.py +++ b/optimum/neuron/accelerate/utils/__init__.py @@ -13,5 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .dataclasses import ModelParallelismPlugin, NeuronDistributedType, NeuronFullyShardedDataParallelPlugin +from .dataclasses import ( + AutocastBackend, + ModelParallelismPlugin, + NeuronDistributedType, + NeuronFullyShardedDataParallelPlugin, +) from .misc import get_tied_parameters_dict, patch_accelerate_is_tpu_available, tie_parameters diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index 1461d6c9f..325b7a088 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -49,6 +49,15 @@ class NeuronDistributedType(str, enum.Enum): MODEL_PARALLELISM = "MODEL_PARALLELISM" +class AutocastBackend(str, enum.Enum): + """ + Represents the backend to use for mixed-precision training. + """ + + XLA = "xla" + AMP = "amp" + + @dataclass class NeuronFullyShardedDataParallelPlugin(FullyShardedDataParallelPlugin): # TODO: redefine the post init to do checks on which option is supported. diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 773649474..3ae153475 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, Union import torch +from transformers.modeling_utils import get_parameter_dtype from ...distributed.utils import named_parameters from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere @@ -39,6 +40,31 @@ def patch_accelerate_is_tpu_available(): patch_everywhere("is_tpu_available", is_tpu_available, module_name_prefix="accelerate") +_ORIG_TORCH_FINFO = torch.finfo + + +def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False): + + def patched_finfo(dtype): + if xla_downcast_bf16 or use_amp or xla_use_bf16: + return _ORIG_TORCH_FINFO(torch.bfloat16) + return _ORIG_TORCH_FINFO(dtype) + + return patched_finfo + + +def create_patched_get_parameter_dtype( + xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False +): + def patched_get_parameter_dtype(module): + dtype = get_parameter_dtype(module) + if xla_downcast_bf16 or use_amp or xla_use_bf16: + return torch.bfloat16 + return dtype + + return patched_get_parameter_dtype + + @requires_neuronx_distributed def get_tied_parameters_dict(model: Union["torch.nn.Module", "NxDPPModel"]) -> Dict[str, str]: from neuronx_distributed.pipeline import NxDPPModel diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 5faac8c80..73e05065b 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -28,6 +28,7 @@ import numpy as np import torch from accelerate import __version__ as accelerate_version +from accelerate.utils import AutocastKwargs from packaging import version from torch.utils.data import Dataset from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments @@ -92,9 +93,6 @@ is_precompilation, is_topology_supported, patch_generation_mixin_to_neuron_generation_mixin, - prepare_environment_for_neuron, - set_neuron_cc_flags_for_model, - set_neuron_cc_optlevel_for_model, skip_first_batches, torch_xla_safe_save_file, ) @@ -131,15 +129,6 @@ if KEEP_HF_HUB_PROGRESS_BARS is None: os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - -if os.environ.get("TORCHELASTIC_RUN_ID"): - import torch_xla.distributed.xla_backend as xbn - - if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - torch.distributed.init_process_group(backend="xla") - if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") - transformers_get_optimizer_cls_and_kwargs = Trainer.get_optimizer_cls_and_kwargs @@ -158,10 +147,11 @@ def __init__(self, *args, **kwargs): if training_args is None and len(args) >= 2: training_args = args[1] + self.use_amp = False if training_args is not None: if training_args.bf16: - training_args.bf16 = False - os.environ["XLA_USE_BF16"] = "1" + if training_args.half_precision_backend == "amp": + self.use_amp = True self.validate_args(training_args) if is_precompilation(): @@ -172,7 +162,6 @@ def __init__(self, *args, **kwargs): transformers.trainer.Accelerator = NeuronAccelerator - prepare_environment_for_neuron() super().__init__(*args, **kwargs) # We need to change which process can be seen as "world process zero" to make sure the proper metrics @@ -195,14 +184,10 @@ def __init__(self, *args, **kwargs): # Make the model Neuron-compatible for generation. patch_generation_mixin_to_neuron_generation_mixin(self.model) - set_neuron_cc_optlevel_for_model(self.model, optlevel=self.args.neuron_cc_optlevel) - set_neuron_cc_flags_for_model(self.model) - # Model cache entry management. model_name_or_path_for_cache_entry = get_model_name_or_path(self.model.config) model_config_for_cache_entry = copy.deepcopy(self.model.config) - use_bf16 = os.environ.get("XLA_USE_BF16", False) or os.environ.get("XLA_DOWNCAST_BF16", False) - precision = "bfloat16" if use_bf16 else "float32" + precision = "bfloat16" if self.accelerator.state.mixed_precision == "bf16" else "float32" neuron_config_for_cache_entry = { "model_class": self.model.__class__.__name__, "precision": precision, @@ -242,14 +227,17 @@ def mp_enabled(self): return self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM def prepare_args_for_precompilation(self, args: "TrainingArguments"): - if is_main_worker() and args.num_train_epochs != 1: - logger.info("Setting the number of epochs for precompilation to 1.") + if args.num_train_epochs != 1: + if is_main_worker(): + logger.info("Setting the number of epochs for precompilation to 1.") args.num_train_epochs = 1 - if is_main_worker() and args.do_eval is True: - logger.info("Disabling evaluation during precompilation as this is not well supported yet.") + if args.do_eval: + if is_main_worker(): + logger.info("Disabling evaluation during precompilation as this is not well supported yet.") args.do_eval = False - if is_main_worker() and args.do_predict is True: - logger.info("Disabling prediction during precompilation as this is not well supported yet.") + if args.do_predict: + if is_main_worker(): + logger.info("Disabling prediction during precompilation as this is not well supported yet.") args.do_predict = False def validate_args(self, args: "TrainingArguments"): @@ -262,6 +250,8 @@ def create_accelerator_and_postprocess(self): gradient_accumulation_steps=self.args.gradient_accumulation_steps, mp_plugin=self.args.mp_plugin, zero_1=self.args.zero_1, + mixed_precision="bf16" if self.args.bf16 else "no", + autocast_backend=self.args.half_precision_backend, ) # deepspeed and accelerate flags covering both trainer args and accelerate launcher @@ -356,6 +346,18 @@ def compute_loss(self, model, inputs, return_outputs: bool = False): loss = super().compute_loss(model, inputs, return_outputs=return_outputs) return loss + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + + autocast_handler = AutocastKwargs( + enabled=self.accelerator.autocast_handler.enabled, + cache_enabled=cache_enabled, + ) + return self.accelerator.autocast(autocast_handler=autocast_handler) + def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: from neuronx_distributed.pipeline import NxDPPModel @@ -369,7 +371,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te loss = self.compute_loss(model, inputs) if get_pipeline_model_parallel_rank() != get_pipeline_model_parallel_size() - 1: - use_bf16 = os.environ.get("XLA_USE_BF16", False) or os.environ.get("XLA_DOWNCAST_BF16", False) + use_bf16 = self.accelerator.state.mixed_precision == "bf16" dtype = torch.bfloat16 if use_bf16 else torch.float32 loss = torch.tensor(0, dtype=dtype).to(xm.xla_device()) else: @@ -394,7 +396,7 @@ def prediction_step( raise ValueError("Only the prediction loss can be returned when doing pipeline parallelism.") loss = model.run_eval(**inputs) if loss is None: - use_bf16 = os.environ.get("XLA_USE_BF16", False) or os.environ.get("XLA_DOWNCAST_BF16", False) + use_bf16 = self.accelerator.state.mixed_precision == "bf16" dtype = torch.bfloat16 if use_bf16 else torch.float32 loss = torch.tensor(0, dtype=dtype).to(xm.xla_device()) return (loss, None, None) @@ -466,7 +468,6 @@ def _save_xla(self, output_dir: Optional[str] = None): if is_main_worker(): logger.info(f"Saving model checkpoint to {output_dir}") - if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -1017,13 +1018,7 @@ def _inner_training_loop( if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() - # It should be equivalent but prefer to use the `zero_grad` method from the optimizer when doing - # pipeline parallelism. - if isinstance(model, NxDPPModel): - self.optimizer.zero_grad() - else: - model.zero_grad() - + self.optimizer.zero_grad() xm.mark_step() self.state.global_step += 1 diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index c56f2fe1d..051b8289c 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -14,8 +14,6 @@ # limitations under the License. """Defines a TrainingArguments class compatible with Neuron.""" -import io -import json import os import warnings from dataclasses import dataclass, field @@ -35,12 +33,12 @@ requires_backends, ) -from ..utils import check_if_transformers_greater, logging +from ..utils import logging from .accelerate import NeuronAcceleratorState, NeuronPartialState from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_tpu_available from .utils import is_accelerate_available, is_main_worker, is_torch_xla_available from .utils.patching import Patcher -from .utils.training_utils import TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP +from .utils.torch_xla_and_neuronx_initialization import set_neuron_cc_optlevel if is_sagemaker_mp_enabled(): @@ -57,6 +55,13 @@ class NeuronTrainingArgumentsMixin: skip_cache_push: bool = field( default=False, metadata={"help": "Whether to skip pushing Neuron artifacts to hub cache"} ) + half_precision_backend: str = field( + default="xla", + metadata={ + "help": "The backend to be used for half precision.", + "choices": ["xla", "amp"], + }, + ) zero_1: bool = field(default=False, metadata={"help": "Whether to use ZeRO Stage 1 Optimization."}) tensor_parallel_size: int = field( default=1, metadata={"help": "The number of replicas the model will be sharded on."} @@ -75,10 +80,10 @@ class NeuronTrainingArgumentsMixin: default=False, metadata={"help": "Whether or not to disable sequence parallelism."}, ) - neuron_cc_optlevel: str = field( - default="auto", + neuron_cc_optlevel: int = field( + default=2, metadata={ - "choices": ["auto", "1", "2", "3"], + "choices": [1, 2, 3], "help": "Specify the level of optimization the Neuron compiler should perform.", }, ) @@ -117,31 +122,9 @@ def __post_init__(self): if self.fsdp != "": # Disabling FSDP until next release because it is still very experimental and not validated. raise RuntimeError("FSDP is not supported yet.") - if self.fsdp_config is None: - self.fsdp_config = {"xla": True} - elif isinstance(self.fsdp_config, str): - with io.open(self.fsdp_config, "r", encoding="utf-8") as f: - self.fsdp_config = json.load(f) - - if "xla" in self.fsdp_config and not self.fsdp_config["xla"]: - raise ValueError( - "XLA FSDP is the only supported FSDP implementation by `optimum-neuron` but the provided FSDP " - "config specified it should not be used." - ) - else: - self.fsdp_config["xla"] = True - os.environ["ACCELERATE_USE_FSDP"] = "true" - - if not check_if_transformers_greater(TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP): - import transformers - - raise RuntimeError( - "The minimal required Transformers version to perform XLA FSDP is " - f"{TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP} but {transformers.__version__} is installed." - ) - if self.neuron_cc_optlevel != "auto": - self.neuron_cc_optlevel = f"-O{self.neuron_cc_optlevel}" + if self.fp16: + raise ValueError("The fp16 data type is not supported in Neuron, please use bf16 instead.") resume_from_checkpoint = self.resume_from_checkpoint if resume_from_checkpoint is None and os.path.isdir(self.output_dir): @@ -184,15 +167,19 @@ def __post_init__(self): num_ranks_per_loading_step=self.num_ranks_per_loading_step, ) + if self.bf16 and self.half_precision_backend == "amp": + os.environ["ACCELERATE_USE_AMP"] = "true" + else: + os.environ["ACCELERATE_USE_AMP"] = "false" + + set_neuron_cc_optlevel(self.neuron_cc_optlevel) + # This is required to be able to use bf16, otherwise a check in super().__post_init__() fails. with Patcher([("transformers.training_args.get_xla_device_type", lambda _: "GPU")]): super().__post_init__() - # Needed only to specialize the warning message for FSDP. @cached_property def _setup_devices(self) -> "torch.device": - if not check_if_transformers_greater("4.30.0"): - return super()._setup_devices requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 17ac6890c..c3ac7920e 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -49,6 +49,4 @@ is_model_officially_supported, is_precompilation, patch_transformers_for_neuron_sdk, - patched_finfo, - prepare_environment_for_neuron, ) diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index 05b4a0963..3a000c14e 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -28,7 +28,7 @@ from transformers import AutoConfig, PretrainedConfig from ..version import __version__ -from .cache_utils import load_custom_cache_repo_name_from_hf_home +from .cache_utils import get_neuron_cache_path, load_custom_cache_repo_name_from_hf_home from .import_utils import is_neuronx_available from .patching import patch_everywhere from .require_utils import requires_torch_neuronx, requires_torch_xla @@ -334,6 +334,8 @@ def hf_create_compile_cache(cache_url): return create_compile_cache(cache_url) try: + if mode == "training" and cache_dir is None: + cache_dir = get_neuron_cache_path() if isinstance(cache_dir, Path): cache_dir = cache_dir.as_posix() default_cache = create_compile_cache(CacheUrl.get_cache_url(cache_dir=cache_dir)) diff --git a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py new file mode 100644 index 000000000..ea0a34660 --- /dev/null +++ b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to initialization of `torch_xla` and `torch_neuronx`""" + +import os +import re +from typing import TYPE_CHECKING + +import torch + +from ...utils import logging +from .misc import is_main_worker +from .require_utils import requires_torch_xla + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + +logger = logging.get_logger() + + +@requires_torch_xla +def init_process_group(): + if os.environ.get("TORCHELASTIC_RUN_ID"): + import torch_xla.distributed.xla_backend as xbn + + if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): + torch.distributed.init_process_group(backend="xla") + if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): + raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") + + +def set_common_neuron_cc_flags(): + """ + Sets environment variables for transformer-based models training with AWS Neuron. + """ + # Set compiler flag to compile for transformer model type + os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer" + # Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during + # checkpointing. More information here: + # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc + os.environ["MALLOC_ARENA_MAX"] = "64" + + +def set_neuron_cc_flags_for_torch_amp(): + """ + Sets the proper compiler flags needed when using PyTorch Autocast. + """ + torch.cuda.is_bf16_supported = lambda: True + neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") + match_ = re.search(r"--auto-cast\s?\=?\s?\w+", neuron_cc_flags) + if match_ is not None: + neuron_cc_flags = neuron_cc_flags[: match_.start(0)] + neuron_cc_flags[match_.end(0) :] + os.environ["NEURON_CC_FLAGS"] = f"{neuron_cc_flags} --auto-cast=none" + + +def set_neuron_cc_optlevel(optlevel: int = 2): + """ + Sets the Neuron compiler optimization level. + """ + assert 1 <= optlevel <= 3 + neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") + match_ = re.search(r"-O[123]", neuron_cc_flags) + if match_: + neuron_cc_flags = neuron_cc_flags[: match_.start(0)] + f"-O{optlevel}" + neuron_cc_flags[match_.end(0) + 1 :] + else: + neuron_cc_flags += f" -O{optlevel}" + os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags + + +def check_neuron_cc_flags_for_model(model: "PreTrainedModel"): + """ + Sets flags for the Neuron compiler depending on the model. + """ + neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") + if "ForCausalLM" or "ForConditionalGeneration" in model.__class__.__name__: + distribution_strategy = "--distribution-strategy=llm-training" + if is_main_worker() and distribution_strategy not in neuron_cc_flags: + logger.warning( + f"No distribution strategy was set. For {model.__class__.__name__} it is possible to set the following " + 'optimization: NEURON_CC_FLAGS=" --distribution-strategy=llm-training".' + ) diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 81cbe15e4..64158534a 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -15,7 +15,6 @@ """Training utilities""" import os -import re from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch @@ -127,12 +126,6 @@ def _generate_supported_model_class_names( _SUPPORTED_MODEL_NAMES.update(_generate_supported_model_class_names(*model_type)) -_MODEL_TYPE_TO_OPTLEVEL: Dict[str, str] = { - "default": "-O2", - "llama": "-O1", -} - - def is_precompilation() -> bool: return os.environ.get("NEURON_PARALLEL_COMPILE") == "1" @@ -231,15 +224,6 @@ def __len__(self): return len(self.samples) -orig_finfo = torch.finfo - - -def patched_finfo(dtype): - if dtype is torch.float32: - return orig_finfo(torch.bfloat16) - return orig_finfo(dtype) - - def patch_generation_mixin_to_neuron_generation_mixin(model: "PreTrainedModel"): """ Changes the vanilla `GenerationMixin` class from Transformers to `NeuronGenerationMixin` in the model's @@ -292,46 +276,6 @@ def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrained cls.__bases__ = tuple(new_bases) -def prepare_environment_for_neuron(): - """ - Prepares the system environment for Transformers models training on AWS Neuron. - """ - # Set compiler flag to compile for transformer model type - os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer" - # Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during - # checkpointing. More information here: - # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc - os.environ["MALLOC_ARENA_MAX"] = "64" - - -def set_neuron_cc_optlevel_for_model(model: "PreTrainedModel", optlevel: str = "auto"): - """ - Sets the Neuron compiler optimization level considering both `model` and `optlevel`. - If `optlevel` is different than `"auto"`, it will be set to that value, otherwise the default value for a given - model is used. - """ - if optlevel == "auto": - optlevel = _MODEL_TYPE_TO_OPTLEVEL.get(model.config.model_type, _MODEL_TYPE_TO_OPTLEVEL["default"]) - neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") - match_ = re.search(r"-O[123]", neuron_cc_flags) - if match_: - neuron_cc_flags = neuron_cc_flags[: match_.start(0)] + f"{optlevel}" + neuron_cc_flags[match_.end(0) + 1 :] - else: - neuron_cc_flags += f" {optlevel} " - os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags - - -def set_neuron_cc_flags_for_model(model: "PreTrainedModel"): - """ - Sets flags for the Neuron compiler depending on the model. - """ - neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") - if "ForCausalLM" or "ForConditionalGeneration" in model.__class__.__name__: - distribution_strategy = "--distribution-strategy=llm-training" - if distribution_strategy not in neuron_cc_flags: - os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags + f" {distribution_strategy}" - - def set_verbosity(verbosity: int): set_verbosity_transformers(verbosity) set_verbosity_optimum(verbosity) diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index 8eaceda2c..849ab1b6d 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -43,7 +43,6 @@ ) import optimum -from optimum.neuron.accelerate.accelerator import NeuronAccelerator from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager from optimum.neuron.distributed.utils import compute_query_indices_for_rank from optimum.neuron.utils.cache_utils import ( @@ -55,7 +54,6 @@ is_torch_xla_available, ) from optimum.neuron.utils.testing_utils import is_trainium_test -from optimum.neuron.utils.training_utils import set_neuron_cc_optlevel_for_model from .distributed import DistributedTest from .utils import SEED, create_accelerator_for_mp, get_model, get_model_inputs @@ -297,13 +295,20 @@ def _parallel_model_matches_original_model( config_overwrite=config_overwrite, use_static_seed_patcher=True, ) - orig_model = NeuronAccelerator.patch_model_for_neuron(orig_model) + + accelerator = create_accelerator_for_mp( + tp_size, + pp_size, + parallelize_embeddings=parallelize_embeddings, + sequence_parallel_enabled=sequence_parallel_enabled, + ) + + # It is ok to use this accelerator because `patch_model_for_neuron` does not depend on the TP or PP size. + orig_model = accelerator.patch_model_for_neuron(orig_model) # TODO: enable that again once it's working, seems to be an AWS issue. orig_model.config.use_cache = False - set_neuron_cc_optlevel_for_model(orig_model) - move_model_to_device(orig_model, xm.xla_device()) orig_model = orig_model.eval() @@ -344,13 +349,6 @@ def _parallel_model_matches_original_model( use_static_seed_patcher=True, ) - accelerator = create_accelerator_for_mp( - tp_size, - pp_size, - parallelize_embeddings=parallelize_embeddings, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - from .utils import create_static_seed_patcher static_seed_patcher = create_static_seed_patcher(model.__class__, SEED) @@ -359,7 +357,6 @@ def _parallel_model_matches_original_model( xm.mark_step() - model = accelerator.patch_model_for_neuron(model) with torch.no_grad(): if pp_size == 1: # This is set to False by `accelerator.prepare`, which we want in the general case, but here let's diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index 082d9aa90..1bd9996dd 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -18,6 +18,7 @@ from tempfile import TemporaryDirectory from unittest import TestCase +import pytest import torch from huggingface_hub import HfApi from transformers.testing_utils import is_staging_test @@ -37,6 +38,7 @@ @is_trainium_test @is_staging_test +@pytest.mark.skip("Not used anymore, will be removed in cleaning PR.") class NeuronCacheCallbackTestCase(StagingTestMixin, TestCase): def test_neuron_hash_for_model(self): with TemporaryDirectory() as tmpdirname: