Skip to content

Commit

Permalink
Enable memory snapshot support upload to manifold and zoomer (#709)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #709

This change adds the support to upload memory snapshot to manifold and shown in zoomer with following changes:
1. Add a zoomer specific memory snapshot profiler wrapper;
2. Internally call the memory_snapshot API from `unitrace`.

Reviewed By: aaronenyeshi

Differential Revision: D53997537

fbshipit-source-id: 2af6cec9cba64f43c6321e4efd497b373db10bd3
  • Loading branch information
yoyoyocmu authored and facebook-github-bot committed Feb 22, 2024
1 parent 593180e commit 8790fc4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 27 deletions.
7 changes: 4 additions & 3 deletions tests/framework/callbacks/test_memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

from torchtnt.framework.callbacks.memory_snapshot import MemorySnapshot
from torchtnt.framework.state import EntryPoint
from torchtnt.utils.memory_snapshot_profiler import MemorySnapshotProfiler


class TestMemorySnapshot(unittest.TestCase):
def test_on_train_step_end(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
memory_snapshot = MemorySnapshot(
output_dir=temp_dir,
memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir),
)
memory_snapshot.memory_snapshot_profiler = Mock()

Expand All @@ -28,7 +29,7 @@ def test_on_train_step_end(self) -> None:
def test_on_eval_step_end(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
memory_snapshot = MemorySnapshot(
output_dir=temp_dir,
memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir),
)
memory_snapshot.memory_snapshot_profiler = Mock()

Expand All @@ -41,7 +42,7 @@ def test_on_eval_step_end(self) -> None:
def test_on_predict_step_end(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
memory_snapshot = MemorySnapshot(
output_dir=temp_dir,
memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir),
)
memory_snapshot.memory_snapshot_profiler = Mock()

Expand Down
16 changes: 4 additions & 12 deletions torchtnt/framework/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.memory_snapshot_profiler import (
MemorySnapshotParams,
MemorySnapshotProfiler,
)
from torchtnt.utils.memory_snapshot_profiler import MemorySnapshotProfilerBase

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -25,8 +21,7 @@ class MemorySnapshot(Callback):
Uses `Memory Snapshots <https://zdevito.github.io/2022/08/16/memory-snapshots.html>`.
Args:
output_dir: Directory where to save the memory snapshots.
memory_snapshot_params: Instance of MemorySnapshotParams which will be passed to MemorySnapshotProfiler.
memory_snapshot_profiler: Instance of MemorySnapshotProfilerBase, controls when and where to save the memory snapshots.
Note: It is recommended to instantiate this callback **as early as possible** in your training/eval/prediction script,
ideally before model initialization, to make sure all memory allocation is captured.
Expand All @@ -36,12 +31,9 @@ class MemorySnapshot(Callback):
def __init__(
self,
*,
output_dir: str,
memory_snapshot_params: Optional[MemorySnapshotParams] = None,
memory_snapshot_profiler: MemorySnapshotProfilerBase,
) -> None:
self.memory_snapshot_profiler = MemorySnapshotProfiler(
output_dir=output_dir, memory_snapshot_params=memory_snapshot_params
)
self.memory_snapshot_profiler = memory_snapshot_profiler

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self.memory_snapshot_profiler.step()
Expand Down
44 changes: 32 additions & 12 deletions torchtnt/utils/memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import TracebackType
from typing import Optional, Type
Expand Down Expand Up @@ -39,7 +40,36 @@ class MemorySnapshotParams:
enable_oom_observer: bool = True


class MemorySnapshotProfiler:
class MemorySnapshotProfilerBase(ABC):
"""
Base class for memory snapshot profiler.
"""

def __enter__(self) -> None:
self.start()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
self.stop()

@abstractmethod
def start(self) -> None:
...

@abstractmethod
def stop(self) -> None:
...

@abstractmethod
def step(self) -> None:
...


class MemorySnapshotProfiler(MemorySnapshotProfilerBase):
"""
Records a history of memory allocation and free events, and dumps to a
file which can be visualized offline. It by default keeps track of
Expand Down Expand Up @@ -71,6 +101,7 @@ def __init__(
output_dir: str,
memory_snapshot_params: Optional[MemorySnapshotParams] = None,
) -> None:
super().__init__()
self.output_dir: str = output_dir
self.params: MemorySnapshotParams = (
memory_snapshot_params or MemorySnapshotParams()
Expand Down Expand Up @@ -115,17 +146,6 @@ def __init__(
f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}."
)

def __enter__(self) -> None:
self.start()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
self.stop()

def start(self) -> None:
if not torch.cuda.is_available():
logger.warn("CUDA unavailable. Not recording memory history.")
Expand Down

0 comments on commit 8790fc4

Please sign in to comment.