From 693ef1028f764c97afad427dbc28819995535fd8 Mon Sep 17 00:00:00 2001 From: Anna Date: Mon, 11 Dec 2023 10:06:31 -0800 Subject: [PATCH 1/2] Remove from mcli.sdk imports (#793) --- .github/mcp/mcp_pytest.py | 12 +++++++----- scripts/train/benchmarking/submit_benchmarks.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/.github/mcp/mcp_pytest.py b/.github/mcp/mcp_pytest.py index 7571200d8d..e7d4270272 100644 --- a/.github/mcp/mcp_pytest.py +++ b/.github/mcp/mcp_pytest.py @@ -6,8 +6,8 @@ import argparse import time -from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs, - wait_for_run_status) +from mcli import (RunConfig, RunStatus, create_run, follow_run_logs, + wait_for_run_status) if __name__ == '__main__': @@ -107,9 +107,11 @@ config = RunConfig( name=name, - cluster=args.cluster, - gpu_type=args.gpu_type, - gpu_num=args.gpu_num, + compute={ + 'cluster': args.cluster, + 'gpu_type': args.gpu_type, + 'gpus': args.gpu_num + }, image=args.image, integrations=[git_integration], command=command, diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index b8fc1eab96..bfff10165a 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -7,8 +7,8 @@ import requests import yaml -from mcli.models.run_config import SchedulingConfig -from mcli.sdk import RunConfig, create_run, get_clusters + +from mcli import RunConfig, SchedulingConfig, create_run, get_clusters def _get_cluster_info(): @@ -470,9 +470,11 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], parameters['model']['fc_type'] = 'te' # Create run config mcli sdk/api config = RunConfig(name=name, - gpu_type=gpu_type, - gpu_num=gpu_num, - cluster=cluster, + compute={ + 'cluster': cluster, + 'gpu_type': gpu_type, + 'gpus': gpu_num + }, image=args.image, integrations=integrations, command=command, From 34ec2f7399cd3eb34ce426ae0cfa4f3fd715d8d8 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 11 Dec 2023 10:50:00 -0800 Subject: [PATCH 2/2] Auto packing fixes (#783) * fix * move to auto only * Fixes + tests * code quality * remove torch 2 --- llmfoundry/data/packing.py | 37 +++++++++++++++++++++++------- scripts/misc/profile_packing.py | 2 +- tests/data/test_packing.py | 40 +++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 45322c9b2f..d3084c72c8 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -1,6 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import logging +import os +import tempfile from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np @@ -8,6 +11,8 @@ from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase +log = logging.getLogger(__name__) + class BinPackCollator: """Utility collator for packing to reduce padding.""" @@ -289,8 +294,13 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, # Set the seed so that auto packing is deterministic. reproducibility.seed_all(0) + max_seq_len = dataloader_cfg.dataset.max_seq_len + # If max_seq_len is very small, skip profiling and select packing ratio of 1. + if max_seq_len <= 100: + return 1 + min_ratio = 1 - max_ratio = dataloader_cfg.dataset.max_seq_len / 100 + max_ratio = max_seq_len / 100 profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, max_ratio, num_packing_ratios, device_batch_size) @@ -299,7 +309,7 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, # profiling_results are sorted from smallest to largest packing_ratio. packing_ratio = 1 for packing_ratio_candidate, _, waste in profiling_results: - if waste > 0: + if waste is None or waste > 0: break packing_ratio = packing_ratio_candidate @@ -318,9 +328,10 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, def profile_packing( - dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - min_ratio: float, max_ratio: float, num_packing_ratios: int, - device_batch_size: int) -> Iterable[Tuple[float, float, float]]: + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + min_ratio: float, max_ratio: float, num_packing_ratios: int, + device_batch_size: int +) -> Iterable[Tuple[float, Optional[float], Optional[float]]]: """Generator function that profiles example packing across packing ratios. Args: @@ -350,6 +361,10 @@ def profile_packing( dataloader_cfg.prefetch_factor = None dataloader_cfg.persistent_workers = False + # If streaming dataset, use a temporary local folder for profiling + if dataloader_cfg.dataset.get('remote') is not None: + dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name + # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] for packing_ratio in np.linspace(min_ratio, @@ -382,7 +397,7 @@ def split_big_batch(raw_batch_size: int) -> List: batches[idx].update({key: split}) return batches - def profile(raw_batch_size: int) -> Tuple[float, float]: + def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: packer = BinPackCollator( collator=lambda x: x, target_batch_size=device_batch_size, @@ -395,9 +410,15 @@ def profile(raw_batch_size: int) -> Tuple[float, float]: for batch in split_big_batch(raw_batch_size): if batch['input_ids'].shape[0] < device_batch_size: continue - _ = packer.pack(batch) + packer.pack(batch) + + if packer.n_packed_examples == 0: + log.debug( + 'No examples packed during profiling. Dataset is smaller than device batch size.' + ) + return None, None - # Return the padding / waste stats over that bunch of data + # Return the padding and waste stats over that bunch of data padding_percent = 100 * (1 - packer.efficiency) waste_percent = 100 * packer.waste return padding_percent, waste_percent diff --git a/scripts/misc/profile_packing.py b/scripts/misc/profile_packing.py index 51841d669e..fff10d158b 100644 --- a/scripts/misc/profile_packing.py +++ b/scripts/misc/profile_packing.py @@ -26,7 +26,7 @@ def parse_args() -> Namespace: help='Path to the YAML that defines the workload to profile.') parser.add_argument('--num-devices', type=int, - default=None, + required=True, help='How many devices your run will use.') parser.add_argument('--min', type=float, diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index a86d88f360..963f8e56b6 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path from typing import Any, Dict, List from unittest.mock import Mock, patch @@ -9,6 +10,7 @@ from composer.utils import dist, reproducibility from omegaconf import DictConfig from pytest import approx +from streaming import MDSWriter from torch.utils.data import DataLoader from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader @@ -149,6 +151,44 @@ def patched_packing_ratio(*args: Any, **kwargs: Any): return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4) +@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio', + patched_packing_ratio) +def test_auto_packing_with_streaming_dataloader(tmp_path: Path): + columns = {'prompt': 'str', 'response': 'str'} + tokenizer = build_tokenizer('gpt2', {}) + remote_dir = str(tmp_path / 'remote') + local_dir = str(tmp_path / 'local') + with MDSWriter(out=remote_dir, columns=columns, compression=None) as out: + out.write({'prompt': 'HELLO', 'response': 'WORLD'}) + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'remote': remote_dir, + 'local': local_dir, + 'packing_ratio': 'auto', + 'max_seq_len': 200, + 'decoder_only_format': True + }, + 'drop_last': False, + # Need to test with 0 num_workers because the packing collator object + # Gets copied per worker and we cannot check the waste for child processes. + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + }) + + loader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size=6).dataloader + + batch_ix = 0 + for _ in loader: + batch_ix += 1 + if batch_ix >= 3: + break + + @pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) @patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio', patched_packing_ratio)