Skip to content

Commit

Permalink
Add start_step to MemorySnapshotProfiler (#610)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #610

Allows user to configure a `start_step` for when to start recording memory history. Addressing YanjunChen329 's feedback that he would like to record memory history for a specific range of steps, e.g. step 500-503.

Also:
* improve docstrings
* add validation for params
* MemorySnapshot callback no longer needs to call `start` and `stop`

Reviewed By: YanjunChen329

Differential Revision: D51048945

fbshipit-source-id: 2848113f305253e394ba2abbf04715388558f0d7
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Nov 7, 2023
1 parent 25fa5ac commit 83cdb29
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 16 deletions.
59 changes: 57 additions & 2 deletions tests/utils/test_memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def test_stop_step(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
memory_snapshot_profiler = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(stop_step=2),
memory_snapshot_params=MemorySnapshotParams(start_step=0, stop_step=2),
)
memory_snapshot_profiler.start()

# initialize device & allocate memory for tensors
device = get_device_from_env()
Expand All @@ -64,3 +63,59 @@ def test_stop_step(self) -> None:
self.assertTrue(os.path.exists(pickle_dump_path))
self.assertTrue(os.path.exists(trace_path))
self.assertTrue(os.path.exists(segment_plot_path))

@unittest.skipUnless(
condition=torch_version_geq_2_0,
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_validation(self) -> None:
"""Test parameter validation."""
with tempfile.TemporaryDirectory() as temp_dir:
with self.assertRaisesRegex(ValueError, "start_step must be nonnegative."):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=-1, stop_step=0
),
)
with self.assertRaisesRegex(
ValueError, "stop_step must be specified when start_step is set."
):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=2, stop_step=None
),
)
with self.assertRaisesRegex(ValueError, "start_step must be < stop_step."):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=2, stop_step=0
),
)
with self.assertRaisesRegex(ValueError, "stop_step must be positive."):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(stop_step=0),
)
with self.assertRaisesRegex(
ValueError,
"stop_step must be enabled with either start_step or enable_oom_observer.",
):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
stop_step=2, enable_oom_observer=False
),
)
with self.assertRaisesRegex(
ValueError,
"At least one of start_step/stop_step or enable_oom_observer must be set.",
):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=None, stop_step=None, enable_oom_observer=False
),
)
14 changes: 1 addition & 13 deletions torchtnt/framework/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.memory_snapshot_profiler import (
MemorySnapshotParams,
Expand Down Expand Up @@ -42,24 +42,12 @@ def __init__(
self.memory_snapshot_profiler = MemorySnapshotProfiler(
output_dir=output_dir, memory_snapshot_params=memory_snapshot_params
)
self.memory_snapshot_profiler.start()

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self.memory_snapshot_profiler.step()

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self.memory_snapshot_profiler.stop()

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
self.memory_snapshot_profiler.step()

def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
# if in fit do nothing since the profiler will be stopped in on_train_end
if state.entry_point == EntryPoint.EVALUATE:
self.memory_snapshot_profiler.stop()

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
self.memory_snapshot_profiler.step()

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
self.memory_snapshot_profiler.stop()
54 changes: 53 additions & 1 deletion torchtnt/utils/memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ class MemorySnapshotParams:
Memory snapshot parameters.
Args:
stop_step: Number of steps after which to dump memory snapshot, and stop recording memory history.
start_step: Step from which to start recording memory history.
stop_step: Step after which to dump memory snapshot, and stop recording memory history.
max_entries: Maximum number of events to keep in memory.
enable_oom_observer: Whether to attach an observer to record OOM events. If stop_step is set, the
OOM observer will only be active until stop_step is reached.
Note: If you set enable_oom_observer to True, you don't necessarily have to set a start_step as attach_oom_observer
will start recording memory history. Note that if you don't set a stop_step, it will continue recording memory
history until the program exits, which may incur a slight performance cost.
"""

start_step: Optional[int] = None
stop_step: Optional[int] = 2
max_entries: int = 100000
enable_oom_observer: bool = True
Expand All @@ -44,6 +50,20 @@ class MemorySnapshotProfiler:
Args:
output_dir: Directory where to save the memory snapshots.
memory_snapshot_params: Instance of MemorySnapshotParams.
Raises:
ValueError: If `start_step` is negative, or `stop_step` is less than or equal to zero.
ValueError: If `start_step` is greater than or equal to `stop_step`.
ValueError: If `start_step` is set and `stop_step` is not set.
ValueError: If `stop_step` is set and neither `start_step` nor `enable_oom_observer` are set.
ValueError: If `enable_oom_observer` is False and neither `start_step` nor `stop_step` is set
Examples::
memory_snapshot_params = MemorySnapshotParams(start_step=5, stop_step=10, enable_oom_observer=True)
memory_snapshot_profiler = MemorySnapshotProfiler(output_dir="/tmp", memory_snapshot_params=memory_snapshot_params)
for batch in dataloader:
...
memory_snapshot_profiler.step()
"""

def __init__(
Expand All @@ -55,6 +75,31 @@ def __init__(
self.params: MemorySnapshotParams = (
memory_snapshot_params or MemorySnapshotParams()
)
start_step = self.params.start_step
stop_step = self.params.stop_step
if start_step is not None:
if start_step < 0:
raise ValueError("start_step must be nonnegative.")
elif stop_step is None:
raise ValueError("stop_step must be specified when start_step is set.")
elif start_step >= stop_step:
raise ValueError("start_step must be < stop_step.")
if stop_step is not None:
if stop_step <= 0:
raise ValueError("stop_step must be positive.")
elif start_step is None and not self.params.enable_oom_observer:
raise ValueError(
"stop_step must be enabled with either start_step or enable_oom_observer."
)
if (
start_step is None
and stop_step is None
and not self.params.enable_oom_observer
):
raise ValueError(
"At least one of start_step/stop_step or enable_oom_observer must be set."
)

self.step_num: int = 0

if not is_torch_version_geq_2_0():
Expand All @@ -63,6 +108,8 @@ def __init__(
attach_oom_observer(
output_dir=output_dir, trace_max_entries=self.params.max_entries
)
if start_step is not None and start_step == 0:
self.start()

logger.info(
f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}."
Expand Down Expand Up @@ -97,6 +144,11 @@ def stop(self) -> None:

def step(self) -> None:
self.step_num += 1
if (
self.params.start_step is not None
and self.step_num == self.params.start_step
):
self.start()
if self.params.stop_step is not None and self.step_num == self.params.stop_step:
log_memory_snapshot(output_dir=self.output_dir)
self.stop()

0 comments on commit 83cdb29

Please sign in to comment.