diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 9b27f4f0d0..4fcb7c4f25 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -2,10 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import os +import tempfile from typing import Union import torch -from composer.utils import dist +from composer.utils import dist, get_file, parse_uri from omegaconf import DictConfig from torch.utils.data import DataLoader from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -38,7 +40,9 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer, --- *** HuggingFace dataset config fields *** cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset - to use. + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to pass to `datasets.load_dataset`, which can be used to load a dataset from local files. @@ -145,7 +149,51 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer, ) else: - dataset = dataset_constructor.build_from_hf(cfg.dataset, tokenizer) + backend, _, _ = parse_uri(cfg.dataset.hf_name) + if backend not in ['', None]: + if cfg.dataset.get('split') is None: + raise ValueError( + 'When using a HuggingFace dataset from a URL, you must set the ' + \ + '`split` key in the dataset config.' + ) + supported_extensions = ['jsonl', 'csv', 'parquet'] + with tempfile.TemporaryDirectory() as tmp_dir: + for extension in supported_extensions: + name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}' + destination = str( + os.path.abspath( + f'{tmp_dir}/{cfg.dataset.split}.{extension}')) + try: + with dist.run_local_rank_zero_first(): + get_file(name, destination, overwrite=True) + except FileNotFoundError as e: + if extension == supported_extensions[-1]: + raise FileNotFoundError( + f'Could not find a {cfg.dataset.split} file with any of ' + \ + f'the supported extensions: {supported_extensions}\n' + \ + f'at {cfg.dataset.hf_name}/{cfg.dataset.split}' + ) from e + else: + print( + f'Could not find {name}, looking for another extension' + ) + continue + # 'json' causes special behavior in the dataset constructor + cfg.dataset.hf_name = extension if extension != 'jsonl' else 'json' + kwargs = cfg.dataset.get('hf_kwargs', {}) + kwargs['data_files'] = destination + cfg.dataset['hf_kwargs'] = kwargs + print(cfg.dataset) + dataset = dataset_constructor.build_from_hf( + cfg.dataset, + tokenizer=tokenizer, + ) + break + else: + dataset = dataset_constructor.build_from_hf( + cfg.dataset, + tokenizer=tokenizer, + ) collate_fn, dataloader_batch_size = _build_collate_fn( cfg.dataset, tokenizer, device_batch_size)