-
Notifications
You must be signed in to change notification settings - Fork 415
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add memory snapshot callback (#2788)
* add memory snapshot callback * fix check * fix check * Update composer/callbacks/memory_snapshot.py Co-authored-by: Charles Tang <[email protected]> * address comments * fix upload filename print * fix cpu check * fix cpu check * add pt version check * add pt version check * fix remote upload * fix test * fix cpu test * fix gpu test * fix gpu test * fix gpu test * fix gpu test * fix gpu test * do plotting before saving * fix test * fix test * fix test --------- Co-authored-by: Charles Tang <[email protected]> Co-authored-by: Mihir Patel <[email protected]>
- Loading branch information
1 parent
4db3230
commit bf2c408
Showing
6 changed files
with
274 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Log memory snapshot during training.""" | ||
import logging | ||
import os | ||
import warnings | ||
from typing import Optional, Union | ||
|
||
import torch.cuda | ||
from packaging import version | ||
|
||
from composer import State | ||
from composer.core import Callback, State, Time, TimeUnit | ||
from composer.loggers import Logger | ||
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
__all__ = ['MemorySnapshot'] | ||
|
||
|
||
class MemorySnapshot(Callback): | ||
"""Logs the memory snapshot of the model. | ||
This callback calls the torch memory snapshot API (see :func:`torch.cuda.memory._snapshot`) to record a model's tensor memory allocation over a user defined interval (only once through time [skip_batches, skip_batches + interval]). This provides a fine-grained GPU memory visualization for debugging GPU OOMs. Captured memory snapshots will show memory events including allocations, frees and OOMs, along with their stack traces over one interval. | ||
Example: | ||
.. doctest:: | ||
>>> from composer import Trainer | ||
>>> from composer.callbacks import MemorySnapshot | ||
>>> # constructing trainer object with this callback | ||
>>> trainer = Trainer( | ||
... model=model, | ||
... train_dataloader=train_dataloader, | ||
... eval_dataloader=eval_dataloader, | ||
... optimizers=optimizer, | ||
... max_duration="1ep", | ||
... callbacks=[MemorySnapshot()], | ||
... ) | ||
.. note:: | ||
Memory snapshot is only supported for GPU devices. | ||
Args: | ||
skip_batches (int, optional): Number of batches to skip before starting recording memory snapshot. Defaults to 1. | ||
interval (Union[int, str, Time], optional): Time string specifying how long to record the tensor allocation. | ||
For example, ``interval='3ba'`` means 3 batches are recorded. Default: '3ba'. | ||
max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000. | ||
folder (str, optional): A format string describing the folder containing the memory snapshot files. | ||
Defaults to ``'{{run_name}}/torch_traces'``. | ||
filename (str, optional): A format string describing how to name the memory snapshot files. | ||
Defaults to ``'rank{{rank}}.{{batch}}.pickle'``. | ||
remote_file_name (str, optional): A format string for the memory snapshot remote file name. | ||
Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pickle'``. | ||
Whenever a trace file is saved, it is also uploaded as a file according to this format string. | ||
The same format variables as for ``filename`` are available. | ||
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading. | ||
Leading slashes (``'/'``) will be stripped. | ||
To disable uploading trace files, set this parameter to ``None``. | ||
overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False. | ||
If False, then the trace folder as determined by ``folder`` must be empty. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
skip_batches: int = 1, | ||
interval: Union[int, str, Time] = '3ba', | ||
max_entries: int = 100000, | ||
folder: str = '{run_name}/torch_traces', | ||
filename: str = 'rank{rank}.{batch}.pt.trace.memory_snapshot.html', | ||
remote_file_name: Optional[ | ||
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory_snapshot.html', | ||
overwrite: bool = False, | ||
) -> None: | ||
self.batches_left_to_skip = skip_batches | ||
# Check that the interval timestring is parsable and convert into time object | ||
self.interval = Time.from_input(interval, TimeUnit.BATCH) | ||
self.max_entries = max_entries | ||
self.folder = folder | ||
self.folder_name = None | ||
self.filename = filename | ||
self.remote_file_name = remote_file_name | ||
self.overwrite = overwrite | ||
self._start_time = None | ||
if remote_file_name: | ||
self.remote_file_name = remote_file_name | ||
_, _, self.remote_path_in_bucket = parse_uri(remote_file_name) | ||
else: | ||
self.remote_path_in_bucket = None | ||
|
||
if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore | ||
# memory snapshot is only supported in torch v2.1.0-rc1 or higher | ||
self._enabled = True | ||
else: | ||
self._enabled = False | ||
log.warning('Memory snapshot is supported after PyTorch 2.1.0. Skipping memory snapshot callback.') | ||
|
||
def init(self, state: State, logger: Logger) -> None: | ||
if not self._enabled: | ||
return | ||
# Not relying on `torch.cuda.is_available()` since the model could be on CPU. | ||
model_device = next(state.model.parameters()).device | ||
|
||
if model_device.type not in ('cuda', 'meta'): | ||
warnings.warn(f'The memory snapshot only works on CUDA devices, but the model is on {model_device.type}.') | ||
self._enabled = False | ||
else: | ||
self.folder_name = format_name_with_dist(self.folder, state.run_name) | ||
os.makedirs(self.folder_name, exist_ok=True) | ||
if not self.overwrite: | ||
ensure_folder_is_empty(self.folder_name) | ||
|
||
def batch_start(self, state: State, logger: Logger) -> None: | ||
if self._enabled and self._start_time is None and self.batches_left_to_skip == 0: | ||
self.start_record_memory_history() | ||
self._start_time = state.timestamp.get(self.interval.unit).value | ||
|
||
def batch_end(self, state: State, logger: Logger) -> None: | ||
if not self._enabled: | ||
return | ||
|
||
if self.batches_left_to_skip > 0: | ||
self.batches_left_to_skip -= 1 | ||
return | ||
assert self._start_time is not None | ||
|
||
if state.timestamp.get(self.interval.unit).value == (self._start_time + self.interval.value): | ||
self.export_memory_snapshot(state, logger) | ||
self.stop_record_memory_history() | ||
self._start_time = None | ||
self._enabled = False | ||
|
||
def start_record_memory_history(self) -> None: | ||
|
||
log.info('Starting snapshot record_memory_history') | ||
torch.cuda.memory._record_memory_history( | ||
True, # type: ignore | ||
trace_alloc_max_entries=self.max_entries, | ||
trace_alloc_record_context=True) | ||
|
||
def stop_record_memory_history(self) -> None: | ||
|
||
log.info('Stopping snapshot record_memory_history') | ||
torch.cuda.memory._record_memory_history(False) # type: ignore | ||
|
||
def export_memory_snapshot(self, state: State, logger: Logger) -> None: | ||
assert self.filename | ||
assert self.folder_name, 'folder_name must be set in init' | ||
filename = os.path.join( | ||
self.folder_name, | ||
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp)) | ||
try: | ||
log.info(f'Saving memory snapshot to local file: {filename}') | ||
snapshot = torch.cuda.memory._snapshot() | ||
# No data was recorded - avoids a `ValueError` in `trace_plot` | ||
if all(len(t) == 0 for t in snapshot['device_traces']): | ||
log.info(f'No allocation is recorded in memory snapshot)') | ||
return | ||
with open(filename, 'w+') as fd: | ||
fd.write(torch.cuda._memory_viz.trace_plot(snapshot, device=None, plot_segments=False)) # type: ignore | ||
except Exception as e: | ||
log.error(f'Failed to capture memory snapshot {e}') | ||
return | ||
if self.remote_path_in_bucket is not None: | ||
remote_file_name = format_name_with_dist_and_time(self.remote_path_in_bucket, | ||
run_name=state.run_name, | ||
timestamp=state.timestamp) | ||
remote_file_name = remote_file_name.lstrip('/') | ||
log.info(f'Uploading memory snapshot to remote: {remote_file_name} from {filename}') | ||
try: | ||
logger.upload_file(remote_file_name=remote_file_name, file_path=filename, overwrite=self.overwrite) | ||
except FileExistsError as e: | ||
raise FileExistsError( | ||
f'Uploading memory snapshot failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory snapshot with Trainer, set save_overwrite to True.' | ||
) from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pathlib | ||
|
||
import pytest | ||
import torch | ||
from packaging import version | ||
from torch.utils.data import DataLoader | ||
|
||
from composer import State, Trainer | ||
from composer.callbacks import MemorySnapshot | ||
from composer.loggers import LoggerDestination | ||
from composer.trainer import Trainer | ||
from tests.common import RandomClassificationDataset, SimpleModel, device | ||
|
||
|
||
@device('cpu', 'gpu') | ||
def test_memory_snapshot_warnings_on_cpu_models(device: str): | ||
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): | ||
# memory snapshot is supported after PyTorch 2.1.0. | ||
return | ||
# Error if the user sets device=cpu even when cuda is available | ||
del device # unused. always using cpu | ||
with pytest.warns(UserWarning, match='The memory snapshot only works on CUDA devices'): | ||
Trainer( | ||
model=SimpleModel(), | ||
callbacks=MemorySnapshot(), | ||
device='cpu', | ||
train_dataloader=DataLoader(RandomClassificationDataset()), | ||
max_duration='1ba', | ||
) | ||
|
||
|
||
class FileUploaderTracker(LoggerDestination): | ||
|
||
def __init__(self) -> None: | ||
self.uploaded_files = [] | ||
|
||
def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool): | ||
del state, overwrite # unused | ||
self.uploaded_files.append((remote_file_name, file_path)) | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.parametrize('interval', ['1ba']) | ||
def test_memory_snapshot(interval: str): | ||
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): | ||
# memory snapshot is supported after PyTorch 2.1.0. | ||
return | ||
# Construct the callbacks | ||
skip_batches = 0 | ||
memory_snapshot = MemorySnapshot(skip_batches=skip_batches, interval=interval) | ||
|
||
simple_model = SimpleModel() | ||
|
||
file_tracker_destination = FileUploaderTracker() | ||
|
||
# Construct the trainer and train | ||
trainer = Trainer( | ||
model=simple_model, | ||
loggers=file_tracker_destination, | ||
callbacks=memory_snapshot, | ||
train_dataloader=DataLoader(RandomClassificationDataset()), | ||
max_duration='2ba', | ||
) | ||
trainer.fit() | ||
assert len(file_tracker_destination.uploaded_files) == 1 | ||
trainer.close() |