Skip to content

Commit

Permalink
Update NeptuneLogger (#3165)
Browse files Browse the repository at this point in the history
* Update NeptuneLogger

* better check symlinks

* Update composer/loggers/neptune_logger.py

Co-authored-by: Sabine <[email protected]>

* use progress bar if possible

* simplify imports

* update oom callback

* fix

* fix typing

* code review

* maybe a fix

* Apply suggestions from code review

Co-authored-by: Siddhant Sadangi <[email protected]>

* format

* Update composer/callbacks/oom_observer.py

* Update tests/loggers/test_neptune_logger.py

* Update tests/loggers/test_neptune_logger.py

* Update tests/loggers/test_neptune_logger.py

---------

Co-authored-by: Sabine <[email protected]>
Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Siddhant Sadangi <[email protected]>
  • Loading branch information
4 people authored Apr 15, 2024
1 parent f59d6ef commit 10fd9b8
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 73 deletions.
59 changes: 38 additions & 21 deletions composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
# SPDX-License-Identifier: Apache-2.0

"""Generate a memory snapshot during an OutOfMemory exception."""
from __future__ import annotations

import dataclasses
import logging
import os
import pickle
import warnings
from typing import Optional
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri
Expand All @@ -22,6 +25,29 @@
__all__ = ['OOMObserver']


@dataclass(frozen=True)
class SnapshotFileNameConfig:
"""Configuration for the file names of the memory snapshot visualizations."""
snapshot_file: str
trace_plot_file: str
segment_plot_file: str
segment_flamegraph_file: str
memory_flamegraph_file: str

@classmethod
def from_file_name(cls, filename: str) -> 'SnapshotFileNameConfig':
return cls(
snapshot_file=filename + '_snapshot.pickle',
trace_plot_file=filename + '_trace_plot.html',
segment_plot_file=filename + '_segment_plot.html',
segment_flamegraph_file=filename + '_segment_flamegraph.svg',
memory_flamegraph_file=filename + '_memory_flamegraph.svg',
)

def list_filenames(self) -> List[str]:
return [getattr(self, field.name) for field in dataclasses.fields(self)]


class OOMObserver(Callback):
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception.
Expand Down Expand Up @@ -94,6 +120,8 @@ def __init__(
self._enabled = False
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.')

self.filename_config: Optional[SnapshotFileNameConfig] = None

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
Expand All @@ -117,17 +145,12 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):

assert self.filename
assert self.folder_name, 'folder_name must be set in init'
filename = os.path.join(
self.folder_name,
filename = Path(self.folder_name) / Path(
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp),
)

try:
snapshot_file = filename + '_snapshot.pickle'
trace_plot_file = filename + '_trace_plot.html'
segment_plot_file = filename + '_segment_plot.html'
segment_flamegraph_file = filename + '_segment_flamegraph.svg'
memory_flamegraph_file = filename + '_memory_flamegraph.svg'
self.filename_config = SnapshotFileNameConfig.from_file_name(str(filename))
log.info(f'Dumping OOMObserver visualizations')

snapshot = torch.cuda.memory._snapshot()
Expand All @@ -136,31 +159,25 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
log.info(f'No allocation is recorded in memory snapshot)')
return

with open(snapshot_file, 'wb') as fd:
with open(self.filename_config.snapshot_file, 'wb') as fd:
pickle.dump(snapshot, fd)

with open(trace_plot_file, 'w+') as fd:
with open(self.filename_config.trace_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore

with open(segment_plot_file, 'w+') as fd:
with open(self.filename_config.segment_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore

with open(segment_flamegraph_file, 'w+') as fd:
with open(self.filename_config.segment_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore

with open(memory_flamegraph_file, 'w+') as fd:
with open(self.filename_config.memory_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore

log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM')

if self.remote_path_in_bucket is not None:
for f in [
snapshot_file,
trace_plot_file,
segment_plot_file,
segment_flamegraph_file,
memory_flamegraph_file,
]:
for f in self.filename_config.list_filenames():
base_file_name = os.path.basename(f)
remote_file_name = os.path.join(self.remote_path_in_bucket, base_file_name)
remote_file_name = remote_file_name.lstrip('/') # remove leading slashes
Expand Down
132 changes: 87 additions & 45 deletions composer/loggers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,30 @@
import pathlib
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Union
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Set, Union

import numpy as np
import torch
from packaging.version import Version

from composer._version import __version__
from composer.loggers import LoggerDestination
from composer.utils import MissingConditionalImportError, dist
from composer.utils import MissingConditionalImportError, VersionedDeprecationWarning, dist

if TYPE_CHECKING:
from composer import Logger
from composer.core import State

NEPTUNE_MODE_TYPE = Literal['async', 'sync', 'offline', 'read-only', 'debug']
NEPTUNE_VERSION_WITH_PROGRESS_BAR = Version('1.9.0')


class NeptuneLogger(LoggerDestination):
"""Log to `neptune.ai <https://neptune.ai/>`_.
For more, see the [Neptune-Composer integration guide](https://docs.neptune.ai/integrations/composer/).
For instructions, see the
`integration guide <https://docs.neptune.ai/integrations/mosaicml_composer/>`_.
Args:
project (str, optional): The name of your Neptune project,
Expand All @@ -36,16 +42,15 @@ class NeptuneLogger(LoggerDestination):
You can leave out this argument if you save your token to the
``NEPTUNE_API_TOKEN`` environment variable (recommended).
You can find your API token in the user menu of the Neptune web app.
rank_zero_only (bool, optional): Whether to log only on the rank-zero process.
(default: ``True``).
upload_artifacts (bool, optional): Whether the logger should upload artifacts to Neptune.
rank_zero_only (bool): Whether to log only on the rank-zero process (default: ``True``).
upload_artifacts (bool, optional): Deprecated. See ``upload_checkpoints``.
upload_checkpoints (bool): Whether the logger should upload checkpoints to Neptune
(default: ``False``).
base_namespace (str, optional): The name of the base namespace to log the metadata to.
(default: "training").
base_namespace (str, optional): The name of the base namespace where the metadata
is logged (default: "training").
neptune_kwargs (Dict[str, Any], optional): Any additional keyword arguments to the
``neptune.init_run()`` function. For options, see the
`Run API reference <https://docs.neptune.ai/api/neptune/#init_run>`_ in the
Neptune docs.
`Run API reference <https://docs.neptune.ai/api/neptune/#init_run>`_.
"""
metric_namespace = 'metrics'
hyperparam_namespace = 'hyperparameters'
Expand All @@ -58,8 +63,10 @@ def __init__(
project: Optional[str] = None,
api_token: Optional[str] = None,
rank_zero_only: bool = True,
upload_artifacts: bool = False,
upload_artifacts: Optional[bool] = None,
upload_checkpoints: bool = False,
base_namespace: str = 'training',
mode: Optional[NEPTUNE_MODE_TYPE] = None,
**neptune_kwargs,
) -> None:
try:
Expand All @@ -74,7 +81,8 @@ def __init__(
verify_type('project', project, (str, type(None)))
verify_type('api_token', api_token, (str, type(None)))
verify_type('rank_zero_only', rank_zero_only, bool)
verify_type('upload_artifacts', upload_artifacts, bool)
verify_type('upload_artifacts', upload_artifacts, (bool, type(None)))
verify_type('upload_checkpoints', upload_checkpoints, bool)
verify_type('base_namespace', base_namespace, str)

if not base_namespace:
Expand All @@ -83,15 +91,19 @@ def __init__(
self._project = project
self._api_token = api_token
self._rank_zero_only = rank_zero_only
self._upload_artifacts = upload_artifacts

if upload_artifacts is not None:
_warn_about_deprecated_upload_artifacts()
self._upload_checkpoints = upload_artifacts
else:
self._upload_checkpoints = upload_checkpoints

self._base_namespace = base_namespace
self._neptune_kwargs = neptune_kwargs

mode = self._neptune_kwargs.pop('mode', 'async')

self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0

self._mode = mode if self._enabled else 'debug'
self._mode: Optional[NEPTUNE_MODE_TYPE] = mode if self._enabled else 'debug'

self._neptune_run = None
self._base_handler = None
Expand All @@ -104,17 +116,8 @@ def __init__(
def neptune_run(self):
"""Gets the Neptune run object from a NeptuneLogger instance.
You can log additional metadata to the run by accessing a path inside the run and assigning metadata to it
with "=" or [Neptune logging methods](https://docs.neptune.ai/logging/methods/).
Example:
from composer import Trainer
from composer.loggers import NeptuneLogger
neptune_logger = NeptuneLogger()
trainer = Trainer(loggers=neptune_logger, ...)
trainer.fit()
neptune_logger.neptune_run["some_metric"] = 1
trainer.close()
To log additional metadata to the run, access a path inside the run and assign metadata
with ``=`` or other `Neptune logging methods <https://docs.neptune.ai/logging/methods/>`_.
"""
from neptune import Run

Expand All @@ -131,19 +134,10 @@ def neptune_run(self):
def base_handler(self):
"""Gets a handler for the base logging namespace.
Use the handler to log extra metadata to the run and organize it under the base namespace (default: "training").
You can operate on it like a run object: Access a path inside the handler and assign metadata to it with "=" or
other [Neptune logging methods](https://docs.neptune.ai/logging/methods/).
Example:
from composer import Trainer
from composer.loggers import NeptuneLogger
neptune_logger = NeptuneLogger()
trainer = Trainer(loggers=neptune_logger, ...)
trainer.fit()
neptune_logger.base_handler["some_metric"] = 1
trainer.close()
Result: The value `1` is organized under "training/some_metric" inside the run.
Use the handler to log extra metadata to the run and organize it under the base namespace
(default: "training"). You can operate on it like a run object: Access a path inside the
handler and assign metadata to it with ``=`` or other
`Neptune logging methods <https://docs.neptune.ai/logging/methods/>`_.
"""
return self.neptune_run[self._base_namespace]

Expand Down Expand Up @@ -213,7 +207,7 @@ def log_traces(self, traces: Dict[str, Any]):

def can_upload_files(self) -> bool:
"""Whether the logger supports uploading files."""
return self._enabled and self._upload_artifacts
return self._enabled and self._upload_checkpoints

def upload_file(
self,
Expand All @@ -226,6 +220,9 @@ def upload_file(
if not self.can_upload_files():
return

if file_path.is_symlink() or file_path.suffix.lower() == '.symlink':
return # skip symlinks

neptune_path = f'{self._base_namespace}/{remote_file_name}'
if self.neptune_run.exists(neptune_path) and not overwrite:

Expand All @@ -236,7 +233,11 @@ def upload_file(
return

del state # unused
self.base_handler[remote_file_name].upload(str(file_path))

from neptune.types import File

with open(str(file_path), 'rb') as fp:
self.base_handler[remote_file_name] = File.from_stream(fp, extension=file_path.suffix)

def download_file(
self,
Expand All @@ -245,7 +246,6 @@ def download_file(
overwrite: bool = False,
progress_bar: bool = True,
):
del progress_bar # not supported

if not self._enabled:
return
Expand All @@ -266,7 +266,11 @@ def download_file(
if not self.neptune_run.exists(file_path):
raise FileNotFoundError(f'File {file_path} not found')

self.base_handler[remote_file_name].download(destination=destination)
if _is_progress_bar_enabled():
self.base_handler[remote_file_name].download(destination=destination, progress_bar=progress_bar)
else:
del progress_bar
self.base_handler[remote_file_name].download(destination=destination)

def log_images(
self,
Expand Down Expand Up @@ -312,4 +316,42 @@ def _validate_image(img: Union[np.ndarray, torch.Tensor], channels_last: bool) -
if not channels_last:
img_numpy = np.moveaxis(img_numpy, 0, -1)

return img_numpy
return _validate_image_value_range(img_numpy)


def _validate_image_value_range(img: np.ndarray) -> np.ndarray:
array_min = img.min()
array_max = img.max()

if (array_min >= 0 and 1 < array_max <= 255) or (array_min >= 0 and array_max <= 1):
return img

from neptune.common.warnings import NeptuneWarning, warn_once

warn_once(
'Image value range is not in the expected range of [0.0, 1.0] or [0, 255]. '
'This might be due to the presence of `transforms.Normalize` in the data pipeline. '
'Logged images may not display correctly in Neptune.',
exception=NeptuneWarning,
)

return _scale_image_to_0_255(img, array_min, array_max)


def _scale_image_to_0_255(img: np.ndarray, array_min: Union[int, float], array_max: Union[int, float]) -> np.ndarray:
scaled_image = 255 * (img - array_min) / (array_max - array_min)
return scaled_image.astype(np.uint8)


def _warn_about_deprecated_upload_artifacts() -> None:
warnings.warn(
VersionedDeprecationWarning(
'The \'upload_artifacts\' parameter is deprecated and will be removed in the next version. '
'Use the \'upload_checkpoints\' parameter instead.',
remove_version='0.23',
),
)


def _is_progress_bar_enabled() -> bool:
return Version(version('neptune')) >= NEPTUNE_VERSION_WITH_PROGRESS_BAR
3 changes: 3 additions & 0 deletions docs/source/doctest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
# Disable wandb
os.environ['WANDB_MODE'] = 'disabled'

# Disable neptune
os.environ['NEPTUNE_MODE'] = 'debug'

# Change the cwd to be the tempfile, so we don't pollute the documentation source folder
tmpdir = tempfile.mkdtemp()
cwd = os.path.abspath('.')
Expand Down
Loading

0 comments on commit 10fd9b8

Please sign in to comment.