From 3972c9d4f98b677d27465f9fac629c97efefc2d7 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 15 Dec 2023 07:47:16 -0800 Subject: [PATCH] Improve naming: JSON shards are actually JSONL, etc. (#537) * Stdize docstrings, also fix ordering of get_sample_data, decode_sample. * Terminology: "joint" -> "mono". * "split" -> "dual" to stop confusing people (SplitWriter != dataaset splits) * "Reader" -> "Shard". They manage shards. They do more than read. * Fix filenames accordingly. * Finally, JSON -> JSONL. * Switch order of decorators... * Fix markdown code. --- STYLE_GUIDE.md | 4 +- benchmarks/backends/write.py | 4 +- benchmarks/samples/bench_and_plot.py | 4 +- streaming/__init__.py | 6 +- streaming/format/__init__.py | 47 ++++-- streaming/format/json/__init__.py | 9 -- streaming/format/{json => jsonl}/README.md | 10 +- streaming/format/jsonl/__init__.py | 9 ++ streaming/format/{json => jsonl}/encodings.py | 10 +- .../format/{json/reader.py => jsonl/shard.py} | 42 ++--- streaming/format/{json => jsonl}/writer.py | 22 +-- streaming/format/mds/__init__.py | 6 +- streaming/format/mds/{reader.py => shard.py} | 56 +++---- streaming/format/mds/writer.py | 10 +- streaming/format/{reader.py => shard.py} | 45 ++++-- streaming/format/writer.py | 24 +-- streaming/format/xsv/__init__.py | 6 +- streaming/format/xsv/{reader.py => shard.py} | 52 +++---- streaming/format/xsv/writer.py | 10 +- streaming/local.py | 4 +- streaming/stream.py | 21 ++- tests/test_encodings.py | 146 +++++++++--------- tests/test_writer.py | 26 ++-- 23 files changed, 301 insertions(+), 272 deletions(-) delete mode 100644 streaming/format/json/__init__.py rename streaming/format/{json => jsonl}/README.md (83%) create mode 100644 streaming/format/jsonl/__init__.py rename streaming/format/{json => jsonl}/encodings.py (86%) rename streaming/format/{json/reader.py => jsonl/shard.py} (88%) rename streaming/format/{json => jsonl}/writer.py (88%) rename streaming/format/mds/{reader.py => shard.py} (95%) rename streaming/format/{reader.py => shard.py} (94%) rename streaming/format/xsv/{reader.py => shard.py} (96%) diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 67156e2a0..4b888accd 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -207,10 +207,10 @@ For example, from [streaming/dataset.py](streaming/dataset.py) """The :class:`Dataset` class, used for building streaming iterable datasets.""" from torch.utils.data import IterableDataset -from streaming.format import reader_from_json +from streaming.format import shard_from_json from streaming.spanner import Spanner -__all__ = ["Dataset"] # export only the Dataset, not other imports like `Spanner` or `reader_from_json` +__all__ = ["Dataset"] # Export `Dataset` only, not the others e.g. `Spanner` or `shard_from_json`. class Dataset(IterableDataset): diff --git a/benchmarks/backends/write.py b/benchmarks/backends/write.py index 9404e5d0a..78c85bbfb 100644 --- a/benchmarks/backends/write.py +++ b/benchmarks/backends/write.py @@ -22,7 +22,7 @@ from wurlitzer import pipes from benchmarks.backends.datagen import generate -from streaming import CSVWriter, JSONWriter, MDSWriter +from streaming import CSVWriter, JSONLWriter, MDSWriter from streaming.util.tabulation import Tabulator @@ -108,7 +108,7 @@ def _write_jsonl(nums: List[int], 'num': 'int', 'txt': 'str', } - with JSONWriter(out=root, columns=columns, size_limit=size_limit) as out: + with JSONLWriter(out=root, columns=columns, size_limit=size_limit) as out: each_sample = zip(nums, txts) if show_progress: each_sample = tqdm(each_sample, total=len(nums), leave=False) diff --git a/benchmarks/samples/bench_and_plot.py b/benchmarks/samples/bench_and_plot.py index 31307ff32..306049875 100644 --- a/benchmarks/samples/bench_and_plot.py +++ b/benchmarks/samples/bench_and_plot.py @@ -17,7 +17,7 @@ from numpy.typing import DTypeLike, NDArray from tqdm import trange -from streaming import CSVWriter, JSONWriter, MDSWriter, StreamingDataset +from streaming import CSVWriter, JSONLWriter, MDSWriter, StreamingDataset def parse_args() -> Namespace: @@ -244,7 +244,7 @@ def bench(args: Namespace, bench_name: str, desc: str, generate: Callable, format_infos = [ ('mds', MDSWriter, args.mds_color), - ('jsonl', JSONWriter, args.jsonl_color), + ('jsonl', JSONLWriter, args.jsonl_color), ('csv', CSVWriter, args.csv_color), ] format_infos = list(filter(lambda info: info[0] in formats, format_infos)) diff --git a/streaming/__init__.py b/streaming/__init__.py index 45ca3f1cf..c8efa5a36 100644 --- a/streaming/__init__.py +++ b/streaming/__init__.py @@ -6,12 +6,12 @@ from streaming._version import __version__ from streaming.dataloader import StreamingDataLoader from streaming.dataset import StreamingDataset -from streaming.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter +from streaming.format import CSVWriter, JSONLWriter, MDSWriter, TSVWriter, XSVWriter from streaming.local import LocalDataset from streaming.stream import Stream from streaming.util import clean_stale_shared_memory __all__ = [ - 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', - 'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory' + 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONLWriter', + 'LocalDataset', 'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory' ] diff --git a/streaming/format/__init__.py b/streaming/format/__init__.py index bbec4927e..dec5ac15c 100644 --- a/streaming/format/__init__.py +++ b/streaming/format/__init__.py @@ -1,32 +1,45 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Individual dataset writer for every format.""" +"""Streaming serialization format, consisting of an index and multiple types of shards.""" from typing import Any, Dict, Optional from streaming.format.index import get_index_basename -from streaming.format.json import JSONReader, JSONWriter -from streaming.format.mds import MDSReader, MDSWriter -from streaming.format.reader import FileInfo, Reader -from streaming.format.xsv import CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter +from streaming.format.jsonl import JSONLShard, JSONLWriter +from streaming.format.mds import MDSShard, MDSWriter +from streaming.format.shard import FileInfo, Shard +from streaming.format.xsv import CSVShard, CSVWriter, TSVShard, TSVWriter, XSVShard, XSVWriter __all__ = [ - 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'Reader', - 'reader_from_json', 'TSVWriter', 'XSVWriter' + 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONLWriter', 'MDSWriter', 'Shard', + 'shard_from_json', 'TSVWriter', 'XSVWriter' ] -_readers = { - 'csv': CSVReader, - 'json': JSONReader, - 'mds': MDSReader, - 'tsv': TSVReader, - 'xsv': XSVReader +# Mapping of shard metadata dict "format" field to what type of Shard it is. +_shards = { + 'csv': CSVShard, + 'jsonl': JSONLShard, + 'mds': MDSShard, + 'tsv': TSVShard, + 'xsv': XSVShard, } -def reader_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Reader: - """Initialize the reader from JSON object. +def _get_shard_class(format_name: str) -> Shard: + """Get the associated Shard class given a Shard format name. + + Args: + format_name (str): Shard format name. + """ + # JSONL shards were originally called JSON shards (while containing JSONL). + if format_name == 'json': + format_name = 'jsonl' + return _shards[format_name] + + +def shard_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Shard: + """Create a shard from a JSON config. Args: dirname (str): Local directory containing shards. @@ -34,8 +47,8 @@ def reader_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> obj (Dict[str, Any]): JSON object to load. Returns: - Reader: Loaded Reader of `format` type + Shard: The loaded Shard. """ assert obj['version'] == 2 - cls = _readers[obj['format']] + cls = _get_shard_class(obj['format']) return cls.from_json(dirname, split, obj) diff --git a/streaming/format/json/__init__.py b/streaming/format/json/__init__.py deleted file mode 100644 index 47e8be8f6..000000000 --- a/streaming/format/json/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Module to write and read the dataset in JSON format.""" - -from streaming.format.json.reader import JSONReader -from streaming.format.json.writer import JSONWriter - -__all__ = ['JSONReader', 'JSONWriter'] diff --git a/streaming/format/json/README.md b/streaming/format/jsonl/README.md similarity index 83% rename from streaming/format/json/README.md rename to streaming/format/jsonl/README.md index 13cd1fd99..59def4e38 100644 --- a/streaming/format/json/README.md +++ b/streaming/format/jsonl/README.md @@ -7,14 +7,14 @@ Example: "words": "str" }, "compression": "zstd:7", - "format": "json", + "format": "jsonl", "hashes": [ "sha1", "xxh3_64" ], "newline": "\n", "raw_data": { - "basename": "shard.00000.json", + "basename": "shard.00000.jsonl", "bytes": 1048546, "hashes": { "sha1": "bfb6509ba6f041726943ce529b36a1cb74e33957", @@ -22,7 +22,7 @@ Example: } }, "raw_meta": { - "basename": "shard.00000.json.meta", + "basename": "shard.00000.jsonl.meta", "bytes": 53590, "hashes": { "sha1": "15ae80e002fe625b0b18f1a45058532ee867fa9b", @@ -33,7 +33,7 @@ Example: "size_limit": 1048576, "version": 2, "zip_data": { - "basename": "shard.00000.json.zstd", + "basename": "shard.00000.jsonl.zstd", "bytes": 149268, "hashes": { "sha1": "7d45c600a71066ca8d43dbbaa2ffce50a91b735e", @@ -41,7 +41,7 @@ Example: } }, "zip_meta": { - "basename": "shard.00000.json.meta.zstd", + "basename": "shard.00000.jsonl.meta.zstd", "bytes": 42180, "hashes": { "sha1": "f64477cca5d27fc3a0301eeb4452ef7310cbf670", diff --git a/streaming/format/jsonl/__init__.py b/streaming/format/jsonl/__init__.py new file mode 100644 index 000000000..53d630a3e --- /dev/null +++ b/streaming/format/jsonl/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming JSONL shards.""" + +from streaming.format.jsonl.shard import JSONLShard +from streaming.format.jsonl.writer import JSONLWriter + +__all__ = ['JSONLShard', 'JSONLWriter'] diff --git a/streaming/format/json/encodings.py b/streaming/format/jsonl/encodings.py similarity index 86% rename from streaming/format/json/encodings.py rename to streaming/format/jsonl/encodings.py index 215b8ee36..2f3048e8f 100644 --- a/streaming/format/json/encodings.py +++ b/streaming/format/jsonl/encodings.py @@ -1,16 +1,16 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Check whether sample encoding is of supported JSON types.""" +"""Check whether sample encoding is of supported JSONL types.""" from abc import ABC, abstractmethod from typing import Any -__all__ = ['is_json_encoded', 'is_json_encoding'] +__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding'] class Encoding(ABC): - """Encoding of an object of JSON type.""" + """Encoding of an object of JSONL type.""" @classmethod @abstractmethod @@ -60,7 +60,7 @@ def is_encoded(cls, obj: Any) -> bool: _encodings = {'str': Str, 'int': Int, 'float': Float} -def is_json_encoded(encoding: str, value: Any) -> bool: +def is_jsonl_encoded(encoding: str, value: Any) -> bool: """Get whether the given object is of this encoding type. Args: @@ -74,7 +74,7 @@ def is_json_encoded(encoding: str, value: Any) -> bool: return cls.is_encoded(value) -def is_json_encoding(encoding: str) -> bool: +def is_jsonl_encoding(encoding: str) -> bool: """Get whether the given encoding is supported. Args: diff --git a/streaming/format/json/reader.py b/streaming/format/jsonl/shard.py similarity index 88% rename from streaming/format/json/reader.py rename to streaming/format/jsonl/shard.py index 698783d71..985b75684 100644 --- a/streaming/format/json/reader.py +++ b/streaming/format/jsonl/shard.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`JSONReader` reads samples from `.json` files that were written by :class:`MDSWriter`.""" +"""Streaming JSONL shard reading.""" import json import os @@ -11,13 +11,13 @@ import numpy as np from typing_extensions import Self -from streaming.format.reader import FileInfo, SplitReader +from streaming.format.shard import DualShard, FileInfo -__all__ = ['JSONReader'] +__all__ = ['JSONLShard'] -class JSONReader(SplitReader): - """Provides random access to the samples of a JSON shard. +class JSONLShard(DualShard): + """Provides random access to the samples of a JSONL shard. Args: dirname (str): Local dataset directory. @@ -68,7 +68,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded JSONReader. + Self: Loaded JSONLShard. """ args = deepcopy(obj) # Version check. @@ -77,9 +77,9 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S f'Expected version 2.') del args['version'] # Check format. - if args['format'] != 'json': - raise ValueError(f'Unsupported data format: {args["format"]}. ' + - f'Expected to be `json`.') + if args['format'] not in {'json', 'jsonl'}: + raise ValueError(f'Unsupported data format: got {args["format"]}, but expected ' + + f'"jsonl" (or "json").') del args['format'] args['dirname'] = dirname args['split'] = split @@ -88,18 +88,6 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S args[key] = FileInfo(**arg) if arg else None return cls(**args) - def decode_sample(self, data: bytes) -> Dict[str, Any]: - """Decode a sample dict from bytes. - - Args: - data (bytes): The sample encoded as bytes. - - Returns: - Dict[str, Any]: Sample dict. - """ - text = data.decode('utf-8') - return json.loads(text) - def get_sample_data(self, idx: int) -> bytes: """Get the raw sample data at the index. @@ -120,3 +108,15 @@ def get_sample_data(self, idx: int) -> bytes: fp.seek(begin) data = fp.read(end - begin) return data + + def decode_sample(self, data: bytes) -> Dict[str, Any]: + """Decode a sample dict from bytes. + + Args: + data (bytes): The sample encoded as bytes. + + Returns: + Dict[str, Any]: Sample dict. + """ + text = data.decode('utf-8') + return json.loads(text) diff --git a/streaming/format/json/writer.py b/streaming/format/jsonl/writer.py similarity index 88% rename from streaming/format/json/writer.py rename to streaming/format/jsonl/writer.py index b0117a47f..99a01d14f 100644 --- a/streaming/format/json/writer.py +++ b/streaming/format/jsonl/writer.py @@ -1,21 +1,21 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`JSONWriter` writes samples to `.json` files that can be read by :class:`JSONReader`.""" +"""Streaming JSONL shard writing.""" import json from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from streaming.format.json.encodings import is_json_encoded, is_json_encoding -from streaming.format.writer import SplitWriter +from streaming.format.jsonl.encodings import is_jsonl_encoded, is_jsonl_encoding +from streaming.format.writer import DualWriter -__all__ = ['JSONWriter'] +__all__ = ['JSONLWriter'] -class JSONWriter(SplitWriter): - r"""Writes a streaming JSON dataset. +class JSONLWriter(DualWriter): + r"""Writes a streaming JSONL dataset. Args: columns (Dict[str, str]): Sample columns. @@ -47,7 +47,7 @@ class JSONWriter(SplitWriter): file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. """ - format = 'json' + format = 'jsonl' def __init__(self, *, @@ -66,7 +66,7 @@ def __init__(self, size_limit=size_limit, **kwargs) for encoding in columns.values(): - assert is_json_encoding(encoding) + assert is_jsonl_encoding(encoding) self.columns = columns self.newline = newline @@ -83,7 +83,7 @@ def encode_sample(self, sample: Dict[str, Any]) -> bytes: obj = {} for key, encoding in self.columns.items(): value = sample[key] - assert is_json_encoded(encoding, value) + assert is_jsonl_encoded(encoding, value) obj[key] = value text = json.dumps(obj, sort_keys=True) + self.newline return text.encode('utf-8') @@ -98,8 +98,8 @@ def get_config(self) -> Dict[str, Any]: obj.update({'columns': self.columns, 'newline': self.newline}) return obj - def encode_split_shard(self) -> Tuple[bytes, bytes]: - """Encode a split shard out of the cached samples (data, meta files). + def encode_dual_shard(self) -> Tuple[bytes, bytes]: + """Encode a dual shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. diff --git a/streaming/format/mds/__init__.py b/streaming/format/mds/__init__.py index 67a5be56f..5136f7efd 100644 --- a/streaming/format/mds/__init__.py +++ b/streaming/format/mds/__init__.py @@ -1,9 +1,9 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Module to write and read the dataset in MDS format.""" +"""MDS shards.""" -from streaming.format.mds.reader import MDSReader +from streaming.format.mds.shard import MDSShard from streaming.format.mds.writer import MDSWriter -__all__ = ['MDSReader', 'MDSWriter'] +__all__ = ['MDSShard', 'MDSWriter'] diff --git a/streaming/format/mds/reader.py b/streaming/format/mds/shard.py similarity index 95% rename from streaming/format/mds/reader.py rename to streaming/format/mds/shard.py index 7ec93c98a..956bb069b 100644 --- a/streaming/format/mds/reader.py +++ b/streaming/format/mds/shard.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`MDSReader` reads samples in `.mds` files written by :class:`StreamingDatasetWriter`.""" +"""MDS shard reading.""" import os from copy import deepcopy @@ -11,12 +11,12 @@ from typing_extensions import Self from streaming.format.mds.encodings import is_mds_encoding_safe, mds_decode -from streaming.format.reader import FileInfo, JointReader +from streaming.format.shard import FileInfo, MonoShard -__all__ = ['MDSReader'] +__all__ = ['MDSShard'] -class MDSReader(JointReader): +class MDSShard(MonoShard): """Provides random access to the samples of an MDS shard. Args: @@ -66,7 +66,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded MDSReader. + Self: Loaded MDSShard. """ args = deepcopy(obj) if args['version'] != 2: @@ -99,6 +99,29 @@ def validate(self, allow_unsafe_types: bool) -> None: raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' + f'proceed anyway, set ``allow_unsafe_types=True``.') + def get_sample_data(self, idx: int) -> bytes: + """Get the raw sample data at the index. + + Args: + idx (int): Sample index. + + Returns: + bytes: Sample data. + """ + filename = os.path.join(self.dirname, self.split, self.raw_data.basename) + offset = (1 + idx) * 4 + with open(filename, 'rb', 0) as fp: + fp.seek(offset) + pair = fp.read(8) + begin, end = np.frombuffer(pair, np.uint32) + fp.seek(begin) + data = fp.read(end - begin) + if not data: + raise IndexError( + f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.' + ) + return data + def decode_sample(self, data: bytes) -> Dict[str, Any]: """Decode a sample dict from bytes. @@ -123,26 +146,3 @@ def decode_sample(self, data: bytes) -> Dict[str, Any]: sample[key] = mds_decode(encoding, value) idx += size return sample - - def get_sample_data(self, idx: int) -> bytes: - """Get the raw sample data at the index. - - Args: - idx (int): Sample index. - - Returns: - bytes: Sample data. - """ - filename = os.path.join(self.dirname, self.split, self.raw_data.basename) - offset = (1 + idx) * 4 - with open(filename, 'rb', 0) as fp: - fp.seek(offset) - pair = fp.read(8) - begin, end = np.frombuffer(pair, np.uint32) - fp.seek(begin) - data = fp.read(end - begin) - if not data: - raise IndexError( - f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.' - ) - return data diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 950c60f20..e7fc9ef4c 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`.""" +"""MDS shard writing.""" import json from typing import Any, Dict, List, Optional, Tuple, Union @@ -10,12 +10,12 @@ from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) -from streaming.format.writer import JointWriter +from streaming.format.writer import MonoWriter __all__ = ['MDSWriter'] -class MDSWriter(JointWriter): +class MDSWriter(MonoWriter): """Writes a streaming MDS dataset. Args: @@ -127,8 +127,8 @@ def get_config(self) -> Dict[str, Any]: }) return obj - def encode_joint_shard(self) -> bytes: - """Encode a joint shard out of the cached samples (single file). + def encode_mono_shard(self) -> bytes: + """Encode a mono shard out of the cached samples (single file). Returns: bytes: File data. diff --git a/streaming/format/reader.py b/streaming/format/shard.py similarity index 94% rename from streaming/format/reader.py rename to streaming/format/shard.py index e2e5271fc..818fc036f 100644 --- a/streaming/format/reader.py +++ b/streaming/format/shard.py @@ -8,10 +8,12 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing_extensions import Self + from streaming.array import Array from streaming.util.shorthand import normalize_bytes -__all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader'] +__all__ = ['FileInfo', 'Shard', 'MonoShard', 'DualShard'] @dataclass @@ -28,7 +30,7 @@ class FileInfo(object): hashes: Dict[str, str] -class Reader(Array, ABC): +class Shard(Array, ABC): """Provides random access to the samples of a shard. Args: @@ -61,6 +63,21 @@ def __init__( self.file_pairs = [] + @classmethod + @abstractmethod + def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Self: + """Initialize from JSON object. + + Args: + dirname (str): Local directory containing shards. + split (str, optional): Which dataset split to use, if any. + obj (Dict[str, Any]): JSON object to load. + + Returns: + Self: Loaded Shard. + """ + raise NotImplementedError + def validate(self, allow_unsafe_types: bool) -> None: """Check whether this shard is acceptable to be part of some Stream. @@ -276,26 +293,26 @@ def get_persistent_size(self, keep_zip: bool) -> int: return size @abstractmethod - def decode_sample(self, data: bytes) -> Dict[str, Any]: - """Decode a sample dict from bytes. + def get_sample_data(self, idx: int) -> bytes: + """Get the raw sample data at the index. Args: - data (bytes): The sample encoded as bytes. + idx (int): Sample index. Returns: - Dict[str, Any]: Sample dict. + bytes: Sample data. """ raise NotImplementedError @abstractmethod - def get_sample_data(self, idx: int) -> bytes: - """Get the raw sample data at the index. + def decode_sample(self, data: bytes) -> Dict[str, Any]: + """Decode a sample dict from bytes. Args: - idx (int): Sample index. + data (bytes): The sample encoded as bytes. Returns: - bytes: Sample data. + Dict[str, Any]: Sample dict. """ raise NotImplementedError @@ -321,8 +338,8 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: yield self[i] -class JointReader(Reader): - """Provides random access to the samples of a joint shard. +class MonoShard(Shard): + """Provides random access to the samples of a mono shard. Args: dirname (str): Local dataset directory. @@ -353,8 +370,8 @@ def __init__( self.file_pairs.append((raw_data, zip_data)) -class SplitReader(Reader): - """Provides random access to the samples of a split shard. +class DualShard(Shard): + """Provides random access to the samples of a dual shard. Args: dirname (str): Local dataset directory. diff --git a/streaming/format/writer.py b/streaming/format/writer.py index 4b98b93d4..7cc606034 100644 --- a/streaming/format/writer.py +++ b/streaming/format/writer.py @@ -24,7 +24,7 @@ from streaming.storage.upload import CloudUploader from streaming.util.shorthand import normalize_bytes -__all__ = ['JointWriter', 'SplitWriter'] +__all__ = ['MonoWriter', 'DualWriter'] logger = logging.getLogger(__name__) @@ -340,8 +340,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseEx self.finish() -class JointWriter(Writer): - """Writes a streaming dataset with joint shards. +class MonoWriter(Writer): + """Writes a streaming dataset with mono shards. Args: out (str | Tuple[str, str]): Output dataset directory to save shard files. @@ -395,8 +395,8 @@ def __init__(self, **kwargs) @abstractmethod - def encode_joint_shard(self) -> bytes: - """Encode a joint shard out of the cached samples (single file). + def encode_mono_shard(self) -> bytes: + """Encode a mono shard out of the cached samples (single file). Returns: bytes: File data. @@ -411,7 +411,7 @@ def flush_shard(self) -> None: return raw_data_basename, zip_data_basename = self._name_next_shard() - raw_data = self.encode_joint_shard() + raw_data = self.encode_mono_shard() raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename, zip_data_basename) obj = { @@ -428,10 +428,10 @@ def flush_shard(self) -> None: future.add_done_callback(self.exception_callback) -class SplitWriter(Writer): - """Writes a streaming dataset with split shards. +class DualWriter(Writer): + """Writes a streaming dataset with dual shards. - Split shards refer to raw data (csv, json, etc.) paired with an index into it. + Dual shards refer to raw data (csv, json, etc.) paired with an index into it. Args: out (str | Tuple[str, str]): Output dataset directory to save shard files. @@ -482,8 +482,8 @@ def __init__(self, **kwargs) @abstractmethod - def encode_split_shard(self) -> Tuple[bytes, bytes]: - """Encode a split shard out of the cached samples (data, meta files). + def encode_dual_shard(self) -> Tuple[bytes, bytes]: + """Encode a dual shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. @@ -499,7 +499,7 @@ def flush_shard(self) -> None: raw_data_basename, zip_data_basename = self._name_next_shard() raw_meta_basename, zip_meta_basename = self._name_next_shard('meta') - raw_data, raw_meta = self.encode_split_shard() + raw_data, raw_meta = self.encode_dual_shard() raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename, zip_data_basename) raw_meta_info, zip_meta_info = self._process_file(raw_meta, raw_meta_basename, diff --git a/streaming/format/xsv/__init__.py b/streaming/format/xsv/__init__.py index 985010a42..8532c1013 100644 --- a/streaming/format/xsv/__init__.py +++ b/streaming/format/xsv/__init__.py @@ -1,9 +1,9 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Module to write and read the dataset in Tabular format.""" +"""Streaming XSV shards, with specializations for CSV and TSV.""" -from streaming.format.xsv.reader import CSVReader, TSVReader, XSVReader +from streaming.format.xsv.shard import CSVShard, TSVShard, XSVShard from streaming.format.xsv.writer import CSVWriter, TSVWriter, XSVWriter -__all__ = ['CSVReader', 'CSVWriter', 'TSVReader', 'TSVWriter', 'XSVReader', 'XSVWriter'] +__all__ = ['CSVShard', 'CSVWriter', 'TSVShard', 'TSVWriter', 'XSVShard', 'XSVWriter'] diff --git a/streaming/format/xsv/reader.py b/streaming/format/xsv/shard.py similarity index 96% rename from streaming/format/xsv/reader.py rename to streaming/format/xsv/shard.py index f43ee6f5d..426954638 100644 --- a/streaming/format/xsv/reader.py +++ b/streaming/format/xsv/shard.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Reads and decode samples from tabular formatted files such as XSV, CSV, and TSV.""" +"""Streaming XSV shard reading, with specializations for CSV and TSV.""" import os from copy import deepcopy @@ -10,13 +10,13 @@ import numpy as np from typing_extensions import Self -from streaming.format.reader import FileInfo, SplitReader +from streaming.format.shard import DualShard, FileInfo from streaming.format.xsv.encodings import xsv_decode -__all__ = ['XSVReader', 'CSVReader', 'TSVReader'] +__all__ = ['XSVShard', 'CSVShard', 'TSVShard'] -class XSVReader(SplitReader): +class XSVShard(DualShard): """Provides random access to the samples of an XSV shard. Args: @@ -73,7 +73,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded XSVReader. + Self: Loaded XSVShard. """ args = deepcopy(obj) if args['version'] != 2: @@ -91,23 +91,6 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S args[key] = FileInfo(**arg) if arg else None return cls(**args) - def decode_sample(self, data: bytes) -> Dict[str, Any]: - """Decode a sample dict from bytes. - - Args: - data (bytes): The sample encoded as bytes. - - Returns: - Dict[str, Any]: Sample dict. - """ - text = data.decode('utf-8') - text = text[:-len(self.newline)] - parts = text.split(self.separator) - sample = {} - for name, encoding, part in zip(self.column_names, self.column_encodings, parts): - sample[name] = xsv_decode(encoding, part) - return sample - def get_sample_data(self, idx: int) -> bytes: """Get the raw sample data at the index. @@ -129,8 +112,25 @@ def get_sample_data(self, idx: int) -> bytes: data = fp.read(end - begin) return data + def decode_sample(self, data: bytes) -> Dict[str, Any]: + """Decode a sample dict from bytes. + + Args: + data (bytes): The sample encoded as bytes. + + Returns: + Dict[str, Any]: Sample dict. + """ + text = data.decode('utf-8') + text = text[:-len(self.newline)] + parts = text.split(self.separator) + sample = {} + for name, encoding, part in zip(self.column_names, self.column_encodings, parts): + sample[name] = xsv_decode(encoding, part) + return sample + -class CSVReader(XSVReader): +class CSVShard(XSVShard): """Provides random access to the samples of a CSV shard. Args: @@ -182,7 +182,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded CSVReader. + Self: Loaded CSVShard. """ args = deepcopy(obj) if args['version'] != 2: @@ -201,7 +201,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S return cls(**args) -class TSVReader(XSVReader): +class TSVShard(XSVShard): """Provides random access to the samples of an XSV shard. Args: @@ -253,7 +253,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded TSVReader. + Self: Loaded TSVShard. """ args = deepcopy(obj) if args['version'] != 2: diff --git a/streaming/format/xsv/writer.py b/streaming/format/xsv/writer.py index b1ab720d3..519ec881b 100644 --- a/streaming/format/xsv/writer.py +++ b/streaming/format/xsv/writer.py @@ -1,20 +1,20 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`XSVWriter` writes samples to `.xsv` files that can be read by :class:`XSVReader`.""" +"""Streaming XSV shard writing, with specializations for CSV and TSV.""" import json from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from streaming.format.writer import SplitWriter +from streaming.format.writer import DualWriter from streaming.format.xsv.encodings import is_xsv_encoding, xsv_encode __all__ = ['XSVWriter', 'CSVWriter', 'TSVWriter'] -class XSVWriter(SplitWriter): +class XSVWriter(DualWriter): r"""Writes a streaming XSV dataset. Args: @@ -114,8 +114,8 @@ def get_config(self) -> Dict[str, Any]: }) return obj - def encode_split_shard(self) -> Tuple[bytes, bytes]: - """Encode a split shard out of the cached samples (data, meta files). + def encode_dual_shard(self) -> Tuple[bytes, bytes]: + """Encode a dual shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. diff --git a/streaming/local.py b/streaming/local.py index 47dd8134f..ed6e99469 100644 --- a/streaming/local.py +++ b/streaming/local.py @@ -11,7 +11,7 @@ from torch.utils.data import Dataset from streaming.array import Array -from streaming.format import get_index_basename, reader_from_json +from streaming.format import get_index_basename, shard_from_json from streaming.spanner import Spanner __all__ = ['LocalDataset'] @@ -39,7 +39,7 @@ def __init__(self, local: str, split: Optional[str] = None): self.shards = [] for info in obj['shards']: - shard = reader_from_json(local, split, info) + shard = shard_from_json(local, split, info) self.shards.append(shard) self.num_samples = sum([shard.samples for shard in self.shards]) diff --git a/streaming/stream.py b/streaming/stream.py index 974ceaac7..3c3735e9f 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -16,7 +16,7 @@ from streaming.compression import decompress from streaming.constant import TICK from streaming.distributed import barrier, get_local_rank -from streaming.format import FileInfo, Reader, get_index_basename, reader_from_json +from streaming.format import FileInfo, Shard, get_index_basename, shard_from_json from streaming.hashing import get_hash from streaming.storage import download_file, wait_for_file_to_exist from streaming.util import retry @@ -352,9 +352,8 @@ def _prepare_shard_part(self, compression: Optional[str] = None) -> int: """Get shard data given metadata for the raw and compressed versions of it. - MDS format uses joint shards (ie, one file per shard). Other formats supported by streaming - use split shards (ie, shard data lives in two files per shard: the raw data itself and - metadata in a separate file). + Shards are either mono shards (one file per shard, like MDS) or dual shards (a pair of data + and meta files per shard, like the Streaming JSONL and XSV shard formats). Args: raw_info (FileInfo): Raw file info. @@ -407,11 +406,11 @@ def _prepare_shard_part(self, raise ValueError(f'Checksum failure: {raw_filename}') return delta - def prepare_shard(self, shard: Reader) -> int: + def prepare_shard(self, shard: Shard) -> int: """Ensure (download, validate, extract, etc.) that we have the given shard. Args: - shard (Reader): Which shard. + shard (Shard): Which shard. Returns: int: Change in cache usage. @@ -421,7 +420,7 @@ def prepare_shard(self, shard: Reader) -> int: delta += self._prepare_shard_part(raw_info, zip_info, shard.compression) return delta - def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Shard]: """Load this Stream's index, retrieving its shard readers. Args: @@ -431,7 +430,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: error. Returns: - `List[Reader]: Shard readers. + `List[Shard]: Shard readers. """ # Download the index file if it does not exist locally. basename = get_index_basename() @@ -471,17 +470,17 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: # Initialize shard readers according to the loaded info. shards = [] for info in obj['shards']: - shard = reader_from_json(self.local, self.split, info) + shard = shard_from_json(self.local, self.split, info) shard.validate(allow_unsafe_types) shards.append(shard) return shards - def set_up_local(self, shards: List[Reader], cache_usage_per_shard: NDArray[np.int64]) -> None: + def set_up_local(self, shards: List[Shard], cache_usage_per_shard: NDArray[np.int64]) -> None: """Bring a local directory into a consistent state, getting which shards are present. Args: - shards (List[Reader]): List of this stream's shards. + shards (List[Shard]): List of this stream's shards. cache_usage_per_shard (NDArray[np.int64]): Cache usage per shard of this stream. """ # List the cache directory (so that we hit the filesystem once). diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 70d048647..dce91d8cd 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -10,16 +10,16 @@ import pytest from PIL import Image -import streaming.format.json.encodings as jsonEnc -import streaming.format.mds.encodings as mdsEnc -import streaming.format.xsv.encodings as xsvEnc +import streaming.format.jsonl.encodings as jsonl_enc +import streaming.format.mds.encodings as mds_enc +import streaming.format.xsv.encodings as xsv_enc class TestMDSEncodings: @pytest.mark.parametrize('data', [b'5', b'\x00\x00']) def test_byte_encode_decode(self, data: bytes): - byte_enc = mdsEnc.Bytes() + byte_enc = mds_enc.Bytes() assert byte_enc.size is None output = byte_enc.encode(data) assert output == data @@ -29,13 +29,13 @@ def test_byte_encode_decode(self, data: bytes): @pytest.mark.parametrize('data', ['9', 25]) def test_byte_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - byte_enc = mdsEnc.Bytes() + byte_enc = mds_enc.Bytes() _ = byte_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [('99', b'99'), ('streaming dataset', b'streaming dataset')]) def test_str_encode_decode(self, data: str, encode_data: bytes): - str_enc = mdsEnc.Str() + str_enc = mds_enc.Str() assert str_enc.size is None # Test encode @@ -51,13 +51,13 @@ def test_str_encode_decode(self, data: str, encode_data: bytes): @pytest.mark.parametrize('data', [b'9', 25]) def test_str_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - str_enc = mdsEnc.Str() + str_enc = mds_enc.Str() _ = str_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [(99, b'c\x00\x00\x00\x00\x00\x00\x00'), (987654321, b'\xb1h\xde:\x00\x00\x00\x00')]) def test_int_encode_decode(self, data: int, encode_data: bytes): - int_enc = mdsEnc.Int() + int_enc = mds_enc.Int() assert int_enc.size == 8 # Test encode @@ -73,7 +73,7 @@ def test_int_encode_decode(self, data: int, encode_data: bytes): @pytest.mark.parametrize('data', [b'9', 25.9]) def test_int_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - int_enc = mdsEnc.Int() + int_enc = mds_enc.Int() _ = int_enc.encode(data) @pytest.mark.parametrize('dtype_str', [ @@ -103,28 +103,28 @@ def test_ndarray_encode_decode(self, dtype_str: str, shape: Tuple[int]): a = np.random.randint(0, 1000, shape).astype(dtype) encoding = 'ndarray' - assert mdsEnc.is_mds_encoding(encoding) - assert mdsEnc.get_mds_encoded_size(encoding) is None - b = mdsEnc.mds_encode(encoding, a) - c = mdsEnc.mds_decode(encoding, b) + assert mds_enc.is_mds_encoding(encoding) + assert mds_enc.get_mds_encoded_size(encoding) is None + b = mds_enc.mds_encode(encoding, a) + c = mds_enc.mds_decode(encoding, b) assert (a == c).all() b1_len = len(b) encoding = f'ndarray:{dtype.__name__}' - assert mdsEnc.is_mds_encoding(encoding) - assert mdsEnc.get_mds_encoded_size(encoding) is None - b = mdsEnc.mds_encode(encoding, a) - c = mdsEnc.mds_decode(encoding, b) + assert mds_enc.is_mds_encoding(encoding) + assert mds_enc.get_mds_encoded_size(encoding) is None + b = mds_enc.mds_encode(encoding, a) + c = mds_enc.mds_decode(encoding, b) assert (a == c).all() b2_len = len(b) shape_str = ','.join(map(str, shape)) encoding = f'ndarray:{dtype.__name__}:{shape_str}' - assert mdsEnc.is_mds_encoding(encoding) - b_size = mdsEnc.get_mds_encoded_size(encoding) + assert mds_enc.is_mds_encoding(encoding) + b_size = mds_enc.get_mds_encoded_size(encoding) assert b_size is not None - b = mdsEnc.mds_encode(encoding, a) - c = mdsEnc.mds_decode(encoding, b) + b = mds_enc.mds_encode(encoding, a) + c = mds_enc.mds_decode(encoding, b) assert (a == c).all() assert len(b) == b_size b3_len = len(b) @@ -134,7 +134,7 @@ def test_ndarray_encode_decode(self, dtype_str: str, shape: Tuple[int]): @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_pil_encode_decode(self, mode: str): - pil_enc = mdsEnc.PIL() + pil_enc = mds_enc.PIL() assert pil_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -158,12 +158,12 @@ def test_pil_encode_decode(self, mode: str): @pytest.mark.parametrize('data', [b'9', 25.9]) def test_pil_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - pil_enc = mdsEnc.PIL() + pil_enc = mds_enc.PIL() _ = pil_enc.encode(data) @pytest.mark.parametrize('mode', ['L', 'RGB']) def test_jpeg_encode_decode(self, mode: str): - jpeg_enc = mdsEnc.JPEG() + jpeg_enc = mds_enc.JPEG() assert jpeg_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -182,7 +182,7 @@ def test_jpeg_encode_decode(self, mode: str): @pytest.mark.parametrize('mode', ['L', 'RGB']) def test_jpegfile_encode_decode(self, mode: str): - jpeg_enc = mdsEnc.JPEG() + jpeg_enc = mds_enc.JPEG() assert jpeg_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -208,12 +208,12 @@ def test_jpegfile_encode_decode(self, mode: str): @pytest.mark.parametrize('data', [b'99', 12.5]) def test_jpeg_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - jpeg_enc = mdsEnc.JPEG() + jpeg_enc = mds_enc.JPEG() _ = jpeg_enc.encode(data) @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_png_encode_decode(self, mode: str): - png_enc = mdsEnc.PNG() + png_enc = mds_enc.PNG() assert png_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -237,12 +237,12 @@ def test_png_encode_decode(self, mode: str): @pytest.mark.parametrize('data', [b'123', 77.7]) def test_png_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - png_enc = mdsEnc.PNG() + png_enc = mds_enc.PNG() _ = png_enc.encode(data) @pytest.mark.parametrize('data', [25, 'streaming', np.array(7)]) def test_pickle_encode_decode(self, data: Any): - pkl_enc = mdsEnc.Pickle() + pkl_enc = mds_enc.Pickle() assert pkl_enc.size is None # Test encode @@ -258,7 +258,7 @@ def test_pickle_encode_decode(self, data: Any): @pytest.mark.parametrize('data', [25, 'streaming', {'alpha': 1, 'beta': 2}]) def test_json_encode_decode(self, data: Any): - json_enc = mdsEnc.JSON() + json_enc = mds_enc.JSON() assert json_enc.size is None # Test encode @@ -275,12 +275,12 @@ def test_json_encode_decode(self, data: Any): def test_json_invalid_data(self): wrong_json_with_single_quotes = "{'name': 'streaming'}" with pytest.raises(json.JSONDecodeError): - json_enc = mdsEnc.JSON() + json_enc = mds_enc.JSON() json_enc._is_valid(wrong_json_with_single_quotes, wrong_json_with_single_quotes) @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*')]) def test_mds_uint8(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt8() + coder = mds_enc.UInt8() assert coder.size == 1 enc = coder.encode(decoded) @@ -293,7 +293,7 @@ def test_mds_uint8(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0')]) def test_mds_uint16(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt16() + coder = mds_enc.UInt16() assert coder.size == 2 enc = coder.encode(decoded) @@ -306,7 +306,7 @@ def test_mds_uint16(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0')]) def test_mds_uint32(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt32() + coder = mds_enc.UInt32() assert coder.size == 4 enc = coder.encode(decoded) @@ -319,7 +319,7 @@ def test_mds_uint32(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0\0\0\0\0')]) def test_mds_uint64(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt64() + coder = mds_enc.UInt64() assert coder.size == 8 enc = coder.encode(decoded) @@ -332,7 +332,7 @@ def test_mds_uint64(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*')]) def test_mds_int8(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int8() + coder = mds_enc.Int8() assert coder.size == 1 enc = coder.encode(decoded) @@ -345,7 +345,7 @@ def test_mds_int8(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0')]) def test_mds_int16(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int16() + coder = mds_enc.Int16() assert coder.size == 2 enc = coder.encode(decoded) @@ -358,7 +358,7 @@ def test_mds_int16(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0')]) def test_mds_int32(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int32() + coder = mds_enc.Int32() assert coder.size == 4 enc = coder.encode(decoded) @@ -371,7 +371,7 @@ def test_mds_int32(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0\0\0\0\0')]) def test_mds_int64(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int64() + coder = mds_enc.Int64() assert coder.size == 8 enc = coder.encode(decoded) @@ -384,7 +384,7 @@ def test_mds_int64(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'@Q')]) def test_mds_float16(self, decoded: float, encoded: bytes): - coder = mdsEnc.Float16() + coder = mds_enc.Float16() assert coder.size == 2 enc = coder.encode(decoded) @@ -397,7 +397,7 @@ def test_mds_float16(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'\0\0(B')]) def test_mds_float32(self, decoded: float, encoded: bytes): - coder = mdsEnc.Float32() + coder = mds_enc.Float32() assert coder.size == 4 enc = coder.encode(decoded) @@ -410,7 +410,7 @@ def test_mds_float32(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'\0\0\0\0\0\0E@')]) def test_mds_float64(self, decoded: float, encoded: bytes): - coder = mdsEnc.Float64() + coder = mds_enc.Float64() assert coder.size == 8 enc = coder.encode(decoded) @@ -423,7 +423,7 @@ def test_mds_float64(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'42'), (-42, b'-42')]) def test_mds_StrInt(self, decoded: int, encoded: bytes): - coder = mdsEnc.StrInt() + coder = mds_enc.StrInt() enc = coder.encode(decoded) assert isinstance(enc, bytes) assert enc == encoded @@ -434,7 +434,7 @@ def test_mds_StrInt(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'42.0'), (-42.0, b'-42.0')]) def test_mds_StrFloat(self, decoded: float, encoded: bytes): - coder = mdsEnc.StrFloat() + coder = mds_enc.StrFloat() enc = coder.encode(decoded) assert isinstance(enc, bytes) assert enc == encoded @@ -446,7 +446,7 @@ def test_mds_StrFloat(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(Decimal('4E15'), b'4E+15'), (Decimal('-4E15'), b'-4E+15')]) def test_mds_StrDecimal(self, decoded: Decimal, encoded: bytes): - coder = mdsEnc.StrDecimal() + coder = mds_enc.StrDecimal() enc = coder.encode(decoded) assert isinstance(enc, bytes) assert enc == encoded @@ -463,14 +463,14 @@ def test_get_mds_encodings(self): expected_encodings = { 'int', 'bytes', 'json', 'ndarray', 'png', 'jpeg', 'str', 'pil', 'pkl' } | scalars - enc = mdsEnc.get_mds_encodings() + enc = mds_enc.get_mds_encodings() assert len(enc) == len(expected_encodings) assert enc == expected_encodings @pytest.mark.parametrize(('enc_name', 'expected_output'), [('jpeg', True), ('', False), ('pngg', False)]) def test_is_mds_encoding(self, enc_name: str, expected_output: bool): - is_supported = mdsEnc.is_mds_encoding(enc_name) + is_supported = mds_enc.is_mds_encoding(enc_name) assert is_supported is expected_output @pytest.mark.parametrize(('encoding', 'decoded', 'encoded'), @@ -480,35 +480,35 @@ def test_is_mds_encoding(self, enc_name: str, expected_output: bool): ('int64', 42, b'*\0\0\0\0\0\0\0'), ('float16', 42.0, b'@Q'), ('float32', 42.0, b'\0\0(B'), ('float64', 42.0, b'\0\0\0\0\0\0E@')]) def test_mds_scalar(self, encoding: str, decoded: Union[int, float], encoded: bytes): - enc = mdsEnc.mds_encode(encoding, decoded) + enc = mds_enc.mds_encode(encoding, decoded) assert isinstance(enc, bytes) assert enc == encoded - dec = mdsEnc.mds_decode(encoding, enc) + dec = mds_enc.mds_decode(encoding, enc) assert dec == decoded - dec = mdsEnc.mds_decode(encoding, encoded) + dec = mds_enc.mds_decode(encoding, encoded) assert dec == decoded @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27), ('str', 'mosaicml')]) def test_mds_encode(self, enc_name: str, data: Any): - output = mdsEnc.mds_encode(enc_name, data) + output = mds_enc.mds_encode(enc_name, data) assert isinstance(output, bytes) @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', 9), ('int', '27'), ('str', 12.5)]) def test_mds_encode_invalid_data(self, enc_name: str, data: Any): with pytest.raises(AttributeError): - _ = mdsEnc.mds_encode(enc_name, data) + _ = mds_enc.mds_encode(enc_name, data) @pytest.mark.parametrize(('enc_name', 'data', 'expected_data_type'), [('bytes', b'c\x00\x00\x00\x00\x00\x00\x00', bytes), ('str', b'mosaicml', str)]) def test_mds_decode(self, enc_name: str, data: Any, expected_data_type: Any): - output = mdsEnc.mds_decode(enc_name, data) + output = mds_enc.mds_decode(enc_name, data) assert isinstance(output, expected_data_type) @pytest.mark.parametrize(('enc_name', 'expected_size'), [('bytes', None), ('int', 8)]) def test_get_mds_encoded_size(self, enc_name: str, expected_size: Any): - output = mdsEnc.get_mds_encoded_size(enc_name) + output = mds_enc.get_mds_encoded_size(enc_name) assert output is expected_size @@ -517,7 +517,7 @@ class TestXSVEncodings: @pytest.mark.parametrize(('data', 'encode_data'), [('99', '99'), ('streaming dataset', 'streaming dataset')]) def test_str_encode_decode(self, data: str, encode_data: str): - str_enc = xsvEnc.Str() + str_enc = xsv_enc.Str() # Test encode enc_data = str_enc.encode(data) @@ -532,12 +532,12 @@ def test_str_encode_decode(self, data: str, encode_data: str): @pytest.mark.parametrize('data', [99, b'streaming dataset', 123.45]) def test_str_encode_invalid_data(self, data: Any): with pytest.raises(Exception): - str_enc = xsvEnc.Str() + str_enc = xsv_enc.Str() _ = str_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [(99, '99'), (987675432, '987675432')]) def test_int_encode_decode(self, data: int, encode_data: str): - int_enc = xsvEnc.Int() + int_enc = xsv_enc.Int() # Test encode enc_data = int_enc.encode(data) @@ -552,12 +552,12 @@ def test_int_encode_decode(self, data: int, encode_data: str): @pytest.mark.parametrize('data', ['99', b'streaming dataset', 123.45]) def test_int_encode_invalid_data(self, data: Any): with pytest.raises(Exception): - int_enc = xsvEnc.Int() + int_enc = xsv_enc.Int() _ = int_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [(1.24, '1.24'), (9.0, '9.0')]) def test_float_encode_decode(self, data: int, encode_data: str): - float_enc = xsvEnc.Float() + float_enc = xsv_enc.Float() # Test encode enc_data = float_enc.encode(data) @@ -572,7 +572,7 @@ def test_float_encode_decode(self, data: int, encode_data: str): @pytest.mark.parametrize('data', ['99', b'streaming dataset', 12]) def test_float_encode_invalid_data(self, data: Any): with pytest.raises(Exception): - float_enc = xsvEnc.Float() + float_enc = xsv_enc.Float() _ = float_enc.encode(data) @pytest.mark.parametrize(('enc_name', 'expected_output'), [ @@ -582,14 +582,14 @@ def test_float_encode_invalid_data(self, data: Any): ('', False), ]) def test_is_xsv_encoding(self, enc_name: str, expected_output: bool): - is_supported = xsvEnc.is_xsv_encoding(enc_name) + is_supported = xsv_enc.is_xsv_encoding(enc_name) assert is_supported is expected_output @pytest.mark.parametrize(('enc_name', 'data', 'expected_data'), [('str', 'mosaicml', 'mosaicml'), ('int', 27, '27'), ('float', 1.25, '1.25')]) def test_xsv_encode(self, enc_name: str, data: Any, expected_data: str): - output = xsvEnc.xsv_encode(enc_name, data) + output = xsv_enc.xsv_encode(enc_name, data) assert isinstance(output, str) assert output == expected_data @@ -597,7 +597,7 @@ def test_xsv_encode(self, enc_name: str, data: Any, expected_data: str): [('str', 'mosaicml', 'mosaicml'), ('int', '27', 27), ('float', '1.25', 1.25)]) def test_xsv_decode(self, enc_name: str, data: str, expected_data: Any): - output = xsvEnc.xsv_decode(enc_name, data) + output = xsv_enc.xsv_decode(enc_name, data) assert isinstance(output, type(expected_data)) assert output == expected_data @@ -606,7 +606,7 @@ class TestJSONEncodings: @pytest.mark.parametrize('data', ['99', 'mosaicml']) def test_str_is_encoded(self, data: str): - json_enc = jsonEnc.Str() + json_enc = jsonl_enc.Str() # Test encode enc_data = json_enc.is_encoded(data) @@ -615,12 +615,12 @@ def test_str_is_encoded(self, data: str): @pytest.mark.parametrize('data', [99, b'mosaicml']) def test_str_is_encoded_invalid_data(self, data: Any): with pytest.raises(AttributeError): - json_enc = jsonEnc.Str() + json_enc = jsonl_enc.Str() _ = json_enc.is_encoded(data) @pytest.mark.parametrize('data', [99, 987675432]) def test_int_is_encoded(self, data: int): - int_enc = jsonEnc.Int() + int_enc = jsonl_enc.Int() # Test encode enc_data = int_enc.is_encoded(data) @@ -629,12 +629,12 @@ def test_int_is_encoded(self, data: int): @pytest.mark.parametrize('data', ['99', b'mosaicml', 1.25]) def test_int_is_encoded_invalid_data(self, data: Any): with pytest.raises(AttributeError): - int_enc = jsonEnc.Int() + int_enc = jsonl_enc.Int() _ = int_enc.is_encoded(data) @pytest.mark.parametrize('data', [1.25]) def test_float_is_encoded(self, data: int): - float_enc = jsonEnc.Float() + float_enc = jsonl_enc.Float() # Test encode enc_data = float_enc.is_encoded(data) @@ -643,7 +643,7 @@ def test_float_is_encoded(self, data: int): @pytest.mark.parametrize('data', ['99', b'mosaicml', 25]) def test_float_is_encoded_invalid_data(self, data: Any): with pytest.raises(AttributeError): - float_enc = jsonEnc.Float() + float_enc = jsonl_enc.Float() _ = float_enc.is_encoded(data) @pytest.mark.parametrize(('enc_name', 'expected_output'), [ @@ -652,13 +652,13 @@ def test_float_is_encoded_invalid_data(self, data: Any): ('float', True), ('', False), ]) - def test_is_json_encoding(self, enc_name: str, expected_output: bool): - is_supported = jsonEnc.is_json_encoding(enc_name) + def test_is_jsonl_encoding(self, enc_name: str, expected_output: bool): + is_supported = jsonl_enc.is_jsonl_encoding(enc_name) assert is_supported is expected_output @pytest.mark.parametrize(('enc_name', 'data', 'expected_output'), [('str', 'hello', True), ('int', 10, True), ('float', 9.9, True)]) - def test_is_json_encoded(self, enc_name: str, data: Any, expected_output: bool): - is_supported = jsonEnc.is_json_encoded(enc_name, data) + def test_is_jsonl_encoded(self, enc_name: str, data: Any, expected_output: bool): + is_supported = jsonl_enc.is_jsonl_encoded(enc_name, data) assert is_supported is expected_output diff --git a/tests/test_writer.py b/tests/test_writer.py index 188a6b40b..d2aa691a1 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -9,7 +9,7 @@ import numpy as np import pytest -from streaming import CSVWriter, JSONWriter, MDSWriter, StreamingDataset, TSVWriter, XSVWriter +from streaming import CSVWriter, JSONLWriter, MDSWriter, StreamingDataset, TSVWriter, XSVWriter from tests.common.datasets import NumberAndSayDataset, SequenceDataset from tests.common.utils import get_config_in_bytes @@ -122,7 +122,7 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s assert before == after -class TestJSONWriter: +class TestJSONLWriter: @pytest.mark.parametrize('num_samples', [100]) @pytest.mark.parametrize('size_limit', [32]) @@ -133,18 +133,18 @@ def test_config(self, local_remote_dir: Tuple[str, str], num_samples: int, columns = dict(zip(dataset.column_names, dataset.column_encodings)) expected_config = { 'version': 2, - 'format': 'json', + 'format': 'jsonl', 'compression': None, 'hashes': [], 'size_limit': size_limit, 'columns': columns, 'newline': '\n' } - writer = JSONWriter(out=local, - columns=columns, - compression=None, - hashes=None, - size_limit=size_limit) + writer = JSONLWriter(out=local, + columns=columns, + compression=None, + hashes=None, + size_limit=size_limit) assert writer.get_config() == expected_config @pytest.mark.parametrize('num_samples', [50000]) @@ -158,11 +158,11 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s local, _ = local_remote_dir dataset = NumberAndSayDataset(num_samples, seed=seed) columns = dict(zip(dataset.column_names, dataset.column_encodings)) - with JSONWriter(out=local, - columns=columns, - compression=compression, - hashes=hashes, - size_limit=size_limit) as out: + with JSONLWriter(out=local, + columns=columns, + compression=compression, + hashes=hashes, + size_limit=size_limit) as out: for sample in dataset: out.write(sample)