Skip to content

Commit

Permalink
Add handling for various types of malformed finetuning data (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 5, 2023
1 parent 1c75fda commit 2799918
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 6 deletions.
21 changes: 20 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def _tokenize_formatted_example(example: Dict[str, Any],
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {example=}.'
)
if not isinstance(example['prompt'], str):
raise TypeError(
f'Unable to tokenize example because "prompt" was not a string. {example=}'
)
if not isinstance(example['response'], str):
raise TypeError(
f'Unable to tokenize example because "response" was not a string. {example=}'
)
return tokenizer(text=example['prompt'], text_target=example['response'])


Expand Down Expand Up @@ -306,7 +314,18 @@ def dataset_mapper(example: Dict):
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
)

return prompt_length_filtered_dataset
empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
lambda example: len(example['input_ids']) > 0 and len(example[
'labels']) > 0 and any(token_id != tokenizer.pad_token_id
for token_id in example['labels']))
empty_examples_removed = len(prompt_length_filtered_dataset) - len(
empty_examples_dropped_dataset)
if empty_examples_removed > 0:
warnings.warn(
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
+ 'or the response was only padding tokens.')

return empty_examples_dropped_dataset

def build_from_streaming(self, *args: Any, **kwargs: Any):
return StreamingFinetuningDataset(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'mosaicml-streaming>=0.5.1,<0.6',
'torch>=1.13.1,<=2.0.1',
'datasets==2.10.1',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
'sentencepiece==0.1.97',
'einops==0.5.0',
'omegaconf>=2.2.3,<3',
Expand Down
141 changes: 137 additions & 4 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import json
import os
import pathlib
import shutil
import sys
import tempfile
Expand All @@ -11,7 +12,7 @@

import pytest
import torch
from composer.utils import dist
from composer.utils import dist, using_torch_2
from omegaconf import OmegaConf as om

from llmfoundry import (build_finetuning_dataloader,
Expand Down Expand Up @@ -278,11 +279,63 @@ def test_finetuning_dataloader(decoder_only_format: bool,
break


def make_tiny_ft_dataset(path: str, size: int = 4):
sample = {'prompt': 'hello', 'response': 'goodbye'}
def make_tiny_ft_dataset(
path: str,
size: int = 4,
add_bad_data_dropped: bool = False,
add_bad_data_error: bool = False,
add_just_bos_eos_pad: bool = False,
pad_token: Optional[str] = None,
start_token: Optional[str] = None,
end_token: Optional[str] = None,
):
good_sample = {'prompt': 'hello', 'response': 'goodbye'}
samples = [good_sample] * size
if add_bad_data_dropped:
if pad_token is None:
raise ValueError(
'pad_token, start_token, and end_token must be specified if add_bad_data is True'
)
# empty prompt
samples.append({'prompt': '', 'response': 'goodbye'})
# empty response
samples.append({'prompt': 'hello', 'response': ''})
# response just pad
samples.append({'prompt': 'hello', 'response': pad_token})
# response just pad multiple times
samples.append({'prompt': 'hello', 'response': pad_token * 3})

if add_bad_data_error:
# prompt just None
samples.append({
'prompt': None,
'response': 'goodbye'
}) # type: ignore (intentional test)
# response just None
samples.append({
'prompt': 'hello',
'response': None
}) # type: ignore (intentional test)

if add_just_bos_eos_pad:
if pad_token is None or start_token is None or end_token is None:
raise ValueError(
'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True'
)
# prompt just start
samples.append({'prompt': start_token, 'response': 'goodbye'})
# response just start
samples.append({'prompt': 'hello', 'response': start_token})
# prompt just end
samples.append({'prompt': end_token, 'response': 'goodbye'})
# response just end
samples.append({'prompt': 'hello', 'response': end_token})
# prompt just pad
samples.append({'prompt': pad_token, 'response': 'goodbye'})

os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
for _ in range(size):
for sample in samples:
_f.write(json.dumps(sample))
_f.write('\n')

Expand Down Expand Up @@ -339,3 +392,83 @@ def test_finetuning_dataloader_small_data(dataset_size: int,

if dist.get_global_rank() == 0:
shutil.rmtree(tiny_dataset_folder_path)


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
@pytest.mark.parametrize('add_bad_data_error', [True, False])
def test_malformed_data(
add_bad_data_dropped: bool,
add_bad_data_error: bool,
tmp_path: pathlib.Path,
):
tokenizer_name = 'mosaicml/mpt-7b'
max_seq_len = 2048
dataset_size = 5
device_batch_size = 5
tiny_dataset_folder_path = tmp_path
tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl')

tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len},
)
tokenizer.add_special_tokens({
'pad_token': '<pad>',
'bos_token': '<bos>',
'eos_token': '<eos>',
})

if dist.get_global_rank() == 0:
make_tiny_ft_dataset(
path=tiny_dataset_path,
size=dataset_size,
add_bad_data_dropped=add_bad_data_dropped,
add_bad_data_error=add_bad_data_error,
add_just_bos_eos_pad=True,
pad_token=tokenizer.pad_token,
start_token=tokenizer.bos_token,
end_token=tokenizer.eos_token,
)

cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': str(tiny_dataset_folder_path),
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 0,
# set prefetch to 2 if < torch 2, else set it to None
'prefetch_factor': None if using_torch_2() else 2,
'pin_memory': False,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

expected_keys = ['input_ids', 'attention_mask', 'labels']
expected_keys += ['bidirectional_mask']

error_context = contextlib.nullcontext()
if add_bad_data_error:
error_context = pytest.raises(TypeError,
match='Unable to tokenize example')

with error_context:
dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)

if not add_bad_data_error:
# +5 because we added samples with just bos/eos in each of prompt/response
expected_num_batches = (dataset_size + 5) // device_batch_size

actual_num_batches = 0
for _ in dl:
actual_num_batches += 1

assert actual_num_batches == expected_num_batches
3 changes: 2 additions & 1 deletion tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def mock_saved_model_path():
device = 'cpu'
model_cfg.model.init_device = device
# build tokenizer
tokenizer = build_tokenizer(model_cfg.tokenizer)
tokenizer = build_tokenizer(model_cfg.tokenizer.name,
model_cfg.tokenizer.get('kwargs', {}))
# build model
model = COMPOSER_MODEL_REGISTRY[model_cfg.model.name](model_cfg.model,
tokenizer)
Expand Down

0 comments on commit 2799918

Please sign in to comment.