From bdb57250aeced0f2e529c8af40844f7b9a2eb5c7 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 24 Dec 2023 17:08:31 -0800 Subject: [PATCH] Break Stream in two, rewrite core part of Stream, in prep for new shards (#547) * Stream -> StreamCore (Shard args) + Stream (all). * Drop pointless underscore vars. * Auto keyword. * Fix handling of generating local when default split. * Clean up. * Improve apply_defaults(). * Plug it in, propagate rewrites outward. * Adjust keep_old_phases vs keep_zip handling. * Adjust hash args handling. * Default apply_defaults() args to auto so you don't have to provide them all. * Update usage in test cases. * Fix edge case. * Another tweak. --- simulation/core/sim_dataset.py | 116 +++++----- simulation/core/yaml_processing.py | 30 ++- streaming/dataset.py | 96 ++++---- streaming/format/score.py | 346 +++++++++++++++++++++++++++++ streaming/phasing.py | 66 ++++++ streaming/stream.py | 345 +++++++++++++--------------- streaming/util/__init__.py | 5 +- streaming/util/auto.py | 36 +++ tests/test_stream.py | 11 +- 9 files changed, 762 insertions(+), 289 deletions(-) create mode 100644 streaming/format/score.py create mode 100644 streaming/phasing.py create mode 100644 streaming/util/auto.py diff --git a/simulation/core/sim_dataset.py b/simulation/core/sim_dataset.py index c3859fb6d..a15dde2b0 100644 --- a/simulation/core/sim_dataset.py +++ b/simulation/core/sim_dataset.py @@ -33,41 +33,50 @@ class SimulationDataset(StreamingDataset): nodes (int): Number of nodes. devices (int): Number of devices. workers (int): Number of workers. - streams (Optional[Sequence[Stream]]): One or more streams to stream/cache samples from, + streams (Sequence[Stream], optional): One or more streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - remote (Optional[str]): Remote path or directory to download the dataset from. If ``None``, + remote (str, optional): Remote path or directory to download the dataset from. If ``None``, its data must exist locally. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - local (Optional[str]): Local working directory to download shards to. This is where shards + local (str, optional): Local working directory to download shards to. This is where shards are cached while they are being used. Uses a temp directory if not set. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (Optional[str]): Which dataset split to use, if any. If provided, we stream from/to + split (str, optional): Which dataset split to use, if any. If provided, we stream from/to the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (Optional[str]): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. Can also take in human-readable number - abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, and so on). Defaults to ``None``. + download_retry (int): Number of download re-attempts before raising an error. Defaults to + ``2``. + download_timeout (str | float): Time in seconds to wait for a file download to complete + before raising an error. Streaming duration shorthand (e.g., ``1m23s``) is also + accepted. Defaults to ``1m``. + hash_algos (str | Sequence[str], optional): Ranked list of hashing algorithms to try. + Defaults to ``None``. + validate_hash (str, optional): Deprecated. See ``hash_algos``. Defaults to ``None``. + keep_old_phases (str): Which old phases of shard files to cache (until shard eviction). + Must be one of ``nil``, ``src``, or ``all``. Defaults to ``nil``. + keep_zip (bool, optional): Deprecated. See ``keep_old_phases``. Defaults to ``None``. + epoch_size (Union[str, int], optional): Number of samples to draw per epoch balanced + across all streams. If ``None``, takes its value from the total number of underlying + samples. Provide this field if you are weighting streams relatively to target a larger + or smaller epoch size. Defaults to ``None``. Can also take in human-readable number + abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``. predownload (int, optional): Target number of samples to download per worker in advance of current sample. Workers will attempt to download ahead by this many samples during, but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. - cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's + cache_limit (Union[str, int], optional): Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to ``None`` to disable shard eviction. Supports integer bytes as well as string human-readable bytes (e.g., ``100b``, ``64kb``, ``77mb``, and so on). Defaults to ``None``. + sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. + Defaults to ``balanced``. + sampling_granularity (int): When picking samples for a stream's final partial repeat, + how many samples to pick from the same shard at a time (``1`` for evenly balanced + across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). + Defaults to ``1``. partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption. The sample space is divided evenly according to the number of canonical @@ -86,17 +95,11 @@ class SimulationDataset(StreamingDataset): shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. + shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``. shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split into blocks of this size, and samples within each block are shuffled. If ``None``, its value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``. - sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. - Defaults to ``balanced``. - sampling_granularity (int): When picking samples for a stream's final partial repeat, - how many samples to pick from the same shard at a time (``1`` for evenly balanced - across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). - Defaults to ``1``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code @@ -105,6 +108,7 @@ class SimulationDataset(StreamingDataset): """ def __init__(self, + *, nodes: int, devices: int, workers: int, @@ -113,12 +117,16 @@ def __init__(self, local: Optional[str] = None, split: Optional[str] = None, download_retry: int = 2, - download_timeout: float = 60, + download_timeout: Union[str, float] = '1m', + hash_algos: Optional[Union[str, Sequence[str]]] = None, validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[Union[int, str]] = None, + keep_old_phases: str = 'nil', + keep_zip: Optional[bool] = None, + epoch_size: Optional[Union[str, int]] = None, predownload: Optional[int] = None, - cache_limit: Optional[Union[int, str]] = None, + cache_limit: Optional[Union[str, int]] = None, + sampling_method: str = 'balanced', + sampling_granularity: int = 1, partition_algo: str = 'relaxed', num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, @@ -126,29 +134,27 @@ def __init__(self, shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, shuffle_block_size: Optional[int] = None, - sampling_method: str = 'balanced', - sampling_granularity: int = 1, batching_method: str = 'random', allow_unsafe_types: bool = False) -> None: - # Time how long it takes for StreamingDataset instantiation t0 = time.time() - # Global arguments (which do not live in Streams). self.nodes = nodes self.devices = devices self.workers = workers - self.partition_algo = partition_algo + + # Global arguments (which do not live in Streams). self.predownload = predownload + self.sampling_method = sampling_method + self.sampling_granularity = sampling_granularity + self.partition_algo = partition_algo + self.num_canonical_nodes = num_canonical_nodes self.batch_size = batch_size self.shuffle = shuffle self.shuffle_algo = shuffle_algo self.shuffle_seed = shuffle_seed self.shuffle_block_size = shuffle_block_size - self.sampling_method = sampling_method - self.sampling_granularity = sampling_granularity self.batching_method = batching_method - self.num_canonical_nodes = num_canonical_nodes self.allow_unsafe_types = allow_unsafe_types self.initial_physical_nodes = nodes @@ -197,26 +203,24 @@ def __init__(self, # Initialize the Stream defaults and normalize to a list of Streams. if streams: - default = { - 'remote': remote, - 'local': local, - 'split': split, - 'download_retry': download_retry, - 'download_timeout': download_timeout, - 'validate_hash': validate_hash, - 'keep_zip': keep_zip, - } for stream in streams: - stream.apply_default(default) + stream.apply_defaults(split=split, + download_retry=download_retry, + download_timeout=download_timeout, + hash_algos=hash_algos, + validate_hash=validate_hash, + keep_old_phases=keep_old_phases, + keep_zip=keep_zip) else: - default = Stream(remote=remote, + streams = Stream(remote=remote, local=local, split=split, download_retry=download_retry, download_timeout=download_timeout, + hash_algos=hash_algos, validate_hash=validate_hash, - keep_zip=keep_zip) - streams = [default] + keep_old_phases=keep_old_phases, + keep_zip=keep_zip), # Validate the stream weighting scheme (relative or absolute) to catch errors before we go # to the trouble of loading them. @@ -231,10 +235,10 @@ def __init__(self, indices_created = [] for stream_idx, stream in enumerate(self.streams): if stream.remote: - filepath = os.path.join(stream.remote, stream.split, get_index_basename()) + filepath = os.path.join(stream.remote, stream.split or '', get_index_basename()) indices_created.append(0) else: - filepath = os.path.join(stream.local, stream.split, get_index_basename()) + filepath = os.path.join(stream.local, stream.split or '', get_index_basename()) # This suffix means a mock index file was created. Have to clean up later. if stream.local.split('_')[-1] == 'indexcreated': indices_created.append(2) @@ -245,9 +249,9 @@ def __init__(self, 'path': filepath, 'local': stream.local, 'remote': stream.remote, - 'proportion': stream._proportion, - 'repeat': stream._repeat, - 'choose': stream._choose + 'proportion': getattr(stream, 'proportion', None), + 'repeat': getattr(stream, 'repeat', None), + 'choose': getattr(stream, 'choose', None), } # Initialize the SimulationWorld, which tells us about nodes/devices/workers @@ -267,7 +271,7 @@ def __init__(self, logger.info(f' Processing index file for stream {stream_id + 1}') stream_shards = stream.get_shards(self.world, self.allow_unsafe_types) num_stream_samples = sum(map(len, stream_shards)) - index_filename = os.path.join(stream.local, stream.split, get_index_basename()) + index_filename = os.path.join(stream.local, stream.split or '', get_index_basename()) index_filenames.append(index_filename) local_foldernames.append(stream.local) if not num_stream_samples: diff --git a/simulation/core/yaml_processing.py b/simulation/core/yaml_processing.py index e1ddefab2..d0dbae699 100644 --- a/simulation/core/yaml_processing.py +++ b/simulation/core/yaml_processing.py @@ -187,11 +187,29 @@ def create_simulation_dataset(nodes: int, devices: int, workers: int, global_bat sampling_granularity = train_dataset.get('sampling_granularity', 1) batching_method = train_dataset.get('batching_method', 'random') - dataset = SimulationDataset(nodes, devices, workers, streams, remote, local, split, - download_retry, download_timeout, validate_hash, keep_zip, - epoch_size, predownload, cache_limit, partition_algo, - num_canonical_nodes, batch_size, shuffle, shuffle_algo, - shuffle_seed, shuffle_block_size, sampling_method, - sampling_granularity, batching_method) + dataset = SimulationDataset(nodes=nodes, + devices=devices, + workers=workers, + streams=streams, + remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + epoch_size=epoch_size, + predownload=predownload, + cache_limit=cache_limit, + partition_algo=partition_algo, + num_canonical_nodes=num_canonical_nodes, + batch_size=batch_size, + shuffle=shuffle, + shuffle_algo=shuffle_algo, + shuffle_seed=shuffle_seed, + shuffle_block_size=shuffle_block_size, + sampling_method=sampling_method, + sampling_granularity=sampling_granularity, + batching_method=batching_method) return dataset diff --git a/streaming/dataset.py b/streaming/dataset.py index e320d7426..9ed84aba8 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -248,15 +248,19 @@ class StreamingDataset(Array, IterableDataset): StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. split (str, optional): Which dataset split to use, if any. If provided, we stream from/to the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced + download_retry (int): Number of download re-attempts before raising an error. Defaults to + ``2``. + download_timeout (str | float): Time in seconds to wait for a file download to complete + before raising an error. Streaming duration shorthand (e.g., ``1m23s``) is also + accepted. Defaults to ``1m``. + hash_algos (str | Sequence[str], optional): Ranked list of hashing algorithms to try. + Defaults to ``None``. + validate_hash (str, optional): Deprecated. See ``hash_algos``. Defaults to ``None``. + keep_old_phases (str, optional): Which old phases of shard files to cache (until shard + eviction). If set, must be one of ``nil``, ``src``, or ``all``. Defaults to ``None``, + which uses ``keep_zip``, falling back to ``nil``. + keep_zip (bool, optional): Deprecated. See ``keep_old_phases``. Defaults to ``None``. + epoch_size (Union[str, int], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults to ``None``. Can also take in human-readable number @@ -266,7 +270,7 @@ class StreamingDataset(Array, IterableDataset): but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. - cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's + cache_limit (Union[str, int], optional): Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to ``None`` to disable shard eviction. Supports integer bytes as well as string @@ -315,12 +319,14 @@ def __init__(self, local: Optional[str] = None, split: Optional[str] = None, download_retry: int = 2, - download_timeout: float = 60, + download_timeout: Union[str, float] = '1m', + hash_algos: Optional[Union[str, Sequence[str]]] = None, validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[Union[int, str]] = None, + keep_old_phases: Optional[str] = None, + keep_zip: Optional[bool] = None, + epoch_size: Optional[Union[str, int]] = None, predownload: Optional[int] = None, - cache_limit: Optional[Union[int, str]] = None, + cache_limit: Optional[Union[str, int]] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, partition_algo: str = 'relaxed', @@ -407,26 +413,26 @@ def __init__(self, # Initialize the Stream defaults and normalize to a list of Streams. if streams: - default = { - 'remote': remote, - 'local': local, - 'split': split, - 'download_retry': download_retry, - 'download_timeout': download_timeout, - 'validate_hash': validate_hash, - 'keep_zip': keep_zip, - } for stream in streams: - stream.apply_default(default) + stream.apply_defaults(split=split, + download_retry=download_retry, + download_timeout=download_timeout, + hash_algos=hash_algos, + validate_hash=validate_hash, + keep_old_phases=keep_old_phases, + keep_zip=keep_zip) else: - default = Stream(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip) - streams = [default] + stream = Stream(remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + hash_algos=hash_algos, + validate_hash=validate_hash, + keep_old_phases=keep_old_phases, + keep_zip=keep_zip) + stream.apply_defaults() + streams = stream, # Validate the stream weighting scheme (relative or absolute) to catch errors before we go # to the trouble of loading them. @@ -455,7 +461,8 @@ def __init__(self, stream_shards = stream.get_shards(world, self.allow_unsafe_types) num_stream_samples = sum(map(len, stream_shards)) if not num_stream_samples: - index_filename = os.path.join(stream.local, stream.split, get_index_basename()) + index_filename = os.path.join(stream.local, stream.split or '', + get_index_basename()) raise RuntimeError(f'Stream contains no samples: {index_filename}.') stream_per_shard += [stream_id] * len(stream_shards) self.shard_offset_per_stream[stream_id] = len(self.shards) @@ -502,11 +509,21 @@ def __init__(self, self.length = ceil(self.epoch_size / world.num_ranks) # Register/lookup our shared memory prefix and filelock root directory. - streams_local = [os.path.abspath(os.path.join(x.local, x.split)) for x in streams] - streams_remote = [ - os.path.join(x.remote, x.split) if x.remote is not None else None for x in streams - ] - self._shm_prefix_int, self._locals_shm = get_shm_prefix(streams_local, streams_remote, + stream_locals = [] + for stream in streams: + local = os.path.join(stream.local, stream.split or '') + local = os.path.abspath(local) + stream_locals.append(local) + + stream_remotes = [] + for stream in streams: + if stream.remote is not None: + remote = os.path.join(stream.remote, stream.split or '') + else: + remote = None + stream_remotes.append(remote) + + self._shm_prefix_int, self._locals_shm = get_shm_prefix(stream_locals, stream_remotes, world) self._filelock_root = os.path.join(os.path.sep, 'tmp', 'streaming') os.makedirs(self._filelock_root, exist_ok=True) @@ -1134,7 +1151,8 @@ def prepare_shard(self, shard_id: int, blocking: bool = True) -> None: # We may need to decompress the shard (if local dir just contains zips). raw_info, _ = shard.file_pairs[0] # Each file pair is present in the same way. - raw_filename = os.path.join(stream.local, stream.split, raw_info.basename) # Find raw. + raw_filename = os.path.join(stream.local, stream.split or '', + raw_info.basename) # Find raw. if not os.path.isfile(raw_filename): # Is raw missing? self._shard_states[shard_id] = _ShardState.PREPARING # Lock the shard. lock.release() # Unblock other workers. diff --git a/streaming/format/score.py b/streaming/format/score.py new file mode 100644 index 000000000..9e4e5cb1f --- /dev/null +++ b/streaming/format/score.py @@ -0,0 +1,346 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A dataset, or sub-dataset if mixing, from which we stream/cache samples.""" + +import os +from hashlib import blake2s +from tempfile import gettempdir +from typing import List, Optional, Sequence, Union +from warnings import warn + +from streaming.distributed import barrier, get_local_rank +from streaming.hashing import is_hash +from streaming.phasing import get_phasings, get_safe_phasing, is_phasing +from streaming.util.auto import Auto, auto +from streaming.util.shorthand import normalize_duration + + +def _normalize_download_retry(download_retry: int) -> int: + """Normalize ``download_retry``. + + Args: + download_retry (int): Input download retry. + + Returns: + int: Normalized download retry. + """ + if download_retry < 0: + raise ValueError(f'Download retry must be non-negative, but got: {download_retry}.') + return download_retry + + +def _normalize_download_timeout(download_timeout: Union[str, float]) -> float: + """Normalize ``download_timeout``. + + Args: + download_timeout (str | float): Input download timeout. + + Returns: + float: Normalized download timeout. + """ + norm_download_timeout = normalize_duration(download_timeout) + if norm_download_timeout <= 0: + raise ValueError(f'Download timeout must be positive, but got: {download_timeout}.') + return norm_download_timeout + + +def _normalize_hash_algos(hash_algos: Optional[Union[str, Sequence[str], Auto]], + validate_hash: Optional[str]) -> List[str]: + """Normalize ``hash_algos`` and ``validate_hash`` (deprecated argument). + + Args: + hash_algos (str | Sequence[str] | Auto, optional): Input hash algos. + validate_hash (str, optional): Input validate hash. + + Returns: + List[str]: Normalized hash algos. + """ + # Normalize `hash_algos`. + if not hash_algos: + norm_hash_algos = None + elif isinstance(hash_algos, str): + norm_hash_algos = [hash_algos] + elif isinstance(hash_algos, Sequence): + norm_hash_algos = list(hash_algos) + else: + norm_hash_algos = None + + # Normalize `validate_hash`. + if validate_hash: + warn(f'`validate_hash` is deprecated. Please use `hash_algos` instead, which also ' + + f'accepts a ranked list specifying the hashing algorithms to attempt to apply.') + norm_validate_hash = [validate_hash] + else: + norm_validate_hash = None + + # Compare and combine normalized `hash_algos` and normalized `validate_hash`. + if not norm_hash_algos: + if not norm_validate_hash: + algos = [] + else: + algos = norm_validate_hash + else: + if not norm_validate_hash: + algos = norm_hash_algos + else: + if norm_hash_algos != norm_validate_hash: + raise ValueError(f'You have specified hashes to check in both the old way and ' + + f'the new way, and also differently: `hash_algos` = ' + + f'{hash_algos}, `validate_hash` = {validate_hash}.') + algos = norm_hash_algos + + # Check each hash algo. + for algo in algos: + if not is_hash(algo): + raise ValueError('Unknown hash algorithm: {algo}.') + + return algos + + +def _normalize_keep_zip(keep_zip: bool) -> str: + """Normalize ``keep_zip`` (deprecated argument). + + Args: + keep_zip (bool): Input keep zip. + + Returns: + str: Normalized phasing. + """ + warn(f'`keep_zip` is deprecated. Please use `keep_old_phases="src"` instead. You stream ' + + f'the earliest form of a file (say, zipped), and access samples from its latest ' + + f'form (say, after unzipping). The intent of the argument is: do we keep that ' + + f'earliest form, so we will be able to stream with this dir as a remote? Options ' + + f'for `keep_old_phases` are {sorted(get_phasings())}.') + return 'src' if keep_zip else 'nil' + + +def _normalize_keep_old_phases(keep_old_phases: Optional[str], keep_zip: Optional[bool]) -> str: + """Normalize ``keep_old_phases`` and ``keep_zip`` (deprecated argument). + + Args: + keep_old_phases (str, optional): Input keep old phases. + keep_zip (bool, optional): Input keep zip. + + Returns: + Normalized phasing. + """ + if keep_old_phases is None: + if keep_zip is None: + phasing = 'nil' + else: + phasing = _normalize_keep_zip(keep_zip) + else: + if keep_zip is None: + phasing = keep_old_phases + else: + norm_keep_zip = _normalize_keep_zip(keep_zip) + if keep_old_phases != norm_keep_zip: + raise ValueError(f'You have specified old phases to keep in both the old way ' + + f'and the new way, and also differently: `keep_old_phases` = ' + + f'{keep_old_phases}, `keep_zip` = {keep_zip}.') + phasing = keep_old_phases + + if not is_phasing(phasing): + raise ValueError('Unknown phasing (i.e., `keep_old_phases` or `keep_zip`): {phasing}.') + + return phasing + + +def _generate_local(remote: str, split: Optional[str]) -> str: + """Derive a local dirname deterministically from remote and optional split. + + Args: + remote (str): Remote path. Must exist. + split (str, optional): Optional split. + + Returns: + str: Local path. + """ + data = remote.encode('utf-8') + hex_digest = blake2s(data, digest_size=16).hexdigest() + return os.path.join(gettempdir(), hex_digest, split or '') + + +class StreamCore: + """The core configuration of a Streaming dataset directory (Stream). + + A StreamingDataset is composed of one/multiple/many Streams. + + Notes: + * Paths: You must provide ``remote`` and/or ``local``. If no ``remote``, the dataset must be + cached. If no ``local``, it deterministically picks ``local`` based on the other paths. + * Splits: This is implemented as sub-path which is appended to ``remote`` and/or ``local`` in + order to derive the root of this Streaming dataset directory (Stream), which all other + dataset paths descend from. E.g., ``/path/to/dataset/index.json`` if ``split=None``, vs + ``/path/to/dataset/train/index.json`` if ``split='train'``. + * Hashing: Trying a hash algorithm means if the Streaming index records the expected hex + digest for this hash of this file, we apply the hash, compare the result to the expected, + and then we are done: either exit success on match, or raise an error on mismatch. If we + are given hash algorithms to apply but the index notes none of them for a file, we raise an + error. Typically, because of the somewhat severe performance impact, hashes are not used + in training. + * Phasing: Streaming downloads shards as their first phase and accesses samples from their + last phase, to which they are converted on the fly. Do we keep the old phases (until the + shard is evicted)? Options are ``nil``, ``src``, and ``all``. ``safe_keep_old_phases`` is + derived from ``keep_old_phases`` -- it is the same, unless there is no separate remote, in + which case ``nil`` is converted to ``src`` (i.e., keep first phase) in order to prevent + making the streaming dataset directory un-streamable by using it. + + Args: + remote (str, optional): Remote path to stream the dataset from. If ``None``, dataset must + be complete locally. Defaults to ``None``. + local (str, optional): Local working directory to stream the dataset to. Uses a temp + directory if not set. Defaults to ``None``. + split (str | Auto, optional): Which dataset split to use, if any. Set to ``auto`` to + inherit from StreamingDataset. Defaults to ``auto``. + download_retry (int | Auto): Number of download re-attempts before raising an error. Set to + ``auto`` to inherit from StreamingDataset. Defaults to ``auto``. + download_timeout (str | float | Auto, optional): Time in seconds to wait for a file + download to complete before raising an error. Streaming duration shorthand (e.g., + ``1m23s``) is also accepted. Set to ``auto`` to inherit from StreamingDataset. Defaults + to ``auto``. + hash_algos (str | Sequence[str] | Auto, optional): Ranked list of hashing algorithms to + try. Set to ``auto`` to inherit from StreamingDataset. Defaults to ``auto``. + validate_hash (str, optional): Deprecated. See ``hash_algos``. Defaults to ``None``. + keep_old_phases (str | Auto): Which old phases of shard files to cache (until shard + eviction). Must be one of ``nil``, ``src``, or ``all``. Set to ``auto`` to inherit from + StreamingDataset. If ``None``, uses ``keep_zip``, falling back to ``nil``. Defaults to + ``None``. + keep_zip (bool, optional): Deprecated. See ``keep_old_phases``. Defaults to ``None``. + """ + + def __init__( + self, + *, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[Union[str, Auto]] = auto, + download_retry: Union[int, Auto] = auto, + download_timeout: Union[str, float, Auto] = auto, + hash_algos: Optional[Union[str, Sequence[str], Auto]] = auto, + validate_hash: Optional[str] = None, + keep_old_phases: Optional[Union[str, Auto]] = auto, + keep_zip: Optional[bool] = None, + ) -> None: + self.remote = remote + + if local is not None: + self.local = local + + if remote is None and local is None: + raise ValueError('Remote and/or local paths must be provided.') + + if not isinstance(split, Auto): + self.split = split + + if local is None and remote is not None and split and isinstance(split, str): + self.local = _generate_local(remote, split) + + if not isinstance(download_retry, Auto): + self.download_retry = _normalize_download_retry(download_retry) + + if not isinstance(download_timeout, Auto): + self.download_timeout = _normalize_download_timeout(download_timeout) + + if not isinstance(hash_algos, Auto) or validate_hash: + self.hash_algos = _normalize_hash_algos(hash_algos, validate_hash) + + if not isinstance(keep_old_phases, Auto): + self.keep_old_phases = _normalize_keep_old_phases(keep_old_phases, keep_zip) + elif keep_zip is not None: + self.keep_old_phases = _normalize_keep_zip(keep_zip) + + if hasattr(self, 'keep_old_phases') and hasattr(self, 'local'): + self.safe_keep_old_phases = get_safe_phasing(self.keep_old_phases, self.remote, + self.local) + + def apply_defaults( + self, + *, + split: Optional[Union[str, Auto]] = auto, + download_retry: Union[int, Auto] = auto, + download_timeout: Union[str, float, Auto] = auto, + hash_algos: Optional[Union[str, Sequence[str], Auto]] = auto, + validate_hash: Optional[str] = None, + keep_old_phases: Optional[Union[str, Auto]] = auto, + keep_zip: Optional[bool] = None, + ) -> None: + """Apply defaults, setting any unknown fields. + + Args: + split (str | Auto, optional): Which dataset split to use, if any. If ``auto``, this + field is skipped and must already have a valid value. Defaults to ``auto``. + download_retry (int | Auto): Number of download re-attempts before raising an error. + If ``auto``, this field is skipped and must already have a valid value. Defaults to + ``auto``. + download_timeout (str | float | Auto, optional): Time in seconds to wait for a file + download to complete before raising an error. Streaming duration shorthand (e.g., + ``1m23s``) is also accepted. If ``auto``, this field is skipped and must already + have a valid value. Defaults to ``auto``. + hash_algos (str | Sequence[str] | Auto, optional): Ranked list of hashing algorithms to + try. If ``auto``, this field is skipped and must already have a valid value. + Defaults to ``auto``. + validate_hash (str, optional): Deprecated. See ``hash_algos``. Defaults to ``None``. + keep_old_phases (str | Auto, optional): Which old phases of shard files to cache (until + shard eviction). If set, must be one of ``nil``, ``src``, or ``all``. If ``None``, + uses ``keep_zip``, falling back to ``nil``. If ``auto``, this field is skipped and + must already have a valid value. Defaults to ``auto``. + keep_zip (bool, optional): Deprecated. See ``keep_old_phases``. Defaults to ``None``. + """ + if not hasattr(self, 'split'): + if isinstance(split, Auto): + raise RuntimeError('Split was not set.') + else: + self.split = split + + if not hasattr(self, 'local'): + if self.remote is None: + raise ValueError('`remote` and/or `local` path must be provided.') + self.local = _generate_local(self.remote, self.split) + + if not get_local_rank(): + if os.path.exists(self.local): + raise ValueError( + f'Could not create a temporary local directory {self.local}. Either ' + + f'delete the directory or specify a unique local directory with the ' + + f'`local` value.') + os.makedirs(self.local) + barrier() + + if not hasattr(self, 'download_retry'): + if isinstance(download_retry, Auto): + raise RuntimeError('Download retry was not set.') + else: + self.download_retry = _normalize_download_retry(download_retry) + + if not hasattr(self, 'download_timeout'): + if isinstance(download_timeout, Auto): + raise RuntimeError('Download timeout was not set.') + else: + self.download_timeout = _normalize_download_timeout(download_timeout) + + if not hasattr(self, 'hash_algos'): + if isinstance(hash_algos, Auto): + raise RuntimeError('Hash algos was not set.') + else: + self.hash_algos = _normalize_hash_algos(hash_algos, validate_hash) + + if not hasattr(self, 'keep_old_phases'): + if isinstance(keep_old_phases, Auto): + raise RuntimeError('Keep old phases was not set.') + else: + self.keep_old_phases = _normalize_keep_old_phases(keep_old_phases, keep_zip) + + if not hasattr(self, 'safe_keep_old_phases'): + self.safe_keep_old_phases = get_safe_phasing(self.keep_old_phases, self.remote, + self.local) + + @property + def safe_keep_zip(self) -> bool: + """Derive ``safe_keep_zip`` for existing code. + + Returns: + bool: Whether to keep the zip phase of files. + """ + return self.safe_keep_old_phases != 'nil' diff --git a/streaming/phasing.py b/streaming/phasing.py new file mode 100644 index 000000000..88ff8ad43 --- /dev/null +++ b/streaming/phasing.py @@ -0,0 +1,66 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Handling for the phasing out of old forms of shard files.""" + +from typing import Optional, Set + +# TODO +_keep_all_phases2phase_out = { + 'origin': None, + 'active': None, + 'both': None, + 'all': None, +} + +# TODO +_keep_old_phases2phase_out = { + 'nil': None, + 'src': None, + 'all': None, +} + + +def get_phasings() -> Set[str]: + """Get all possible values of phasing. + + Returns: + Set[str]: All phasings. + """ + return set(_keep_old_phases2phase_out) + + +def is_phasing(phasing: str) -> bool: + """Determine whether the given str is a valid phasing. + + Args: + phasing (Str): The purported phasing. + + Returns: + bool: Whether it is a phasing. + """ + return phasing in _keep_old_phases2phase_out + + +def get_safe_phasing(phasing: str, remote: Optional[str], local: str) -> str: + """Get a phasing value which protects against destroying a dataset in-place. + + That is, you need the source form to be able to stream from it, but the final form to be able + to use it. Do you drop the source phase (``nil``), keep the source (``src``), or keep all + phases (``all``)? + + Args: + phasing (str): Unsafe phasing. + remote (str, optional): Remote path. + local (str): Local dirname. + + Returns: + str: Safe phasing. + """ + if remote not in {None, local}: + return phasing + + if phasing != 'nil': + return phasing + + return 'src' diff --git a/streaming/stream.py b/streaming/stream.py index 3c3735e9f..013153528 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -3,11 +3,9 @@ """A dataset, or sub-dataset if mixing, from which we stream/cache samples.""" -import hashlib import json import os -import tempfile -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import numpy as np from numpy.typing import NDArray @@ -15,95 +13,99 @@ from streaming.compression import decompress from streaming.constant import TICK -from streaming.distributed import barrier, get_local_rank from streaming.format import FileInfo, Shard, get_index_basename, shard_from_json +from streaming.format.score import StreamCore from streaming.hashing import get_hash from streaming.storage import download_file, wait_for_file_to_exist -from streaming.util import retry +from streaming.util.auto import Auto, auto +from streaming.util.retrying import retry +from streaming.util.shorthand import normalize_count from streaming.world import World -class Stream: - """A dataset, or sub-dataset if mixing, from which we stream/cache samples. - - We initialize a StreamingDataset with one or more Streams. Streams may be resampled to achieve - different mixtures of samples. - - Stream init takes three kinds of arguments: - - * At least one of ``remote`` and ``local`` must exist. If no ``remote``, the data must be - local. If no ``local``, we cache to a temp directory. - - * ``remote`` - * ``local`` - - * At most one of ``proportion``, ``repeat``, or ``choose`` may exist. If provided one of these, - we derive the rest. Note that ``proportion`` (relative) and ``repeat``/``choose`` (absolute) - are mutually incompatible -- you must entirely use one or the other (or neither) for all - sub-datasets. If none are provided for all streams and ``epoch_size`` is unspecified, then - each sample from each stream is seen once per epoch. If none are provided for all streams - and ``epoch_size`` is specified, then streams are sampled in proportion to their size. - - * ``proportion`` - * ``repeat`` - * ``choose`` - - * The remaining arguments are optional knobs for controlling downloading behavior and default - to ``None``. If ``None``, they take a default value provided to or by the StreamingDataset - init. - - * ``split`` - * ``download_retry`` - * ``download_timeout`` - * ``validate_hash`` - * ``keep_zip`` +class Stream(StreamCore): + """A Streaming dataset directory. + + A StreamingDataset is composed of one/multiple/many Streams. + + Notes: + * Weights: ``proportion`` is relative, and ``repeat``, ``choose``, and nothing are absolute. + Relative and absolute weighting cannot be mixed. If weighting relatively and ``epoch_size`` + is not provided, takes the total number of underlying samples as the epoch size. + * Paths: You must provide ``remote`` and/or ``local``. If no ``remote``, the dataset must be + cached. If no ``local``, it deterministically picks ``local`` based on the other paths. + * Splits: This is implemented as sub-path which is appended to ``remote`` and/or ``local`` in + order to derive the root of this Streaming dataset directory (Stream), which all other + dataset paths descend from. E.g., ``/path/to/dataset/index.json`` if ``split=None``, vs + ``/path/to/dataset/train/index.json`` if ``split='train'``. + * Hashing: Trying a hash algorithm means if the Streaming index records the expected hex + digest for this hash of this file, we apply the hash, compare the result to the expected, + and then we are done: either exit success on match, or raise an error on mismatch. If we + are given hash algorithms to apply but the index notes none of them for a file, we raise an + error. Typically, because of the somewhat severe performance impact, hashes are not used + in training. + * Phasing: Streaming downloads shards as their first phase and accesses samples from their + last phase, to which they are converted on the fly. Do we keep the old phases (until the + shard is evicted)? Options are ``nil``, ``src``, and ``all``. ``safe_keep_old_phases`` is + derived from ``keep_old_phases`` -- it is the same, unless there is no separate remote, in + which case ``nil`` is converted to ``src`` (i.e., keep first phase) in order to prevent + making the streaming dataset directory un-streamable by using it. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. Defaults to - ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - proportion (float, optional): How much to upsample or downsample this sub-dataset, as the - proportion of the total combined dataset that consists of this sub-dataset. If - using proportions, all sub-datasets provided together to the StreamingDataset init must - define their proportions. The total combined number of samples is either the - StreamingDataset argument "epoch_size" if provided, or kept the same total size as the - underlying data if not. If provided, must be non-negative. Defaults to ``None``. - repeat (float, optional): How much to upsample or downsample this sub-dataset, as a - multipler on the number of samples. If provided, must be non-negative. Defaults to - ``None``. - choose (int, optional): How much to upsample or downsample this sub-dataset, as the exact - number of resulting samples. If provided, must be non-negative. Defaults to ``None``. - download_retry (int, optional): Number of download re-attempts before giving up. Defaults - to ``None``. - download_timeout (float, optional): Number of seconds to wait for a shard to download - before raising an exception. Defaults to ``None``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool, optional): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep if and only if remote is local or no remote. + proportion (float, optional): The proportion of this StreamingDataset's samples that are + sampled from this Stream. As this is a relative measure, use ``epoch_size`` to + determine the absolute resulting size in samples. Defaults to ``None``. + repeat (float, optional): Stream size multiplier, aka number of times to see each of this + Stream's samples per epoch. Defaults to ``None``. + choose (str | int, optional): Stream size, aka number of samples to draw from this Stream + per epoch. Defaults to ``None``. + remote (str, optional): Remote path to stream the dataset from. If ``None``, dataset must + be complete locally. Defaults to ``None``. + local (str, optional): Local working directory to stream the dataset to. Uses a temp + directory if not set. Defaults to ``None``. + split (str | Auto, optional): Which dataset split to use, if any. Set to ``auto`` to + inherit from StreamingDataset. Defaults to ``auto``. + download_retry (int | Auto): Number of download re-attempts before raising an error. Set to + ``auto`` to inherit from StreamingDataset. Defaults to ``auto``. + download_timeout (str | float | Auto, optional): Time in seconds to wait for a file + download to complete before raising an error. Streaming duration shorthand (e.g., + ``1m23s``) is also accepted. Set to ``auto`` to inherit from StreamingDataset. Defaults + to ``auto``. + hash_algos (str | Sequence[str] | Auto, optional): Ranked list of hashing algorithms to + try. Set to ``auto`` to inherit from StreamingDataset. Defaults to ``auto``. + validate_hash (str, optional): Deprecated. See ``hash_algos``. Defaults to ``None``. + keep_old_phases (str | Auto, optional): Which old phases of shard files to cache (until + shard eviction). Must be one of ``nil``, ``src``, or ``all``. Set to ``auto`` to + inherit from StreamingDataset. If ``None``, uses ``keep_zip``, falling back to ``nil``. Defaults to ``None``. + keep_zip (bool, optional): Deprecated. See ``keep_old_phases``. Defaults to ``None``. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - proportion: Optional[float] = None, - repeat: Optional[float] = None, - choose: Optional[int] = None, - download_retry: Optional[int] = None, - download_timeout: Optional[float] = None, - validate_hash: Optional[str] = None, - keep_zip: Optional[bool] = None) -> None: - self.remote = remote - self._local = local - self.split = split or '' - + def __init__( + self, + *, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[Union[str, int]] = None, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[Union[str, Auto]] = auto, + download_retry: Union[int, Auto] = auto, + download_timeout: Union[str, float, Auto] = auto, + hash_algos: Optional[Union[str, Sequence[str], Auto]] = auto, + validate_hash: Optional[str] = None, + keep_old_phases: Optional[Union[str, Auto]] = auto, + keep_zip: Optional[bool] = None, + ) -> None: + super().__init__(remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + hash_algos=hash_algos, + validate_hash=validate_hash, + keep_old_phases=keep_old_phases, + keep_zip=keep_zip) has_proportion = proportion is not None has_repeat = repeat is not None has_choose = choose is not None @@ -111,89 +113,26 @@ def __init__(self, raise ValueError('At most one of `proportion`, `repeat`, and `choose` may be ' + 'specified; the others are derived') - self._proportion = proportion if proportion is not None: if proportion < 0: raise ValueError('`proportion` must be non-negative') self.proportion = proportion - self._repeat = repeat if repeat is not None: if repeat < 0: raise ValueError('`repeat` must be non-negative') self.repeat = repeat - self._choose = choose if choose is not None: - if choose < 0: + self.choose = normalize_count(choose) + if self.choose < 0: raise ValueError('`choose` must be non-negative') - self.choose = choose - - self._download_retry = download_retry - if download_retry is not None: - if download_retry < 0: - raise ValueError('`download_retry` must be non-negative') - self.download_retry = download_retry - - self._download_timeout = download_timeout - if download_timeout is not None: - if download_timeout <= 0: - raise ValueError('`download_timeout` must be positive') - self.download_timeout = download_timeout - - self.validate_hash = validate_hash - - if local is None: - self.local = self._get_temporary_directory() - if get_local_rank() == 0: - if os.path.exists(self.local): - raise ValueError( - f'Could not create a temporary local directory {self.local} . Either ' + - f'delete the directory or specify a unique local directory with the ' + - f'`local` value.') - os.makedirs(self.local) - barrier() - else: - self.local = local - - self._keep_zip = keep_zip - if keep_zip is not None: - self.keep_zip = keep_zip - self.safe_keep_zip = self.keep_zip or self.remote in {None, self.local} - - def _get_temporary_directory(self) -> str: - """Construct a path to a temporary directory based on remote and split.""" - root = tempfile.gettempdir() - hash = '' - if self.remote is not None: - hash = hashlib.blake2s(self.remote.encode('utf-8'), digest_size=16).hexdigest() - return os.path.join(root, hash, self.split) - - def apply_default(self, default: dict) -> None: - """Apply defaults, setting any unset fields. - - We use pairs of (name, _name) in order to make type checking happy. - - Args: - default (Self): Stream containing default values for all optional fields. - """ - if not (self.remote or self._local): - raise ValueError('`remote` and/or `local` path must be provided') - - if not self.split: - self.split = default['split'] or '' - if self._download_retry is None: - self.download_retry = default['download_retry'] - if self._download_timeout is None: - self.download_timeout = default['download_timeout'] - if self.validate_hash is None: - self.validate_hash = default['validate_hash'] or None - if self._keep_zip is None: - self.keep_zip = default['keep_zip'] - self.safe_keep_zip = default['keep_zip'] or self.remote in {None, self.local} @classmethod - def validate_weights(cls, streams: Sequence[Self]) -> Tuple[bool, bool]: + def validate_weights( + cls, + streams: Sequence[Self], + ) -> Tuple[bool, bool]: """Validate stream weights, returning whether relative or absolute weighting was used. Args: @@ -221,8 +160,13 @@ def validate_weights(cls, streams: Sequence[Self]) -> Tuple[bool, bool]: return is_proportional, is_unspecified @classmethod - def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.int64], - choose_per_epoch: Optional[int], seed: int) -> int: + def apply_weights( + cls, + streams: Sequence[Self], + samples_per_stream: NDArray[np.int64], + choose_per_epoch: Optional[int], + seed: int, + ) -> int: """Given samples per stream, derive each stream's proportion/repeat/samples. Modifies streams to save the derived weights. @@ -289,7 +233,11 @@ def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.i return choose_per_epoch - def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + def _download_file( + self, + from_basename: str, + to_basename: Optional[str] = None, + ) -> str: """Safely download a file from remote to local cache. Args: @@ -303,8 +251,8 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) if self.remote is None: remote = None else: - remote = os.path.join(self.remote, self.split, from_basename) - local = os.path.join(self.local, self.split, to_basename or from_basename) + remote = os.path.join(self.remote, self.split or '', from_basename) + local = os.path.join(self.local, self.split or '', to_basename or from_basename) # Attempt to download, possibly repeating on failure. retry(num_attempts=self.download_retry)( @@ -312,8 +260,13 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) return local - def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_filename: str, - compression: Optional[str]) -> None: + def _decompress_shard_part( + self, + zip_info: FileInfo, + zip_filename: str, + raw_filename: str, + compression: Optional[str], + ) -> None: """Validate and decompress shard data. Args: @@ -326,14 +279,18 @@ def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_file data = open(zip_filename, 'rb').read() # Validate what was downloaded. - if self.validate_hash: - if self.validate_hash not in zip_info.hashes: - raise ValueError( - f'Hash algorithm `{self.validate_hash}` chosen for data ' + + if self.hash_algos: + for algo in self.hash_algos: + if algo in zip_info.hashes: + if get_hash(algo, data) == zip_info.hashes[algo]: + break + else: + raise RuntimeError(f'Hash check failure: {zip_filename}.') + else: + raise RuntimeError( + f'Hash algorithms `{self.hash_algos}` chosen for data ' + f'validation does not match with those provided during dataset ' + - f'creation `{sorted(zip_info.hashes.keys())}`. Provide one of those.') - if get_hash(self.validate_hash, data) != zip_info.hashes[self.validate_hash]: - raise ValueError(f'Checksum failure: {zip_filename}') + f'creation: `{sorted(zip_info.hashes)}`. Provide one of those.') # Decompress and save that. data = decompress(compression, data) # pyright: ignore @@ -346,10 +303,12 @@ def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_file if not self.safe_keep_zip: os.remove(zip_filename) - def _prepare_shard_part(self, - raw_info: FileInfo, - zip_info: Optional[FileInfo] = None, - compression: Optional[str] = None) -> int: + def _prepare_shard_part( + self, + raw_info: FileInfo, + zip_info: Optional[FileInfo] = None, + compression: Optional[str] = None, + ) -> int: """Get shard data given metadata for the raw and compressed versions of it. Shards are either mono shards (one file per shard, like MDS) or dual shards (a pair of data @@ -366,11 +325,11 @@ def _prepare_shard_part(self, """ # Has raw? delta = 0 - raw_filename = os.path.join(self.local, self.split, raw_info.basename) + raw_filename = os.path.join(self.local, self.split or '', raw_info.basename) if os.path.isfile(raw_filename): # Has raw. if zip_info and not self.safe_keep_zip: - zip_filename = os.path.join(self.local, self.split, zip_info.basename) + zip_filename = os.path.join(self.local, self.split or '', zip_info.basename) if os.path.isfile(zip_filename): # If don't keep zip and it has a zip, drop the zip. os.remove(zip_filename) @@ -379,7 +338,7 @@ def _prepare_shard_part(self, # Missing raw. Uses zip? if zip_info: # Ensure has zip. - zip_filename = os.path.join(self.local, self.split, zip_info.basename) + zip_filename = os.path.join(self.local, self.split or '', zip_info.basename) if not os.path.isfile(zip_filename): self._download_file(zip_info.basename) delta += zip_info.bytes @@ -395,18 +354,26 @@ def _prepare_shard_part(self, delta += raw_info.bytes # Validate. - if self.validate_hash: - if self.validate_hash not in raw_info.hashes: + if self.hash_algos: + data = open(raw_filename, 'rb').read() + for algo in self.hash_algos: + if algo in raw_info.hashes: + if get_hash(algo, data) == raw_info.hashes[algo]: + break + else: + raise RuntimeError(f'Hash check failure: {raw_filename}.') + else: raise ValueError( - f'Hash algorithm `{self.validate_hash}` chosen for data ' + + f'Hash algorithms `{self.hash_algos}` chosen for data ' + f'validation does not match with those provided during dataset ' + - f'creation `{sorted(raw_info.hashes.keys())}`. Provide one of those.') - data = open(raw_filename, 'rb').read() - if get_hash(self.validate_hash, data) != raw_info.hashes[self.validate_hash]: - raise ValueError(f'Checksum failure: {raw_filename}') + f'creation: `{sorted(raw_info.hashes.keys())}`. Provide one of those.') + return delta - def prepare_shard(self, shard: Shard) -> int: + def prepare_shard( + self, + shard: Shard, + ) -> int: """Ensure (download, validate, extract, etc.) that we have the given shard. Args: @@ -420,7 +387,11 @@ def prepare_shard(self, shard: Shard) -> int: delta += self._prepare_shard_part(raw_info, zip_info, shard.compression) return delta - def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Shard]: + def get_shards( + self, + world: World, + allow_unsafe_types: bool, + ) -> List[Shard]: """Load this Stream's index, retrieving its shard readers. Args: @@ -434,7 +405,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Shard]: """ # Download the index file if it does not exist locally. basename = get_index_basename() - filename = os.path.join(self.local, self.split, basename) # pyright: ignore + filename = os.path.join(self.local, self.split or '', basename) # pyright: ignore if not os.path.exists(filename): if world.is_local_leader: if self.remote: @@ -476,7 +447,11 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Shard]: return shards - def set_up_local(self, shards: List[Shard], cache_usage_per_shard: NDArray[np.int64]) -> None: + def set_up_local( + self, + shards: List[Shard], + cache_usage_per_shard: NDArray[np.int64], + ) -> None: """Bring a local directory into a consistent state, getting which shards are present. Args: @@ -484,7 +459,7 @@ def set_up_local(self, shards: List[Shard], cache_usage_per_shard: NDArray[np.in cache_usage_per_shard (NDArray[np.int64]): Cache usage per shard of this stream. """ # List the cache directory (so that we hit the filesystem once). - local_dirname = os.path.join(self.local, self.split) + local_dirname = os.path.join(self.local, self.split or '') listing = set() for dirname, _, subfiles in os.walk(local_dirname): for subfile in subfiles: @@ -495,11 +470,11 @@ def set_up_local(self, shards: List[Shard], cache_usage_per_shard: NDArray[np.in for i, shard in enumerate(shards): cache_usage_per_shard[i] = shard.set_up_local(listing, self.safe_keep_zip) - def get_index_size(self) -> int: + def get_index_size(self,) -> int: """Get the size of the index file in bytes. Returns: int: Size in bytes. """ - filename = os.path.join(self.local, self.split, get_index_basename()) + filename = os.path.join(self.local, self.split or '', get_index_basename()) return os.stat(filename).st_size diff --git a/streaming/util/__init__.py b/streaming/util/__init__.py index d17cf214d..021b75a0c 100644 --- a/streaming/util/__init__.py +++ b/streaming/util/__init__.py @@ -3,6 +3,7 @@ """Utilities for streaming.""" +from streaming.util.auto import Auto, auto, is_auto from streaming.util.importing import get_import_exception_message, redirect_imports from streaming.util.merging import merge_index from streaming.util.retrying import retry @@ -13,7 +14,7 @@ from streaming.util.tabulation import Tabulator __all__ = [ - 'get_import_exception_message', 'redirect_imports', 'merge_index', 'retry', - 'clean_stale_shared_memory', 'get_list_arg', 'get_str2str_arg', 'normalize_dec_bytes', + 'Auto', 'auto', 'is_auto', 'get_import_exception_message', 'redirect_imports', 'merge_index', + 'retry', 'clean_stale_shared_memory', 'get_list_arg', 'get_str2str_arg', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_duration', 'Tabulator' ] diff --git a/streaming/util/auto.py b/streaming/util/auto.py new file mode 100644 index 000000000..7c027dd23 --- /dev/null +++ b/streaming/util/auto.py @@ -0,0 +1,36 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A magical argument keyword that means derive this argument's value automatically.""" + +from typing import Any + +__all__ = ['Auto', 'auto', 'is_auto'] + + +class Auto: + """A magical argument keyword that means derive this argument's value automatically. + + This is useful when your argument's type doesn't have any blank space like ``0`` or ``''`` in + this method's usage, ``None`` has its own productive meaning, and using a different type would + be ugly and hard to follow. + """ + pass + + +# The singleton instance of this class. +auto = Auto() + + +def is_auto(arg: Any) -> bool: + """Wrap the is-auto checking hack. + + Typechecking is not satisfied with `is auto`, you have to do `isinstance(Auto)`. + + Args: + arg (Any): The argument. + + Returns: + bool: Whether the argument is auto. + """ + return isinstance(arg, Auto) diff --git a/tests/test_stream.py b/tests/test_stream.py index 818c19ae8..1140fcfd0 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -23,6 +23,11 @@ def test_local_is_none_with_no_split() -> None: shutil.rmtree(local, ignore_errors=True) barrier() stream = Stream(remote=remote, local=None) + stream.apply_defaults(split=None, + download_retry=2, + download_timeout='1m', + hash_algos=None, + keep_old_phases=None) assert local == stream.local shutil.rmtree(local, ignore_errors=True) @@ -35,6 +40,10 @@ def test_local_is_none_with_split() -> None: shutil.rmtree(local, ignore_errors=True) barrier() stream = Stream(remote=remote, local=None, split='train') + stream.apply_defaults(download_retry=2, + download_timeout='1m', + hash_algos=None, + keep_old_phases=None) assert local == stream.local shutil.rmtree(local, ignore_errors=True) @@ -52,7 +61,7 @@ def test_local_exists(split: Optional[str]) -> None: def test_existing_local_raises_exception(monkeypatch: MonkeyPatch) -> None: local = tempfile.mkdtemp() monkeypatch.setattr(tempfile, 'gettempdir', lambda: local) - with pytest.raises(ValueError, match=f'Could not create a temporary local directory.*'): + with pytest.raises(ValueError): _ = Stream() shutil.rmtree(local, ignore_errors=True)