Skip to content

Commit

Permalink
precommit and remove erroneous token counting from denoising text dat…
Browse files Browse the repository at this point in the history
…aloader
  • Loading branch information
dakinggg committed Oct 15, 2023
1 parent f88d267 commit eac52cb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
13 changes: 5 additions & 8 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.models import utils

__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
Expand Down Expand Up @@ -355,7 +354,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader[Dict]:
) -> DataSpec:
"""Constructor function for a Mixture of Denoisers dataloader.
This function constructs a dataloader that can be used to train an
Expand Down Expand Up @@ -520,10 +519,7 @@ def build_text_denoising_dataloader(
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
return DataSpec(dataloader=dl)


def noise_token_sequence(
Expand Down Expand Up @@ -876,7 +872,8 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size).dataloader
loader = build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size).dataloader
assert isinstance(loader.dataset, StreamingTextDataset)

print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ def _build_collate_fn(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

device_batch_size = 2
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader
dataloader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader

packing = cfg.dataset.get('packing_ratio') is not None

Expand Down
6 changes: 4 additions & 2 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def test_finetuning_dataloader(decoder_only_format: bool,
else:
expected_keys += ['decoder_attention_mask', 'decoder_input_ids']

loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader
loader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader
batch_ix = 0
for batch in loader:
for k in expected_keys:
Expand Down Expand Up @@ -546,7 +547,8 @@ def test_malformed_data(
match='Unable to tokenize example')

with error_context:
dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader
dl = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader

if not add_bad_data_error:
# +5 because we added samples with just bos/eos in each of prompt/response
Expand Down

0 comments on commit eac52cb

Please sign in to comment.