From df8f543d1ef2ca176336a6f05daf7901a23e89dc Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Thu, 9 Nov 2023 14:24:37 -0800 Subject: [PATCH] Improve logging of MB/GB in oom logger (#616) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/616 Improve readability of oom logger logging ## Before: ``` Saving memory snapshot device: 0, alloc: 360000000098304, device_alloc: 84990623744, device_free: 82956713984 ``` ## After: ``` Saving memory snapshot device: 0, alloc: 335276.13 GB, device_alloc: 79.15 GB, device_free: 77.26 GB ``` Reviewed By: JKSenthil Differential Revision: D51166528 fbshipit-source-id: ec5ad8ea6f8ad06fb26256545f5a3a5d9f8ebc0f --- tests/utils/test_oom.py | 12 ++++++++++++ torchtnt/utils/oom.py | 9 ++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_oom.py b/tests/utils/test_oom.py index 9d6e36e90d..afb4303090 100644 --- a/tests/utils/test_oom.py +++ b/tests/utils/test_oom.py @@ -12,6 +12,7 @@ import torch from torchtnt.utils.device import get_device_from_env from torchtnt.utils.oom import ( + _bytes_to_mb_gb, is_out_of_cpu_memory, is_out_of_cuda_memory, is_out_of_memory_error, @@ -91,3 +92,14 @@ def test_log_memory_snapshot(self) -> None: segment_plot_path = os.path.join(save_dir, "segment_plot.html") self.assertTrue(os.path.exists(segment_plot_path)) + + def test_bytes_to_mb_gb(self) -> None: + bytes_to_mb_test_cases = [ + (0, "0.0 MB"), + (100000, "0.1 MB"), + (1000000, "0.95 MB"), + (1000000000, "0.93 GB"), + (1000000000000, "931.32 GB"), + ] + for inp, expected in bytes_to_mb_test_cases: + self.assertEqual(expected, _bytes_to_mb_gb(inp)) diff --git a/torchtnt/utils/oom.py b/torchtnt/utils/oom.py index 1633c4d521..771a8fcfac 100644 --- a/torchtnt/utils/oom.py +++ b/torchtnt/utils/oom.py @@ -44,6 +44,13 @@ def is_out_of_memory_error(exception: BaseException) -> bool: return is_out_of_cpu_memory(exception) or is_out_of_cuda_memory(exception) +def _bytes_to_mb_gb(num_bytes: int) -> str: + if num_bytes < 1024 * 1024: + return f"{round(num_bytes / (1024 * 1024), 2)} MB" + else: + return f"{round(num_bytes / (1024 * 1024 * 1024), 2)} GB" + + def _oom_observer( output_dir: str, ) -> Callable[[Union[int, torch.device], int, int, int], None]: @@ -57,7 +64,7 @@ def oom_logger( Log memory snapshot in the event of CUDA OOM. """ logger.info( - f"Saving memory snapshot device: {device}, alloc: {alloc}, device_alloc: {device_alloc}, device_free: {device_free}" + f"Saving memory snapshot device: {device}, alloc: {_bytes_to_mb_gb(alloc)}, device_alloc: {_bytes_to_mb_gb(device_alloc)}, device_free: {_bytes_to_mb_gb(device_free)}" ) try: log_memory_snapshot(output_dir, "oom")