Skip to content

Commit

Permalink
Refactor save interval and eval interval to share code (mosaicml#2600)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 5, 2023
1 parent 4934aa5 commit 70422ab
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 148 deletions.
6 changes: 3 additions & 3 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from pathlib import Path
from typing import Callable, List, Optional, Union

from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.loggers import Logger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_symlink_file, dist, ensure_folder_has_no_conflicting_files,
format_name_with_dist, format_name_with_dist_and_time, is_model_deepspeed, reproducibility)
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, reproducibility)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
from composer.utils.misc import using_torch_2

Expand Down
3 changes: 1 addition & 2 deletions composer/callbacks/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from typing import Any, List, Optional, Union, cast

from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time, get_precision_context
from composer.loggers import Logger
from composer.models import HuggingFaceModel
from composer.utils import dist
from composer.utils import create_interval_scheduler, dist
from composer.utils.import_helpers import MissingConditionalImportError


Expand Down
65 changes: 22 additions & 43 deletions composer/callbacks/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utilities for callbacks."""
"""Callback utils."""

import math
from typing import Callable, Union
import warnings
from typing import Callable, Optional, Set, Union

from composer.core import Event, State, Time, TimeUnit
from composer.core import Event, State, Time
from composer.utils.misc import create_interval_scheduler as _create_interval_scheduler


def create_interval_scheduler(interval: Union[str, int, Time],
include_end_of_training=True) -> Callable[[State, Event], bool]:
include_end_of_training: bool = True,
checkpoint_events: bool = True,
final_events: Optional[Set[Event]] = None) -> Callable[[State, Event], bool]:
"""Helper function to create a scheduler according to a specified interval.
Args:
Expand All @@ -19,46 +22,22 @@ def create_interval_scheduler(interval: Union[str, int, Time],
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
include_end_of_training (bool): If true, the returned callable will return true at the end of training as well.
Otherwise, the returned callable will return true at intervals only.
checkpoint_events (bool): If true, will use the EPOCH_CHECKPOINT and BATCH_CHECKPOINT events. If False, will use
the EPOCH_END and BATCH_END events.
final_events (Optional[Set[Event]]): The set of events to trigger on at the end of training.
Returns:
Callable[[State, Event], bool]: A function that returns true at interval and at the end of training if specified.
For example, it can be passed as the ``save_interval`` argument into the :class:`.CheckpointSaver`.
"""
if isinstance(interval, str):
interval = Time.from_timestring(interval)
if isinstance(interval, int):
interval = Time(interval, TimeUnit.EPOCH)

if interval.unit == TimeUnit.EPOCH:
save_event = Event.EPOCH_CHECKPOINT
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
save_event = Event.BATCH_CHECKPOINT
else:
raise NotImplementedError(
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

def check_interval(state: State, event: Event):
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

if include_end_of_training and elapsed_duration >= 1.0:
return True

# previous timestamp will only be None if training has not started, but we are returning False
# in this case, just to be safe
if state.previous_timestamp is None:
return False

if interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
previous_count = state.previous_timestamp.get(interval.unit)
count = state.timestamp.get(interval.unit)
else:
raise NotImplementedError(
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

threshold_passed = math.floor(previous_count / interval.value) != math.floor(count / interval.value)
return event == save_event and threshold_passed

return check_interval
warnings.warn(
'`composer.callbacks.utils.create_interval_scheduler has been moved to `composer.utils.misc.create_interval_scheduler` '
+ 'and will be removed in a future release.',
DeprecationWarning,
)
return _create_interval_scheduler(
interval=interval,
include_end_of_training=include_end_of_training,
checkpoint_events=checkpoint_events,
final_events=final_events,
)
97 changes: 6 additions & 91 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,105 +5,18 @@

from __future__ import annotations

import math
import textwrap
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from composer.core.data_spec import DataSpec, ensure_data_spec
from composer.core.event import Event
from composer.core.state import State
from composer.core.time import Time, TimeUnit
from composer.core.time import Time
from composer.devices import Device, DeviceGPU
from composer.utils import create_interval_scheduler

__all__ = ['Evaluator', 'evaluate_periodically', 'ensure_evaluator', 'validate_eval_automicrobatching']


def evaluate_periodically(eval_interval: Union[str, Time, int], eval_at_fit_end: bool = True):
"""Helper function to generate an evaluation interval callable.
Args:
eval_interval (str | Time | int): A :class:`.Time` instance or time string, or integer in epochs,
representing how often to evaluate. Set to ``0`` to disable evaluation.
eval_at_fit_end (bool): Whether to evaluate at the end of training, regardless of `eval_interval`.
Default: True
Returns:
(State, Event) -> bool: A callable for the ``eval_interval`` argument of an
:class:`.Evaluator`.
"""
if isinstance(eval_interval, int):
eval_interval = Time(eval_interval, TimeUnit.EPOCH)
if isinstance(eval_interval, str):
eval_interval = Time.from_timestring(eval_interval)

last_batch_seen = -1

def should_eval(state: State, event: Event):
# `TimeUnit.Duration` value is a float from `[0.0, 1.0)`
if not eval_interval.unit == TimeUnit.DURATION and int(eval_interval) <= 0:
return False
nonlocal last_batch_seen # required to use the last_batch_seen from the outer function scope

# if requested, evaluate at the end of training, as long as the length of training is specified.
if eval_at_fit_end and event == Event.FIT_END and state.timestamp.batch != last_batch_seen:
return True

# Previous timestamp will only be None if training has not started, but we are returning False
# in this case, just to be safe
if state.previous_timestamp is None:
return False

if eval_interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
previous_count = state.previous_timestamp.get(eval_interval.unit)
count = state.timestamp.get(eval_interval.unit)
# If the eval_interval is a duration, we will track progress in terms of the unit of max_duration
elif eval_interval.unit == TimeUnit.DURATION:
assert state.max_duration is not None
previous_count = state.previous_timestamp.get(state.max_duration.unit)
count = state.timestamp.get(state.max_duration.unit)
else:
raise ValueError(f'Invalid eval_interval unit: {eval_interval.unit}')

threshold_passed = math.floor(previous_count / eval_interval.value) != math.floor(count / eval_interval.value)

if eval_interval.unit == TimeUnit.EPOCH and event == Event.EPOCH_END and threshold_passed:
last_batch_seen = state.timestamp.batch
return True
elif eval_interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE
} and event == Event.BATCH_END and threshold_passed:
last_batch_seen = state.timestamp.batch
return True
elif eval_interval.unit == TimeUnit.DURATION:
assert state.max_duration is not None, 'max_duration should not be None'
if state.dataloader_len is None:
raise RuntimeError(
f'Evaluation interval of type `dur` or {TimeUnit.DURATION} requires the dataloader to be sized.')
if state.max_duration.unit == TimeUnit.EPOCH and int(
state.timestamp.batch) % math.ceil(state.max_duration.value * float(eval_interval) *
state.dataloader_len) == 0 and event == Event.BATCH_END:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.BATCH and int(state.timestamp.batch) % math.ceil(
state.max_duration.value * eval_interval.value) == 0 and event == Event.BATCH_END:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.SAMPLE and event == Event.BATCH_END:
samples_per_interval = math.ceil(state.max_duration.value * eval_interval)
threshold_passed = math.floor(previous_count / samples_per_interval) != math.floor(
count / samples_per_interval)
if threshold_passed:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.TOKEN and event == Event.BATCH_END:
tokens_per_interval = math.ceil(state.max_duration.value * eval_interval)
threshold_passed = math.floor(previous_count / tokens_per_interval) != math.floor(
count / tokens_per_interval)
if threshold_passed:
last_batch_seen = state.timestamp.batch
return True
return False

return should_eval
__all__ = ['Evaluator', 'ensure_evaluator', 'validate_eval_automicrobatching']


class Evaluator:
Expand Down Expand Up @@ -198,7 +111,9 @@ def eval_interval(self, eval_interval: Optional[Union[int, str, Time, Callable[[
if eval_interval is None:
self._eval_interval = None
elif not callable(eval_interval):
self._eval_interval = evaluate_periodically(eval_interval)
self._eval_interval = create_interval_scheduler(eval_interval,
checkpoint_events=False,
final_events={Event.FIT_END})
else:
self._eval_interval = eval_interval

Expand Down
5 changes: 3 additions & 2 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from composer.utils.import_helpers import MissingConditionalImportError, import_object
from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
from composer.utils.misc import (get_free_tcp_port, is_model_deepspeed, is_model_fsdp, is_notebook, model_eval_mode,
using_torch_2)
from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp,
is_notebook, model_eval_mode, using_torch_2)
from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError,
OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore)
from composer.utils.retrying import retry
Expand Down Expand Up @@ -83,6 +83,7 @@
'convert_nested_dict_to_flat_dict',
'convert_flat_dict_to_nested_dict',
'using_torch_2',
'create_interval_scheduler',
'EvalClient',
'LambdaEvalClient',
'LocalEvalClient',
Expand Down
125 changes: 123 additions & 2 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,140 @@

"""Miscellaneous Helpers."""

import math
import socket
from contextlib import contextmanager
from typing import Type
from typing import TYPE_CHECKING, Callable, Optional, Set, Type, Union

import torch
from packaging import version
from torch.nn.parallel import DistributedDataParallel

if TYPE_CHECKING:
from composer.core import Event, State, Time

__all__ = [
'is_model_deepspeed', 'is_model_fsdp', 'is_notebook', 'warning_on_one_line', 'get_free_tcp_port', 'model_eval_mode'
'is_model_deepspeed',
'is_model_fsdp',
'is_notebook',
'warning_on_one_line',
'get_free_tcp_port',
'model_eval_mode',
'create_interval_scheduler',
]


def create_interval_scheduler(interval: Union[str, int, 'Time'],
include_end_of_training: bool = True,
checkpoint_events: bool = True,
final_events: Optional[Set['Event']] = None) -> Callable[['State', 'Event'], bool]:
"""Helper function to create a scheduler according to a specified interval.
Args:
interval (Union[str, int, :class:`.Time`]): If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
include_end_of_training (bool): If true, the returned callable will return true at the end of training as well.
Otherwise, the returned callable will return true at intervals only.
checkpoint_events (bool): If true, will use the EPOCH_CHECKPOINT and BATCH_CHECKPOINT events. If False, will use
the EPOCH_END and BATCH_END events.
final_events (Optional[Set[Event]]): The set of events to trigger on at the end of training.
Returns:
Callable[[State, Event], bool]: A function that returns true at interval and at the end of training if specified.
For example, it can be passed as the ``save_interval`` argument into the :class:`.CheckpointSaver`.
"""
# inlined to avoid circular import
from composer.core import Event, State, Time, TimeUnit

if final_events is None:
final_events = {Event.BATCH_CHECKPOINT, Event.EPOCH_CHECKPOINT}

if isinstance(interval, str):
interval = Time.from_timestring(interval)
if isinstance(interval, int):
interval = Time(interval, TimeUnit.EPOCH)

if interval.unit == TimeUnit.EPOCH:
interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}:
interval_event = Event.BATCH_CHECKPOINT if checkpoint_events else Event.BATCH_END
else:
raise NotImplementedError(
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

last_batch_seen = -1

def check_interval(state: State, event: Event):
# `TimeUnit.Duration` value is a float from `[0.0, 1.0)`
if not interval.unit == TimeUnit.DURATION and int(interval) <= 0:
return False
nonlocal last_batch_seen # required to use the last_batch_seen from the outer function scope

# Previous timestamp will only be None if training has not started, but we are returning False
# in this case, just to be safe
if state.previous_timestamp is None:
return False

elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

if include_end_of_training and event in final_events and elapsed_duration >= 1.0 and state.timestamp.batch != last_batch_seen:
return True

if interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
previous_count = state.previous_timestamp.get(interval.unit)
count = state.timestamp.get(interval.unit)
# If the eval_interval is a duration, we will track progress in terms of the unit of max_duration
elif interval.unit == TimeUnit.DURATION:
assert state.max_duration is not None
previous_count = state.previous_timestamp.get(state.max_duration.unit)
count = state.timestamp.get(state.max_duration.unit)
else:
raise NotImplementedError(
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

threshold_passed = math.floor(previous_count / interval.value) != math.floor(count / interval.value)

if interval.unit != TimeUnit.DURATION and event == interval_event and threshold_passed:
last_batch_seen = state.timestamp.batch
return True
elif interval.unit == TimeUnit.DURATION:
assert state.max_duration is not None, 'max_duration should not be None'
if state.dataloader_len is None:
raise RuntimeError(
f'Interval of type `dur` or {TimeUnit.DURATION} requires the dataloader to be sized.')

if event == interval_event:
if state.max_duration.unit == TimeUnit.EPOCH and int(state.timestamp.batch) % math.ceil(
state.max_duration.value * float(interval) * state.dataloader_len) == 0:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.BATCH and int(state.timestamp.batch) % math.ceil(
state.max_duration.value * interval.value) == 0:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.SAMPLE:
samples_per_interval = math.ceil(state.max_duration.value * interval)
threshold_passed = math.floor(previous_count / samples_per_interval) != math.floor(
count / samples_per_interval)
if threshold_passed:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.TOKEN:
tokens_per_interval = math.ceil(state.max_duration.value * interval)
threshold_passed = math.floor(previous_count / tokens_per_interval) != math.floor(
count / tokens_per_interval)
if threshold_passed:
last_batch_seen = state.timestamp.batch
return True
return False

return check_interval


def is_model_deepspeed(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`~deepspeed.DeepSpeedEngine`."""
try:
Expand Down
Loading

0 comments on commit 70422ab

Please sign in to comment.