Skip to content

Commit

Permalink
Add deterministic behavior for genie tnt finetune (#530)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #530

Adding missing setting in tnt environment and seed setting to ensure deterministic behavior.
This is needed to do genie finetune matching between genie and d2go.

Reviewed By: krishnakumar-kapil

Differential Revision: D48826358

fbshipit-source-id: ef3cc24aada50853a45f497f6530d5333e1df9cc
  • Loading branch information
Tsahi Glik authored and facebook-github-bot committed Sep 7, 2023
1 parent 68f5eae commit 0c0a715
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
15 changes: 15 additions & 0 deletions tests/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_device_from_env,
get_nvidia_smi_gpu_stats,
get_psutil_cpu_stats,
maybe_enable_tf32,
record_data_in_stream,
)

Expand Down Expand Up @@ -351,3 +352,17 @@ def test_record_data_in_stream_list(self) -> None:
record_data_in_stream(data, curr_stream)
mock_record_stream_a.assert_called_once()
mock_record_stream_b.assert_called_once()

@unittest.skipUnless(
condition=(cuda_available), reason="This test must run on a GPU host."
)
def test_maybe_enable_tf32(self) -> None:
maybe_enable_tf32("highest")
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
self.assertFalse(torch.backends.cudnn.allow_tf32)
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)

maybe_enable_tf32("high")
self.assertEqual(torch.get_float32_matmul_precision(), "high")
self.assertTrue(torch.backends.cudnn.allow_tf32)
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
2 changes: 2 additions & 0 deletions tests/utils/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest

import numpy as np
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_deterministic_true(self) -> None:
self.assertEqual(
warn_only, torch.is_deterministic_algorithms_warn_only_enabled()
)
self.assertEqual(os.environ["CUBLAS_WORKSPACE_CONFIG"], ":4096:8")

def test_deterministic_false(self) -> None:
for deterministic in ("default", 0):
Expand Down
14 changes: 7 additions & 7 deletions torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,20 @@ def collect_system_stats(device: torch.device) -> Dict[str, Any]:


def maybe_enable_tf32(precision: str = "high") -> None:
"""Conditionally sets the precision of float32 matrix multiplications.
"""Conditionally sets the precision of float32 matrix multiplications and conv operations.
For more information, see the `PyTorch docs <https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html>`_
Args:
precision: The setting to determine which datatypes to use for matrix multiplication.
"""
if not (
is_torch_version_geq_1_12() # API exposed from PyTorch 1.12 onward
and torch.cuda.is_available() # Not relevant for non-CUDA devices
and torch.cuda.get_device_capability()
>= (8, 0) # Available only for Ampere architectures onwards
and torch.get_float32_matmul_precision()
== "highest" # Only change the setting if on highest precision
torch.cuda.is_available() # Not relevant for non-CUDA devices
and is_torch_version_geq_1_12() # API exposed from PyTorch 1.12 onward
):
return
torch.set_float32_matmul_precision(precision)
if precision == "highest":
torch.backends.cudnn.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
2 changes: 2 additions & 0 deletions torchtnt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,5 @@ def seed(seed: int, deterministic: Optional[Union[str, int]] = None) -> None:
_log.debug("Enabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

0 comments on commit 0c0a715

Please sign in to comment.