From 4cdc2cdd311c163ff801cb43d2a90f199f0d0d77 Mon Sep 17 00:00:00 2001 From: bigning Date: Mon, 23 Sep 2024 13:04:27 -0700 Subject: [PATCH 01/17] fix 2.4.1ckpt (#3629) --- composer/trainer/_patch_pytorch.py | 6 +-- composer/utils/checkpoint.py | 69 +++++++++++++----------------- 2 files changed, 31 insertions(+), 44 deletions(-) diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index fcca94d73a..77c4d733f7 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -945,8 +945,7 @@ def unshard_with_sync(self): if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse( torch.__version__, -) < version.parse('2.4.1'): - # 2.4.0 only patch +) < version.parse('2.4.2'): # PyTorch issue: https://github.com/pytorch/pytorch/issues/133923 from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from typing import Mapping, Collection @@ -1003,9 +1002,6 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: for key, value in state_dict.items(): _traverse_obj((str(key),), value) -if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse( - torch.__version__, -) < version.parse('2.4.2'): # Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks from torch.distributed.fsdp._flat_param import FlatParamHandle original_unshard = FlatParamHandle.unshard diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index b966c918c5..11e79bff10 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -623,50 +623,41 @@ def dist_cp_load( load_planner: Optional[LoadPlanner] = None, ): if version.parse(torch.__version__) >= version.parse('2.4.0'): - if version.parse(torch.__version__) < version.parse('2.4.1'): - # PyTorch 2.4.0 - from torch.distributed.checkpoint.utils import CheckpointException - try: - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - except CheckpointException as e: - checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata - if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: - # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. - # Torch issue: https://github.com/pytorch/pytorch/issues/133923. - # We override the traverse_state_dict so that the load planner could - # use the old way of flattening the state dict - log.debug('Trying to load checkpointing saved before torch 2.4') - - import torch.distributed.checkpoint._nested_dict as nested_dict - import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util - from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 - - from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse - - nested_dict.traverse_state_dict = backward_compatible_traverse - sharded_tensor_util.traverse_state_dict = backward_compatible_traverse - - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - # Revert the override - nested_dict.traverse_state_dict = traverse_2_4_0 - sharded_tensor_util.traverse_state_dict = traverse_2_4_0 - else: - raise e - else: - # PyTorch 2.4.1 + from torch.distributed.checkpoint.utils import CheckpointException + try: dist_cp.load( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, ) + except CheckpointException as e: + checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata + if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: + # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. + # Torch issue: https://github.com/pytorch/pytorch/issues/133923. + # We override the traverse_state_dict so that the load planner could + # use the old way of flattening the state dict + log.debug('Trying to load checkpointing saved before torch 2.4') + + import torch.distributed.checkpoint._nested_dict as nested_dict + import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util + from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 + + from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse + + nested_dict.traverse_state_dict = backward_compatible_traverse + sharded_tensor_util.traverse_state_dict = backward_compatible_traverse + + dist_cp.load( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + # Revert the override + nested_dict.traverse_state_dict = traverse_2_4_0 + sharded_tensor_util.traverse_state_dict = traverse_2_4_0 + else: + raise e else: dist_cp.load_state_dict( state_dict=state_dict, From 17304a0b54daa8f982107cdf138afb3b0e0e6da8 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 23 Sep 2024 16:46:29 -0400 Subject: [PATCH 02/17] More checkpoint debug logs (#3632) --- composer/trainer/trainer.py | 3 +++ composer/utils/checkpoint.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 815aa50001..32a9426523 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1902,13 +1902,16 @@ def __init__( log.info('No previous autoresume checkpoint found') # Actually load the checkpoint from potentially updated arguments if load_path is not None: + log.info(f'Loading checkpoint from {load_path}') if load_object_store is None: load_object_store = maybe_create_object_store_from_uri(load_path) + log.debug(f'Created object store from load path: {load_object_store}') if isinstance(load_object_store, WandBLogger): import wandb if wandb.run is None: load_object_store.init(self.state, self.logger) _, _, parsed_load_path = parse_uri(load_path) + log.debug(f'Parsed load path: {parsed_load_path}') self._rng_state = checkpoint.load_checkpoint( state=self.state, diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 11e79bff10..ccefee4e60 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -417,6 +417,7 @@ def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, if source_path.endswith('.symlink') or os.path.islink(source_path): source_path = extract_path_from_symlink(source_path, object_store=object_store) metadata_path = str(Path(source_path) / Path('.metadata')) + log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.') if object_store is None: return not os.path.exists(metadata_path) else: @@ -531,6 +532,7 @@ def load_checkpoint( :attr:`load_weights_only` is not None. Otherwise, None. """ path = partial_format(path, run_name=state.run_name) + log.debug(f'Loading checkpoint from formatted path: {path}') using_legacy_sharded = False if state.fsdp_sharded_state_dict_enabled: assert object_store is None or isinstance( @@ -538,6 +540,7 @@ def load_checkpoint( ObjectStore, ), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore' using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path) + log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}') if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded: rng_state_dicts = load_sharded_checkpoint( From c261d132649b2c7faf1d6a26069d19cd313d0294 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 23 Sep 2024 18:18:12 -0400 Subject: [PATCH 03/17] Lower DeepSpeed deprecation version (#3634) --- composer/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 32a9426523..a118b1cf60 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1188,7 +1188,7 @@ def __init__( + 'which provides similar functionality. Please use the `parallelism_config` parameter instead. Please open ' + 'a GitHub issue if you need help migrating from DeepSpeed to FSDP.', - remove_version='0.28.0', + remove_version='0.27.0', ), ) From 546f7e87a163171afb2bbc010898e0dc47a1e0b0 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 23 Sep 2024 15:20:13 -0700 Subject: [PATCH 04/17] Bump version 25 (#3633) --- composer/_version.py | 2 +- docker/README.md | 4 ++-- docker/build_matrix.yaml | 16 ++++++++-------- docker/generate_build_matrix.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/composer/_version.py b/composer/_version.py index 564ccfe75f..ef0d45782a 100644 --- a/composer/_version.py +++ b/composer/_version.py @@ -3,4 +3,4 @@ """The Composer Version.""" -__version__ = '0.25.0.dev0' +__version__ = '0.26.0.dev0' diff --git a/docker/README.md b/docker/README.md index a561d1237d..fd68d04951 100644 --- a/docker/README.md +++ b/docker/README.md @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the | Composer Version | CUDA Support | Docker Tag | |--------------------|----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 0.24.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.24.1` | -| 0.24.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.24.1_cpu` | +| 0.25.0 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.25.0` | +| 0.25.0 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.25.0_cpu` | **Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml index 40edd23992..65b8e747a1 100644 --- a/docker/build_matrix.yaml +++ b/docker/build_matrix.yaml @@ -194,9 +194,9 @@ TORCHVISION_VERSION: 0.17.2 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.24.1 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.25.0 CUDA_VERSION: 12.4.1 - IMAGE_NAME: composer-0-24-1 + IMAGE_NAME: composer-0-25-0 MOFED_VERSION: latest-23.10 NVIDIA_REQUIRE_CUDA_OVERRIDE: '' PYTHON_VERSION: '3.11' @@ -204,17 +204,17 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.4.1 TAGS: - - mosaicml/composer:0.24.1 - - ghcr.io/databricks-mosaic/composer:0.24.1 + - mosaicml/composer:0.25.0 + - ghcr.io/databricks-mosaic/composer:0.25.0 - mosaicml/composer:latest - ghcr.io/databricks-mosaic/composer:latest TARGET: composer_stage TORCHVISION_VERSION: 0.19.1 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: ubuntu:20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.24.1 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.25.0 CUDA_VERSION: '' - IMAGE_NAME: composer-0-24-1-cpu + IMAGE_NAME: composer-0-25-0-cpu MOFED_VERSION: latest-23.10 NVIDIA_REQUIRE_CUDA_OVERRIDE: '' PYTHON_VERSION: '3.11' @@ -222,8 +222,8 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.4.1 TAGS: - - mosaicml/composer:0.24.1_cpu - - ghcr.io/databricks-mosaic/composer:0.24.1_cpu + - mosaicml/composer:0.25.0_cpu + - ghcr.io/databricks-mosaic/composer:0.25.0_cpu - mosaicml/composer:latest_cpu - ghcr.io/databricks-mosaic/composer:latest_cpu TARGET: composer_stage diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py index 9e47662a4b..a3336a3d19 100644 --- a/docker/generate_build_matrix.py +++ b/docker/generate_build_matrix.py @@ -244,7 +244,7 @@ def _main(): composer_entries = [] # The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images - composer_versions = ['0.24.1'] # Only build images for the latest composer version + composer_versions = ['0.25.0'] # Only build images for the latest composer version composer_python_versions = [PRODUCTION_PYTHON_VERSION] # just build composer against the latest for product in itertools.product(composer_python_versions, composer_versions, cuda_options): From 82b9d1fea2b44b198169e7ac8fed882ee7299036 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:52:14 -0700 Subject: [PATCH 05/17] Add backward compatibility checkpoint tests for v0.25.0 (#3635) --- tests/trainer/test_fsdp_checkpoint.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 5f97c6092c..f9b40e2e78 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -513,6 +513,7 @@ def test_fsdp_mixed_with_sync( '0.22.0', '0.23.0', '0.24.0', + '0.25.0', ], ) @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') @@ -532,8 +533,9 @@ def test_fsdp_load_old_checkpoint( if composer_version == '0.18.1' and state_dict_type == 'full' and precision == 'amp_bf16' and sharding_strategy == 'FULL_SHARD': pytest.skip('TODO: This checkpoint is missing') - if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0') - ) or (composer_version == '0.24.0' and version.parse(torch.__version__) < version.parse('2.4.0')): + if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0')) or ( + composer_version == '0.24.0' and version.parse(torch.__version__) < version.parse('2.4.0') + ) or (composer_version == '0.25.0' and version.parse(torch.__version__) < version.parse('2.5.0')): pytest.skip('Current torch version is older than torch version that checkpoint was written with.') if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']: From 566d262e622797c76c5246cff87453e3876393ca Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 27 Sep 2024 13:15:53 -0400 Subject: [PATCH 06/17] Don't use TP when `tensor_parallel_degree` is 1 (#3636) Co-authored-by: Eitan Turok --- composer/core/state.py | 6 ++++++ tests/trainer/test_tp.py | 39 ++++++++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 3980514380..4c1e1a92bb 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -612,6 +612,12 @@ def _validate_parallelism_configs(self): 'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, ' 'which is the default and recommended setting.', ) + if self.tp_config.tensor_parallel_degree == 1: + warnings.warn( + 'Received tensor_parallel_degree of 1, which is a no-op. Tensor parallelism will not be used.', + UserWarning, + ) + self.tp_config = None # Load monolith rank0 only if self.load_monolith_rank0_only: diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index 03dc37fc1b..b03d170a05 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -1,11 +1,14 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import contextlib + import pytest import torch from packaging import version from torch.utils.data import DataLoader +from composer.optim import DecoupledSGDW from composer.trainer.trainer import Trainer from composer.utils import dist from tests.common import ( @@ -17,12 +20,14 @@ @pytest.mark.gpu @world_size(4) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') -def test_tp_train(world_size: int): +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') +@pytest.mark.parametrize('tensor_parallel_degree', [1, 2]) +def test_tp_train(world_size: int, tensor_parallel_degree: int): from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel model = SimpleModel() + optimizer = DecoupledSGDW(model.parameters(), lr=0.1) dataset = RandomClassificationDataset(size=8) dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset)) @@ -31,18 +36,26 @@ def test_tp_train(world_size: int): 'fc2': RowwiseParallel(), } - trainer = Trainer( - model=model, - train_dataloader=dataloader, - parallelism_config={ - 'tp': { - 'layer_plan': layer_plan, - 'tensor_parallel_degree': 2, + if tensor_parallel_degree == 1: + expected_warning = 'Received tensor_parallel_degree of 1, which is a no-op. Tensor parallelism will not be used.' + ctx = pytest.warns(UserWarning, match=expected_warning) + else: + ctx = contextlib.nullcontext() + + with ctx: + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'tp': { + 'layer_plan': layer_plan, + 'tensor_parallel_degree': tensor_parallel_degree, + }, + 'fsdp': {}, }, - 'fsdp': {}, - }, - max_duration='3ba', - ) + max_duration='3ba', + ) trainer.fit() From 71efd5876c5fff8381c2d960930f1105a45aae40 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Sep 2024 08:12:52 -0700 Subject: [PATCH 07/17] Update huggingface-hub requirement from <0.25,>=0.21.2 to >=0.21.2,<0.26 (#3637) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2f8d41de48..f88cb94071 100644 --- a/setup.py +++ b/setup.py @@ -182,7 +182,7 @@ def package_files(prefix: str, directory: str, extension: str): extra_deps['nlp'] = [ 'transformers>=4.11,!=4.34.0,<4.45', 'datasets>=2.4,<4', - 'huggingface-hub>=0.21.2,<0.25', + 'huggingface-hub>=0.21.2,<0.26', ] extra_deps['peft'] = [ From ab98cba78d5f7f65db441372594d411088f8a9a7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Sep 2024 08:13:01 -0700 Subject: [PATCH 08/17] Update transformers requirement from !=4.34.0,<4.45,>=4.11 to >=4.11,!=4.34.0,<4.46 (#3638) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f88cb94071..39ddd8de38 100644 --- a/setup.py +++ b/setup.py @@ -180,7 +180,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['nlp'] = [ - 'transformers>=4.11,!=4.34.0,<4.45', + 'transformers>=4.11,!=4.34.0,<4.46', 'datasets>=2.4,<4', 'huggingface-hub>=0.21.2,<0.26', ] From e5e2f744baa9f5bec8ff24b58f5a6b2a564919d0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Sep 2024 08:13:13 -0700 Subject: [PATCH 09/17] Bump databricks-sdk from 0.32.0 to 0.33.0 (#3639) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 39ddd8de38..d676b781a6 100644 --- a/setup.py +++ b/setup.py @@ -225,13 +225,13 @@ def package_files(prefix: str, directory: str, extension: str): extra_deps['mlflow'] = [ 'mlflow>=2.14.1,<3.0', - 'databricks-sdk==0.32.0', + 'databricks-sdk==0.33.0', 'pynvml>=11.5.0,<12', ] extra_deps['pandas'] = ['pandas>=2.0.0,<3.0'] -extra_deps['databricks'] = ['databricks-sdk==0.32.0'] +extra_deps['databricks'] = ['databricks-sdk==0.33.0'] extra_deps['all'] = {dep for deps in extra_deps.values() for dep in deps} From 29632931896c3d0f05b396a55c70231c45163ad1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 1 Oct 2024 13:39:54 -0400 Subject: [PATCH 10/17] Remove Legacy Checkpointing (#3631) --- composer/utils/checkpoint.py | 52 +++++---------------------- tests/trainer/test_fsdp_checkpoint.py | 27 +++----------- 2 files changed, 13 insertions(+), 66 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index ccefee4e60..9f480059d3 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -413,24 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo ) + extra_suffix -def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str): - if source_path.endswith('.symlink') or os.path.islink(source_path): - source_path = extract_path_from_symlink(source_path, object_store=object_store) - metadata_path = str(Path(source_path) / Path('.metadata')) - log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.') - if object_store is None: - return not os.path.exists(metadata_path) - else: - try: - _, _, metadata_path = parse_uri(metadata_path) - with tempfile.TemporaryDirectory() as temp_dir: - metadata_destination = os.path.join(str(temp_dir), '.metadata') - download_object_or_file(metadata_path, metadata_destination, object_store) - return False - except FileNotFoundError: - return True - - def load_checkpoint( path: str, state: State, @@ -533,16 +515,8 @@ def load_checkpoint( """ path = partial_format(path, run_name=state.run_name) log.debug(f'Loading checkpoint from formatted path: {path}') - using_legacy_sharded = False + if state.fsdp_sharded_state_dict_enabled: - assert object_store is None or isinstance( - object_store, - ObjectStore, - ), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore' - using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path) - log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}') - - if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded: rng_state_dicts = load_sharded_checkpoint( source_path=path, state=state, @@ -557,26 +531,20 @@ def load_checkpoint( ) else: # Download the checkpoint to the node-local folder - log.debug('Loading checkpoint at %s', path) # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node. # If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder. - needs_unique_checkpoint_folder = state.fsdp_sharded_state_dict_enabled or dist.get_local_rank() == 0 - tempdir_ctx = tempfile.TemporaryDirectory() if needs_unique_checkpoint_folder else contextlib.nullcontext(None) + tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None) with tempdir_ctx as tempdir: try: # Get the path to the proper checkpoint folder corresponding to the current rank's node. # If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir. - node_checkpoint_folder = ( - tempdir if state.fsdp_sharded_state_dict_enabled else _get_local_rank_zero_path(tempdir) - ) - assert node_checkpoint_folder is not None + node_checkpoint_folder = _get_local_rank_zero_path(tempdir) composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint( path=path, node_checkpoint_folder=node_checkpoint_folder, object_store=object_store, progress_bar=progress_bar, - fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled, deepspeed_sharded_checkpoint=is_model_deepspeed(state.model), ) rng_state_dicts = _restore_checkpoint( @@ -596,6 +564,8 @@ def load_checkpoint( # be a shared resource between nodes. dist.barrier() log.info('%s loaded from %s', 'Model weights' if load_weights_only else 'Trainer checkpoint', path) + + # Verify all ranks resumed on same step step_to_resume_from = state.timestamp.batch.value max_step_to_resume_from = state.device.tensor_to_device( torch.tensor(state.timestamp.batch.value, dtype=torch.int64), @@ -802,7 +772,6 @@ def download_checkpoint( node_checkpoint_folder: str, object_store: Optional[Union[ObjectStore, LoggerDestination]], progress_bar: bool, - fsdp_sharded_state_dict_enabled: bool = False, deepspeed_sharded_checkpoint: bool = False, ) -> tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. @@ -829,9 +798,7 @@ def download_checkpoint( # and only rank zero has this file unless fsdp_sharded_state_dict_enabled then # every rank has it's own file. extracted_checkpoint_folder = None - composer_states_filepath = ( - rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath - ) + composer_states_filepath = rank_zero_checkpoint_filepath if is_compressed_pt(path): original_path = path @@ -841,9 +808,8 @@ def download_checkpoint( with compressor.decompress(original_path) as in_file: shutil.copyfileobj(in_file, out_file) - checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint try: - if not checkpoint_is_sharded and dist.get_local_rank() == 0: + if not deepspeed_sharded_checkpoint and dist.get_local_rank() == 0: # If the checkpoint is not sharded, then local rank 0 on each node needs to download the # global rank 0 checkpoint path = _format_path_with_rank_zero(path) @@ -862,7 +828,7 @@ def download_checkpoint( # the underlying issue is that the checkpoint file does not exist on the disk # or could not be downloaded raise RuntimeError(f'Checkpoint {path} does not exist') - elif checkpoint_is_sharded: + elif deepspeed_sharded_checkpoint: # If the checkpoint is sharded, then every rank needs to download its own checkpoint path = _format_path_with_current_rank(path) try: @@ -892,7 +858,7 @@ def download_checkpoint( finally: # Use busy wait to avoid timeouts on large downloads for non-sharded checkpoints - if not checkpoint_is_sharded: + if not deepspeed_sharded_checkpoint: signal_file_path = os.path.join( node_checkpoint_folder, dist.get_node_signal_file_name(), diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index f9b40e2e78..9f785a94ff 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -31,7 +31,7 @@ from composer.optim import DecoupledAdamW from composer.trainer import Trainer from composer.utils import FSDPConfig, TPConfig, dist, parse_uri -from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded +from composer.utils.checkpoint import dist_cp_load from composer.utils.file_helpers import get_file from composer.utils.object_store import S3ObjectStore from composer.utils.reproducibility import get_rng_state @@ -539,7 +539,8 @@ def test_fsdp_load_old_checkpoint( pytest.skip('Current torch version is older than torch version that checkpoint was written with.') if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']: - rank = 0 if state_dict_type == 'full' else '{rank}' + if state_dict_type == 'sharded': + pytest.skip('Loading legacy sharded checkpoints are not supported after v0.25.0.') load_path_dir = ( f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/' @@ -549,11 +550,7 @@ def test_fsdp_load_old_checkpoint( if ((version.parse(composer_version) > version.parse('0.15.0')) and state_dict_type != 'full'): load_path_dir = (load_path_dir + 'ep0-ba2/') - load_path = load_path_dir + f'ba2_rank{rank}.pt' - assert is_checkpoint_legacy_sharded( - object_store=S3ObjectStore(bucket=f'{s3_bucket}'), - source_path=load_path.lstrip(f's3://{s3_bucket}/'), - ) + load_path = load_path_dir + f'ba2_rank0.pt' else: load_path = ( f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/' @@ -911,16 +908,9 @@ def test_fsdp_partitioned_state_dict_load( load_path = 's3://' + save_folder.strip('s3://').format( run_name=run_name, ) + ('/ba2' if not use_symlink else '/latest-rank0.pt.symlink') - object_store = S3ObjectStore(bucket=f'{s3_bucket}') else: - object_store = None load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2')) - assert not is_checkpoint_legacy_sharded( - object_store=object_store, - source_path=load_path.replace(f's3://{s3_bucket}/', ''), - ) - if autoresume: load_path = None trainer2 = get_trainer( @@ -1015,10 +1005,6 @@ def test_elastic_resumption( else: save_folder = None sharded_load_path = os.path.join(base_path, 'ba2') - assert not is_checkpoint_legacy_sharded( - object_store=S3ObjectStore(bucket=f'{s3_bucket}'), - source_path=sharded_load_path.replace(f's3://{s3_bucket}/', ''), - ) sharded_trainer = get_trainer( save_folder=save_folder, @@ -1239,11 +1225,6 @@ def set_up_planner( load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2')) - assert not is_checkpoint_legacy_sharded( - object_store=None, - source_path=load_path, - ) - trainer2 = get_trainer( save_folder=str(save_folder), load_path=load_path, From 3eda9cf461cfde3ecedd8c913fa776cd34b23f3b Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Tue, 1 Oct 2024 14:42:07 -0400 Subject: [PATCH 11/17] Surface UC permission error (#3642) --- composer/utils/object_store/uc_object_store.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 138214e0e9..0bccb4a67e 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -29,10 +29,12 @@ def _wrap_errors(uri: str, e: Exception): # Wrap DatabricksError in ObjectStoreTransientError. # If the file is not found, raise FileNotFoundError. from databricks.sdk.errors import DatabricksError - from databricks.sdk.errors.platform import NotFound + from databricks.sdk.errors.platform import NotFound, PermissionDenied if isinstance(e, DatabricksError): if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore raise FileNotFoundError(f'Object {uri} not found') from e + if isinstance(e, PermissionDenied): + raise e raise ObjectStoreTransientError from e # Wrap ChunkedEncodingError in ObjectStoreTransientError. From f76c2fffa2f6bc0d1f81eb43c4b998ac4237b85f Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 2 Oct 2024 13:55:22 -0400 Subject: [PATCH 12/17] Tensor Parallelism Tests (#3620) Co-authored-by: Eitan Turok Co-authored-by: Mihir Patel --- tests/common/__init__.py | 6 + tests/common/datasets.py | 69 ++++++-- tests/common/models.py | 14 +- tests/trainer/test_tp.py | 330 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 398 insertions(+), 21 deletions(-) diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 05e4ae2629..d15dfa8afb 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -8,6 +8,7 @@ InfiniteClassificationDataset, ParityDataset, RandomClassificationDataset, + RandomClassificationDatasetReplicated, RandomImageDataset, RandomSegmentationDataset, RandomTextClassificationDataset, @@ -21,6 +22,7 @@ EmbeddedWeightTiedModel, EmptyModel, EvenSimplerMLP, + SimpleComposerMLP, SimpleConvModel, SimpleMLP, SimpleModel, @@ -28,6 +30,7 @@ SimpleTransformerClassifier, SimpleTransformerMaskedLM, SimpleWeightTiedModel, + TPSimpleComposerMLP, ZeroModel, composer_resnet, ) @@ -42,6 +45,7 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]: __all__ = [ 'assert_state_equivalent', 'RandomClassificationDataset', + 'RandomClassificationDatasetReplicated', 'RandomTextClassificationDataset', 'RandomTextLMDataset', 'RandomImageDataset', @@ -67,4 +71,6 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]: 'composer_resnet', 'SimpleMLP', 'EvenSimplerMLP', + 'SimpleComposerMLP', + 'TPSimpleComposerMLP', ] diff --git a/tests/common/datasets.py b/tests/common/datasets.py index eea4271543..69ee95b041 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from torchvision.datasets import VisionDataset -from composer.utils import dist +from composer.utils import dist, reproducibility from tests.common.models import configure_tiny_bert_tokenizer, configure_tiny_gpt2_tokenizer @@ -55,12 +55,19 @@ class RandomClassificationDataset(Dataset): num_classes (int): number of classes (default: 2) """ - def __init__(self, shape: Sequence[int] = (1, 1, 1), size: int = 100, num_classes: int = 2): - self.size = size - self.shape = shape - self.num_classes = num_classes - self.x = None - self.y = None + def __init__( + self, + shape: Sequence[int] = (1, 1, 1), + size: int = 100, + num_classes: int = 2, + device: Optional[torch.device] = None, + ): + self.size: int = size + self.shape: Sequence[int] = shape + self.num_classes: int = num_classes + self.device: Optional[torch.device] = device + self.x: Optional[torch.Tensor] = None + self.y: Optional[torch.Tensor] = None def __len__(self): return self.size @@ -69,12 +76,56 @@ def __getitem__(self, index: int): # Note: lazily generate data so it runs after Composer seeds everything, giving the same # dataset across multiple calls when using the same seed. if self.x is None: - self.x = torch.randn(self.size, *self.shape) + self.x = torch.randn( + self.size, + *self.shape, + device=self.device, + ) if self.y is None: - self.y = torch.randint(0, self.num_classes, size=(self.size,)) + self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device) return self.x[index], self.y[index] +class RandomClassificationDatasetReplicated(RandomClassificationDataset): + """Like RandomClassificationDataset but samples are replicated across tensor parallelism groups.""" + + def __init__( + self, + shape: Sequence[int] = (1, 1, 1), + size: int = 100, + num_classes: int = 2, + device: Optional[torch.device] = None, + seed: int = 44, + replication: Optional[int] = 2, + ): + super().__init__(shape, size, num_classes, device) + self.rank = dist.get_local_rank() + self.world_size = dist.get_world_size() + assert replication is not None + self.n_tp_groups = replication # the number of tp groups that we are replicating across + self.seed = seed + + def _generate_data(self): + tp_group_id = self.rank // self.n_tp_groups + seed = self.seed + tp_group_id # all ranks in the same TP group have the same seed + reproducibility.seed_all(seed) + self.x = torch.randn(self.size, *self.shape, device=self.device) + self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + if self.x is None and self.y is None: + self._generate_data() + + assert self.x is not None + assert self.y is not None + + rank_idx = idx // self.world_size + return self.x[rank_idx], self.y[rank_idx] + + class RandomImageDataset(VisionDataset): """ Image Classification dataset with values drawn from a normal distribution Args: diff --git a/tests/common/models.py b/tests/common/models.py index 2310a03a82..00a6be34a4 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -135,13 +135,25 @@ def forward(self, x): # test ComposerModels instead of nn.Module. class SimpleComposerMLP(ComposerClassifier): - def __init__(self, num_features: int, device: str, num_classes: int = 3): + def __init__(self, num_features: int, device: Union[str, torch.device], num_classes: int = 3): fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=False) + net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) + super().__init__(num_classes=num_classes, module=net) + +# Like SimpleComposerMLP but saves each layer which is necessary to TP to it. +class TPSimpleComposerMLP(ComposerClassifier): + + def __init__(self, num_features: int, device: Union[str, torch.device], num_classes: int = 3): + fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=False) net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) super().__init__(num_classes=num_classes, module=net) + self.fc1 = fc1 + self.fc2 = fc2 + class SimpleWeightTiedModel(ComposerClassifier): """Small classification model with tied weights. diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index b03d170a05..07f1ad1b37 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -2,22 +2,314 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +from typing import Any, Optional, Union +import numpy as np import pytest import torch from packaging import version -from torch.utils.data import DataLoader +from torch.distributed._tensor import Replicate +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Dataset +from composer.callbacks import MemoryMonitor +from composer.loggers import InMemoryLogger from composer.optim import DecoupledSGDW from composer.trainer.trainer import Trainer -from composer.utils import dist +from composer.utils import FSDPConfig, ParallelismConfig, TPConfig, dist, reproducibility from tests.common import ( RandomClassificationDataset, + RandomClassificationDatasetReplicated, SimpleModel, + TPSimpleComposerMLP, + deep_compare, world_size, ) +def get_base_trainer( + parallelism_config: Optional[ParallelismConfig] = None, + size: int = 4, + batch_size: int = 1, + num_classes: int = 2, + num_features: int = 2, + seed: int = 44, + device: Union[torch.device, str] = 'cuda', + replication: Optional[int] = None, +): + """Trainer for a simple model with any parallelism_config.""" + + reproducibility.seed_all(seed) + if isinstance(device, str): + device = torch.device(device) + + dataset: Dataset = RandomClassificationDatasetReplicated( + shape=(num_features,), + num_classes=num_classes, + size=size, + device=device, + replication=replication, + ) # X=(num_features,), y=(,), i.e. scalar + + dataloader = DataLoader( + dataset, + sampler=dist.get_sampler(dataset), + batch_size=batch_size, + ) # X=(batch_size, num_features), y=(batch_size,) + + model = TPSimpleComposerMLP(num_features=num_features, device=device, num_classes=num_classes) + + trainer = Trainer( + seed=seed, + device='gpu', + model=model, + max_duration='1ep', + train_dataloader=dataloader, + precision='fp32', + parallelism_config=parallelism_config, + callbacks=[MemoryMonitor()], + loggers=[InMemoryLogger()], + progress_bar=False, + log_to_console=False, + ) + + return trainer + + +def get_trainer( + parallelism_strategy: str, + size: int = 4, + batch_size: int = 1, + num_classes: int = 2, + num_features: int = 2, + seed: int = 44, + device: Union[torch.device, str] = 'cuda', + replication: Optional[int] = None, +) -> Trainer: + + if parallelism_strategy == 'ddp': + parallelism_config = None + elif parallelism_strategy == 'fsdp': + fsdp_config = FSDPConfig( + state_dict_type='full', + sharding_strategy='SHARD_GRAD_OP', + mixed_precision='full', + use_orig_params=True, + ) + parallelism_config = ParallelismConfig(fsdp=fsdp_config) + elif parallelism_strategy == 'tp-fsdp': + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + fsdp_config = FSDPConfig( + state_dict_type='full', + sharding_strategy='SHARD_GRAD_OP', + mixed_precision='full', + use_orig_params=True, + ) + layer_plan = { + 'fc1': ColwiseParallel(), + 'fc2': RowwiseParallel(), + } + tp_config = TPConfig( + layer_plan=layer_plan, + tensor_parallel_degree=1 if replication is None else replication, + ) + parallelism_config = ParallelismConfig(fsdp=fsdp_config, tp=tp_config) + else: + raise ValueError( + f'`parallelism_strategy` must be one of `ddp`, `fsdp`, `tp-fsdp` but was {parallelism_strategy=}.', + ) + + trainer = get_base_trainer( + size=size, + batch_size=batch_size, + num_classes=num_classes, + num_features=num_features, + seed=seed, + device=device, + replication=replication, + parallelism_config=parallelism_config, + ) + + return trainer + + +def forward_pass(trainer): + reproducibility.seed_all(trainer.state.seed) + batch = next(iter(trainer.state.train_dataloader)) + output = trainer.state.model.forward(batch) + return output + + +def _replace_state_dict_name(state_dict: dict[str, Any], old_name: str, new_name: str) -> dict[str, Any]: + keys = list(state_dict.keys()) + for key in keys: + if old_name in key: + new_key = key.replace(old_name, new_name, 1) + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def compare_models( + ddp_trainer: Trainer, + fsdp_trainer: Trainer, + tp_fsdp_trainer: Trainer, + check_grad: bool = False, + atol: float = 0.0, + rtol: float = 0.0, +): + + # Normally, we compare various models by their state_dict(). + # However, calling `tp_fsdp_trainer.state.state_dict()` directly causes a NCCL timeout + # due to this pytorch bug: https://github.com/pytorch/pytorch/issues/134095/. + # As a workaround, we use `tp_fsdp_trainer.state.model.named_parameters()` instead. + # This issues only exists with `tp_fsdp_trainer.state.state_dict()` and does not + # arise when calling `ddp_trainer.state.state_dict()` or `fsdp_trainer.state.state_dict()`. + with FSDP.summon_full_params(fsdp_trainer.state.model, with_grads=check_grad): + with FSDP.summon_full_params(tp_fsdp_trainer.state.model, with_grads=check_grad): + ddp_params = dict(ddp_trainer.state.model.named_parameters()) + fsdp_params = dict(fsdp_trainer.state.model.named_parameters()) + tp_fsdp_params = dict(tp_fsdp_trainer.state.model.named_parameters()) + + # patch the state dict names: + # - ddp adds an extra 'module.' to all param names + # - fsdp adds an extra '_fsdp_wrapped_module.' to all param names + # - tp-fsdp adds an extra '_fsdp_wrapped_module.' to all param names + ddp_params = _replace_state_dict_name(ddp_params, 'module.', '') + fsdp_params = _replace_state_dict_name(fsdp_params, '_fsdp_wrapped_module.', '') + tp_fsdp_params = _replace_state_dict_name(tp_fsdp_params, '_fsdp_wrapped_module.', '') + + # check grad + if check_grad: + + def get_grads(params): + return {name: param.grad for name, param in params.items()} + + ddp_params = get_grads(ddp_params) + fsdp_params = get_grads(fsdp_params) + tp_fsdp_params = get_grads(tp_fsdp_params) + + # collect tensors from different ranks for comparison + tp_fsdp_params = { + name: param.redistribute(device_mesh=param.device_mesh, placements=[Replicate()]).to_local() + for name, param in tp_fsdp_params.items() + } + + deep_compare(ddp_params, fsdp_params, atol=atol, rtol=rtol) + deep_compare(tp_fsdp_params, fsdp_params, atol=atol, rtol=rtol) + deep_compare(ddp_params, fsdp_params, atol=atol, rtol=rtol) + + +def get_stats(trainer: Trainer) -> dict[str, np.ndarray]: + logger = trainer.logger.destinations[0] + stats = { + 'loss_array': + logger.get_timeseries('loss/train/total')['loss/train/total'], # type: ignore + 'accuracy_array': + logger.get_timeseries('metrics/train/MulticlassAccuracy') # type: ignore + ['metrics/train/MulticlassAccuracy'], + } + return stats + + +@pytest.mark.gpu +@world_size(4) +@pytest.mark.parametrize('replication', [2]) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') +@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') +def test_tp_forwards_backwards_correctness(world_size: int, replication: int): + """Test that training with DDP, FSDP, TP-FSDP results in the same: + - initial weights + - forward pass + - gradients + - updated weights + after training for a single step via manually doing forward, backward pass. + """ + + # Initialize trainers with DDP, FSDP, TP-FSDP + ddp_trainer = get_trainer('ddp', replication=replication) + fsdp_trainer = get_trainer('fsdp', replication=replication) + tp_fsdp_trainer = get_trainer('tp-fsdp', replication=replication) + + # Ensure initial model weights are the same + compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer) + + # Forward pass + ddp_out = forward_pass(ddp_trainer) + fsdp_out = forward_pass(fsdp_trainer) + tp_fsdp_out = forward_pass(tp_fsdp_trainer) + + # Ensure output of the forward pass is the same + deep_compare(ddp_out, fsdp_out) + deep_compare(ddp_out, tp_fsdp_out) + deep_compare(fsdp_out, tp_fsdp_out) + + # Compute gradients + torch.sum(ddp_out).backward() + torch.sum(fsdp_out).backward() + torch.sum(tp_fsdp_out).backward() + + # Ensure the model gradients are the same + compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, check_grad=True) + + # Update the model weights + ddp_trainer.state.optimizers[0].step() + fsdp_trainer.state.optimizers[0].step() + tp_fsdp_trainer.state.optimizers[0].step() + + # Ensure the updated model weights are the same + compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer) + + +@pytest.mark.gpu +@world_size(4) +@pytest.mark.parametrize('replication', [2]) +@pytest.mark.parametrize('batch_size', [1, 4]) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') +@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') +def test_tp_fit_correctness(world_size: int, batch_size: int, replication: int): + """Test that training with DDP, FSDP, TP-FSDP results in the same: + - updated weights + - loss + - accuracy + after training for multiple steps via trainer.fit(). + """ + + # Initialize number of samples in the dataset + # train_steps = 20 # number of steps to train for + train_steps = 20 + samples_per_batch = world_size * batch_size // replication + dataset_size = samples_per_batch * train_steps + + # DDP fit + ddp_trainer = get_trainer('ddp', size=dataset_size, batch_size=batch_size, replication=replication) + ddp_trainer.fit() + ddp_trainer.close() + ddp_stats = get_stats(ddp_trainer) + + # FSDP fit + fsdp_trainer = get_trainer('fsdp', size=dataset_size, batch_size=batch_size, replication=replication) + fsdp_trainer.fit() + fsdp_trainer.close() + fsdp_stats = get_stats(fsdp_trainer) + + # TP-FSDP fit + tp_fsdp_trainer = get_trainer('tp-fsdp', size=dataset_size, batch_size=batch_size, replication=replication) + tp_fsdp_trainer.fit() + tp_fsdp_trainer.close() + tp_fsdp_stats = get_stats(tp_fsdp_trainer) + + # Ensure the updated models weights are the same + # Drop tolerance due to precision issues across different parallelism strategies + compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, atol=1e-5, rtol=1e-3) + + # Compare the loss, accuracy stats + # Drop tolerance due to precision issues across different parallelism strategies + deep_compare(ddp_stats, fsdp_stats, atol=6e-5) + deep_compare(tp_fsdp_stats, fsdp_stats, atol=6e-5) + deep_compare(ddp_stats, tp_fsdp_stats, atol=6e-5) + + @pytest.mark.gpu @world_size(4) @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') @@ -26,6 +318,9 @@ def test_tp_train(world_size: int, tensor_parallel_degree: int): from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + # For TP to produce the correct result, each TP rank receives the same data + # In this test, TP ranks receive different data as we are testing the TP + # mechanism, not actual TP correctness. model = SimpleModel() optimizer = DecoupledSGDW(model.parameters(), lr=0.1) dataset = RandomClassificationDataset(size=8) @@ -67,6 +362,9 @@ def test_tp_train(world_size: int, tensor_parallel_degree: int): def test_tp_with_param_groups(world_size: int): from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + # For TP to produce the correct result, each TP rank receives the same data + # In this test, TP ranks receive different data as we are testing the TP + # mechanism, not actual TP correctness. model = SimpleModel() dataset = RandomClassificationDataset(size=8) dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset)) @@ -91,23 +389,23 @@ def test_tp_with_param_groups(world_size: int): optimizers=optimizer, train_dataloader=dataloader, parallelism_config={ - 'tp': { - 'layer_plan': layer_plan, - 'tensor_parallel_degree': 2, - }, + 'tp': TPConfig(layer_plan=layer_plan, tensor_parallel_degree=2), 'fsdp': {}, }, max_duration='3ba', ) -@pytest.mark.gpu @world_size(4) +@pytest.mark.gpu @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_with_subset_of_params(world_size: int): from torch.distributed.tensor.parallel import ColwiseParallel + # For TP to produce the correct result, each TP rank receives the same data + # In this test, TP ranks receive different data as we are testing the TP + # mechanism, not actual TP correctness. model = SimpleModel() dataset = RandomClassificationDataset(size=8) dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset)) @@ -125,11 +423,21 @@ def test_tp_with_subset_of_params(world_size: int): optimizers=optimizer, train_dataloader=dataloader, parallelism_config={ - 'tp': { - 'layer_plan': layer_plan, - 'tensor_parallel_degree': 2, - }, + 'tp': TPConfig(layer_plan=layer_plan, tensor_parallel_degree=2), 'fsdp': {}, }, max_duration='3ba', ) + + +@world_size(4) +@pytest.mark.gpu +@pytest.mark.skip('This is broken due to https://github.com/pytorch/pytorch/issues/134095/.') +@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') +def test_tp_fsdp_state_dict(world_size: int): + tp_fsdp_trainer = get_trainer('tp_fsdp', replication=2) + tp_fsdp_state_dict1 = tp_fsdp_trainer.state.state_dict() # work sometimes, fails sometimes + with FSDP.summon_full_params(tp_fsdp_trainer.state.model, with_grads=True): + tp_fsdp_state_dict2 = tp_fsdp_trainer.state.state_dict() # fails always + + deep_compare(tp_fsdp_state_dict1['model'], tp_fsdp_state_dict2['model']) From cf3844fc918fd88d82103c89294b22195c1e0aba Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 3 Oct 2024 22:52:10 -0400 Subject: [PATCH 13/17] Switch to log.info for deterministic mode (#3643) Co-authored-by: Saaketh Narayan --- composer/utils/reproducibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index 7eb0037475..4158b2ceef 100644 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -127,7 +127,7 @@ def configure_deterministic_mode(): # See https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html # and https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' - warnings.warn('Deterministic mode is activated. This will negatively impact performance.', category=UserWarning) + log.info('Deterministic mode is activated. This will negatively impact performance.') def get_random_seed() -> int: From b89d699a40614a2ffa344d84ebc597e1332d4c15 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 07:20:50 -0700 Subject: [PATCH 14/17] Update pre-commit requirement from <4,>=3.4.0 to >=3.4.0,<5 (#3645) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d676b781a6..3ebcd94851 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ def package_files(prefix: str, directory: str, extension: str): 'yamllint==1.35.1', 'recommonmark==0.7.1', 'sphinx==4.4.0', - 'pre-commit>=3.4.0,<4', + 'pre-commit>=3.4.0,<5', # embedding md in rst require docutils>=0.17. See # https://myst-parser.readthedocs.io/en/latest/sphinx/use.html?highlight=parser#include-markdown-files-into-an-rst-file 'docutils==0.17.1', From 4e1ec1798845cadace5ee431b9a1b7b6208262de Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 07:36:45 -0700 Subject: [PATCH 15/17] Update peft requirement from <0.13,>=0.10.0 to >=0.10.0,<0.14 (#3646) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3ebcd94851..7ba4a2d569 100644 --- a/setup.py +++ b/setup.py @@ -186,7 +186,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['peft'] = [ - 'peft>=0.10.0,<0.13', + 'peft>=0.10.0,<0.14', ] extra_deps['sentencepiece'] = [ From bb7ea43d3b9e57162f7cb820f10539239ba9caed Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 7 Oct 2024 15:11:03 -0700 Subject: [PATCH 16/17] Create callback to load checkpoint (#3641) --- composer/callbacks/__init__.py | 2 + composer/callbacks/load_checkpoint.py | 76 +++++++++++++++++++ pyproject.toml | 4 +- tests/callbacks/callback_settings.py | 14 ++++ tests/callbacks/test_callbacks.py | 23 ++++-- tests/callbacks/test_load_checkpoint.py | 68 +++++++++++++++++ .../test_loggers_across_callbacks.py | 28 ++++--- tests/fixtures/fixtures.py | 18 +++++ tests/loggers/test_mosaicml_logger.py | 27 ++++--- tests/loggers/test_wandb_logger.py | 28 ++++--- tests/models/test_hf_model.py | 19 +---- 11 files changed, 253 insertions(+), 54 deletions(-) create mode 100644 composer/callbacks/load_checkpoint.py create mode 100644 tests/callbacks/test_load_checkpoint.py diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index 16a50a31a9..30b7053ae1 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -13,6 +13,7 @@ from composer.callbacks.free_outputs import FreeOutputs from composer.callbacks.generate import Generate from composer.callbacks.image_visualizer import ImageVisualizer +from composer.callbacks.load_checkpoint import LoadCheckpoint from composer.callbacks.lr_monitor import LRMonitor from composer.callbacks.memory_monitor import MemoryMonitor from composer.callbacks.memory_snapshot import MemorySnapshot @@ -44,4 +45,5 @@ 'FreeOutputs', 'MemorySnapshot', 'OOMObserver', + 'LoadCheckpoint', ] diff --git a/composer/callbacks/load_checkpoint.py b/composer/callbacks/load_checkpoint.py new file mode 100644 index 0000000000..148746a19e --- /dev/null +++ b/composer/callbacks/load_checkpoint.py @@ -0,0 +1,76 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Load a checkpoint.""" +import logging +from typing import Optional, Union + +from composer.core import Callback, State +from composer.core.event import Event +from composer.loggers import Logger +from composer.models.huggingface import HuggingFaceModel +from composer.utils.checkpoint import load_checkpoint +from composer.utils.file_helpers import maybe_create_object_store_from_uri, parse_uri + +log = logging.getLogger(__name__) + + +class LoadCheckpoint(Callback): + """Callback that loads a checkpoint at the specified event. + + Args: + load_path (str): The path to the checkpoint to load. + load_options (Optional[dict]): A dictionary of options to pass to the checkpoint loading function. + event (Union[str, Event]): The event at which to load the checkpoint. Defaults to ``Event.BEFORE_LOAD``. + """ + + def __init__( + self, + load_path: str, + load_weights_only: bool = False, + strict_model_weights: bool = True, + ignore_keys: Optional[list[str]] = None, + event: Union[str, Event] = Event.BEFORE_LOAD, + ): + super().__init__() + self.load_path = load_path + self.load_object_store = maybe_create_object_store_from_uri(load_path) + _, _, self.parsed_path = parse_uri(load_path) + + self.load_weights_only = load_weights_only + self.strict_model_weights = strict_model_weights + self.ignore_keys = ignore_keys + + self.event = event if isinstance(event, Event) else Event[event.upper()] + + def run_event(self, event: Event, state: State, logger: Logger) -> None: + if event == self.event: + log.info(f'Loading checkpoint from {self.load_path} at {self.event}.') + self._load(state, logger) + log.info(f'Finished loading checkpoint from {self.load_path} at {self.event}.') + + return super().run_event(event, state, logger) + + def _load(self, state: State, logger: Logger) -> None: + + # We need to temporarily disable the `should_save_peft_only` flag on the model + # so that we can have access to the full model weights for loading. + model = state.model + original_should_save_peft_only = False + if isinstance(model, HuggingFaceModel): + original_should_save_peft_only = model.should_save_peft_only + model.should_save_peft_only = False + + load_checkpoint( + path=self.parsed_path, + state=state, + logger=logger, + object_store=self.load_object_store, + strict_model_weights=self.strict_model_weights, + ignore_keys=self.ignore_keys, + load_weights_only=self.load_weights_only, + ) + + # Restore the original `should_save_peft_only` flag on the model + if isinstance(model, HuggingFaceModel): + model.should_save_peft_only = original_should_save_peft_only diff --git a/pyproject.toml b/pyproject.toml index 1c3b82a699..76f6986c82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,9 @@ filterwarnings = [ '''ignore:'.*_state_dict' is deprecated and will be removed in future versions.*:UserWarning''', # Ignore mlflow warnings about transformers versions, '''ignore:The 'transformers' MLflow Models integration.*:UserWarning''', - # Ignore our own deprecation warnings, + # Ignore the flash v3 warnings from transformer engine + '''ignore:To use flash-attn v3*:UserWarning''', + # Ignore our own deprecation warnings '''ignore::composer.utils.warnings.VersionedDeprecationWarning''', # Ignore deprecation warning for torch.load '''ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning''', diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py index 4b3ef478a2..18be26a64d 100644 --- a/tests/callbacks/callback_settings.py +++ b/tests/callbacks/callback_settings.py @@ -1,8 +1,10 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import os from typing import Any +from unittest import mock from unittest.mock import MagicMock import pytest @@ -26,6 +28,7 @@ SystemMetricsMonitor, ThresholdStopper, ) +from composer.callbacks.load_checkpoint import LoadCheckpoint from composer.loggers import ( CometMLLogger, ConsoleLogger, @@ -155,6 +158,9 @@ 'trace_handlers': [MagicMock()], 'schedule': composer.profiler.cyclic_schedule(), }, + LoadCheckpoint: { + 'load_path': 'fake-path', + }, } _callback_marks: dict[ @@ -201,6 +207,14 @@ NeptuneLogger: [pytest.mark.skipif(not _NEPTUNE_INSTALLED, reason='neptune is optional')], } +_callback_patches: dict[type[Callback], Any] = { + LoadCheckpoint: mock.patch('composer.callbacks.load_checkpoint.load_checkpoint'), +} + + +def get_cb_patches(impl: type[Callback]): + return _callback_patches.get(impl, contextlib.nullcontext()) + def get_cb_kwargs(impl: type[Callback]): return _callback_kwargs.get(impl, {}) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 6df2637e69..7a76fcd6b5 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -10,7 +10,12 @@ from composer.loggers import Logger, LoggerDestination from composer.profiler import Profiler, ProfilerAction from composer.trainer import Trainer -from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks +from tests.callbacks.callback_settings import ( + get_cb_kwargs, + get_cb_model_and_datasets, + get_cb_patches, + get_cbs_and_marks, +) from tests.common import EventCounterCallback @@ -154,8 +159,12 @@ def test_trains(self, cb_cls: type[Callback], device_train_microbatch_size: int, del _remote # unused. `_remote` must be passed through to parameterize the test markers. cb_kwargs = get_cb_kwargs(cb_cls) cb = cb_cls(**cb_kwargs) - trainer = self._get_trainer(cb, device_train_microbatch_size) - trainer.fit() + + maybe_patch_context = get_cb_patches(cb_cls) + + with maybe_patch_context: + trainer = self._get_trainer(cb, device_train_microbatch_size) + trainer.fit() @pytest.mark.filterwarnings('ignore::UserWarning') def test_trains_multiple_calls(self, cb_cls: type[Callback], device_train_microbatch_size: int, _remote: bool): @@ -167,8 +176,12 @@ def test_trains_multiple_calls(self, cb_cls: type[Callback], device_train_microb del _remote # unused. `_remote` must be passed through to parameterize the test markers. cb_kwargs = get_cb_kwargs(cb_cls) cb = cb_cls(**cb_kwargs) - trainer = self._get_trainer(cb, device_train_microbatch_size) - trainer.fit() + + maybe_patch_context = get_cb_patches(cb_cls) + + with maybe_patch_context: + trainer = self._get_trainer(cb, device_train_microbatch_size) + trainer.fit() assert trainer.state.max_duration is not None trainer.state.max_duration = cast(Time[int], trainer.state.max_duration * 2) diff --git a/tests/callbacks/test_load_checkpoint.py b/tests/callbacks/test_load_checkpoint.py new file mode 100644 index 0000000000..e93840a7e5 --- /dev/null +++ b/tests/callbacks/test_load_checkpoint.py @@ -0,0 +1,68 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +from unittest import mock +from unittest.mock import call + +from torch.utils.data import DataLoader + +from composer.callbacks import LoadCheckpoint +from composer.core.state import State +from composer.models.huggingface import HuggingFaceModel +from composer.trainer.trainer import Trainer +from tests.common.datasets import RandomTextLMDataset + + +def test_load_checkpoint_callback( + tiny_gpt2_model, + tiny_gpt2_tokenizer, + gpt2_peft_config, +): + + model = HuggingFaceModel( + tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + peft_config=gpt2_peft_config, + should_save_peft_only=True, + ) + + # Function to check the arguments passed to the load_checkpoint function. + def check_callback_load_args(state: State, **kwargs): + assert state.model == model + + # Check that the `should_save_peft_only` flag on the model was set to False when loading the checkpoint. + assert state.model.should_save_peft_only == False + + # Patch the load_checkpoint function to check the arguments passed to it. + with mock.patch( + 'composer.callbacks.load_checkpoint.load_checkpoint', + new=mock.MagicMock(wraps=check_callback_load_args), + ) as callback_load: + with mock.patch('composer.trainer.trainer.checkpoint.load_checkpoint') as trainer_load: + + calls = mock.MagicMock() + calls.attach_mock(trainer_load, 'trainer_load') + calls.attach_mock(callback_load, 'callback_load') + + Trainer( + model=model, + callbacks=[LoadCheckpoint( + load_path='fake-path', + event='BEFORE_LOAD', + )], + train_dataloader=DataLoader(RandomTextLMDataset()), + max_duration='1ba', + load_path='fake_path', + ) + + callback_load.assert_called_once() + trainer_load.assert_called_once() + + # Assert that the callback_load and trainer_load functions were called in the correct order. + assert calls.mock_calls == [ + call.callback_load(**callback_load.call_args.kwargs), + call.trainer_load(**trainer_load.call_args.kwargs), + ] + + # Check that the `should_save_peft_only` flag on the model was reset to its original value after loading the checkpoint. + assert model.should_save_peft_only == True diff --git a/tests/callbacks/test_loggers_across_callbacks.py b/tests/callbacks/test_loggers_across_callbacks.py index 5286d4f207..17886874b1 100644 --- a/tests/callbacks/test_loggers_across_callbacks.py +++ b/tests/callbacks/test_loggers_across_callbacks.py @@ -9,7 +9,12 @@ from composer.loggers import ConsoleLogger, LoggerDestination, ProgressBarLogger, SlackLogger from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.trainer import Trainer -from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks +from tests.callbacks.callback_settings import ( + get_cb_kwargs, + get_cb_model_and_datasets, + get_cb_patches, + get_cbs_and_marks, +) @pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True)) @@ -27,12 +32,15 @@ def test_loggers_on_callbacks(logger_cls: type[LoggerDestination], callback_cls: callback_kwargs = get_cb_kwargs(callback_cls) callback = callback_cls(**callback_kwargs) model, train_dataloader, _ = get_cb_model_and_datasets(callback) - trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - train_subset_num_batches=2, - max_duration='1ep', - callbacks=callback, - loggers=logger, - ) - trainer.fit() + maybe_patch_context = get_cb_patches(callback_cls) + + with maybe_patch_context: + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + train_subset_num_batches=2, + max_duration='1ep', + callbacks=callback, + loggers=logger, + ) + trainer.fit() diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index c4dd3fa65f..b6ae26c996 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -424,6 +424,24 @@ def tiny_gpt2_model(_session_tiny_gpt2_model): return copy.deepcopy(_session_tiny_gpt2_model) +def _gpt2_peft_config(): + pytest.importorskip('peft') + from peft import get_peft_config + + peft_config = get_peft_config({ + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'target_modules': ['c_attn'], + 'fan_in_fan_out': True, + }) + return peft_config + + +@pytest.fixture +def gpt2_peft_config(): + return _gpt2_peft_config() + + @pytest.fixture def tiny_opt_config(_session_tiny_opt_config): return copy.deepcopy(_session_tiny_opt_config) diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py index f04d4acf17..53311b3334 100644 --- a/tests/loggers/test_mosaicml_logger.py +++ b/tests/loggers/test_mosaicml_logger.py @@ -21,7 +21,12 @@ ) from composer.trainer import Trainer from composer.utils import dist, get_composer_env_dict -from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks +from tests.callbacks.callback_settings import ( + get_cb_kwargs, + get_cb_model_and_datasets, + get_cb_patches, + get_cbs_and_marks, +) from tests.common import RandomClassificationDataset, SimpleModel from tests.common.markers import world_size @@ -121,15 +126,17 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: type[Callba callback = callback_cls(**callback_kwargs) train_dataset = RandomClassificationDataset() model, train_dataloader, _ = get_cb_model_and_datasets(callback, sampler=dist.get_sampler(train_dataset)) - trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - train_subset_num_batches=1, - max_duration='1ep', - callbacks=callback, - loggers=MosaicMLLogger(), - ) - trainer.fit() + maybe_patch_context = get_cb_patches(callback_cls) + with maybe_patch_context: + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + train_subset_num_batches=1, + max_duration='1ep', + callbacks=callback, + loggers=MosaicMLLogger(), + ) + trainer.fit() if dist.get_global_rank() == 0: assert len(mock_mapi.run_metadata[run_name].keys()) > 0 diff --git a/tests/loggers/test_wandb_logger.py b/tests/loggers/test_wandb_logger.py index b0462fc842..79186b661b 100644 --- a/tests/loggers/test_wandb_logger.py +++ b/tests/loggers/test_wandb_logger.py @@ -21,7 +21,12 @@ from composer.loggers import InMemoryLogger, Logger, WandBLogger from composer.trainer import Trainer from composer.utils import dist -from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks +from tests.callbacks.callback_settings import ( + get_cb_kwargs, + get_cb_model_and_datasets, + get_cb_patches, + get_cbs_and_marks, +) from tests.common.datasets import RandomImageDataset from tests.common.models import SimpleConvModel @@ -290,15 +295,18 @@ def test_logged_data_is_json_serializable(callback_cls: type[Callback]): callback = callback_cls(**callback_kwargs) logger = InMemoryLogger() # using an in memory logger to manually validate json serializability model, train_dataloader, _ = get_cb_model_and_datasets(callback) - trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - train_subset_num_batches=2, - max_duration='1ep', - callbacks=callback, - loggers=logger, - ) - trainer.fit() + maybe_patch_context = get_cb_patches(callback_cls) + + with maybe_patch_context: + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + train_subset_num_batches=2, + max_duration='1ep', + callbacks=callback, + loggers=logger, + ) + trainer.fit() for log_calls in logger.data.values(): for _, data in log_calls: diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index 89a892f452..27873d573c 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -35,30 +35,13 @@ configure_tiny_t5_model, configure_tiny_t5_tokenizer, ) +from tests.fixtures.fixtures import _gpt2_peft_config from tests.loggers.test_remote_uploader_downloader import DummyObjectStore if TYPE_CHECKING: from peft import PeftConfig -def _gpt2_peft_config(): - pytest.importorskip('peft') - from peft import get_peft_config - - peft_config = get_peft_config({ - 'peft_type': 'LORA', - 'task_type': 'CAUSAL_LM', - 'target_modules': ['c_attn'], - 'fan_in_fan_out': True, - }) - return peft_config - - -@pytest.fixture -def gpt2_peft_config(): - return _gpt2_peft_config() - - def _mpt_peft_config(): pytest.importorskip('peft') from peft import get_peft_config From 5538d2c31bdf00ff141e840b1a83e36ac3321d8c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:14:01 -0400 Subject: [PATCH 17/17] Bump jupyter from 1.0.0 to 1.1.1 (#3595) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Mihir Patel --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7ba4a2d569..789f3e9720 100644 --- a/setup.py +++ b/setup.py @@ -108,7 +108,7 @@ def package_files(prefix: str, directory: str, extension: str): 'pytest==7.4.4', 'ipython==8.11.0', 'ipykernel==6.29.5', - 'jupyter==1.0.0', + 'jupyter==1.1.1', 'yamllint==1.35.1', 'recommonmark==0.7.1', 'sphinx==4.4.0',