Skip to content

Commit

Permalink
Add precision config arg for FP8 (mosaicml#2335)
Browse files Browse the repository at this point in the history
* add precision_config arg to get_precision_context

* add docstrings

* apply pre-commit changes

* default precision_config in context manager

* check for key before access

* import type

* pre-commit formatting

* docstring formatting

* try to fix docs formatting

* let context manager init config

* fix None typing

* user must supply Format enum

* test config inside fp8 context

* reformat test

* fix precision_config type in docs

* revert unnecessary docstring change

* add default values to precision config

* H100 is 9.0

* get_fp8_recipe is in te.fp8
  • Loading branch information
julian-q committed Jun 30, 2023
1 parent b73c776 commit ca2e88a
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def loss(p, y, reduction='none'):
assert self._loss_fn is not None, 'loss_fn should be set on Event.INIT'
return self._loss_fn(p, (torch.Tensor(), y), reduction=reduction)

with get_precision_context(state.precision):
with get_precision_context(state.precision, state.precision_config):
new_input, new_target = select_using_loss(input, target, model, loss, self.keep, self.scale_factor)
state.batch_set_item(self.input_key, new_input)
state.batch_set_item(self.target_key, new_target)
2 changes: 1 addition & 1 deletion composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _activate_model(self, state: State, logger: Logger) -> None:
try:
# Start by running a forward and backward pass
# of the maximum sequence length to allocate cache.
with get_precision_context(state.precision):
with get_precision_context(state.precision, state.precision_config):
outputs = state.model.forward(model_inputs)
loss = self._original_model.loss(outputs, model_inputs)

Expand Down
19 changes: 12 additions & 7 deletions composer/core/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import contextlib
import os
import textwrap
from typing import Generator, Union
from typing import Any, Dict, Generator, Optional, Union

import torch

Expand Down Expand Up @@ -38,11 +38,14 @@ class Precision(StringEnum):


@contextlib.contextmanager
def get_precision_context(precision: Union[str, Precision]) -> Generator[None, None, None]:
def get_precision_context(precision: Union[str, Precision],
precision_config: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]:
"""Returns a context manager to automatically cast to a specific precision.
Args:
precision (str | Precision): Precision for the context
precision_config (Optional[Dict[str, Any]]): Config for FP8 scaling strategy. See parameters for
`DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_.
"""
precision = Precision(precision)
if precision == Precision.FP32:
Expand All @@ -67,11 +70,13 @@ def get_precision_context(precision: Union[str, Precision]) -> Generator[None, N
if te_installed and torch.cuda.get_device_capability()[0] > 8:
from transformer_engine.common.recipe import DelayedScaling, Format

# These default values for fp8_recipe are taken from NVidia's docs. We may want to change
# these once we get a chance to do more convergence experiments.
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#id1
fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
if precision_config is None:
precision_config = {
'fp8_format': Format.HYBRID,
'amax_history_len': 16,
'amax_compute_algo': 'max',
}
fp8_recipe = DelayedScaling(**precision_config)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
yield
else:
Expand Down
12 changes: 12 additions & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class State(Serializable):
max_duration (str | Time, optional): The maximum duration to train for. (default: ``None``)
precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for
the supported precisions.
precision_config (Optional[Dict[str, Any]]): The config for FP8 scaling strategy. See parameters for
`DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer being used to
train the model. Multiple optimizers are not currently supported.
schedulers (types.PyTorchScheduler | Sequence[types.PyTorchScheduler], optional):
Expand Down Expand Up @@ -433,6 +435,7 @@ def __init__(

# precision
precision: Union[str, Precision] = Precision.FP32,
precision_config: Optional[Dict[str, Any]] = None,

# optimizers
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
Expand Down Expand Up @@ -472,6 +475,7 @@ def __init__(
self.eval_timestamp = Timestamp()
self.predict_timestamp = Timestamp()
self._precision = Precision(precision)
self._precision_config = precision_config

if optimizers is None:
self._optimizers = []
Expand Down Expand Up @@ -1409,6 +1413,14 @@ def precision(self):
def precision(self, precision: Union[str, Precision]):
self._precision = Precision(precision)

@property
def precision_config(self):
"""The config for FP8 scaling strategy.
See parameters for `DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_.
"""
return self._precision_config

@property
def is_model_ddp(self):
"""Whether :attr:`model` is an instance of a :class:`.DistributedDataParallel`."""
Expand Down
25 changes: 17 additions & 8 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ def _get_ddp_sync_strategy(ddp_sync_strategy: Optional[Union[str, DDPSyncStrateg
return ddp_sync_strategy


def _get_precision_context(precision: Precision, deepspeed_enabled: bool):
def _get_precision_context(precision: Precision, precision_config: Optional[Dict[str, Any]], deepspeed_enabled: bool):
if deepspeed_enabled:
return contextlib.nullcontext()
return get_precision_context(precision)
return get_precision_context(precision, precision_config)


def _generate_run_name() -> str:
Expand Down Expand Up @@ -742,6 +742,8 @@ class Trainer:
precision (Precision | str, optional): Numerical precision to use for training. One of ``fp32``, ``amp_bf16``
or ``amp_fp16`` (recommended). (default: ``Precision.FP32`` if training on CPU; ``Precision.AMP_FP16`` if
training on GPU)
precision_config (Optional[Dict[str, Any]]): The config for FP8 scaling strategy. See parameters for
`DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_.
device_train_microbatch_size (Union[int, str), optional): The number of samples to process on each device per
microbatch during training. Gradients are summed over the microbatches per device. If set to ``auto``,
dynamically decreases device_train_microbatch_size if microbatch is too large for GPU. (default: ``None``)
Expand Down Expand Up @@ -872,6 +874,7 @@ def __init__(
# System/Numerics
device: Optional[Union[str, Device]] = None,
precision: Optional[Union[str, Precision]] = None,
precision_config: Optional[Dict[str, Any]] = None,
device_train_microbatch_size: Optional[Union[int, str]] = None,

# Reproducibility
Expand Down Expand Up @@ -1010,6 +1013,7 @@ def __init__(
device_train_microbatch_size=device_train_microbatch_size,
auto_microbatching=auto_microbatching,
precision=precision,
precision_config=precision_config,
optimizers=optimizers,
run_name=run_name,
deepspeed_config=deepspeed_config,
Expand Down Expand Up @@ -2101,7 +2105,7 @@ def _eval_train_metrics(self, device_batch):

with torch.no_grad(),\
model_eval_mode(self.state.model),\
_get_precision_context(self.state.precision, self.state.deepspeed_enabled):
_get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled):
eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs)
for _, metric in self.state.train_metrics.items():
self._original_model.update_metric(
Expand Down Expand Up @@ -2334,7 +2338,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
# Forward pass
self.engine.run_event(Event.BEFORE_FORWARD)

with _get_precision_context(self.state.precision, self.state.deepspeed_enabled):
with _get_precision_context(self.state.precision, self.state.precision_config,
self.state.deepspeed_enabled):
self.state.outputs = self.state.model(self.state.batch)

self.engine.run_event(Event.AFTER_FORWARD)
Expand All @@ -2356,7 +2361,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
# Loss
self.engine.run_event(Event.BEFORE_LOSS)

with _get_precision_context(self.state.precision, self.state.deepspeed_enabled):
with _get_precision_context(self.state.precision, self.state.precision_config,
self.state.deepspeed_enabled):
self.state.loss = self._original_model.loss(self.state.outputs, self.state.batch)

assert self.state.loss is not None
Expand Down Expand Up @@ -2527,7 +2533,8 @@ def predict_batch_end(self, state: State, logger: Logger) -> None:
self.engine.run_event(Event.PREDICT_BATCH_START)

self.engine.run_event(Event.PREDICT_BEFORE_FORWARD)
with _get_precision_context(self.state.precision, self.state.deepspeed_enabled):
with _get_precision_context(self.state.precision, self.state.precision_config,
self.state.deepspeed_enabled):
self.state.outputs = self.state.model(self.state.batch)
self.engine.run_event(Event.PREDICT_AFTER_FORWARD)

Expand Down Expand Up @@ -2819,7 +2826,8 @@ def _eval_loop(

self.engine.run_event(Event.EVAL_BEFORE_FORWARD)

with _get_precision_context(self.state.precision, self.state.deepspeed_enabled):
with _get_precision_context(self.state.precision, self.state.precision_config,
self.state.deepspeed_enabled):
self.state.outputs = self._original_model.eval_forward(self.state.batch)

self.engine.run_event(Event.EVAL_AFTER_FORWARD)
Expand All @@ -2830,7 +2838,8 @@ def _eval_loop(
continue

# Run in same precision context to avoid NaNs
with _get_precision_context(self.state.precision, self.state.deepspeed_enabled):
with _get_precision_context(self.state.precision, self.state.precision_config,
self.state.deepspeed_enabled):
if isinstance(self.state.device, DeviceMPS):
# torchmetrics math has numerical errors on M1 devices
# running the compute on CPU instead
Expand Down
23 changes: 20 additions & 3 deletions tests/test_precision.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional

import pytest
import torch
import torch.distributed
from torch.utils.data import DataLoader

from composer import Trainer
from composer.core import Precision
from composer.core import Precision, get_precision_context
from composer.models import composer_resnet_cifar
from tests.common import RandomImageDataset

try:
import transformer_engine.pytorch as te
del te
te_installed = True
except ImportError:
te_installed = False


def get_trainer(precision: Precision) -> Trainer:
def get_trainer(precision: Precision, precision_config: Optional[Dict[str, Any]] = None) -> Trainer:

return Trainer(
model=composer_resnet_cifar('resnet_9'),
Expand All @@ -36,6 +36,7 @@ def get_trainer(precision: Precision) -> Trainer:
num_workers=0,
),
precision=precision,
precision_config=precision_config,
max_duration='1ep',
eval_interval='1ep',
train_subset_num_batches=1,
Expand Down Expand Up @@ -108,3 +109,19 @@ def test_amp_fp8_path():
else:
with pytest.raises(ImportError, match='AMP_FP8 precision is used but TransformerEngine is not installed'):
trainer.fit()


@pytest.mark.gpu
def test_amp_fp8_config():
if te_installed and torch.cuda.get_device_capability()[0] >= 9:
from transformer_engine.common.recipe import Format
precision_config = {
'fp8_format': Format.HYBRID,
'amax_history_len': 16,
'amax_compute_algo': 'max',
}
trainer = get_trainer(Precision.AMP_FP8, precision_config=precision_config)
with get_precision_context(trainer.state.precision, trainer.state.precision_config):
fp8_recipe = te.fp8.get_fp8_recipe()
for k, v in precision_config.items():
assert getattr(fp8_recipe, k) == v
1 change: 1 addition & 0 deletions tests/utils/test_autolog_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def test_extract_hparams_trainer():
# System/Numerics
'device': 'DeviceCPU',
'precision': 'Precision',
'precision_config': None,
'device_train_microbatch_size': 16,

# Reproducibility
Expand Down

0 comments on commit ca2e88a

Please sign in to comment.