Skip to content

Commit

Permalink
Mixed-precision training with both torch_xla or torch.autocast (#523
Browse files Browse the repository at this point in the history
)
  • Loading branch information
michaelbenayoun authored Apr 3, 2024
1 parent 12b06a3 commit 3005c77
Show file tree
Hide file tree
Showing 14 changed files with 335 additions and 218 deletions.
6 changes: 3 additions & 3 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ def main():

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# embedding_size = model.get_input_embeddings().weight.shape[0]
# if len(tokenizer) > embedding_size:
# model.resize_token_embeddings(len(tokenizer))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
127 changes: 77 additions & 50 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
import os
import re
import shutil
import sys
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union

import torch
from accelerate import Accelerator
from accelerate.checkpointing import save_accelerator_state, save_custom_state
from accelerate.utils import DistributedType
from accelerate.utils import AutocastKwargs, DistributedType
from accelerate.utils.operations import gather_object, recursively_apply
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
Expand All @@ -41,21 +43,23 @@
is_neuronx_distributed_available,
is_torch_xla_available,
patch_within_function,
patched_finfo,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from ..utils.torch_xla_and_neuronx_initialization import check_neuron_cc_flags_for_model
from .optimizer import NeuronAcceleratedOptimizer
from .scheduler import NeuronAcceleratedScheduler
from .state import NeuronAcceleratorState
from .utils import (
AutocastBackend,
ModelParallelismPlugin,
NeuronDistributedType,
NeuronFullyShardedDataParallelPlugin,
get_tied_parameters_dict,
patch_accelerate_is_tpu_available,
tie_parameters,
)
from .utils.misc import create_patched_finfo
from .utils.operations import _xla_gather


Expand Down Expand Up @@ -83,22 +87,20 @@
MODEL_PATCHING_SPECS = [
("config.layerdrop", 0),
("no_sync", lambda: contextlib.nullcontext()),
(
"forward",
DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))),
),
]

NxDPPMODEL_PATCHING_SPECS = [
(
"forward",
DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))),
),
]
NxDPPMODEL_PATCHING_SPECS = []


class NeuronAccelerator(Accelerator):
def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, zero_1: bool = False, **kwargs):
def __init__(
self,
*args,
mp_plugin: Optional[ModelParallelismPlugin] = None,
zero_1: bool = False,
autocast_backend: Union[str, AutocastBackend] = "xla",
**kwargs,
):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_tpu_available()

Expand Down Expand Up @@ -132,34 +134,23 @@ def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, ze
)
self.fsdp_plugin = fsdp_plugin

use_neuronx_distributed_tp = os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false")
use_neuronx_distributed_pp = os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_PP", "false")
if mp_plugin is None:
if use_neuronx_distributed_tp == "false":
tp_size = 1
else:
tp_size = int(use_neuronx_distributed_tp)
if use_neuronx_distributed_pp == "false":
pp_size = 1
else:
pp_size = int(use_neuronx_distributed_pp)
mp_plugin = ModelParallelismPlugin(
tensor_parallel_size=tp_size, parallelize_embeddings=True, pipeline_parallel_size=pp_size
)
self._model_cpu_parameters_to_xla = {}

if mp_plugin.tensor_parallel_size > 1:
os.environ["ACCELERATE_USE_NEURONX_DISTRIBUTED_TP"] = "true"
if not isinstance(autocast_backend, AutocastBackend):
autocast_backend = AutocastBackend(autocast_backend)

if mp_plugin.pipeline_parallel_size > 1:
os.environ["ACCELERATE_USE_NEURONX_DISTRIBUTED_PP"] = "true"

patched_accelerator_state = partial(NeuronAcceleratorState, mp_plugin=mp_plugin)
patched_accelerator_state = partial(
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
)
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
super().__init__(**full_kwargs)

self.zero_1 = zero_1

if self.autocast_handler is None:
enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP
self.autocast_handler = AutocastKwargs(enabled=enabled)

if self.fsdp_plugin is not None and self.zero_1:
raise ValueError("Either enable XLA ZeRO Stage 1 or XLA FSDP but not both.")

Expand Down Expand Up @@ -244,6 +235,7 @@ def _prepare_optimizer_for_mp(self, optimizer: torch.optim.Optimizer, device_pla
optimizer = Parallelizer.optimizer_for_mp(optimizer, cpu_parameters_to_xla)
else:
xla_parameters, _ = Parallelizer.optimizer_cpu_params_to_xla_params(optimizer, cpu_parameters_to_xla)

if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
args = (xla_parameters,) + args[1:]
Expand Down Expand Up @@ -325,12 +317,30 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement:
def prepare_scheduler(self, scheduler: "LRScheduler"):
return super().prepare_scheduler(scheduler)

@staticmethod
def patch_model_for_neuron(
model: "torch.nn.Module", patching_specs: Optional[List[Tuple[str, Any]]] = None
self,
model: "torch.nn.Module",
patching_specs: Optional[List[Tuple[str, Any]]] = None,
) -> "torch.nn.Module":
if patching_specs is None:
patching_specs = MODEL_PATCHING_SPECS

# Working on a copy for safety.
patching_specs = list(patching_specs)

mixed_precision_is_bf16 = self.state.mixed_precision == "bf16"
patched_finfo = create_patched_finfo(
xla_downcast_bf16=mixed_precision_is_bf16 and self.state.downcast_bfloat,
use_amp=mixed_precision_is_bf16 and self.state.autocast_backend is AutocastBackend.AMP,
xla_use_bf16=mixed_precision_is_bf16 and not self.state.downcast_bfloat,
)
patching_specs.append(
(
"forward",
DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))),
),
)

prepared_patching_specs = []
for spec in patching_specs:
prepared_patching_specs.append((model,) + spec)
Expand Down Expand Up @@ -420,6 +430,7 @@ def _prepare_model_for_mp(
return model

cpu_ids = {name: id(param) for name, param in model.named_parameters()}

tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)
Expand All @@ -431,39 +442,28 @@ def _prepare_model_for_mp(
model.local_module = self.patch_model_for_neuron(
model.local_module, patching_specs=NxDPPMODEL_PATCHING_SPECS
)
model_to_cast = model.local_module
else:
model_to_cast = model

# Update CPU ids
original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"]
for key in list(cpu_ids.keys()):
cpu_ids[original_parameter_names_to_gqa_qkv_names.get(key, key)] = cpu_ids.pop(key)

model_to_cast = model.local_module if isinstance(model, NxDPPModel) else model
if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model_to_cast.to(torch.bfloat16)
else:
model_to_cast.to(torch.float32)

def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings

if isinstance(model, NxDPPModel):
with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_mp)]):
model.move_model_to_device()
tie_parameters(model, tied_parameters_dict)
model.move_model_to_device()
tie_parameters(model, tied_parameters_dict)
xla_params = dict(model.local_named_parameters())
self._model_cpu_parameters_to_xla[id(model)] = {
cpu_ids[name]: xla_params[name] for name, _ in model.local_named_parameters()
}
else:
with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_mp)]):
move_model_to_device(model, self.device)
tie_parameters(model, tied_parameters_dict)
move_model_to_device(model, self.device)
tie_parameters(model, tied_parameters_dict)
xla_params = dict(model.named_parameters())

symmetric_diff = set(cpu_ids.keys()).symmetric_difference((xla_params.keys()))
Expand All @@ -490,6 +490,10 @@ def prepare_model(
if model in self._models:
return model

# Since it is not possible to set the best compiler flags for a given model because XLA is initialized before
# we get access to the model, we simply check if the flags are the best and notify the user otherwise.
check_neuron_cc_flags_for_model(model)

model = self.patch_model_for_neuron(model)

# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
Expand Down Expand Up @@ -533,6 +537,29 @@ def clip_grad_norm_for_xla_fsdp(self, parameters, max_norm, norm_type: int = 2):
if parameters == list(model.parameters()):
return model.clip_grad_norm_(max_norm, norm_type)

@contextlib.contextmanager
def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None):
if cache_enabled:
warnings.warn(
"Passing `cache_enabled=True` to `accelerator.autocast` is deprecated and will be removed in v0.23.0. "
"Please use the `AutocastKwargs` class instead and pass it to the `Accelerator` as a `kwarg_handler`.",
FutureWarning,
)
if self.autocast_handler is not None:
self.autocast_handler.cache_enabled = True
else:
self.autocast_handler = AutocastKwargs(cache_enabled=True)
if autocast_handler is None:
# By default `self.autocast_handler` enables autocast if:
# - `self.state.mixed_precision == "bf16"`
# - `self.state.autocast_backend is AutocastBackend.AMP`
autocast_handler = self.autocast_handler
autocast_kwargs = autocast_handler.to_kwargs()
autocast_context = torch.autocast(dtype=torch.bfloat16, device_type="cuda", **autocast_kwargs)
autocast_context.__enter__()
yield
autocast_context.__exit__(*sys.exc_info())

@requires_neuronx_distributed
def _prepare_clip_grad_norm(self, parameters, max_norm, norm_type: int = 2):
from neuronx_distributed.pipeline import NxDPPModel
Expand Down
75 changes: 53 additions & 22 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Custom PartialState and AcceleratorState for Neuron."""

import os
from typing import Optional, Union

import torch
from accelerate.state import AcceleratorState, PartialState, ThreadLocalSharedDict
Expand All @@ -35,8 +36,13 @@

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from ..utils.torch_xla_and_neuronx_initialization import (
init_process_group,
set_common_neuron_cc_flags,
set_neuron_cc_flags_for_torch_amp,
)
from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin
from .utils.dataclasses import ModelParallelismPlugin
from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin


if is_torch_xla_available():
Expand Down Expand Up @@ -84,6 +90,11 @@ def __init__(self, cpu: bool = False, **kwargs):
self.device = torch.device("cuda", self.local_process_index)
torch.cuda.set_device(self.device)
elif is_torch_xla_available() and not cpu:
# It is important to set the environment variables before initializing the process group otherwise they will be ignored by the Neuron compiler.
set_common_neuron_cc_flags()
if os.environ.get("ACCELERATE_USE_AMP", "false") == "true":
set_neuron_cc_flags_for_torch_amp()
init_process_group()
self.distributed_type = DistributedType.TPU
self.num_processes = xm.xrt_world_size()
self.process_index = xm.get_ordinal()
Expand Down Expand Up @@ -224,17 +235,26 @@ def __init__(
deepspeed_plugin=None,
fsdp_plugin=None,
megatron_lm_plugin=None,
mp_plugin=None,
mp_plugin: Optional[ModelParallelismPlugin] = None,
autocast_backend: Optional[Union[str, AutocastBackend]] = None,
_from_accelerator: bool = False,
**kwargs,
):
self.__dict__ = self._shared_state
if parse_flag_from_env("ACCELERATE_USE_CPU"):
cpu = True

if autocast_backend is None:
autocast_backend = AutocastBackend.XLA
elif not isinstance(autocast_backend, AutocastBackend):
autocast_backend = AutocastBackend(autocast_backend)

if NeuronPartialState._shared_state == {}:
if autocast_backend is AutocastBackend.AMP:
os.environ["ACCELERATE_USE_AMP"] = "true"
NeuronPartialState(cpu, **kwargs)
self.__dict__.update(NeuronPartialState._shared_state)
self._check_initialized(mixed_precision, cpu)
self._check_initialized(mixed_precision, cpu, autocast_backend)
if not self.initialized:
self.deepspeed_plugin = None
self.ipex_plugin = None
Expand All @@ -253,34 +273,30 @@ def __init__(
)
# deepspeed handles mixed_precision using deepspeed_config
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision

self._autocast_backend = autocast_backend

if self.distributed_type == DistributedType.TPU:
if mixed_precision == "bf16":
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
if autocast_backend is AutocastBackend.AMP:
self.downcast_bfloat = True
elif os.environ.get("ACCELERATE_DOWNCAST_BF16"):
os.environ["XLA_USE_BF16"] = str(0)
os.environ["XLA_DOWNCAST_BF16"] = str(1)
self.downcast_bfloat = True
else:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_DOWNCAST_BF16"] = str(0)
self.downcast_bfloat = False
if (
os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false") == "true"
or os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_PP", "false") == "true"
):
if mp_plugin is None:
raise ValueError(
"Could not initialize model parallelism because no `ModelParallelismPlugin` was provided."
)
if mp_plugin.should_parallelize:
self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM
else:
logger.warning(
"Model parallelism is requested but nothing is done because the tensor parallel size and "
"the pipeline parallel size are set to 1."
)
self.mp_plugin = mp_plugin
else:
self.mp_plugin = ModelParallelismPlugin()

if mp_plugin is None:
mp_plugin = ModelParallelismPlugin()

if mp_plugin.should_parallelize:
self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM

self.mp_plugin = mp_plugin
print("MP PLUGIN", self.mp_plugin)

if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
Expand Down Expand Up @@ -323,3 +339,18 @@ def __init__(
):
torch.backends.cuda.matmul.allow_tf32 = True
PartialState._shared_state["distributed_type"] = self.distributed_type

def _check_initialized(self, mixed_precision=None, cpu=None, autocast_backend=None):
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
super()._check_initialized(mixed_precision=mixed_precision, cpu=cpu)
err = (
"AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and "
"pass `{flag}` to `Accelerator()`."
)
if self.initialized:
if autocast_backend is not None and autocast_backend != self.autocast_backend:
raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'"))

@property
def autocast_backend(self):
return self._autocast_backend
Loading

0 comments on commit 3005c77

Please sign in to comment.