Skip to content

Commit

Permalink
Fix propagation of drop_last and add error message when it would pr…
Browse files Browse the repository at this point in the history
…oduce no batches (#549)

* add error message, text, and fix drop last

* pyright

* pyright

* pyright

* lots of fixes

* precommit

* precommit
  • Loading branch information
dakinggg committed Aug 25, 2023
1 parent 3e3c7d3 commit 3eac52e
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 5 deletions.
28 changes: 25 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
from typing import Union

import datasets as hf_datasets
import torch
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
Expand Down Expand Up @@ -165,11 +166,30 @@ def build_finetuning_dataloader(cfg: DictConfig,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
+
f'your per device batch size is {dataloader_batch_size}. Please increase the number '
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
Expand Down Expand Up @@ -235,8 +255,10 @@ def _validate_config(dataset_cfg: DictConfig):
)


def _build_hf_dataset_from_remote(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase):
def _build_hf_dataset_from_remote(
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Builds a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,11 @@ def get_preprocessing_fn_from_str(self,

return preprocessing_fn

def build_from_hf(self, cfg: DictConfig, max_seq_len: int,
tokenizer: PreTrainedTokenizerBase):
def build_from_hf(
self, cfg: DictConfig, max_seq_len: int,
tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Load a HuggingFace Datasets, preprocess, and tokenize.
Note: This function will drop examples where the prompt is longer than the max_seq_len
Expand Down
69 changes: 69 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import json
import os
import shutil
import sys
Expand All @@ -9,6 +11,7 @@

import pytest
import torch
from composer.utils import dist
from omegaconf import OmegaConf as om

from llmfoundry import (build_finetuning_dataloader,
Expand Down Expand Up @@ -282,3 +285,69 @@ def test_finetuning_dataloader(decoder_only_format: bool,
batch_ix += 1
if batch_ix >= 3:
break


def make_tiny_ft_dataset(path: str, size: int = 4):
sample = {'prompt': 'hello', 'response': 'goodbye'}
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
for _ in range(size):
_f.write(json.dumps(sample))
_f.write('\n')


@pytest.mark.world_size(2)
@pytest.mark.parametrize('dataset_size', [4, 8])
@pytest.mark.parametrize('device_batch_size', [2, 4])
@pytest.mark.parametrize('drop_last', [True, False])
def test_finetuning_dataloader_small_data(dataset_size: int,
device_batch_size: int,
drop_last: bool):
tokenizer_name = 'gpt2'
max_seq_len = 2048
tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
if dist.get_global_rank() == 0:
make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size)

cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': tiny_dataset_folder_path,
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': drop_last,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
om.create({
'name': tokenizer_name,
'kwargs': {
'model_max_length': max_seq_len
}
}))

expected_keys = ['input_ids', 'attention_mask', 'labels']
expected_keys += ['bidirectional_mask']

error_context = contextlib.nullcontext()
if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last:
error_context = pytest.raises(ValueError, match='Your dataset')

with error_context:
_ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)

if dist.get_global_rank() == 0:
shutil.rmtree(tiny_dataset_folder_path)

0 comments on commit 3eac52e

Please sign in to comment.