Skip to content

Commit

Permalink
Improve logging of MB/GB in oom logger (#616)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #616

Improve readability of oom logger logging

## Before:
```
Saving memory snapshot device: 0, alloc: 360000000098304, device_alloc: 84990623744, device_free: 82956713984
```
## After:
```
Saving memory snapshot device: 0, alloc: 335276.13 GB, device_alloc: 79.15 GB, device_free: 77.26 GB
```

Reviewed By: JKSenthil

Differential Revision: D51166528

fbshipit-source-id: ec5ad8ea6f8ad06fb26256545f5a3a5d9f8ebc0f
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Nov 9, 2023
1 parent a3c3eb0 commit df8f543
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
12 changes: 12 additions & 0 deletions tests/utils/test_oom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torchtnt.utils.device import get_device_from_env
from torchtnt.utils.oom import (
_bytes_to_mb_gb,
is_out_of_cpu_memory,
is_out_of_cuda_memory,
is_out_of_memory_error,
Expand Down Expand Up @@ -91,3 +92,14 @@ def test_log_memory_snapshot(self) -> None:

segment_plot_path = os.path.join(save_dir, "segment_plot.html")
self.assertTrue(os.path.exists(segment_plot_path))

def test_bytes_to_mb_gb(self) -> None:
bytes_to_mb_test_cases = [
(0, "0.0 MB"),
(100000, "0.1 MB"),
(1000000, "0.95 MB"),
(1000000000, "0.93 GB"),
(1000000000000, "931.32 GB"),
]
for inp, expected in bytes_to_mb_test_cases:
self.assertEqual(expected, _bytes_to_mb_gb(inp))
9 changes: 8 additions & 1 deletion torchtnt/utils/oom.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def is_out_of_memory_error(exception: BaseException) -> bool:
return is_out_of_cpu_memory(exception) or is_out_of_cuda_memory(exception)


def _bytes_to_mb_gb(num_bytes: int) -> str:
if num_bytes < 1024 * 1024:
return f"{round(num_bytes / (1024 * 1024), 2)} MB"
else:
return f"{round(num_bytes / (1024 * 1024 * 1024), 2)} GB"


def _oom_observer(
output_dir: str,
) -> Callable[[Union[int, torch.device], int, int, int], None]:
Expand All @@ -57,7 +64,7 @@ def oom_logger(
Log memory snapshot in the event of CUDA OOM.
"""
logger.info(
f"Saving memory snapshot device: {device}, alloc: {alloc}, device_alloc: {device_alloc}, device_free: {device_free}"
f"Saving memory snapshot device: {device}, alloc: {_bytes_to_mb_gb(alloc)}, device_alloc: {_bytes_to_mb_gb(device_alloc)}, device_free: {_bytes_to_mb_gb(device_free)}"
)
try:
log_memory_snapshot(output_dir, "oom")
Expand Down

0 comments on commit df8f543

Please sign in to comment.