From 2799918ac056f21364f4c32e235bdc0349a7393f Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:44:15 -0700 Subject: [PATCH] Add handling for various types of malformed finetuning data (#576) --- llmfoundry/data/finetuning/tasks.py | 21 ++++- setup.py | 1 + tests/test_dataloader.py | 141 +++++++++++++++++++++++++++- tests/test_eval.py | 3 +- 4 files changed, 160 insertions(+), 6 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 59b62413d4..f4da9750c7 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -52,6 +52,14 @@ def _tokenize_formatted_example(example: Dict[str, Any], '"prompt" and "response" are required keys but at least one was missing ' +\ f'from {example=}.' ) + if not isinstance(example['prompt'], str): + raise TypeError( + f'Unable to tokenize example because "prompt" was not a string. {example=}' + ) + if not isinstance(example['response'], str): + raise TypeError( + f'Unable to tokenize example because "response" was not a string. {example=}' + ) return tokenizer(text=example['prompt'], text_target=example['response']) @@ -306,7 +314,18 @@ def dataset_mapper(example: Dict): f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.' ) - return prompt_length_filtered_dataset + empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( + lambda example: len(example['input_ids']) > 0 and len(example[ + 'labels']) > 0 and any(token_id != tokenizer.pad_token_id + for token_id in example['labels'])) + empty_examples_removed = len(prompt_length_filtered_dataset) - len( + empty_examples_dropped_dataset) + if empty_examples_removed > 0: + warnings.warn( + f'Dropped {empty_examples_removed} examples where the prompt or response was empty, ' + + 'or the response was only padding tokens.') + + return empty_examples_dropped_dataset def build_from_streaming(self, *args: Any, **kwargs: Any): return StreamingFinetuningDataset(*args, **kwargs) diff --git a/setup.py b/setup.py index e963833a40..fc3973e7ce 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ 'mosaicml-streaming>=0.5.1,<0.6', 'torch>=1.13.1,<=2.0.1', 'datasets==2.10.1', + 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.1.97', 'einops==0.5.0', 'omegaconf>=2.2.3,<3', diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 457265806f..3aad4c68d5 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -3,6 +3,7 @@ import contextlib import json import os +import pathlib import shutil import sys import tempfile @@ -11,7 +12,7 @@ import pytest import torch -from composer.utils import dist +from composer.utils import dist, using_torch_2 from omegaconf import OmegaConf as om from llmfoundry import (build_finetuning_dataloader, @@ -278,11 +279,63 @@ def test_finetuning_dataloader(decoder_only_format: bool, break -def make_tiny_ft_dataset(path: str, size: int = 4): - sample = {'prompt': 'hello', 'response': 'goodbye'} +def make_tiny_ft_dataset( + path: str, + size: int = 4, + add_bad_data_dropped: bool = False, + add_bad_data_error: bool = False, + add_just_bos_eos_pad: bool = False, + pad_token: Optional[str] = None, + start_token: Optional[str] = None, + end_token: Optional[str] = None, +): + good_sample = {'prompt': 'hello', 'response': 'goodbye'} + samples = [good_sample] * size + if add_bad_data_dropped: + if pad_token is None: + raise ValueError( + 'pad_token, start_token, and end_token must be specified if add_bad_data is True' + ) + # empty prompt + samples.append({'prompt': '', 'response': 'goodbye'}) + # empty response + samples.append({'prompt': 'hello', 'response': ''}) + # response just pad + samples.append({'prompt': 'hello', 'response': pad_token}) + # response just pad multiple times + samples.append({'prompt': 'hello', 'response': pad_token * 3}) + + if add_bad_data_error: + # prompt just None + samples.append({ + 'prompt': None, + 'response': 'goodbye' + }) # type: ignore (intentional test) + # response just None + samples.append({ + 'prompt': 'hello', + 'response': None + }) # type: ignore (intentional test) + + if add_just_bos_eos_pad: + if pad_token is None or start_token is None or end_token is None: + raise ValueError( + 'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True' + ) + # prompt just start + samples.append({'prompt': start_token, 'response': 'goodbye'}) + # response just start + samples.append({'prompt': 'hello', 'response': start_token}) + # prompt just end + samples.append({'prompt': end_token, 'response': 'goodbye'}) + # response just end + samples.append({'prompt': 'hello', 'response': end_token}) + # prompt just pad + samples.append({'prompt': pad_token, 'response': 'goodbye'}) + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'w') as _f: - for _ in range(size): + for sample in samples: _f.write(json.dumps(sample)) _f.write('\n') @@ -339,3 +392,83 @@ def test_finetuning_dataloader_small_data(dataset_size: int, if dist.get_global_rank() == 0: shutil.rmtree(tiny_dataset_folder_path) + + +@pytest.mark.parametrize('add_bad_data_dropped', [True, False]) +@pytest.mark.parametrize('add_bad_data_error', [True, False]) +def test_malformed_data( + add_bad_data_dropped: bool, + add_bad_data_error: bool, + tmp_path: pathlib.Path, +): + tokenizer_name = 'mosaicml/mpt-7b' + max_seq_len = 2048 + dataset_size = 5 + device_batch_size = 5 + tiny_dataset_folder_path = tmp_path + tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl') + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + tokenizer.add_special_tokens({ + 'pad_token': '', + 'bos_token': '', + 'eos_token': '', + }) + + if dist.get_global_rank() == 0: + make_tiny_ft_dataset( + path=tiny_dataset_path, + size=dataset_size, + add_bad_data_dropped=add_bad_data_dropped, + add_bad_data_error=add_bad_data_error, + add_just_bos_eos_pad=True, + pad_token=tokenizer.pad_token, + start_token=tokenizer.bos_token, + end_token=tokenizer.eos_token, + ) + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': str(tiny_dataset_folder_path), + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 0, + # set prefetch to 2 if < torch 2, else set it to None + 'prefetch_factor': None if using_torch_2() else 2, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + expected_keys = ['input_ids', 'attention_mask', 'labels'] + expected_keys += ['bidirectional_mask'] + + error_context = contextlib.nullcontext() + if add_bad_data_error: + error_context = pytest.raises(TypeError, + match='Unable to tokenize example') + + with error_context: + dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + + if not add_bad_data_error: + # +5 because we added samples with just bos/eos in each of prompt/response + expected_num_batches = (dataset_size + 5) // device_batch_size + + actual_num_batches = 0 + for _ in dl: + actual_num_batches += 1 + + assert actual_num_batches == expected_num_batches diff --git a/tests/test_eval.py b/tests/test_eval.py index 9b7573bf17..ecd15ab62f 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -40,7 +40,8 @@ def mock_saved_model_path(): device = 'cpu' model_cfg.model.init_device = device # build tokenizer - tokenizer = build_tokenizer(model_cfg.tokenizer) + tokenizer = build_tokenizer(model_cfg.tokenizer.name, + model_cfg.tokenizer.get('kwargs', {})) # build model model = COMPOSER_MODEL_REGISTRY[model_cfg.model.name](model_cfg.model, tokenizer)