Skip to content

Commit

Permalink
Merge branch 'main' into mvpatel2000/patch-import
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Oct 13, 2024
2 parents cbecd0a + 5538d2c commit 9f53e31
Show file tree
Hide file tree
Showing 27 changed files with 759 additions and 220 deletions.
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.25.0.dev0'
__version__ = '0.26.0.dev0'
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.load_checkpoint import LoadCheckpoint
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.memory_snapshot import MemorySnapshot
Expand Down Expand Up @@ -44,4 +45,5 @@
'FreeOutputs',
'MemorySnapshot',
'OOMObserver',
'LoadCheckpoint',
]
76 changes: 76 additions & 0 deletions composer/callbacks/load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Load a checkpoint."""
import logging
from typing import Optional, Union

from composer.core import Callback, State
from composer.core.event import Event
from composer.loggers import Logger
from composer.models.huggingface import HuggingFaceModel
from composer.utils.checkpoint import load_checkpoint
from composer.utils.file_helpers import maybe_create_object_store_from_uri, parse_uri

log = logging.getLogger(__name__)


class LoadCheckpoint(Callback):
"""Callback that loads a checkpoint at the specified event.
Args:
load_path (str): The path to the checkpoint to load.
load_options (Optional[dict]): A dictionary of options to pass to the checkpoint loading function.
event (Union[str, Event]): The event at which to load the checkpoint. Defaults to ``Event.BEFORE_LOAD``.
"""

def __init__(
self,
load_path: str,
load_weights_only: bool = False,
strict_model_weights: bool = True,
ignore_keys: Optional[list[str]] = None,
event: Union[str, Event] = Event.BEFORE_LOAD,
):
super().__init__()
self.load_path = load_path
self.load_object_store = maybe_create_object_store_from_uri(load_path)
_, _, self.parsed_path = parse_uri(load_path)

self.load_weights_only = load_weights_only
self.strict_model_weights = strict_model_weights
self.ignore_keys = ignore_keys

self.event = event if isinstance(event, Event) else Event[event.upper()]

def run_event(self, event: Event, state: State, logger: Logger) -> None:
if event == self.event:
log.info(f'Loading checkpoint from {self.load_path} at {self.event}.')
self._load(state, logger)
log.info(f'Finished loading checkpoint from {self.load_path} at {self.event}.')

return super().run_event(event, state, logger)

def _load(self, state: State, logger: Logger) -> None:

# We need to temporarily disable the `should_save_peft_only` flag on the model
# so that we can have access to the full model weights for loading.
model = state.model
original_should_save_peft_only = False
if isinstance(model, HuggingFaceModel):
original_should_save_peft_only = model.should_save_peft_only
model.should_save_peft_only = False

load_checkpoint(
path=self.parsed_path,
state=state,
logger=logger,
object_store=self.load_object_store,
strict_model_weights=self.strict_model_weights,
ignore_keys=self.ignore_keys,
load_weights_only=self.load_weights_only,
)

# Restore the original `should_save_peft_only` flag on the model
if isinstance(model, HuggingFaceModel):
model.should_save_peft_only = original_should_save_peft_only
6 changes: 6 additions & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,12 @@ def _validate_parallelism_configs(self):
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
'which is the default and recommended setting.',
)
if self.tp_config.tensor_parallel_degree == 1:
warnings.warn(
'Received tensor_parallel_degree of 1, which is a no-op. Tensor parallelism will not be used.',
UserWarning,
)
self.tp_config = None

# Load monolith rank0 only
if self.load_monolith_rank0_only:
Expand Down
6 changes: 1 addition & 5 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,8 +945,7 @@ def unshard_with_sync(self):

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.1'):
# 2.4.0 only patch
) < version.parse('2.4.2'):
# PyTorch issue: https://github.com/pytorch/pytorch/issues/133923
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from typing import Mapping, Collection
Expand Down Expand Up @@ -1003,9 +1002,6 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
for key, value in state_dict.items():
_traverse_obj((str(key),), value)

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.2'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard
Expand Down
5 changes: 4 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ def __init__(
+
'which provides similar functionality. Please use the `parallelism_config` parameter instead. Please open '
+ 'a GitHub issue if you need help migrating from DeepSpeed to FSDP.',
remove_version='0.28.0',
remove_version='0.27.0',
),
)

Expand Down Expand Up @@ -1902,13 +1902,16 @@ def __init__(
log.info('No previous autoresume checkpoint found')
# Actually load the checkpoint from potentially updated arguments
if load_path is not None:
log.info(f'Loading checkpoint from {load_path}')
if load_object_store is None:
load_object_store = maybe_create_object_store_from_uri(load_path)
log.debug(f'Created object store from load path: {load_object_store}')
if isinstance(load_object_store, WandBLogger):
import wandb
if wandb.run is None:
load_object_store.init(self.state, self.logger)
_, _, parsed_load_path = parse_uri(load_path)
log.debug(f'Parsed load path: {parsed_load_path}')

self._rng_state = checkpoint.load_checkpoint(
state=self.state,
Expand Down
120 changes: 40 additions & 80 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,23 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo
) + extra_suffix


def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str):
if source_path.endswith('.symlink') or os.path.islink(source_path):
source_path = extract_path_from_symlink(source_path, object_store=object_store)
metadata_path = str(Path(source_path) / Path('.metadata'))
if object_store is None:
return not os.path.exists(metadata_path)
else:
try:
_, _, metadata_path = parse_uri(metadata_path)
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
download_object_or_file(metadata_path, metadata_destination, object_store)
return False
except FileNotFoundError:
return True


def load_checkpoint(
path: str,
state: State,
Expand Down Expand Up @@ -531,15 +514,9 @@ def load_checkpoint(
:attr:`load_weights_only` is not None. Otherwise, None.
"""
path = partial_format(path, run_name=state.run_name)
using_legacy_sharded = False
if state.fsdp_sharded_state_dict_enabled:
assert object_store is None or isinstance(
object_store,
ObjectStore,
), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore'
using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path)
log.debug(f'Loading checkpoint from formatted path: {path}')

if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded:
if state.fsdp_sharded_state_dict_enabled:
rng_state_dicts = load_sharded_checkpoint(
source_path=path,
state=state,
Expand All @@ -554,26 +531,20 @@ def load_checkpoint(
)
else:
# Download the checkpoint to the node-local folder
log.debug('Loading checkpoint at %s', path)
# Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node.
# If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder.
needs_unique_checkpoint_folder = state.fsdp_sharded_state_dict_enabled or dist.get_local_rank() == 0
tempdir_ctx = tempfile.TemporaryDirectory() if needs_unique_checkpoint_folder else contextlib.nullcontext(None)
tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None)
with tempdir_ctx as tempdir:
try:
# Get the path to the proper checkpoint folder corresponding to the current rank's node.
# If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir.
node_checkpoint_folder = (
tempdir if state.fsdp_sharded_state_dict_enabled else _get_local_rank_zero_path(tempdir)
)
assert node_checkpoint_folder is not None
node_checkpoint_folder = _get_local_rank_zero_path(tempdir)

composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint(
path=path,
node_checkpoint_folder=node_checkpoint_folder,
object_store=object_store,
progress_bar=progress_bar,
fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled,
deepspeed_sharded_checkpoint=is_model_deepspeed(state.model),
)
rng_state_dicts = _restore_checkpoint(
Expand All @@ -593,6 +564,8 @@ def load_checkpoint(
# be a shared resource between nodes.
dist.barrier()
log.info('%s loaded from %s', 'Model weights' if load_weights_only else 'Trainer checkpoint', path)

# Verify all ranks resumed on same step
step_to_resume_from = state.timestamp.batch.value
max_step_to_resume_from = state.device.tensor_to_device(
torch.tensor(state.timestamp.batch.value, dtype=torch.int64),
Expand Down Expand Up @@ -623,50 +596,41 @@ def dist_cp_load(
load_planner: Optional[LoadPlanner] = None,
):
if version.parse(torch.__version__) >= version.parse('2.4.0'):
if version.parse(torch.__version__) < version.parse('2.4.1'):
# PyTorch 2.4.0
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
except CheckpointException as e:
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
# We override the traverse_state_dict so that the load planner could
# use the old way of flattening the state dict
log.debug('Trying to load checkpointing saved before torch 2.4')

import torch.distributed.checkpoint._nested_dict as nested_dict
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0

from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse

nested_dict.traverse_state_dict = backward_compatible_traverse
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse

dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
# Revert the override
nested_dict.traverse_state_dict = traverse_2_4_0
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
else:
raise e
else:
# PyTorch 2.4.1
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
except CheckpointException as e:
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
# We override the traverse_state_dict so that the load planner could
# use the old way of flattening the state dict
log.debug('Trying to load checkpointing saved before torch 2.4')

import torch.distributed.checkpoint._nested_dict as nested_dict
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0

from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse

nested_dict.traverse_state_dict = backward_compatible_traverse
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse

dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
# Revert the override
nested_dict.traverse_state_dict = traverse_2_4_0
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
else:
raise e
else:
dist_cp.load_state_dict(
state_dict=state_dict,
Expand Down Expand Up @@ -808,7 +772,6 @@ def download_checkpoint(
node_checkpoint_folder: str,
object_store: Optional[Union[ObjectStore, LoggerDestination]],
progress_bar: bool,
fsdp_sharded_state_dict_enabled: bool = False,
deepspeed_sharded_checkpoint: bool = False,
) -> tuple[str, Optional[str], bool]:
"""Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``.
Expand All @@ -835,9 +798,7 @@ def download_checkpoint(
# and only rank zero has this file unless fsdp_sharded_state_dict_enabled then
# every rank has it's own file.
extracted_checkpoint_folder = None
composer_states_filepath = (
rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath
)
composer_states_filepath = rank_zero_checkpoint_filepath

if is_compressed_pt(path):
original_path = path
Expand All @@ -847,9 +808,8 @@ def download_checkpoint(
with compressor.decompress(original_path) as in_file:
shutil.copyfileobj(in_file, out_file)

checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint
try:
if not checkpoint_is_sharded and dist.get_local_rank() == 0:
if not deepspeed_sharded_checkpoint and dist.get_local_rank() == 0:
# If the checkpoint is not sharded, then local rank 0 on each node needs to download the
# global rank 0 checkpoint
path = _format_path_with_rank_zero(path)
Expand All @@ -868,7 +828,7 @@ def download_checkpoint(
# the underlying issue is that the checkpoint file does not exist on the disk
# or could not be downloaded
raise RuntimeError(f'Checkpoint {path} does not exist')
elif checkpoint_is_sharded:
elif deepspeed_sharded_checkpoint:
# If the checkpoint is sharded, then every rank needs to download its own checkpoint
path = _format_path_with_current_rank(path)
try:
Expand Down Expand Up @@ -898,7 +858,7 @@ def download_checkpoint(

finally:
# Use busy wait to avoid timeouts on large downloads for non-sharded checkpoints
if not checkpoint_is_sharded:
if not deepspeed_sharded_checkpoint:
signal_file_path = os.path.join(
node_checkpoint_folder,
dist.get_node_signal_file_name(),
Expand Down
4 changes: 3 additions & 1 deletion composer/utils/object_store/uc_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ def _wrap_errors(uri: str, e: Exception):
"""
# Wrap DatabricksError in ObjectStoreTransientError.
# If the file is not found, raise FileNotFoundError.
from databricks.sdk.errors import DatabricksError, NotFound
from databricks.sdk.errors import DatabricksError, NotFound, PermissionDenied
if isinstance(e, DatabricksError):
if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore
raise FileNotFoundError(f'Object {uri} not found') from e
if isinstance(e, PermissionDenied):
raise e
raise ObjectStoreTransientError from e

# Wrap ChunkedEncodingError in ObjectStoreTransientError.
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def configure_deterministic_mode():
# See https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
# and https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
warnings.warn('Deterministic mode is activated. This will negatively impact performance.', category=UserWarning)
log.info('Deterministic mode is activated. This will negatively impact performance.')


def get_random_seed() -> int:
Expand Down
4 changes: 2 additions & 2 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the
<!-- BEGIN_COMPOSER_BUILD_MATRIX -->
| Composer Version | CUDA Support | Docker Tag |
|--------------------|----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.24.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.24.1` |
| 0.24.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.24.1_cpu` |
| 0.25.0 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.25.0` |
| 0.25.0 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.25.0_cpu` |
<!-- END_COMPOSER_BUILD_MATRIX -->

**Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually
Expand Down
Loading

0 comments on commit 9f53e31

Please sign in to comment.