Skip to content

Commit

Permalink
add memory snapshot callback (#2788)
Browse files Browse the repository at this point in the history
* add memory snapshot callback

* fix check

* fix check

* Update composer/callbacks/memory_snapshot.py

Co-authored-by: Charles Tang <[email protected]>

* address comments

* fix upload filename print

* fix cpu check

* fix cpu check

* add pt version check

* add pt version check

* fix remote upload

* fix test

* fix cpu test

* fix gpu test

* fix gpu test

* fix gpu test

* fix gpu test

* fix gpu test

* do plotting before saving

* fix test

* fix test

* fix test

---------

Co-authored-by: Charles Tang <[email protected]>
Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
3 people committed Feb 2, 2024
1 parent 4db3230 commit bf2c408
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 4 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.memory_snapshot import MemorySnapshot
from composer.callbacks.mlperf import MLPerfCallback
from composer.callbacks.nan_monitor import NaNMonitor
from composer.callbacks.optimizer_monitor import OptimizerMonitor
Expand All @@ -42,4 +43,5 @@
'SystemMetricsMonitor',
'Generate',
'FreeOutputs',
'MemorySnapshot',
]
182 changes: 182 additions & 0 deletions composer/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log memory snapshot during training."""
import logging
import os
import warnings
from typing import Optional, Union

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State, Time, TimeUnit
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri

log = logging.getLogger(__name__)

__all__ = ['MemorySnapshot']


class MemorySnapshot(Callback):
"""Logs the memory snapshot of the model.
This callback calls the torch memory snapshot API (see :func:`torch.cuda.memory._snapshot`) to record a model's tensor memory allocation over a user defined interval (only once through time [skip_batches, skip_batches + interval]). This provides a fine-grained GPU memory visualization for debugging GPU OOMs. Captured memory snapshots will show memory events including allocations, frees and OOMs, along with their stack traces over one interval.
Example:
.. doctest::
>>> from composer import Trainer
>>> from composer.callbacks import MemorySnapshot
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[MemorySnapshot()],
... )
.. note::
Memory snapshot is only supported for GPU devices.
Args:
skip_batches (int, optional): Number of batches to skip before starting recording memory snapshot. Defaults to 1.
interval (Union[int, str, Time], optional): Time string specifying how long to record the tensor allocation.
For example, ``interval='3ba'`` means 3 batches are recorded. Default: '3ba'.
max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000.
folder (str, optional): A format string describing the folder containing the memory snapshot files.
Defaults to ``'{{run_name}}/torch_traces'``.
filename (str, optional): A format string describing how to name the memory snapshot files.
Defaults to ``'rank{{rank}}.{{batch}}.pickle'``.
remote_file_name (str, optional): A format string for the memory snapshot remote file name.
Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pickle'``.
Whenever a trace file is saved, it is also uploaded as a file according to this format string.
The same format variables as for ``filename`` are available.
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading.
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False.
If False, then the trace folder as determined by ``folder`` must be empty.
"""

def __init__(
self,
skip_batches: int = 1,
interval: Union[int, str, Time] = '3ba',
max_entries: int = 100000,
folder: str = '{run_name}/torch_traces',
filename: str = 'rank{rank}.{batch}.pt.trace.memory_snapshot.html',
remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory_snapshot.html',
overwrite: bool = False,
) -> None:
self.batches_left_to_skip = skip_batches
# Check that the interval timestring is parsable and convert into time object
self.interval = Time.from_input(interval, TimeUnit.BATCH)
self.max_entries = max_entries
self.folder = folder
self.folder_name = None
self.filename = filename
self.remote_file_name = remote_file_name
self.overwrite = overwrite
self._start_time = None
if remote_file_name:
self.remote_file_name = remote_file_name
_, _, self.remote_path_in_bucket = parse_uri(remote_file_name)
else:
self.remote_path_in_bucket = None

if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore
# memory snapshot is only supported in torch v2.1.0-rc1 or higher
self._enabled = True
else:
self._enabled = False
log.warning('Memory snapshot is supported after PyTorch 2.1.0. Skipping memory snapshot callback.')

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
model_device = next(state.model.parameters()).device

if model_device.type not in ('cuda', 'meta'):
warnings.warn(f'The memory snapshot only works on CUDA devices, but the model is on {model_device.type}.')
self._enabled = False
else:
self.folder_name = format_name_with_dist(self.folder, state.run_name)
os.makedirs(self.folder_name, exist_ok=True)
if not self.overwrite:
ensure_folder_is_empty(self.folder_name)

def batch_start(self, state: State, logger: Logger) -> None:
if self._enabled and self._start_time is None and self.batches_left_to_skip == 0:
self.start_record_memory_history()
self._start_time = state.timestamp.get(self.interval.unit).value

def batch_end(self, state: State, logger: Logger) -> None:
if not self._enabled:
return

if self.batches_left_to_skip > 0:
self.batches_left_to_skip -= 1
return
assert self._start_time is not None

if state.timestamp.get(self.interval.unit).value == (self._start_time + self.interval.value):
self.export_memory_snapshot(state, logger)
self.stop_record_memory_history()
self._start_time = None
self._enabled = False

def start_record_memory_history(self) -> None:

log.info('Starting snapshot record_memory_history')
torch.cuda.memory._record_memory_history(
True, # type: ignore
trace_alloc_max_entries=self.max_entries,
trace_alloc_record_context=True)

def stop_record_memory_history(self) -> None:

log.info('Stopping snapshot record_memory_history')
torch.cuda.memory._record_memory_history(False) # type: ignore

def export_memory_snapshot(self, state: State, logger: Logger) -> None:
assert self.filename
assert self.folder_name, 'folder_name must be set in init'
filename = os.path.join(
self.folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp))
try:
log.info(f'Saving memory snapshot to local file: {filename}')
snapshot = torch.cuda.memory._snapshot()
# No data was recorded - avoids a `ValueError` in `trace_plot`
if all(len(t) == 0 for t in snapshot['device_traces']):
log.info(f'No allocation is recorded in memory snapshot)')
return
with open(filename, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot, device=None, plot_segments=False)) # type: ignore
except Exception as e:
log.error(f'Failed to capture memory snapshot {e}')
return
if self.remote_path_in_bucket is not None:
remote_file_name = format_name_with_dist_and_time(self.remote_path_in_bucket,
run_name=state.run_name,
timestamp=state.timestamp)
remote_file_name = remote_file_name.lstrip('/')
log.info(f'Uploading memory snapshot to remote: {remote_file_name} from {filename}')
try:
logger.upload_file(remote_file_name=remote_file_name, file_path=filename, overwrite=self.overwrite)
except FileExistsError as e:
raise FileExistsError(
f'Uploading memory snapshot failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory snapshot with Trainer, set save_overwrite to True.'
) from e
16 changes: 14 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import Metric

from composer.callbacks import CheckpointSaver, OptimizerMonitor
from composer.callbacks import CheckpointSaver, MemorySnapshot, OptimizerMonitor
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator,
ensure_time, get_precision_context, validate_eval_automicrobatching)
Expand Down Expand Up @@ -1072,6 +1072,15 @@ def __init__(
loggers.append(remote_ud)
self.state.profiler.bind_to_state(self.state)

# MemorySnapshot
for cb in self.state.callbacks:
if isinstance(cb, MemorySnapshot):
if cb.remote_file_name:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(uri=cb.remote_file_name,
loggers=loggers)
if remote_ud is not None:
loggers.append(remote_ud)

if progress_bar and log_to_console:
warnings.warn(
'Setting both `progress_bar` and `log_to_console` both to True is not recommended and will'
Expand Down Expand Up @@ -1215,7 +1224,10 @@ def __init__(

# Log hparams.
if self.auto_log_hparams:
self.local_hparams = extract_hparams(locals())
locs = locals()
if 'cb' in locs:
del locs['cb']
self.local_hparams = extract_hparams(locs)
self.logger.log_hyperparameters(self.local_hparams)

# Log composer version
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ components of training.
~lr_monitor.LRMonitor
~optimizer_monitor.OptimizerMonitor
~memory_monitor.MemoryMonitor
~memory_snapshot.MemorySnapshot
~nan_monitor.NaNMonitor
~image_visualizer.ImageVisualizer
~mlperf.MLPerfCallback
Expand Down
8 changes: 6 additions & 2 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import composer.profiler
from composer import Callback
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, HealthChecker,
ImageVisualizer, MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor,
ThresholdStopper)
ImageVisualizer, MemoryMonitor, MemorySnapshot, MLPerfCallback, SpeedMonitor,
SystemMetricsMonitor, ThresholdStopper)
from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
Expand Down Expand Up @@ -128,6 +128,10 @@
pytest.mark.filterwarnings(
r'ignore:The memory monitor only works on CUDA devices, but the model is on cpu:UserWarning')
],
MemorySnapshot: [
pytest.mark.filterwarnings(
r'ignore:The memory snapshot only works on CUDA devices, but the model is on cpu:UserWarning')
],
MLPerfCallback: [pytest.mark.skipif(not _MLPERF_INSTALLED, reason='MLPerf is optional')],
WandBLogger: [
pytest.mark.filterwarnings(r'ignore:unclosed file:ResourceWarning'),
Expand Down
69 changes: 69 additions & 0 deletions tests/callbacks/test_memory_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pathlib

import pytest
import torch
from packaging import version
from torch.utils.data import DataLoader

from composer import State, Trainer
from composer.callbacks import MemorySnapshot
from composer.loggers import LoggerDestination
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel, device


@device('cpu', 'gpu')
def test_memory_snapshot_warnings_on_cpu_models(device: str):
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
# memory snapshot is supported after PyTorch 2.1.0.
return
# Error if the user sets device=cpu even when cuda is available
del device # unused. always using cpu
with pytest.warns(UserWarning, match='The memory snapshot only works on CUDA devices'):
Trainer(
model=SimpleModel(),
callbacks=MemorySnapshot(),
device='cpu',
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='1ba',
)


class FileUploaderTracker(LoggerDestination):

def __init__(self) -> None:
self.uploaded_files = []

def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool):
del state, overwrite # unused
self.uploaded_files.append((remote_file_name, file_path))


@pytest.mark.gpu
@pytest.mark.parametrize('interval', ['1ba'])
def test_memory_snapshot(interval: str):
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
# memory snapshot is supported after PyTorch 2.1.0.
return
# Construct the callbacks
skip_batches = 0
memory_snapshot = MemorySnapshot(skip_batches=skip_batches, interval=interval)

simple_model = SimpleModel()

file_tracker_destination = FileUploaderTracker()

# Construct the trainer and train
trainer = Trainer(
model=simple_model,
loggers=file_tracker_destination,
callbacks=memory_snapshot,
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)
trainer.fit()
assert len(file_tracker_destination.uploaded_files) == 1
trainer.close()

0 comments on commit bf2c408

Please sign in to comment.