From ca2e88ab2fcb643c3451a1a7010e1b12bb96a3d6 Mon Sep 17 00:00:00 2001 From: Julian Quevedo Date: Fri, 30 Jun 2023 09:07:33 -0700 Subject: [PATCH] Add precision config arg for FP8 (#2335) * add precision_config arg to get_precision_context * add docstrings * apply pre-commit changes * default precision_config in context manager * check for key before access * import type * pre-commit formatting * docstring formatting * try to fix docs formatting * let context manager init config * fix None typing * user must supply Format enum * test config inside fp8 context * reformat test * fix precision_config type in docs * revert unnecessary docstring change * add default values to precision config * H100 is 9.0 * get_fp8_recipe is in te.fp8 --- .../selective_backprop/selective_backprop.py | 2 +- .../seq_length_warmup/seq_length_warmup.py | 2 +- composer/core/precision.py | 19 ++++++++------ composer/core/state.py | 12 +++++++++ composer/trainer/trainer.py | 25 +++++++++++++------ tests/test_precision.py | 23 ++++++++++++++--- tests/utils/test_autolog_hparams.py | 1 + 7 files changed, 64 insertions(+), 20 deletions(-) diff --git a/composer/algorithms/selective_backprop/selective_backprop.py b/composer/algorithms/selective_backprop/selective_backprop.py index 64062fc38f..06b054bb6f 100644 --- a/composer/algorithms/selective_backprop/selective_backprop.py +++ b/composer/algorithms/selective_backprop/selective_backprop.py @@ -279,7 +279,7 @@ def loss(p, y, reduction='none'): assert self._loss_fn is not None, 'loss_fn should be set on Event.INIT' return self._loss_fn(p, (torch.Tensor(), y), reduction=reduction) - with get_precision_context(state.precision): + with get_precision_context(state.precision, state.precision_config): new_input, new_target = select_using_loss(input, target, model, loss, self.keep, self.scale_factor) state.batch_set_item(self.input_key, new_input) state.batch_set_item(self.target_key, new_target) diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 2c19d91cd1..be0ec5e447 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -287,7 +287,7 @@ def _activate_model(self, state: State, logger: Logger) -> None: try: # Start by running a forward and backward pass # of the maximum sequence length to allocate cache. - with get_precision_context(state.precision): + with get_precision_context(state.precision, state.precision_config): outputs = state.model.forward(model_inputs) loss = self._original_model.loss(outputs, model_inputs) diff --git a/composer/core/precision.py b/composer/core/precision.py index 1b0f185eec..56333e6f85 100644 --- a/composer/core/precision.py +++ b/composer/core/precision.py @@ -6,7 +6,7 @@ import contextlib import os import textwrap -from typing import Generator, Union +from typing import Any, Dict, Generator, Optional, Union import torch @@ -38,11 +38,14 @@ class Precision(StringEnum): @contextlib.contextmanager -def get_precision_context(precision: Union[str, Precision]) -> Generator[None, None, None]: +def get_precision_context(precision: Union[str, Precision], + precision_config: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]: """Returns a context manager to automatically cast to a specific precision. Args: precision (str | Precision): Precision for the context + precision_config (Optional[Dict[str, Any]]): Config for FP8 scaling strategy. See parameters for + `DelayedScaling `_. """ precision = Precision(precision) if precision == Precision.FP32: @@ -67,11 +70,13 @@ def get_precision_context(precision: Union[str, Precision]) -> Generator[None, N if te_installed and torch.cuda.get_device_capability()[0] > 8: from transformer_engine.common.recipe import DelayedScaling, Format - # These default values for fp8_recipe are taken from NVidia's docs. We may want to change - # these once we get a chance to do more convergence experiments. - # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#id1 - fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') + if precision_config is None: + precision_config = { + 'fp8_format': Format.HYBRID, + 'amax_history_len': 16, + 'amax_compute_algo': 'max', + } + fp8_recipe = DelayedScaling(**precision_config) with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): yield else: diff --git a/composer/core/state.py b/composer/core/state.py index 6bc819485a..ff01d34630 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -274,6 +274,8 @@ class State(Serializable): max_duration (str | Time, optional): The maximum duration to train for. (default: ``None``) precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. + precision_config (Optional[Dict[str, Any]]): The config for FP8 scaling strategy. See parameters for + `DelayedScaling `_. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer being used to train the model. Multiple optimizers are not currently supported. schedulers (types.PyTorchScheduler | Sequence[types.PyTorchScheduler], optional): @@ -433,6 +435,7 @@ def __init__( # precision precision: Union[str, Precision] = Precision.FP32, + precision_config: Optional[Dict[str, Any]] = None, # optimizers optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None, @@ -472,6 +475,7 @@ def __init__( self.eval_timestamp = Timestamp() self.predict_timestamp = Timestamp() self._precision = Precision(precision) + self._precision_config = precision_config if optimizers is None: self._optimizers = [] @@ -1409,6 +1413,14 @@ def precision(self): def precision(self, precision: Union[str, Precision]): self._precision = Precision(precision) + @property + def precision_config(self): + """The config for FP8 scaling strategy. + + See parameters for `DelayedScaling `_. + """ + return self._precision_config + @property def is_model_ddp(self): """Whether :attr:`model` is an instance of a :class:`.DistributedDataParallel`.""" diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c4779d958d..87d233467b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -308,10 +308,10 @@ def _get_ddp_sync_strategy(ddp_sync_strategy: Optional[Union[str, DDPSyncStrateg return ddp_sync_strategy -def _get_precision_context(precision: Precision, deepspeed_enabled: bool): +def _get_precision_context(precision: Precision, precision_config: Optional[Dict[str, Any]], deepspeed_enabled: bool): if deepspeed_enabled: return contextlib.nullcontext() - return get_precision_context(precision) + return get_precision_context(precision, precision_config) def _generate_run_name() -> str: @@ -742,6 +742,8 @@ class Trainer: precision (Precision | str, optional): Numerical precision to use for training. One of ``fp32``, ``amp_bf16`` or ``amp_fp16`` (recommended). (default: ``Precision.FP32`` if training on CPU; ``Precision.AMP_FP16`` if training on GPU) + precision_config (Optional[Dict[str, Any]]): The config for FP8 scaling strategy. See parameters for + `DelayedScaling `_. device_train_microbatch_size (Union[int, str), optional): The number of samples to process on each device per microbatch during training. Gradients are summed over the microbatches per device. If set to ``auto``, dynamically decreases device_train_microbatch_size if microbatch is too large for GPU. (default: ``None``) @@ -872,6 +874,7 @@ def __init__( # System/Numerics device: Optional[Union[str, Device]] = None, precision: Optional[Union[str, Precision]] = None, + precision_config: Optional[Dict[str, Any]] = None, device_train_microbatch_size: Optional[Union[int, str]] = None, # Reproducibility @@ -1010,6 +1013,7 @@ def __init__( device_train_microbatch_size=device_train_microbatch_size, auto_microbatching=auto_microbatching, precision=precision, + precision_config=precision_config, optimizers=optimizers, run_name=run_name, deepspeed_config=deepspeed_config, @@ -2101,7 +2105,7 @@ def _eval_train_metrics(self, device_batch): with torch.no_grad(),\ model_eval_mode(self.state.model),\ - _get_precision_context(self.state.precision, self.state.deepspeed_enabled): + _get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled): eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs) for _, metric in self.state.train_metrics.items(): self._original_model.update_metric( @@ -2334,7 +2338,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, # Forward pass self.engine.run_event(Event.BEFORE_FORWARD) - with _get_precision_context(self.state.precision, self.state.deepspeed_enabled): + with _get_precision_context(self.state.precision, self.state.precision_config, + self.state.deepspeed_enabled): self.state.outputs = self.state.model(self.state.batch) self.engine.run_event(Event.AFTER_FORWARD) @@ -2356,7 +2361,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, # Loss self.engine.run_event(Event.BEFORE_LOSS) - with _get_precision_context(self.state.precision, self.state.deepspeed_enabled): + with _get_precision_context(self.state.precision, self.state.precision_config, + self.state.deepspeed_enabled): self.state.loss = self._original_model.loss(self.state.outputs, self.state.batch) assert self.state.loss is not None @@ -2527,7 +2533,8 @@ def predict_batch_end(self, state: State, logger: Logger) -> None: self.engine.run_event(Event.PREDICT_BATCH_START) self.engine.run_event(Event.PREDICT_BEFORE_FORWARD) - with _get_precision_context(self.state.precision, self.state.deepspeed_enabled): + with _get_precision_context(self.state.precision, self.state.precision_config, + self.state.deepspeed_enabled): self.state.outputs = self.state.model(self.state.batch) self.engine.run_event(Event.PREDICT_AFTER_FORWARD) @@ -2819,7 +2826,8 @@ def _eval_loop( self.engine.run_event(Event.EVAL_BEFORE_FORWARD) - with _get_precision_context(self.state.precision, self.state.deepspeed_enabled): + with _get_precision_context(self.state.precision, self.state.precision_config, + self.state.deepspeed_enabled): self.state.outputs = self._original_model.eval_forward(self.state.batch) self.engine.run_event(Event.EVAL_AFTER_FORWARD) @@ -2830,7 +2838,8 @@ def _eval_loop( continue # Run in same precision context to avoid NaNs - with _get_precision_context(self.state.precision, self.state.deepspeed_enabled): + with _get_precision_context(self.state.precision, self.state.precision_config, + self.state.deepspeed_enabled): if isinstance(self.state.device, DeviceMPS): # torchmetrics math has numerical errors on M1 devices # running the compute on CPU instead diff --git a/tests/test_precision.py b/tests/test_precision.py index 412da3dc1e..46571529c6 100644 --- a/tests/test_precision.py +++ b/tests/test_precision.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, Optional import pytest import torch @@ -7,19 +8,18 @@ from torch.utils.data import DataLoader from composer import Trainer -from composer.core import Precision +from composer.core import Precision, get_precision_context from composer.models import composer_resnet_cifar from tests.common import RandomImageDataset try: import transformer_engine.pytorch as te - del te te_installed = True except ImportError: te_installed = False -def get_trainer(precision: Precision) -> Trainer: +def get_trainer(precision: Precision, precision_config: Optional[Dict[str, Any]] = None) -> Trainer: return Trainer( model=composer_resnet_cifar('resnet_9'), @@ -36,6 +36,7 @@ def get_trainer(precision: Precision) -> Trainer: num_workers=0, ), precision=precision, + precision_config=precision_config, max_duration='1ep', eval_interval='1ep', train_subset_num_batches=1, @@ -108,3 +109,19 @@ def test_amp_fp8_path(): else: with pytest.raises(ImportError, match='AMP_FP8 precision is used but TransformerEngine is not installed'): trainer.fit() + + +@pytest.mark.gpu +def test_amp_fp8_config(): + if te_installed and torch.cuda.get_device_capability()[0] >= 9: + from transformer_engine.common.recipe import Format + precision_config = { + 'fp8_format': Format.HYBRID, + 'amax_history_len': 16, + 'amax_compute_algo': 'max', + } + trainer = get_trainer(Precision.AMP_FP8, precision_config=precision_config) + with get_precision_context(trainer.state.precision, trainer.state.precision_config): + fp8_recipe = te.fp8.get_fp8_recipe() + for k, v in precision_config.items(): + assert getattr(fp8_recipe, k) == v diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py index fc7cf5ecf8..66cab9e9b8 100644 --- a/tests/utils/test_autolog_hparams.py +++ b/tests/utils/test_autolog_hparams.py @@ -170,6 +170,7 @@ def test_extract_hparams_trainer(): # System/Numerics 'device': 'DeviceCPU', 'precision': 'Precision', + 'precision_config': None, 'device_train_microbatch_size': 16, # Reproducibility