Skip to content

Commit

Permalink
Add callback for enabling Tensorfloat32 (#885)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #885

This is something that can boost performance quite a bit with float32 training on CUDA, so I figured it'd make sense to package it up into a re-useable callback.

Reviewed By: diego-urgell

Differential Revision: D61608792

fbshipit-source-id: ccd0712c9022029bf59ee0730a71ad59feea60ae
  • Loading branch information
alanhdu authored and facebook-github-bot committed Aug 26, 2024
1 parent ebda066 commit 926b5ec
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 0 deletions.
103 changes: 103 additions & 0 deletions tests/framework/callbacks/test_tensorfloat32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import contextlib
import unittest
from typing import Iterator

import torch
from torchtnt.framework._test_utils import (
DummyFitUnit,
DummyPredictUnit,
DummyTrainUnit,
generate_random_dataloader,
)
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.tensorfloat32 import EnableTensorFloat32
from torchtnt.framework.fit import fit
from torchtnt.framework.predict import predict
from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit


class _CheckTensorFloat32Enabled(Callback):
def __init__(self, testcase: unittest.TestCase) -> None:
self.testcase = testcase

def assert_enabled(self) -> None:
self.testcase.assertEqual(torch.get_float32_matmul_precision(), "high")
self.testcase.assertTrue(torch.backends.cudnn.allow_tf32)
self.testcase.assertTrue(torch.backends.cuda.matmul.allow_tf32)

def on_train_step_start(self, state: State, unit: TTrainUnit) -> None:
self.assert_enabled()

def on_eval_step_start(self, state: State, unit: TEvalUnit) -> None:
self.assert_enabled()

def on_predict_step_start(self, state: State, unit: TPredictUnit) -> None:
self.assert_enabled()


class EnableTensorFloat32Test(unittest.TestCase):
@contextlib.contextmanager
def check_proper_restore(self) -> Iterator[EnableTensorFloat32]:
callback = EnableTensorFloat32()

# Disable TensorFloat32
torch.set_float32_matmul_precision("highest")
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False

yield callback

# Original Values are Restored
self.assertIsNone(callback.original_cuda_matmul)
self.assertIsNone(callback.original_cudnn)
self.assertIsNone(callback.original_float32_matmul_precision)

self.assertEqual(torch.get_float32_matmul_precision(), "highest")
self.assertFalse(torch.backends.cudnn.allow_tf32)
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)

def test_tensorfloat32_callback_train(self) -> None:
input_dim = batch_size = max_epochs = 2
dataset_len = 5

unit = DummyTrainUnit(input_dim=input_dim)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
with self.check_proper_restore() as callback:
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)]
train(unit, dataloader, max_epochs=max_epochs, callbacks=callbacks)

def test_tensorfloat32_callback_fit(self) -> None:
input_dim = batch_size = max_epochs = 2
dataset_len = 5

unit = DummyFitUnit(input_dim=input_dim)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
with self.check_proper_restore() as callback:
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)]
fit(
unit,
dataloader,
dataloader,
max_epochs=max_epochs,
callbacks=callbacks,
)

def test_tensorfloat32_callback_predict(self) -> None:
input_dim = batch_size = 2
dataset_len = 5

unit = DummyPredictUnit(input_dim=input_dim)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
with self.check_proper_restore() as callback:
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)]
predict(unit, dataloader, callbacks=callbacks)
2 changes: 2 additions & 0 deletions torchtnt/framework/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .slow_rank_detector import SlowRankDetector
from .system_resources_monitor import SystemResourcesMonitor
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
from .tensorfloat32 import EnableTensorFloat32
from .throughput_logger import ThroughputLogger
from .time_limit_interrupter import TimeLimitInterrupter
from .time_wait_for_batch_logger import TimeWaitForBatchLogger
Expand All @@ -34,6 +35,7 @@
"BaseCSVWriter",
"EarlyStopping",
"EmptyCudaCache",
"EnableTensorFloat32",
"GarbageCollector",
"IterationTimeLogger",
"Lambda",
Expand Down
87 changes: 87 additions & 0 deletions torchtnt/framework/callbacks/tensorfloat32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
from typing import Optional

import torch
from torchtnt.framework.callback import Callback
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.rank_zero_log import rank_zero_info

logger: logging.Logger = logging.getLogger(__name__)


class EnableTensorFloat32(Callback):
"""
A callback that enables TensorFloat32 operations on CUDA.
Args:
float32_matmul_precision: precision to use for float32 matmul operations.
See `torch.set_float32_matmul_precision` for details.
"""

def __init__(self, float32_matmul_precision: str = "high") -> None:
self.float32_matmul_precision = float32_matmul_precision

self.original_float32_matmul_precision: Optional[str] = None
self.original_cuda_matmul: Optional[bool] = None
self.original_cudnn: Optional[bool] = None

def _enable(self) -> None:
rank_zero_info("Enabling TensorFloat32 operations on CUDA", logger=logger)
assert self.original_float32_matmul_precision is None
assert self.original_cuda_matmul is None
assert self.original_cudnn is None

self.original_float32_matmul_precision = torch.get_float32_matmul_precision()
self.original_cuda_matmul = torch.backends.cuda.matmul.allow_tf32
self.original_cudnn = torch.backends.cudnn.allow_tf32

torch.set_float32_matmul_precision(self.float32_matmul_precision)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def _reset(self) -> None:
rank_zero_info(
"Restoring original TensorFloat32 permissions on CUDA", logger=logger
)
if self.original_float32_matmul_precision is not None:
torch.set_float32_matmul_precision(self.original_float32_matmul_precision)
self.original_float32_matmul_precision = None

if self.original_cuda_matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = self.original_cuda_matmul
self.original_cuda_matmul = None

if self.original_cudnn is not None:
torch.backends.cudnn.allow_tf32 = self.original_cudnn
self.original_cudnn = None

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
self._enable()

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._reset()

def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
if state.entry_point == EntryPoint.FIT:
return # if fitting, this is already handled in on_train_start
self._enable()

def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
if state.entry_point == EntryPoint.FIT:
return # if fitting, this is already handled in on_train_end
self._reset()

def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
self._enable()

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
self._reset()

0 comments on commit 926b5ec

Please sign in to comment.