From eac52cb0a45cc8b63f5f16538a588eb4c9c37960 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 14 Oct 2023 20:53:58 -0700 Subject: [PATCH] precommit and remove erroneous token counting from denoising text dataloader --- llmfoundry/data/denoising.py | 13 +++++-------- llmfoundry/data/finetuning/dataloader.py | 3 ++- tests/test_dataloader.py | 6 ++++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 3101b71faa..4c74f4c773 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -17,8 +17,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.data.packing import BinPackWrapper -from llmfoundry.data.text_data import (StreamingTextDataset, - get_tokens_per_batch_func) +from llmfoundry.data.text_data import StreamingTextDataset from llmfoundry.models import utils __all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader'] @@ -355,7 +354,7 @@ def build_text_denoising_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -) -> DataLoader[Dict]: +) -> DataSpec: """Constructor function for a Mixture of Denoisers dataloader. This function constructs a dataloader that can be used to train an @@ -520,10 +519,7 @@ def build_text_denoising_dataloader( timeout=cfg.get('timeout', 0), ) - token_counting_func = get_tokens_per_batch_func( - pad_token_id=tokenizer.pad_token_id) - - return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + return DataSpec(dataloader=dl) def noise_token_sequence( @@ -876,7 +872,8 @@ def _format_tokens_for_decoder_only( tokenizer = build_tokenizer(tokenizer_name=tokenizer_name, tokenizer_kwargs=tokenizer_kwargs) - loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size).dataloader + loader = build_text_denoising_dataloader(cfg, tokenizer, + device_batch_size).dataloader assert isinstance(loader.dataset, StreamingTextDataset) print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 763fe1e2a3..2dde563ac6 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -449,7 +449,8 @@ def _build_collate_fn( tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) device_batch_size = 2 - dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader + dataloader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader packing = cfg.dataset.get('packing_ratio') is not None diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 5313005106..b9da212df0 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -292,7 +292,8 @@ def test_finetuning_dataloader(decoder_only_format: bool, else: expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] - loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader + loader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader batch_ix = 0 for batch in loader: for k in expected_keys: @@ -546,7 +547,8 @@ def test_malformed_data( match='Unable to tokenize example') with error_context: - dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader + dl = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader if not add_bad_data_error: # +5 because we added samples with just bos/eos in each of prompt/response