From 3eac52e64f22afd6f6b2be78ab773c9acdda7a1e Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 24 Aug 2023 21:07:43 -0700 Subject: [PATCH] Fix propagation of `drop_last` and add error message when it would produce no batches (#549) * add error message, text, and fix drop last * pyright * pyright * pyright * lots of fixes * precommit * precommit --- llmfoundry/data/finetuning/dataloader.py | 28 ++++++++-- llmfoundry/data/finetuning/tasks.py | 7 ++- tests/test_dataloader.py | 69 ++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 5 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 86aa7e0815..004870d7b4 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -1,9 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import logging import os +from typing import Union +import datasets as hf_datasets import torch from composer.utils import dist, get_file, parse_uri from omegaconf import DictConfig @@ -165,11 +166,30 @@ def build_finetuning_dataloader(cfg: DictConfig, collate_fn, dataloader_batch_size = _build_collate_fn( cfg.dataset, tokenizer, device_batch_size) + if cfg.drop_last: + world_size = dist.get_world_size() + minimum_dataset_size = world_size * dataloader_batch_size + if hasattr(dataset, '__len__'): + full_dataset_size = len(dataset) + if full_dataset_size < minimum_dataset_size: + raise ValueError( + f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) ' + + + f'has {full_dataset_size} samples, but your minimum batch size ' + + + f'is {minimum_dataset_size} because you are running on {world_size} gpus and ' + + + f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + + + f'of samples in your dataset to at least {minimum_dataset_size}.' + ) + assert dataset is not None return DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, + drop_last=cfg.drop_last, sampler=dist.get_sampler(dataset, drop_last=cfg.drop_last, shuffle=cfg.dataset.shuffle), @@ -235,8 +255,10 @@ def _validate_config(dataset_cfg: DictConfig): ) -def _build_hf_dataset_from_remote(cfg: DictConfig, - tokenizer: PreTrainedTokenizerBase): +def _build_hf_dataset_from_remote( + cfg: DictConfig, tokenizer: PreTrainedTokenizerBase +) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, + hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Builds a dataset from a remote object store. This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 9c01aab49c..59b62413d4 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -254,8 +254,11 @@ def get_preprocessing_fn_from_str(self, return preprocessing_fn - def build_from_hf(self, cfg: DictConfig, max_seq_len: int, - tokenizer: PreTrainedTokenizerBase): + def build_from_hf( + self, cfg: DictConfig, max_seq_len: int, + tokenizer: PreTrainedTokenizerBase + ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, + hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. Note: This function will drop examples where the prompt is longer than the max_seq_len diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 7039814d42..3cd930b85e 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,5 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib +import json import os import shutil import sys @@ -9,6 +11,7 @@ import pytest import torch +from composer.utils import dist from omegaconf import OmegaConf as om from llmfoundry import (build_finetuning_dataloader, @@ -282,3 +285,69 @@ def test_finetuning_dataloader(decoder_only_format: bool, batch_ix += 1 if batch_ix >= 3: break + + +def make_tiny_ft_dataset(path: str, size: int = 4): + sample = {'prompt': 'hello', 'response': 'goodbye'} + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as _f: + for _ in range(size): + _f.write(json.dumps(sample)) + _f.write('\n') + + +@pytest.mark.world_size(2) +@pytest.mark.parametrize('dataset_size', [4, 8]) +@pytest.mark.parametrize('device_batch_size', [2, 4]) +@pytest.mark.parametrize('drop_last', [True, False]) +def test_finetuning_dataloader_small_data(dataset_size: int, + device_batch_size: int, + drop_last: bool): + tokenizer_name = 'gpt2' + max_seq_len = 2048 + tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') + tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl') + if dist.get_global_rank() == 0: + make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': tiny_dataset_folder_path, + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': drop_last, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + om.create({ + 'name': tokenizer_name, + 'kwargs': { + 'model_max_length': max_seq_len + } + })) + + expected_keys = ['input_ids', 'attention_mask', 'labels'] + expected_keys += ['bidirectional_mask'] + + error_context = contextlib.nullcontext() + if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last: + error_context = pytest.raises(ValueError, match='Your dataset') + + with error_context: + _ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + + if dist.get_global_rank() == 0: + shutil.rmtree(tiny_dataset_folder_path)