diff --git a/.mypy.ini b/.mypy.ini index fd2f013a1..4bb08e36c 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -16,6 +16,15 @@ disallow_incomplete_defs = True disallow_untyped_calls = True disallow_untyped_decorators = True +[mypy-tiled._tests.adapters.*] +ignore_errors = False +ignore_missing_imports = False +check_untyped_defs = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +disallow_untyped_calls = True +disallow_untyped_decorators = False + [mypy-tiled._tests.test_protocols] ignore_errors = False ignore_missing_imports = False diff --git a/CHANGELOG.md b/CHANGELOG.md index a4e761482..f686e4688 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Write the date in place of the "Unreleased" in the case a new version is release ### Added - Add method to `TableAdapter` which accepts a Python dictionary. +- Added an `Arrow` adapter which supports reading/writing arrow tables via `RecordBatchFileReader`/`RecordBatchFileWriter`. ### Changed - Make `tiled.client` accept a Python dictionary when fed to `write_dataframe()`. @@ -71,7 +72,6 @@ Write the date in place of the "Unreleased" in the case a new version is release ## v0.1.0b1 (2024-05-25) ### Added - - Support for `FullText` search on SQLite-backed catalogs ### Fixed diff --git a/tiled/_tests/adapters/__init__.py b/tiled/_tests/adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tiled/_tests/adapters/test_arrow.py b/tiled/_tests/adapters/test_arrow.py new file mode 100644 index 000000000..2f10b1dfb --- /dev/null +++ b/tiled/_tests/adapters/test_arrow.py @@ -0,0 +1,69 @@ +import tempfile + +import pyarrow as pa +import pytest + +from tiled.adapters.arrow import ArrowAdapter +from tiled.structures.table import TableStructure + +names = ["f0", "f1", "f2"] +data0 = [ + pa.array([1, 2, 3, 4, 5]), + pa.array(["foo0", "bar0", "baz0", None, "goo0"]), + pa.array([True, None, False, True, None]), +] +data1 = [ + pa.array([6, 7, 8, 9, 10, 11, 12]), + pa.array(["foo1", "bar1", None, "baz1", "biz", None, "goo"]), + pa.array([None, True, True, False, False, None, True]), +] +data2 = [pa.array([13, 14]), pa.array(["foo2", "baz2"]), pa.array([False, None])] + +batch0 = pa.record_batch(data0, names=names) +batch1 = pa.record_batch(data1, names=names) +batch2 = pa.record_batch(data2, names=names) +data_uri = "file://localhost/" + tempfile.gettempdir() + + +@pytest.fixture +def adapter() -> ArrowAdapter: + table = pa.Table.from_arrays(data0, names) + structure = TableStructure.from_arrow_table(table, npartitions=3) + assets = ArrowAdapter.init_storage(data_uri, structure=structure) + return ArrowAdapter([asset.data_uri for asset in assets], structure=structure) + + +def test_attributes(adapter: ArrowAdapter) -> None: + assert adapter.structure().columns == names + assert adapter.structure().npartitions == 3 + + +def test_write_read(adapter: ArrowAdapter) -> None: + # test writing to a partition and reading it + adapter.write_partition(batch0, 0) + assert pa.Table.from_arrays(data0, names) == pa.Table.from_pandas( + adapter.read_partition(0) + ) + + adapter.write_partition([batch0, batch1], 1) + assert pa.Table.from_batches([batch0, batch1]) == pa.Table.from_pandas( + adapter.read_partition(1) + ) + + adapter.write_partition([batch0, batch1, batch2], 2) + assert pa.Table.from_batches([batch0, batch1, batch2]) == pa.Table.from_pandas( + adapter.read_partition(2) + ) + + # test write to all partitions and read all + adapter.write_partition([batch0, batch1, batch2], 0) + adapter.write_partition([batch2, batch0, batch1], 1) + adapter.write_partition([batch1, batch2, batch0], 2) + + assert pa.Table.from_pandas(adapter.read()) == pa.Table.from_batches( + [batch0, batch1, batch2, batch2, batch0, batch1, batch1, batch2, batch0] + ) + + # test adapter.write() raises NotImplementedError when there are more than 1 partitions + with pytest.raises(NotImplementedError): + adapter.write(batch0) diff --git a/tiled/_tests/test_writing.py b/tiled/_tests/test_writing.py index c207afb2b..571aaad59 100644 --- a/tiled/_tests/test_writing.py +++ b/tiled/_tests/test_writing.py @@ -30,7 +30,7 @@ from ..structures.data_source import DataSource from ..structures.sparse import COOStructure from ..structures.table import TableStructure -from ..utils import patch_mimetypes +from ..utils import APACHE_ARROW_FILE_MIME_TYPE, patch_mimetypes from ..validation_registration import ValidationRegistry from .utils import fail_with_status_code @@ -537,7 +537,7 @@ def test_write_with_specified_mimetype(tree): df = pandas.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) structure = TableStructure.from_pandas(df) - for mimetype in [PARQUET_MIMETYPE, "text/csv"]: + for mimetype in [PARQUET_MIMETYPE, "text/csv", APACHE_ARROW_FILE_MIME_TYPE]: x = client.new( "table", [ diff --git a/tiled/adapters/arrow.py b/tiled/adapters/arrow.py new file mode 100644 index 000000000..84e5afffb --- /dev/null +++ b/tiled/adapters/arrow.py @@ -0,0 +1,334 @@ +from pathlib import Path +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union + +import pandas +import pyarrow +import pyarrow.feather as feather +import pyarrow.fs + +from ..structures.core import Spec, StructureFamily +from ..structures.data_source import Asset, DataSource, Management +from ..structures.table import TableStructure +from ..utils import ensure_uri, path_from_uri +from .array import ArrayAdapter +from .protocols import AccessPolicy +from .type_alliases import JSON + + +class ArrowAdapter: + """ArrowAdapter Class""" + + structure_family = StructureFamily.table + + def __init__( + self, + data_uris: List[str], + structure: Optional[TableStructure] = None, + metadata: Optional[JSON] = None, + specs: Optional[List[Spec]] = None, + access_policy: Optional[AccessPolicy] = None, + ) -> None: + """ + + Parameters + ---------- + data_uris : list of uris where data sits. + structure : + metadata : + specs : + access_policy : + """ + # TODO Store data_uris instead and generalize to non-file schemes. + self._partition_paths = [path_from_uri(uri) for uri in data_uris] + self._metadata = metadata or {} + if structure is None: + table = feather.read_table(self._partition_paths) + structure = TableStructure.from_arrow_table(table) + self._structure = structure + self.specs = list(specs or []) + self.access_policy = access_policy + + def metadata(self) -> JSON: + """ + + Returns + ------- + + """ + return self._metadata + + @classmethod + def init_storage(cls, data_uri: str, structure: TableStructure) -> List[Asset]: + """ + Class to initialize the list of assets for given uri. + Parameters + ---------- + data_uri : + structure : + + Returns + ------- + The list of assets. + """ + directory = path_from_uri(data_uri) + directory.mkdir(parents=True, exist_ok=True) + assets = [ + Asset( + data_uri=f"{data_uri}/partition-{i}.arrow", + is_directory=False, + parameter="data_uris", + num=i, + ) + for i in range(structure.npartitions) + ] + return assets + + def structure(self) -> TableStructure: + """ + + Returns + ------- + + """ + return self._structure + + def get(self, key: str) -> Union[ArrayAdapter, None]: + """ + + Parameters + ---------- + key : + + Returns + ------- + + """ + if key not in self.structure().columns: + return None + return ArrayAdapter.from_array(self.read([key])[key].values) + + def generate_data_sources( + self, + mimetype: str, + dict_or_none: Callable[[TableStructure], Dict[str, str]], + item: Union[str, Path], + is_directory: bool, + ) -> List[DataSource]: + """ + + Parameters + ---------- + mimetype : + dict_or_none : + item : + is_directory : + + Returns + ------- + + """ + return [ + DataSource( + structure_family=self.structure_family, + mimetype=mimetype, + structure=dict_or_none(self.structure()), + parameters={}, + management=Management.external, + assets=[ + Asset( + data_uri=ensure_uri(item), + is_directory=is_directory, + parameter="data_uris", # <-- PLURAL! + num=0, # <-- denoting that the Adapter expects a list, and this is the first element + ) + ], + ) + ] + + # + @classmethod + def from_single_file( + cls, + data_uri: str, + structure: Optional[TableStructure] = None, + metadata: Optional[JSON] = None, + specs: Optional[List[Spec]] = None, + access_policy: Optional[AccessPolicy] = None, + ) -> "ArrowAdapter": + """ + + Parameters + ---------- + data_uri : + structure : + metadata : + specs : + access_policy : + + Returns + ------- + + """ + return cls( + [data_uri], + structure=structure, + metadata=metadata, + specs=specs, + access_policy=access_policy, + ) + + def __getitem__(self, key: str) -> ArrayAdapter: + """ + + Parameters + ---------- + key : + + Returns + ------- + + """ + # Must compute to determine shape. + return ArrayAdapter.from_array(self.read([key])[key].values) + + def items(self) -> Iterator[Tuple[str, ArrayAdapter]]: + """ + + Returns + ------- + + """ + yield from ( + (key, ArrayAdapter.from_array(self.read([key])[key].values)) + for key in self._structure.columns + ) + + def reader_handle_partiton(self, partition: int) -> pyarrow.RecordBatchFileReader: + """ + Function to initialize and return the reader handle. + Parameters + ---------- + partition : the integer number corresponding to a specific partition. + Returns + ------- + The reader handle for specific partition. + """ + if not Path(self._partition_paths[partition]).exists(): + raise ValueError(f"partition {partition} has not been stored yet") + else: + return pyarrow.ipc.open_file(self._partition_paths[partition]) + + def reader_handle_all(self) -> Iterator[pyarrow.RecordBatchFileReader]: + """ + Function to initialize and return the reader handle. + Returns + ------- + The reader handle. + """ + for path in self._partition_paths: + if not Path(path).exists(): + raise ValueError(f"path {path} has not been stored yet") + else: + with pyarrow.ipc.open_file(path) as reader: + yield reader + + def write_partition( + self, + data: Union[List[pyarrow.record_batch], pyarrow.record_batch, pandas.DataFrame], + partition: int, + ) -> None: + """ + "Function to write the data into specific partition as arrow format." + Parameters + ---------- + data : data to write into arrow file. Can be a list of record batch, or pandas dataframe. + partition: integer index of partition to be read. + Returns + ------- + """ + if isinstance(data, pandas.DataFrame): + table = pyarrow.Table.from_pandas(data) + batches = table.to_batches() + else: + if not isinstance(data, list): + batches = [data] + else: + batches = data + + schema = batches[0].schema + + uri = self._partition_paths[partition] + + with pyarrow.ipc.new_file(uri, schema) as file_writer: + for batch in batches: + file_writer.write_batch(batch) + + def write( + self, + data: Union[List[pyarrow.record_batch], pyarrow.record_batch, pandas.DataFrame], + ) -> None: + """ + "Function to write the data as arrow format." + Parameters + ---------- + data : data to write into arrow file. Can be a list of record batch, or pandas dataframe. + Returns + ------- + """ + if isinstance(data, pandas.DataFrame): + table = pyarrow.Table.from_pandas(data) + batches = table.to_batches() + else: + if not isinstance(data, list): + batches = [data] + else: + batches = data + + schema = batches[0].schema + + if self.structure().npartitions != 1: + raise NotImplementedError + uri = self._partition_paths[0] + + with pyarrow.ipc.new_file(uri, schema) as file_writer: + for batch in data: + file_writer.write_batch(batch) + + def read(self, fields: Optional[Union[str, List[str]]] = None) -> pandas.DataFrame: + """ + The concatenated data from given set of partitions as pyarrow table. + Parameters + ---------- + Returns + ------- + Returns the concatenated pyarrow table as pandas dataframe. + """ + data = pyarrow.concat_tables( + [partition.read_all() for partition in self.reader_handle_all()] + ) + table = data.to_pandas() + if fields is not None: + return table[fields] + return table + + def read_partition( + self, + partition: int, + fields: Optional[Union[str, List[str]]] = None, + ) -> pandas.DataFrame: + """ + Function to read a batch of data from a given partition. + Parameters + ---------- + partition : the index of the partition to read. + fields : optional fields parameter. + + Returns + ------- + The pyarrow table corresponding to a given partition and batch as pandas dataframe. + """ + reader = self.reader_handle_partiton(partition) + table = reader.read_all().to_pandas() + if fields is not None: + return table[fields] + return table diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index 2c4f51881..b3b407955 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -53,6 +53,7 @@ ) from ..mimetypes import ( + APACHE_ARROW_FILE_MIME_TYPE, AWKWARD_BUFFERS_MIMETYPE, DEFAULT_ADAPTERS_BY_MIMETYPE, PARQUET_MIMETYPE, @@ -109,6 +110,9 @@ SPARSE_BLOCKS_PARQUET_MIMETYPE: lambda: importlib.import_module( "...adapters.sparse_blocks_parquet", __name__ ).SparseBlocksParquetAdapter.init_storage, + APACHE_ARROW_FILE_MIME_TYPE: lambda: importlib.import_module( + "...adapters.arrow", __name__ + ).ArrowAdapter.init_storage, } ) diff --git a/tiled/mimetypes.py b/tiled/mimetypes.py index 0ce87f5a0..adc86c203 100644 --- a/tiled/mimetypes.py +++ b/tiled/mimetypes.py @@ -2,7 +2,7 @@ import importlib from .serialization.table import XLSX_MIME_TYPE -from .utils import OneShotCachedMap +from .utils import APACHE_ARROW_FILE_MIME_TYPE, OneShotCachedMap # This maps MIME types (i.e. file formats) for appropriate Readers. # OneShotCachedMap is used to defer imports. We don't want to pay up front @@ -46,6 +46,9 @@ AWKWARD_BUFFERS_MIMETYPE: lambda: importlib.import_module( "..adapters.awkward_buffers", __name__ ).AwkwardBuffersAdapter.from_directory, + APACHE_ARROW_FILE_MIME_TYPE: lambda: importlib.import_module( + "..adapters.arrow", __name__ + ).ArrowAdapter, } ) diff --git a/tiled/structures/table.py b/tiled/structures/table.py index 8cf6de0f1..26c3f4496 100644 --- a/tiled/structures/table.py +++ b/tiled/structures/table.py @@ -56,6 +56,25 @@ def from_dict(cls, d): data_uri = B64_ENCODED_PREFIX + schema_b64 return cls(arrow_schema=data_uri, npartitions=1, columns=list(d.keys())) + def from_arrays(cls, arr, names): + import pyarrow + + schema_bytes = pyarrow.Table.from_arrays(arr, names).schema.serialize() + schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") + data_uri = B64_ENCODED_PREFIX + schema_b64 + return cls(arrow_schema=data_uri, npartitions=1, columns=list(names)) + + @classmethod + def from_arrow_table(cls, table, npartitions=1) -> "TableStructure": + schema_bytes = table.schema.serialize() + schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") + data_uri = B64_ENCODED_PREFIX + schema_b64 + return cls( + arrow_schema=data_uri, + npartitions=npartitions, + columns=list(table.column_names), + ) + @property def arrow_schema_decoded(self): import pyarrow