Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds max seq len filter before finetuning #358

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
)

else:
dataset = dataset_constructor.build_from_hf(cfg.dataset, tokenizer)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
Expand Down
20 changes: 18 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

import importlib
import os
import warnings
from typing import Any, Callable, Dict, Optional, Union

import datasets as hf_datasets
Expand Down Expand Up @@ -220,11 +221,15 @@ def get_preprocessing_fn_from_str(self,

return preprocessing_fn

def build_from_hf(self, cfg: DictConfig, tokenizer: Tokenizer):
def build_from_hf(self, cfg: DictConfig, max_seq_len: int,
tokenizer: Tokenizer):
"""Load a HuggingFace Datasets, preprocess, and tokenize.

Note: This function will drop examples where the prompt is longer than the max_seq_len

Args:
cfg (DictConfig): The dataset configuration.
max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped.
tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset.

Returns:
Expand All @@ -248,9 +253,20 @@ def dataset_mapper(example: Dict):
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=max(os.cpu_count() - 2, 1),
)
prompt_length_filtered_dataset = tokenized_dataset.filter(
lambda example: len(example['input_ids']) < max_seq_len,
num_proc=max(os.cpu_count() - 2, 1))

examples_removed = len(tokenized_dataset) - len(
prompt_length_filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
)

return tokenized_dataset
return prompt_length_filtered_dataset

def build_from_streaming(self, *args: Any, **kwargs: Any):
return StreamingFinetuningDataset(*args, **kwargs)
Expand Down