Skip to content

Commit

Permalink
Merge branch 'main' into shashank/fix_FA_CE
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 11, 2023
2 parents eb1bb73 + 34ec2f7 commit 02f7d09
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 19 deletions.
12 changes: 7 additions & 5 deletions .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import argparse
import time

from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs,
wait_for_run_status)
from mcli import (RunConfig, RunStatus, create_run, follow_run_logs,
wait_for_run_status)

if __name__ == '__main__':

Expand Down Expand Up @@ -107,9 +107,11 @@

config = RunConfig(
name=name,
cluster=args.cluster,
gpu_type=args.gpu_type,
gpu_num=args.gpu_num,
compute={
'cluster': args.cluster,
'gpu_type': args.gpu_type,
'gpus': args.gpu_num
},
image=args.image,
integrations=[git_integration],
command=command,
Expand Down
37 changes: 29 additions & 8 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)


class BinPackCollator:
"""Utility collator for packing to reduce padding."""
Expand Down Expand Up @@ -289,8 +294,13 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
# Set the seed so that auto packing is deterministic.
reproducibility.seed_all(0)

max_seq_len = dataloader_cfg.dataset.max_seq_len
# If max_seq_len is very small, skip profiling and select packing ratio of 1.
if max_seq_len <= 100:
return 1

min_ratio = 1
max_ratio = dataloader_cfg.dataset.max_seq_len / 100
max_ratio = max_seq_len / 100
profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio,
max_ratio, num_packing_ratios,
device_batch_size)
Expand All @@ -299,7 +309,7 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
# profiling_results are sorted from smallest to largest packing_ratio.
packing_ratio = 1
for packing_ratio_candidate, _, waste in profiling_results:
if waste > 0:
if waste is None or waste > 0:
break
packing_ratio = packing_ratio_candidate

Expand All @@ -318,9 +328,10 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,


def profile_packing(
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int) -> Iterable[Tuple[float, float, float]]:
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int
) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
"""Generator function that profiles example packing across packing ratios.
Args:
Expand Down Expand Up @@ -350,6 +361,10 @@ def profile_packing(
dataloader_cfg.prefetch_factor = None
dataloader_cfg.persistent_workers = False

# If streaming dataset, use a temporary local folder for profiling
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
for packing_ratio in np.linspace(min_ratio,
Expand Down Expand Up @@ -382,7 +397,7 @@ def split_big_batch(raw_batch_size: int) -> List:
batches[idx].update({key: split})
return batches

def profile(raw_batch_size: int) -> Tuple[float, float]:
def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
packer = BinPackCollator(
collator=lambda x: x,
target_batch_size=device_batch_size,
Expand All @@ -395,9 +410,15 @@ def profile(raw_batch_size: int) -> Tuple[float, float]:
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer.pack(batch)
packer.pack(batch)

if packer.n_packed_examples == 0:
log.debug(
'No examples packed during profiling. Dataset is smaller than device batch size.'
)
return None, None

# Return the padding / waste stats over that bunch of data
# Return the padding and waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return padding_percent, waste_percent
Expand Down
2 changes: 1 addition & 1 deletion scripts/misc/profile_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def parse_args() -> Namespace:
help='Path to the YAML that defines the workload to profile.')
parser.add_argument('--num-devices',
type=int,
default=None,
required=True,
help='How many devices your run will use.')
parser.add_argument('--min',
type=float,
Expand Down
12 changes: 7 additions & 5 deletions scripts/train/benchmarking/submit_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import requests
import yaml
from mcli.models.run_config import SchedulingConfig
from mcli.sdk import RunConfig, create_run, get_clusters

from mcli import RunConfig, SchedulingConfig, create_run, get_clusters


def _get_cluster_info():
Expand Down Expand Up @@ -470,9 +470,11 @@ def run_config(config: Tuple[str, int, int, str, str, int, str],
parameters['model']['fc_type'] = 'te'
# Create run config mcli sdk/api
config = RunConfig(name=name,
gpu_type=gpu_type,
gpu_num=gpu_num,
cluster=cluster,
compute={
'cluster': cluster,
'gpu_type': gpu_type,
'gpus': gpu_num
},
image=args.image,
integrations=integrations,
command=command,
Expand Down
40 changes: 40 additions & 0 deletions tests/data/test_packing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import Mock, patch

Expand All @@ -9,6 +10,7 @@
from composer.utils import dist, reproducibility
from omegaconf import DictConfig
from pytest import approx
from streaming import MDSWriter
from torch.utils.data import DataLoader

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
Expand Down Expand Up @@ -149,6 +151,44 @@ def patched_packing_ratio(*args: Any, **kwargs: Any):
return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4)


@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio)
def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
columns = {'prompt': 'str', 'response': 'str'}
tokenizer = build_tokenizer('gpt2', {})
remote_dir = str(tmp_path / 'remote')
local_dir = str(tmp_path / 'local')
with MDSWriter(out=remote_dir, columns=columns, compression=None) as out:
out.write({'prompt': 'HELLO', 'response': 'WORLD'})
cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'remote': remote_dir,
'local': local_dir,
'packing_ratio': 'auto',
'max_seq_len': 200,
'decoder_only_format': True
},
'drop_last': False,
# Need to test with 0 num_workers because the packing collator object
# Gets copied per worker and we cannot check the waste for child processes.
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
})

loader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size=6).dataloader

batch_ix = 0
for _ in loader:
batch_ix += 1
if batch_ix >= 3:
break


@pytest.mark.parametrize('packing_ratio', ['auto', 2.0])
@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio)
Expand Down

0 comments on commit 02f7d09

Please sign in to comment.