Skip to content

Commit

Permalink
fix the places that assumed iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 15, 2023
1 parent b3fe63b commit f88d267
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size).dataloader
assert isinstance(loader.dataset, StreamingTextDataset)

print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _build_collate_fn(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

device_batch_size = 2
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader

packing = cfg.dataset.get('packing_ratio') is not None

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
max(raw_batch_sizes) * 100)
max(raw_batch_sizes) * 100).dataloader

# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_num_samples_in_batch(batch: Batch) -> int:
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer

Expand Down
8 changes: 4 additions & 4 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_correct_padding(tokenizer_name: str,
test_cfg.eval_loader,
tokenizer,
batch_size,
)
).dataloader
batch = next(iter(eval_loader))

assert batch['input_ids'].shape == torch.Size([batch_size, 2048])
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
tokenizer_kwargs={'model_max_length': max_seq_len})

loader = build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size)
device_batch_size).dataloader
batch_ix = 0
for batch in loader:
for k in expected_keys:
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_finetuning_dataloader(decoder_only_format: bool,
else:
expected_keys += ['decoder_attention_mask', 'decoder_input_ids']

loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader
batch_ix = 0
for batch in loader:
for k in expected_keys:
Expand Down Expand Up @@ -546,7 +546,7 @@ def test_malformed_data(
match='Unable to tokenize example')

with error_context:
dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader

if not add_bad_data_error:
# +5 because we added samples with just bos/eos in each of prompt/response
Expand Down

0 comments on commit f88d267

Please sign in to comment.