From 9a087cbe441e3ff438b90654492c75a9ab20323c Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:18:04 -0700 Subject: [PATCH] Fix huggingface custom split path issue (#588) --- llmfoundry/data/finetuning/dataloader.py | 14 ++-- tests/test_dataloader.py | 81 ++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 263aa47fb2..b5a7420b34 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -292,19 +292,25 @@ def _build_hf_dataset_from_remote( """ supported_extensions = ['jsonl', 'csv', 'parquet'] finetune_dir = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - f'downloaded_finetuning_data/{cfg.dataset.split}') + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + 'downloaded_finetuning', + cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not', + ) os.makedirs(finetune_dir, exist_ok=True) for extension in supported_extensions: name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}' destination = str( - os.path.abspath(f'{finetune_dir}/{cfg.dataset.split}.{extension}')) + os.path.abspath( + os.path.join( + finetune_dir, 'data', + f'{cfg.dataset.split}-00000-of-00001.{extension}'))) # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed') if dist.get_local_rank() == 0: try: - get_file(name, destination, overwrite=True) + get_file(path=name, destination=destination, overwrite=True) except FileNotFoundError as e: if extension == supported_extensions[-1]: files_searched = [ diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 3aad4c68d5..f9a281efa7 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -394,6 +394,87 @@ def test_finetuning_dataloader_small_data(dataset_size: int, shutil.rmtree(tiny_dataset_folder_path) +@pytest.mark.parametrize('split', ['train', 'custom', 'data']) +def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): + tokenizer_name = 'gpt2' + max_seq_len = 2048 + tiny_dataset_folder_path = str(tmp_path) + tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'data', + f'{split}-00000-of-00001.jsonl') + if dist.get_global_rank() == 0: + make_tiny_ft_dataset(path=tiny_dataset_path, size=16) + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': tiny_dataset_folder_path, + 'split': split, + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + _ = build_finetuning_dataloader(cfg, tokenizer, 4) + + +def mock_get_file(path: str, destination: str, overwrite: bool = False): + make_tiny_ft_dataset(path=destination, size=16) + + +@pytest.mark.parametrize('split', ['train', 'custom', 'data']) +def test_finetuning_dataloader_custom_split_remote( + tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch): + tokenizer_name = 'gpt2' + max_seq_len = 2048 + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': 's3://test-bucket/path/to/data', + 'split': split, + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + with monkeypatch.context() as m: + m.setattr('llmfoundry.data.finetuning.dataloader.get_file', + mock_get_file) + _ = build_finetuning_dataloader(cfg, tokenizer, 4) + + @pytest.mark.parametrize('add_bad_data_dropped', [True, False]) @pytest.mark.parametrize('add_bad_data_error', [True, False]) def test_malformed_data(