Skip to content

Commit

Permalink
Updated streaming args for StreamingDataset subclasses (#602)
Browse files Browse the repository at this point in the history
* Fix streaming dataset setup for fine-tuning

Currently, when a StreamingFinetuningDataset is created using the
build_finetuning_dataloader method, a failure is returned as some
incorrect parameters are passed through to the constructor of
StreamingFinetuningDataset. This patch fixes the paramter mismatch and
adds test coverage for this case.

* updated StreamingTextDataset and StreamingFinetuningDataset with new streaming args, bumped streaming version

* updated StreamingTextDataset and StreamingFinetuningDataset with new streaming args, bumped streaming version

---------

Co-authored-by: Aiden Grossman <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
3 people committed Sep 16, 2023
1 parent c308d10 commit 229ab4f
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 33 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def build_finetuning_dataloader(cfg: DictConfig,
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
Expand Down
104 changes: 73 additions & 31 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,44 +71,76 @@ class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Args:
local (str): Local dataset directory where shards are cached by split.
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
remote (str, optional): Download shards from this remote path or directory. If None, this
rank and worker's partition of the dataset must all exist locally. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
keep_zip (bool, optional): Whether to keep or delete the compressed file when
decompressing downloaded shards. If set to None, keep if remote is local. Defaults to
``None``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
If ``None``, defaults to the number of nodes of the initial run. Defaults to 128.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""

def __init__(self,
local: str,
tokenizer: PreTrainedTokenizerBase,
local: str,
remote: Optional[str] = None,
split: Optional[str] = None,
shuffle: bool = False,
predownload: Optional[int] = 100_000,
keep_zip: bool = False,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = 128,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1b',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
**kwargs: Any):

if len(kwargs) > 0:
Expand All @@ -125,18 +157,28 @@ def __init__(self,
)

# Build Dataset
super().__init__(local=local,
remote=remote,
split=split,
shuffle=shuffle,
predownload=predownload,
keep_zip=keep_zip,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
shuffle_seed=shuffle_seed,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size)
super().__init__(
local=local,
remote=remote,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
)

self.tokenizer = tokenizer

Expand Down
13 changes: 12 additions & 1 deletion llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ class StreamingTextDataset(StreamingDataset):
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""

def __init__(self,
Expand All @@ -91,6 +98,8 @@ def __init__(self,
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
**kwargs: Any):

group_method = kwargs.pop('group_method', None)
Expand Down Expand Up @@ -138,6 +147,8 @@ def __init__(self,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
'mosaicml[libcloud,wandb,mlflow]>=0.16.1,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.33,<4.34',
'mosaicml-streaming>=0.5.1,<0.6',
'mosaicml-streaming>=0.6,<0.7',
'torch>=1.13.1,<2.1.1',
'datasets>=2.14.5,<2.15',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
Expand Down
58 changes: 58 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from composer.utils import dist, using_torch_2
from omegaconf import OmegaConf as om
from streaming import MDSWriter

from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
Expand Down Expand Up @@ -42,6 +43,25 @@ def get_abs_data_path(data_local: str):
return os.path.join(os.getcwd(), data_local)


def build_mock_ft_streaming_dataset(data_path: str, split: str):
columns = {'prompt': 'str', 'response': 'str'}

dataset = [{
'prompt': 'This is just a test1',
'response': 'Hello World1'
}, {
'prompt': 'This is just a test2',
'response': 'Hello world2'
}]

output_path = os.path.join(data_path, split)

with MDSWriter(columns=columns, out=output_path,
compression=None) as output_writer:
for sample in dataset:
output_writer.write(sample)


@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
@pytest.mark.parametrize('pretokenize', [False, True])
def test_correct_padding(tokenizer_name: str,
Expand Down Expand Up @@ -414,6 +434,44 @@ def test_finetuning_dataloader_custom_split_remote(
_ = build_finetuning_dataloader(cfg, tokenizer, 4)


def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
max_seq_len = 2048

remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

build_mock_ft_streaming_dataset(remote_path, 'train')

cfg = {
'name': 'finetuning',
'dataset': {
'remote': remote_path,
'local': local_path,
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

_ = build_finetuning_dataloader(cfg, tokenizer, 4)


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
@pytest.mark.parametrize('add_bad_data_error', [True, False])
def test_malformed_data(
Expand Down

0 comments on commit 229ab4f

Please sign in to comment.