Skip to content

Commit

Permalink
Fix huggingface custom split path issue (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 8, 2023
1 parent d6ebcc5 commit 9a087cb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
14 changes: 10 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,25 @@ def _build_hf_dataset_from_remote(
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
finetune_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
f'downloaded_finetuning_data/{cfg.dataset.split}')
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
destination = str(
os.path.abspath(f'{finetune_dir}/{cfg.dataset.split}.{extension}'))
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{cfg.dataset.split}-00000-of-00001.{extension}')))
# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
if dist.get_local_rank() == 0:
try:
get_file(name, destination, overwrite=True)
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
files_searched = [
Expand Down
81 changes: 81 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,87 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
shutil.rmtree(tiny_dataset_folder_path)


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
tokenizer_name = 'gpt2'
max_seq_len = 2048
tiny_dataset_folder_path = str(tmp_path)
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'data',
f'{split}-00000-of-00001.jsonl')
if dist.get_global_rank() == 0:
make_tiny_ft_dataset(path=tiny_dataset_path, size=16)

cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': tiny_dataset_folder_path,
'split': split,
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len},
)

_ = build_finetuning_dataloader(cfg, tokenizer, 4)


def mock_get_file(path: str, destination: str, overwrite: bool = False):
make_tiny_ft_dataset(path=destination, size=16)


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split_remote(
tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch):
tokenizer_name = 'gpt2'
max_seq_len = 2048

cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': 's3://test-bucket/path/to/data',
'split': split,
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len},
)

with monkeypatch.context() as m:
m.setattr('llmfoundry.data.finetuning.dataloader.get_file',
mock_get_file)
_ = build_finetuning_dataloader(cfg, tokenizer, 4)


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
@pytest.mark.parametrize('add_bad_data_error', [True, False])
def test_malformed_data(
Expand Down

0 comments on commit 9a087cb

Please sign in to comment.