Skip to content

Commit

Permalink
Fixes + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 8, 2023
1 parent 3339924 commit 0fa1123
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
39 changes: 29 additions & 10 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
Expand All @@ -10,6 +13,8 @@
from streaming.base.util import clean_stale_shared_memory
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)


class BinPackCollator:
"""Utility collator for packing to reduce padding."""
Expand Down Expand Up @@ -291,8 +296,13 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
# Set the seed so that auto packing is deterministic.
reproducibility.seed_all(0)

max_seq_len = dataloader_cfg.dataset.max_seq_len
# If max_seq_len is very small, skip profiling and select packing ratio of 1.
if max_seq_len <= 100:
return 1

min_ratio = 1
max_ratio = dataloader_cfg.dataset.max_seq_len / 100
max_ratio = max_seq_len / 100
profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio,
max_ratio, num_packing_ratios,
device_batch_size)
Expand All @@ -301,12 +311,10 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
# profiling_results are sorted from smallest to largest packing_ratio.
packing_ratio = 1
for packing_ratio_candidate, _, waste in profiling_results:
if waste > 0:
if waste is None or waste > 0:
break
packing_ratio = packing_ratio_candidate

clean_stale_shared_memory()

# Select the minimum packing ratio across all ranks.
if dist.is_available() and dist.is_initialized():
device = get_device(None)
Expand All @@ -322,9 +330,10 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,


def profile_packing(
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int) -> Iterable[Tuple[float, float, float]]:
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int
) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
"""Generator function that profiles example packing across packing ratios.
Args:
Expand Down Expand Up @@ -354,6 +363,10 @@ def profile_packing(
dataloader_cfg.prefetch_factor = None if using_torch_2() else 2
dataloader_cfg.persistent_workers = False

# If streaming dataset, use a temporary local folder for profiling
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
for packing_ratio in np.linspace(min_ratio,
Expand Down Expand Up @@ -386,7 +399,7 @@ def split_big_batch(raw_batch_size: int) -> List:
batches[idx].update({key: split})
return batches

def profile(raw_batch_size: int) -> Tuple[float, float]:
def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
packer = BinPackCollator(
collator=lambda x: x,
target_batch_size=device_batch_size,
Expand All @@ -399,9 +412,15 @@ def profile(raw_batch_size: int) -> Tuple[float, float]:
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer.pack(batch)
packer.pack(batch)

if packer.n_packed_examples == 0:
log.debug(
'No examples packed during profiling. Dataset is smaller than device batch size.'
)
return None, None

# Return the padding / waste stats over that bunch of data
# Return the padding and waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return padding_percent, waste_percent
Expand Down
41 changes: 41 additions & 0 deletions tests/data/test_packing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import Mock, patch

import numpy as np
import pytest
import torch
from composer.utils import dist, reproducibility, using_torch_2
from omegaconf import DictConfig
from pytest import approx
from streaming import MDSWriter
from torch.utils.data import DataLoader

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
Expand Down Expand Up @@ -149,6 +152,44 @@ def patched_packing_ratio(*args: Any, **kwargs: Any):
return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4)


@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio)
def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
columns = {'prompt': 'str', 'response': 'str'}
tokenizer = build_tokenizer('gpt2', {})
remote_dir = str(tmp_path / 'remote')
local_dir = str(tmp_path / 'local')
with MDSWriter(out=remote_dir, columns=columns, compression=None) as out:
out.write({'prompt': 'HELLO', 'response': 'WORLD'})
cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'remote': remote_dir,
'local': local_dir,
'packing_ratio': 'auto',
'max_seq_len': 200,
'decoder_only_format': True
},
'drop_last': False,
# Need to test with 0 num_workers because the packing collator object
# Gets copied per worker and we cannot check the waste for child processes.
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None if using_torch_2() else 2,
'persistent_workers': False,
'timeout': 0,
})

loader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size=6).dataloader

batch_ix = 0
for _ in loader:
batch_ix += 1
if batch_ix >= 3:
break


@pytest.mark.parametrize('packing_ratio', ['auto', 2.0])
@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio)
Expand Down

0 comments on commit 0fa1123

Please sign in to comment.