-
Notifications
You must be signed in to change notification settings - Fork 266
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add callback for enabling Tensorfloat32 (#885)
Summary: Pull Request resolved: #885 This is something that can boost performance quite a bit with float32 training on CUDA, so I figured it'd make sense to package it up into a re-useable callback. Reviewed By: diego-urgell Differential Revision: D61608792 fbshipit-source-id: ccd0712c9022029bf59ee0730a71ad59feea60ae
- Loading branch information
1 parent
ebda066
commit 926b5ec
Showing
3 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import contextlib | ||
import unittest | ||
from typing import Iterator | ||
|
||
import torch | ||
from torchtnt.framework._test_utils import ( | ||
DummyFitUnit, | ||
DummyPredictUnit, | ||
DummyTrainUnit, | ||
generate_random_dataloader, | ||
) | ||
from torchtnt.framework.callback import Callback | ||
from torchtnt.framework.callbacks.tensorfloat32 import EnableTensorFloat32 | ||
from torchtnt.framework.fit import fit | ||
from torchtnt.framework.predict import predict | ||
from torchtnt.framework.state import State | ||
from torchtnt.framework.train import train | ||
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit | ||
|
||
|
||
class _CheckTensorFloat32Enabled(Callback): | ||
def __init__(self, testcase: unittest.TestCase) -> None: | ||
self.testcase = testcase | ||
|
||
def assert_enabled(self) -> None: | ||
self.testcase.assertEqual(torch.get_float32_matmul_precision(), "high") | ||
self.testcase.assertTrue(torch.backends.cudnn.allow_tf32) | ||
self.testcase.assertTrue(torch.backends.cuda.matmul.allow_tf32) | ||
|
||
def on_train_step_start(self, state: State, unit: TTrainUnit) -> None: | ||
self.assert_enabled() | ||
|
||
def on_eval_step_start(self, state: State, unit: TEvalUnit) -> None: | ||
self.assert_enabled() | ||
|
||
def on_predict_step_start(self, state: State, unit: TPredictUnit) -> None: | ||
self.assert_enabled() | ||
|
||
|
||
class EnableTensorFloat32Test(unittest.TestCase): | ||
@contextlib.contextmanager | ||
def check_proper_restore(self) -> Iterator[EnableTensorFloat32]: | ||
callback = EnableTensorFloat32() | ||
|
||
# Disable TensorFloat32 | ||
torch.set_float32_matmul_precision("highest") | ||
torch.backends.cudnn.allow_tf32 = False | ||
torch.backends.cuda.matmul.allow_tf32 = False | ||
|
||
yield callback | ||
|
||
# Original Values are Restored | ||
self.assertIsNone(callback.original_cuda_matmul) | ||
self.assertIsNone(callback.original_cudnn) | ||
self.assertIsNone(callback.original_float32_matmul_precision) | ||
|
||
self.assertEqual(torch.get_float32_matmul_precision(), "highest") | ||
self.assertFalse(torch.backends.cudnn.allow_tf32) | ||
self.assertFalse(torch.backends.cuda.matmul.allow_tf32) | ||
|
||
def test_tensorfloat32_callback_train(self) -> None: | ||
input_dim = batch_size = max_epochs = 2 | ||
dataset_len = 5 | ||
|
||
unit = DummyTrainUnit(input_dim=input_dim) | ||
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) | ||
with self.check_proper_restore() as callback: | ||
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)] | ||
train(unit, dataloader, max_epochs=max_epochs, callbacks=callbacks) | ||
|
||
def test_tensorfloat32_callback_fit(self) -> None: | ||
input_dim = batch_size = max_epochs = 2 | ||
dataset_len = 5 | ||
|
||
unit = DummyFitUnit(input_dim=input_dim) | ||
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) | ||
with self.check_proper_restore() as callback: | ||
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)] | ||
fit( | ||
unit, | ||
dataloader, | ||
dataloader, | ||
max_epochs=max_epochs, | ||
callbacks=callbacks, | ||
) | ||
|
||
def test_tensorfloat32_callback_predict(self) -> None: | ||
input_dim = batch_size = 2 | ||
dataset_len = 5 | ||
|
||
unit = DummyPredictUnit(input_dim=input_dim) | ||
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) | ||
with self.check_proper_restore() as callback: | ||
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)] | ||
predict(unit, dataloader, callbacks=callbacks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import logging | ||
from typing import Optional | ||
|
||
import torch | ||
from torchtnt.framework.callback import Callback | ||
from torchtnt.framework.state import EntryPoint, State | ||
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit | ||
from torchtnt.utils.rank_zero_log import rank_zero_info | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class EnableTensorFloat32(Callback): | ||
""" | ||
A callback that enables TensorFloat32 operations on CUDA. | ||
Args: | ||
float32_matmul_precision: precision to use for float32 matmul operations. | ||
See `torch.set_float32_matmul_precision` for details. | ||
""" | ||
|
||
def __init__(self, float32_matmul_precision: str = "high") -> None: | ||
self.float32_matmul_precision = float32_matmul_precision | ||
|
||
self.original_float32_matmul_precision: Optional[str] = None | ||
self.original_cuda_matmul: Optional[bool] = None | ||
self.original_cudnn: Optional[bool] = None | ||
|
||
def _enable(self) -> None: | ||
rank_zero_info("Enabling TensorFloat32 operations on CUDA", logger=logger) | ||
assert self.original_float32_matmul_precision is None | ||
assert self.original_cuda_matmul is None | ||
assert self.original_cudnn is None | ||
|
||
self.original_float32_matmul_precision = torch.get_float32_matmul_precision() | ||
self.original_cuda_matmul = torch.backends.cuda.matmul.allow_tf32 | ||
self.original_cudnn = torch.backends.cudnn.allow_tf32 | ||
|
||
torch.set_float32_matmul_precision(self.float32_matmul_precision) | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
def _reset(self) -> None: | ||
rank_zero_info( | ||
"Restoring original TensorFloat32 permissions on CUDA", logger=logger | ||
) | ||
if self.original_float32_matmul_precision is not None: | ||
torch.set_float32_matmul_precision(self.original_float32_matmul_precision) | ||
self.original_float32_matmul_precision = None | ||
|
||
if self.original_cuda_matmul is not None: | ||
torch.backends.cuda.matmul.allow_tf32 = self.original_cuda_matmul | ||
self.original_cuda_matmul = None | ||
|
||
if self.original_cudnn is not None: | ||
torch.backends.cudnn.allow_tf32 = self.original_cudnn | ||
self.original_cudnn = None | ||
|
||
def on_train_start(self, state: State, unit: TTrainUnit) -> None: | ||
self._enable() | ||
|
||
def on_train_end(self, state: State, unit: TTrainUnit) -> None: | ||
self._reset() | ||
|
||
def on_eval_start(self, state: State, unit: TEvalUnit) -> None: | ||
if state.entry_point == EntryPoint.FIT: | ||
return # if fitting, this is already handled in on_train_start | ||
self._enable() | ||
|
||
def on_eval_end(self, state: State, unit: TEvalUnit) -> None: | ||
if state.entry_point == EntryPoint.FIT: | ||
return # if fitting, this is already handled in on_train_end | ||
self._reset() | ||
|
||
def on_predict_start(self, state: State, unit: TPredictUnit) -> None: | ||
self._enable() | ||
|
||
def on_predict_end(self, state: State, unit: TPredictUnit) -> None: | ||
self._reset() |