diff --git a/doc/changes/DM-45431.feature.md b/doc/changes/DM-45431.feature.md new file mode 100644 index 0000000000..495ca87ef2 --- /dev/null +++ b/doc/changes/DM-45431.feature.md @@ -0,0 +1,3 @@ +The ParquetFormatter now declares it can_accept Arrow tables, Astropy tables, Numpy tables, and pandas DataFraemes. +This means that we have complete lossless storage of any parquet-compatible type into a datastore that has declared a different type; e.g. an astropy table with units can be persisted into a DataFrame storage class without those units being stripped. +This ticket also adds can_accept to the InMemoryDatastore delegates, and now one ArrowTableDelegate handles all the parquet-compatible datasets. diff --git a/python/lsst/daf/butler/_storage_class_delegate.py b/python/lsst/daf/butler/_storage_class_delegate.py index 344c299f2e..1392ba583c 100644 --- a/python/lsst/daf/butler/_storage_class_delegate.py +++ b/python/lsst/daf/butler/_storage_class_delegate.py @@ -86,6 +86,31 @@ def __init__(self, storageClass: StorageClass): assert storageClass is not None self.storageClass = storageClass + def can_accept(self, inMemoryDataset: Any) -> bool: + """Indicate whether this delegate can accept the specified + storage class directly. + + Parameters + ---------- + inMemoryDataset : `object` + The dataset that is to be stored. + + Returns + ------- + accepts : `bool` + If `True` the delegate can handle data of this type without + requiring datastore to convert it. If `False` the datastore + will attempt to convert before storage. + + Notes + ----- + The base class always returns `False` even if the given type is an + instance of the delegate type. This will result in a storage class + conversion no-op but also allows mocks with mocked storage classes + to work properly. + """ + return False + @staticmethod def _attrNames(componentName: str, getter: bool = True) -> tuple[str, ...]: """Return list of suitable attribute names to attempt to use. diff --git a/python/lsst/daf/butler/configs/storageClasses.yaml b/python/lsst/daf/butler/configs/storageClasses.yaml index 08a1287c50..e56a2e111f 100644 --- a/python/lsst/daf/butler/configs/storageClasses.yaml +++ b/python/lsst/daf/butler/configs/storageClasses.yaml @@ -127,7 +127,7 @@ storageClasses: astropy.table.Table: lsst.daf.butler.formatters.parquet.astropy_to_pandas numpy.ndarray: pandas.DataFrame.from_records dict: pandas.DataFrame.from_records - delegate: lsst.daf.butler.delegates.dataframe.DataFrameDelegate + delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate derivedComponents: columns: DataFrameIndex rowcount: int @@ -179,7 +179,7 @@ storageClasses: pandas.core.frame.DataFrame: lsst.daf.butler.formatters.parquet.pandas_to_astropy numpy.ndarray: lsst.daf.butler.formatters.parquet.numpy_to_astropy dict: astropy.table.Table - delegate: lsst.daf.butler.delegates.arrowastropy.ArrowAstropyDelegate + delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate derivedComponents: columns: ArrowColumnList rowcount: int @@ -200,7 +200,7 @@ storageClasses: pandas.core.frame.DataFrame: pandas.DataFrame.to_records astropy.table.Table: astropy.table.Table.as_array dict: lsst.daf.butler.formatters.parquet._numpy_dict_to_numpy - delegate: lsst.daf.butler.delegates.arrownumpy.ArrowNumpyDelegate + delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate derivedComponents: columns: ArrowColumnList rowcount: int @@ -221,7 +221,7 @@ storageClasses: pandas.core.frame.DataFrame: lsst.daf.butler.formatters.parquet._pandas_to_numpy_dict astropy.table.Table: lsst.daf.butler.formatters.parquet._astropy_to_numpy_dict numpy.ndarray: lsst.daf.butler.formatters.parquet._numpy_to_numpy_dict - delegate: lsst.daf.butler.delegates.arrownumpydict.ArrowNumpyDictDelegate + delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate derivedComponents: columns: ArrowColumnList rowcount: int diff --git a/python/lsst/daf/butler/datastore/generic_base.py b/python/lsst/daf/butler/datastore/generic_base.py index 2921edd987..75415fea29 100644 --- a/python/lsst/daf/butler/datastore/generic_base.py +++ b/python/lsst/daf/butler/datastore/generic_base.py @@ -35,7 +35,6 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Generic, TypeVar -from .._exceptions import DatasetTypeNotSupportedError from ..datastore._datastore import Datastore from .stored_file_info import StoredDatastoreItemInfo @@ -54,34 +53,6 @@ class GenericBaseDatastore(Datastore, Generic[_InfoType]): Should always be sub-classed since key abstract methods are missing. """ - def _validate_put_parameters(self, inMemoryDataset: object, ref: DatasetRef) -> None: - """Validate the supplied arguments for put. - - Parameters - ---------- - inMemoryDataset : `object` - The dataset to store. - ref : `DatasetRef` - Reference to the associated Dataset. - """ - storageClass = ref.datasetType.storageClass - - # Sanity check - if not isinstance(inMemoryDataset, storageClass.pytype): - raise TypeError( - f"Inconsistency between supplied object ({type(inMemoryDataset)}) " - f"and storage class type ({storageClass.pytype})" - ) - - # Confirm that we can accept this dataset - if not self.constraints.isAcceptable(ref): - # Raise rather than use boolean return value. - raise DatasetTypeNotSupportedError( - f"Dataset {ref} has been rejected by this datastore via configuration." - ) - - return - def remove(self, ref: DatasetRef) -> None: """Indicate to the Datastore that a dataset can be removed. diff --git a/python/lsst/daf/butler/datastores/inMemoryDatastore.py b/python/lsst/daf/butler/datastores/inMemoryDatastore.py index 25b2b1099f..d40659b0b8 100644 --- a/python/lsst/daf/butler/datastores/inMemoryDatastore.py +++ b/python/lsst/daf/butler/datastores/inMemoryDatastore.py @@ -39,6 +39,7 @@ from urllib.parse import urlencode from lsst.daf.butler import DatasetId, DatasetRef, StorageClass +from lsst.daf.butler._exceptions import DatasetTypeNotSupportedError from lsst.daf.butler.datastore import DatasetRefURIs, DatastoreConfig from lsst.daf.butler.datastore.generic_base import GenericBaseDatastore, post_process_get from lsst.daf.butler.datastore.record_data import DatastoreRecordData @@ -397,11 +398,21 @@ def put(self, inMemoryDataset: Any, ref: DatasetRef) -> None: allow `ChainedDatastore` to put to multiple datastores without requiring that every datastore accepts the dataset. """ + if not self.constraints.isAcceptable(ref): + # Raise rather than use boolean return value. + raise DatasetTypeNotSupportedError( + f"Dataset {ref} has been rejected by this datastore via configuration." + ) + # May need to coerce the in memory dataset to the correct # python type, otherwise parameters may not work. - inMemoryDataset = ref.datasetType.storageClass.coerce_type(inMemoryDataset) - - self._validate_put_parameters(inMemoryDataset, ref) + try: + delegate = ref.datasetType.storageClass.delegate() + except TypeError: + # TypeError is raised when a storage class doesn't have a delegate. + delegate = None + if not delegate or not delegate.can_accept(inMemoryDataset): + inMemoryDataset = ref.datasetType.storageClass.coerce_type(inMemoryDataset) self.datasets[ref.id] = inMemoryDataset log.debug("Store %s in %s", ref, self.name) diff --git a/python/lsst/daf/butler/delegates/arrowastropy.py b/python/lsst/daf/butler/delegates/arrowastropy.py deleted file mode 100644 index de9de17e93..0000000000 --- a/python/lsst/daf/butler/delegates/arrowastropy.py +++ /dev/null @@ -1,83 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -"""Support for reading Astropy tables with the Arrow formatter.""" -from __future__ import annotations - -from typing import Any - -import astropy.table as atable -from lsst.daf.butler.formatters.parquet import ArrowAstropySchema -from lsst.utils.introspection import get_full_type_name - -from .arrowtable import ArrowTableDelegate - -__all__ = ["ArrowAstropyDelegate"] - - -class ArrowAstropyDelegate(ArrowTableDelegate): - """Delegate that understands the ``ArrowAstropy`` storage class.""" - - _datasetType = atable.Table - - def getComponent(self, composite: atable.Table, componentName: str) -> Any: - """Get a component from an astropy table stored via ArrowAstropy. - - Parameters - ---------- - composite : `~astropy.table.Table` - Astropy table to access component. - componentName : `str` - Name of component to retrieve. - - Returns - ------- - component : `object` - The component. - - Raises - ------ - AttributeError - The component can not be found. - """ - match componentName: - case "columns": - return list(composite.columns.keys()) - case "schema": - return ArrowAstropySchema(composite) - case "rowcount": - return len(composite) - - raise AttributeError( - f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}" - ) - - def _getColumns(self, inMemoryDataset: atable.Table) -> list[str]: - return inMemoryDataset.columns.keys() - - def _selectColumns(self, inMemoryDataset: atable.Table, columns: list[str]) -> atable.Table: - return inMemoryDataset[columns] diff --git a/python/lsst/daf/butler/delegates/arrownumpy.py b/python/lsst/daf/butler/delegates/arrownumpy.py deleted file mode 100644 index 4fcd34b604..0000000000 --- a/python/lsst/daf/butler/delegates/arrownumpy.py +++ /dev/null @@ -1,85 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -"""Support for reading numpy tables (structured arrays) with the Arrow -formatter. -""" -from __future__ import annotations - -from typing import Any - -import numpy as np -from lsst.daf.butler.formatters.parquet import ArrowNumpySchema -from lsst.utils.introspection import get_full_type_name - -from .arrowtable import ArrowTableDelegate - -__all__ = ["ArrowNumpyDelegate"] - - -class ArrowNumpyDelegate(ArrowTableDelegate): - """Delegate that understands the ``ArrowNumpy`` storage class.""" - - _datasetType = np.ndarray - - def getComponent(self, composite: np.ndarray, componentName: str) -> Any: - """Get a component from a numpy table stored via ArrowNumpy. - - Parameters - ---------- - composite : `~numpy.ndarray` - Numpy table to access component. - componentName : `str` - Name of component to retrieve. - - Returns - ------- - component : `object` - The component. - - Raises - ------ - AttributeError - The component can not be found. - """ - match componentName: - case "columns": - return list(composite.dtype.names) - case "schema": - return ArrowNumpySchema(composite.dtype) - case "rowcount": - return len(composite) - - raise AttributeError( - f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}" - ) - - def _getColumns(self, inMemoryDataset: np.ndarray) -> list[str]: - return inMemoryDataset.dtype.names - - def _selectColumns(self, inMemoryDataset: np.ndarray, columns: list[str]) -> np.ndarray: - return inMemoryDataset[columns] diff --git a/python/lsst/daf/butler/delegates/arrownumpydict.py b/python/lsst/daf/butler/delegates/arrownumpydict.py deleted file mode 100644 index b94d0f0ee3..0000000000 --- a/python/lsst/daf/butler/delegates/arrownumpydict.py +++ /dev/null @@ -1,89 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -"""Support for reading dictionaries of numpy arrays with the Arrow -formatter. -""" -from __future__ import annotations - -from typing import Any - -import numpy as np -from lsst.daf.butler.formatters.parquet import ArrowNumpySchema, _numpy_dict_to_dtype -from lsst.utils.introspection import get_full_type_name - -from .arrowtable import ArrowTableDelegate - -__all__ = ["ArrowNumpyDictDelegate"] - - -class ArrowNumpyDictDelegate(ArrowTableDelegate): - """Delegate that understands the ``ArrowNumpyDict`` storage class.""" - - _datasetType = dict - - def getComponent(self, composite: dict[str, np.ndarray], componentName: str) -> Any: - """Get a component from a dict of numpy arrays stored via - ArrowNumpyDict. - - Parameters - ---------- - composite : `~numpy.ndarray` - Numpy table to access component. - componentName : `str` - Name of component to retrieve. - - Returns - ------- - component : `object` - The component. - - Raises - ------ - AttributeError - The component can not be found. - """ - match componentName: - case "columns": - return list(composite.keys()) - case "schema": - dtype, _ = _numpy_dict_to_dtype(composite) - return ArrowNumpySchema(dtype) - case "rowcount": - return len(composite[list(composite.keys())[0]]) - - raise AttributeError( - f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}" - ) - - def _getColumns(self, inMemoryDataset: dict[str, np.ndarray]) -> list[str]: - return list(inMemoryDataset.keys()) - - def _selectColumns( - self, inMemoryDataset: dict[str, np.ndarray], columns: list[str] - ) -> dict[str, np.ndarray]: - return {column: inMemoryDataset[column] for column in columns} diff --git a/python/lsst/daf/butler/delegates/arrowtable.py b/python/lsst/daf/butler/delegates/arrowtable.py index 88e552922b..1d4c728245 100644 --- a/python/lsst/daf/butler/delegates/arrowtable.py +++ b/python/lsst/daf/butler/delegates/arrowtable.py @@ -28,24 +28,29 @@ """Support for reading Arrow tables.""" from __future__ import annotations +__all__ = ["ArrowTableDelegate"] + from collections.abc import Mapping -from typing import Any +from typing import TYPE_CHECKING, Any import pyarrow as pa from lsst.daf.butler import StorageClassDelegate from lsst.utils.introspection import get_full_type_name from lsst.utils.iteration import ensure_iterable -__all__ = ["ArrowTableDelegate"] +if TYPE_CHECKING: + import pandas class ArrowTableDelegate(StorageClassDelegate): - """Delegate that understands the ``ArrowTable`` storage class.""" + """Delegate that understands ArrowTable and related storage classes.""" - _datasetType = pa.Table + def can_accept(self, inMemoryDataset: Any) -> bool: + # Docstring inherited. + return _checkArrowCompatibleType(inMemoryDataset) is not None - def getComponent(self, composite: pa.Table, componentName: str) -> Any: - """Get a component from an Arrow table. + def getComponent(self, composite: Any, componentName: str) -> Any: + """Get a component from an Arrow table or equivalent. Parameters ---------- @@ -64,72 +69,179 @@ def getComponent(self, composite: pa.Table, componentName: str) -> Any: AttributeError The component can not be found. """ - if componentName in ("columns", "schema"): - # The schema will be translated to column format - # depending on the input type. - return composite.schema + typeString = _checkArrowCompatibleType(composite) + + if typeString is None: + raise ValueError(f"Unsupported composite type {get_full_type_name(composite)}") + + if componentName == "columns": + if typeString == "arrow": + return composite.schema + elif typeString == "astropy": + return list(composite.columns.keys()) + elif typeString == "numpy": + return list(composite.dtype.names) + elif typeString == "numpydict": + return list(composite.keys()) + elif typeString == "pandas": + import pandas + + if isinstance(composite.columns, pandas.MultiIndex): + return composite.columns + else: + return pandas.Index(self._getAllDataframeColumns(composite)) + + elif componentName == "schema": + if typeString == "arrow": + return composite.schema + elif typeString == "astropy": + from lsst.daf.butler.formatters.parquet import ArrowAstropySchema + + return ArrowAstropySchema(composite) + elif typeString == "numpy": + from lsst.daf.butler.formatters.parquet import ArrowNumpySchema + + return ArrowNumpySchema(composite.dtype) + elif typeString == "numpydict": + from lsst.daf.butler.formatters.parquet import ArrowNumpySchema, _numpy_dict_to_dtype + + dtype, _ = _numpy_dict_to_dtype(composite) + return ArrowNumpySchema(dtype) + elif typeString == "pandas": + from lsst.daf.butler.formatters.parquet import DataFrameSchema + + return DataFrameSchema(composite.iloc[:0]) elif componentName == "rowcount": - return len(composite[composite.schema.names[0]]) + if typeString == "arrow": + return len(composite[composite.schema.names[0]]) + elif typeString in ["astropy", "numpy", "pandas"]: + return len(composite) + elif typeString == "numpydict": + return len(composite[list(composite.keys())[0]]) raise AttributeError( f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}" ) def handleParameters(self, inMemoryDataset: Any, parameters: Mapping[str, Any] | None = None) -> Any: - if not isinstance(inMemoryDataset, self._datasetType): - raise ValueError( - f"inMemoryDataset must be a {get_full_type_name(self._datasetType)} and " - f"not {get_full_type_name(inMemoryDataset)}." - ) + typeString = _checkArrowCompatibleType(inMemoryDataset) + + if typeString is None: + raise ValueError(f"Unsupported inMemoryDataset type {get_full_type_name(inMemoryDataset)}") if parameters is None: return inMemoryDataset if "columns" in parameters: - read_columns = list(ensure_iterable(parameters["columns"])) - for column in read_columns: - if not isinstance(column, str): - raise NotImplementedError( - "InMemoryDataset of an Arrow Table only supports string column names." - ) - if column not in self._getColumns(inMemoryDataset): - raise ValueError(f"Unrecognized column name {column!r}.") + readColumns = list(ensure_iterable(parameters["columns"])) + + if typeString == "arrow": + allColumns = inMemoryDataset.schema.names + elif typeString == "astropy": + allColumns = inMemoryDataset.columns.keys() + elif typeString == "numpy": + allColumns = inMemoryDataset.dtype.names + elif typeString == "numpydict": + allColumns = list(inMemoryDataset.keys()) + elif typeString == "pandas": + import pandas + + allColumns = self._getAllDataframeColumns(inMemoryDataset) + + if typeString == "pandas" and isinstance(inMemoryDataset.columns, pandas.MultiIndex): + from ..formatters.parquet import _standardize_multi_index_columns + + # We have a multi-index dataframe which needs special + # handling. + readColumns = _standardize_multi_index_columns( + inMemoryDataset.columns, + parameters["columns"], + stringify=False, + ) + else: + readColumns = list(ensure_iterable(parameters["columns"])) + + for column in readColumns: + if not isinstance(column, str): + raise NotImplementedError( + f"InMemoryDataset of a {get_full_type_name(inMemoryDataset)} only " + "supports string column names." + ) + if column not in allColumns: + raise ValueError(f"Unrecognized column name {column!r}.") + + if typeString == "pandas": + # Exclude index columns from the subset. + readColumns = [ + name + for name in ensure_iterable(parameters["columns"]) + if name not in inMemoryDataset.index.names + ] # Ensure uniqueness, keeping order. - read_columns = list(dict.fromkeys(read_columns)) - - return self._selectColumns(inMemoryDataset, read_columns) + readColumns = list(dict.fromkeys(readColumns)) + + if typeString == "arrow": + return inMemoryDataset.select(readColumns) + elif typeString in ("astropy", "numpy", "pandas"): + return inMemoryDataset[readColumns] + elif typeString == "numpydict": + return {column: inMemoryDataset[column] for column in readColumns} else: return inMemoryDataset - def _getColumns(self, inMemoryDataset: pa.Table) -> list[str]: - """Get the column names from the inMemoryDataset. - - Parameters - ---------- - inMemoryDataset : `object` - Dataset to extract columns. + def _getAllDataframeColumns(self, dataset: pandas.DataFrame) -> list[str]: + """Get all columns, including index columns. Returns ------- columns : `list` [`str`] - List of columns. - """ - return inMemoryDataset.schema.names - - def _selectColumns(self, inMemoryDataset: pa.Table, columns: list[str]) -> pa.Table: - """Select a subset of columns from the inMemoryDataset. - - Parameters - ---------- - inMemoryDataset : `object` - Dataset to extract columns. - columns : `list` [`str`] - List of columns to extract. - - Returns - ------- - subDataset : `object` - Subselection of inMemoryDataset. + List of all columns. """ - return inMemoryDataset.select(columns) + allColumns = list(dataset.columns) + if dataset.index.names[0] is not None: + allColumns.extend(dataset.index.names) + + return allColumns + + +def _checkArrowCompatibleType(dataset: Any) -> str | None: + """Check a dataset for arrow compatiblity and return type string. + + Parameters + ---------- + dataset : `object` + Dataset object. + + Returns + ------- + typeString : `str` + Type string will be ``arrow`` or ``astropy`` or ``numpy`` or ``pandas`` + or "numpydict". + """ + import numpy as np + from astropy.table import Table as astropyTable + + if isinstance(dataset, pa.Table): + return "arrow" + elif isinstance(dataset, astropyTable): + return "astropy" + elif isinstance(dataset, np.ndarray): + return "numpy" + elif isinstance(dataset, dict): + for key, item in dataset.items(): + if not isinstance(item, np.ndarray): + # This is some other sort of dictionary. + return None + return "numpydict" + elif hasattr(dataset, "to_parquet"): + # This may be a pandas DataFrame + try: + import pandas + except ImportError: + pandas = None + + if pandas is not None and isinstance(dataset, pandas.DataFrame): + return "pandas" + + return None diff --git a/python/lsst/daf/butler/delegates/dataframe.py b/python/lsst/daf/butler/delegates/dataframe.py deleted file mode 100644 index 342cb3408e..0000000000 --- a/python/lsst/daf/butler/delegates/dataframe.py +++ /dev/null @@ -1,163 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -"""Support for reading DataFrames.""" -from __future__ import annotations - -import collections.abc -from collections.abc import Mapping -from typing import Any - -import pandas -from lsst.daf.butler import StorageClassDelegate -from lsst.daf.butler.formatters.parquet import DataFrameSchema -from lsst.utils.introspection import get_full_type_name -from lsst.utils.iteration import ensure_iterable - -from ..formatters.parquet import _standardize_multi_index_columns - -__all__ = ["DataFrameDelegate"] - - -class DataFrameDelegate(StorageClassDelegate): - """Delegate that understands the ``DataFrame`` storage class.""" - - def getComponent(self, composite: pandas.DataFrame, componentName: str) -> Any: - """Get a component from a DataFrame. - - Parameters - ---------- - composite : `~pandas.DataFrame` - ``DataFrame`` to access component. - componentName : `str` - Name of component to retrieve. - - Returns - ------- - component : `object` - The component. - - Raises - ------ - AttributeError - The component can not be found. - """ - if componentName == "columns": - if isinstance(composite.columns, pandas.MultiIndex): - return composite.columns - else: - return pandas.Index(self._getAllColumns(composite)) - elif componentName == "rowcount": - return len(composite) - elif componentName == "schema": - return DataFrameSchema(composite.iloc[:0]) - else: - raise AttributeError( - f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}" - ) - - def handleParameters( - self, inMemoryDataset: pandas.DataFrame, parameters: Mapping[str, Any] | None = None - ) -> Any: - """Return possibly new in-memory dataset using the supplied parameters. - - Parameters - ---------- - inMemoryDataset : `object` - Object to modify based on the parameters. - parameters : `dict`, optional - Parameters to apply. Values are specific to the parameter. - Supported parameters are defined in the associated - `StorageClass`. If no relevant parameters are specified the - ``inMemoryDataset`` will be return unchanged. - - Returns - ------- - inMemoryDataset : `object` - Original in-memory dataset, or updated form after parameters - have been used. - """ - if not isinstance(inMemoryDataset, pandas.DataFrame): - raise ValueError( - "handleParameters for a DataFrame must get a DataFrame, " - f"not {get_full_type_name(inMemoryDataset)}." - ) - - if parameters is None: - return inMemoryDataset - - if "columns" in parameters: - allColumns = self._getAllColumns(inMemoryDataset) - - if not isinstance(parameters["columns"], collections.abc.Iterable): - raise NotImplementedError( - "InMemoryDataset of a DataFrame only supports list/tuple of string column names" - ) - - if isinstance(inMemoryDataset.columns, pandas.MultiIndex): - # We have a multi-index dataframe which needs special handling. - readColumns = _standardize_multi_index_columns( - inMemoryDataset.columns, - parameters["columns"], - stringify=False, - ) - else: - for column in ensure_iterable(parameters["columns"]): - if not isinstance(column, str): - raise NotImplementedError( - "InMemoryDataset of a DataFrame only supports string column names." - ) - if column not in allColumns: - raise ValueError(f"Unrecognized column name {column!r}.") - - # Exclude index columns from the subset. - readColumns = [ - name - for name in ensure_iterable(parameters["columns"]) - if name not in inMemoryDataset.index.names - ] - - # Ensure uniqueness, keeping order. - readColumns = list(dict.fromkeys(readColumns)) - - return inMemoryDataset[readColumns] - else: - return inMemoryDataset - - def _getAllColumns(self, inMemoryDataset: pandas.DataFrame) -> list[str]: - """Get all columns, including index columns. - - Returns - ------- - columns : `list` [`str`] - List of all columns. - """ - allColumns = list(inMemoryDataset.columns) - if inMemoryDataset.index.names[0] is not None: - allColumns.extend(inMemoryDataset.index.names) - - return allColumns diff --git a/python/lsst/daf/butler/formatters/parquet.py b/python/lsst/daf/butler/formatters/parquet.py index 97e800d139..5c9c952bf4 100644 --- a/python/lsst/daf/butler/formatters/parquet.py +++ b/python/lsst/daf/butler/formatters/parquet.py @@ -57,6 +57,7 @@ import pyarrow as pa import pyarrow.parquet as pq from lsst.daf.butler import FormatterV2 +from lsst.daf.butler.delegates.arrowtable import _checkArrowCompatibleType from lsst.resources import ResourcePath from lsst.utils.introspection import get_full_type_name from lsst.utils.iteration import ensure_iterable @@ -77,6 +78,10 @@ class ParquetFormatter(FormatterV2): default_extension = ".parq" can_read_from_local_file = True + def can_accept(self, in_memory_dataset: Any) -> bool: + # Docstring inherited. + return _checkArrowCompatibleType(in_memory_dataset) is not None + def read_from_local_file(self, path: str, component: str | None = None, expected_size: int = -1) -> Any: # Docstring inherited from Formatter.read. schema = pq.read_schema(path) @@ -143,44 +148,30 @@ def read_from_local_file(self, path: str, component: str | None = None, expected return arrow_table def write_local_file(self, in_memory_dataset: Any, uri: ResourcePath) -> None: - import numpy as np - from astropy.table import Table as astropyTable - arrow_table = None - if isinstance(in_memory_dataset, pa.Table): - # This will be the most likely match. - arrow_table = in_memory_dataset - elif isinstance(in_memory_dataset, astropyTable): - arrow_table = astropy_to_arrow(in_memory_dataset) - elif isinstance(in_memory_dataset, np.ndarray): - arrow_table = numpy_to_arrow(in_memory_dataset) - elif isinstance(in_memory_dataset, dict): - try: - arrow_table = numpy_dict_to_arrow(in_memory_dataset) - except (TypeError, AttributeError) as e: - raise ValueError( - "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." - ) from e - elif isinstance(in_memory_dataset, pa.Schema): + if isinstance(in_memory_dataset, pa.Schema): pq.write_metadata(in_memory_dataset, uri.ospath) return - else: - if hasattr(in_memory_dataset, "to_parquet"): - # This may be a pandas DataFrame - try: - import pandas as pd - except ImportError: - pd = None - if pd is not None and isinstance(in_memory_dataset, pd.DataFrame): - arrow_table = pandas_to_arrow(in_memory_dataset) + type_string = _checkArrowCompatibleType(in_memory_dataset) - if arrow_table is None: + if type_string is None: raise ValueError( f"Unsupported type {get_full_type_name(in_memory_dataset)} of " "inMemoryDataset for ParquetFormatter." ) + if type_string == "arrow": + arrow_table = in_memory_dataset + elif type_string == "astropy": + arrow_table = astropy_to_arrow(in_memory_dataset) + elif type_string == "numpy": + arrow_table = numpy_to_arrow(in_memory_dataset) + elif type_string == "numpydict": + arrow_table = numpy_dict_to_arrow(in_memory_dataset) + else: + arrow_table = pandas_to_arrow(in_memory_dataset) + row_group_size = compute_row_group_size(arrow_table.schema) pq.write_table(arrow_table, uri.ospath, row_group_size=row_group_size) diff --git a/tests/test_parquet.py b/tests/test_parquet.py index 351b90f94f..97485d46e0 100644 --- a/tests/test_parquet.py +++ b/tests/test_parquet.py @@ -62,24 +62,11 @@ StorageClassFactory, ) -try: - from lsst.daf.butler.delegates.arrowastropy import ArrowAstropyDelegate -except ImportError: - atable = None - pa = None -try: - from lsst.daf.butler.delegates.arrownumpy import ArrowNumpyDelegate -except ImportError: - np = None - pa = None try: from lsst.daf.butler.delegates.arrowtable import ArrowTableDelegate except ImportError: pa = None -try: - from lsst.daf.butler.delegates.dataframe import DataFrameDelegate -except ImportError: - pd = None + try: from lsst.daf.butler.formatters.parquet import ( ArrowAstropySchema, @@ -751,10 +738,55 @@ def testBadDataFrameColumnParquet(self): with self.assertRaises(RuntimeError): self.butler.put(bad_df, self.datasetType, dataId={}) + @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.") + def testWriteReadAstropyTableLossless(self): + tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True) + + self.butler.put(tab1, self.datasetType, dataId={}) + + tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") + + _checkAstropyTableEquality(tab1, tab2) + + @unittest.skipUnless(np is not None, "Cannot test reading as numpy without numpy.") + def testWriteReadNumpyTableLossless(self): + tab1 = _makeSimpleNumpyTable(include_multidim=True) + + self.butler.put(tab1, self.datasetType, dataId={}) + + tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpy") + + _checkNumpyTableEquality(tab1, tab2) + + @unittest.skipUnless(pa is not None, "Cannot test reading as arrow without pyarrow.") + def testWriteReadArrowTableLossless(self): + tab1 = _makeSimpleArrowTable(include_multidim=False, include_masked=True) + + self.butler.put(tab1, self.datasetType, dataId={}) + + tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowTable") + + self.assertEqual(tab1.schema, tab2.schema) + tab1_np = arrow_to_numpy(tab1) + tab2_np = arrow_to_numpy(tab2) + for col in tab1.column_names: + np.testing.assert_array_equal(tab2_np[col], tab1_np[col]) + + @unittest.skipUnless(np is not None, "Cannot test reading as numpy dict without numpy.") + def testWriteReadNumpyDictLossless(self): + tab1 = _makeSimpleNumpyTable(include_multidim=True) + dict1 = _numpy_to_numpy_dict(tab1) + + self.butler.put(tab1, self.datasetType, dataId={}) + + dict2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpyDict") -@unittest.skipUnless(pd is not None, "Cannot test InMemoryDataFrameDelegate without pandas.") + _checkNumpyDictEquality(dict1, dict2) + + +@unittest.skipUnless(pd is not None, "Cannot test InMemoryDatastore with DataFrames without pandas.") class InMemoryDataFrameDelegateTestCase(ParquetFormatterDataFrameTestCase): - """Tests for InMemoryDatastore, using DataFrameDelegate.""" + """Tests for InMemoryDatastore, using ArrowTableDelegate with Dataframe.""" configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml") @@ -776,7 +808,7 @@ def testLegacyDataFrame(self): def testBadInput(self): df1, _ = _makeSingleIndexDataFrame() - delegate = DataFrameDelegate("DataFrame") + delegate = ArrowTableDelegate("DataFrame") with self.assertRaises(ValueError): delegate.handleParameters(inMemoryDataset="not_a_dataframe") @@ -832,7 +864,7 @@ def testAstropyTable(self): self.butler.put(tab1, self.datasetType, dataId={}) # Read the whole Table. tab2 = self.butler.get(self.datasetType, dataId={}) - self._checkAstropyTableEquality(tab1, tab2) + _checkAstropyTableEquality(tab1, tab2) # Read the columns. columns2 = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={}) self.assertEqual(len(columns2), len(tab1.dtype.names)) @@ -846,15 +878,15 @@ def testAstropyTable(self): self.assertEqual(schema, ArrowAstropySchema(tab1)) # Read just some columns a few different ways. tab3 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "c"]}) - self._checkAstropyTableEquality(tab1[("a", "c")], tab3) + _checkAstropyTableEquality(tab1[("a", "c")], tab3) tab4 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "a"}) - self._checkAstropyTableEquality(tab1[("a",)], tab4) + _checkAstropyTableEquality(tab1[("a",)], tab4) tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["index", "a"]}) - self._checkAstropyTableEquality(tab1[("index", "a")], tab5) + _checkAstropyTableEquality(tab1[("index", "a")], tab5) tab6 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "ddd"}) - self._checkAstropyTableEquality(tab1[("ddd",)], tab6) + _checkAstropyTableEquality(tab1[("ddd",)], tab6) tab7 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "a"]}) - self._checkAstropyTableEquality(tab1[("a",)], tab7) + _checkAstropyTableEquality(tab1[("a",)], tab7) # Passing an unrecognized column should be a ValueError. with self.assertRaises(ValueError): self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["e"]}) @@ -865,7 +897,7 @@ def testAstropyTableBigEndian(self): self.butler.put(tab1, self.datasetType, dataId={}) # Read the whole Table. tab2 = self.butler.get(self.datasetType, dataId={}) - self._checkAstropyTableEquality(tab1, tab2, has_bigendian=True) + _checkAstropyTableEquality(tab1, tab2, has_bigendian=True) def testAstropyTableWithMetadata(self): tab1 = _makeSimpleAstropyTable(include_multidim=True) @@ -884,7 +916,7 @@ def testAstropyTableWithMetadata(self): # Read the whole Table. tab2 = self.butler.get(self.datasetType, dataId={}) # This will check that the metadata is equivalent as well. - self._checkAstropyTableEquality(tab1, tab2) + _checkAstropyTableEquality(tab1, tab2) def testArrowAstropySchema(self): tab1 = _makeSimpleAstropyTable() @@ -945,7 +977,7 @@ def testAstropyParquet(self): tab2a = self.butler.get(self.datasetType, dataId={}) tab2b = self.butler.get("astropy_parquet", dataId={}) - self._checkAstropyTableEquality(tab2a, tab2b) + _checkAstropyTableEquality(tab2a, tab2b) columns2a = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={}) columns2b = self.butler.get("astropy_parquet.columns", dataId={}) @@ -971,7 +1003,7 @@ def testWriteAstropyReadAsArrowTable(self): tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowTable") tab2_astropy = arrow_to_astropy(tab2) - self._checkAstropyTableEquality(tab1, tab2_astropy) + _checkAstropyTableEquality(tab1, tab2_astropy) # Check reading the columns. columns = tab2.schema.names @@ -1056,7 +1088,7 @@ def testWriteSingleIndexDataFrameWithMaskedColsReadAsAstropyTable(self): df1_tab = pandas_to_astropy(df1) - self._checkAstropyTableEquality(df1_tab, tab2) + _checkAstropyTableEquality(df1_tab, tab2) @unittest.skipUnless(np is not None, "Cannot test reading as numpy without numpy.") def testWriteAstropyReadAsNumpyTable(self): @@ -1068,7 +1100,7 @@ def testWriteAstropyReadAsNumpyTable(self): # This is tricky because it loses the units. tab2_astropy = atable.Table(tab2) - self._checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True) + _checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True) # Check reading the columns. columns = list(tab2.dtype.names) @@ -1095,7 +1127,7 @@ def testWriteAstropyReadAsNumpyDict(self): # This is tricky because it loses the units. tab2_astropy = atable.Table(tab2) - self._checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True) + _checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True) def testBadAstropyColumnParquet(self): tab1 = _makeSimpleAstropyTable() @@ -1120,52 +1152,12 @@ def testBadAstropyColumnParquet(self): with self.assertRaises(RuntimeError): self.butler.put(bad_tab, self.datasetType, dataId={}) - def _checkAstropyTableEquality(self, table1, table2, skip_units=False, has_bigendian=False): - """Check if two astropy tables have the same columns/values. - Parameters - ---------- - table1 : `astropy.table.Table` - table2 : `astropy.table.Table` - skip_units : `bool` - has_bigendian : `bool` - """ - if not has_bigendian: - self.assertEqual(table1.dtype, table2.dtype) - else: - for name in table1.dtype.names: - # Only check type matches, force to little-endian. - self.assertEqual(table1.dtype[name].newbyteorder(">"), table2.dtype[name].newbyteorder(">")) - - self.assertEqual(table1.meta, table2.meta) - if not skip_units: - for name in table1.columns: - self.assertEqual(table1[name].unit, table2[name].unit) - self.assertEqual(table1[name].description, table2[name].description) - self.assertEqual(table1[name].format, table2[name].format) - # We need to check masked/regular columns after filling. - has_masked = False - if isinstance(table1[name], atable.column.MaskedColumn): - c1 = table1[name].filled() - has_masked = True - else: - c1 = np.array(table1[name]) - if has_masked: - self.assertIsInstance(table2[name], atable.column.MaskedColumn) - c2 = table2[name].filled() - else: - self.assertFalse(isinstance(table2[name], atable.column.MaskedColumn)) - c2 = np.array(table2[name]) - np.testing.assert_array_equal(c1, c2) - # If we have a masked column then we test the underlying data. - if has_masked: - np.testing.assert_array_equal(np.array(c1), np.array(c2)) - np.testing.assert_array_equal(table1[name].mask, table2[name].mask) - - -@unittest.skipUnless(atable is not None, "Cannot test InMemoryArrowAstropyDelegate without astropy.") +@unittest.skipUnless(atable is not None, "Cannot test InMemoryDatastore with AstropyTable without astropy.") class InMemoryArrowAstropyDelegateTestCase(ParquetFormatterArrowAstropyTestCase): - """Tests for InMemoryDatastore, using ArrowAstropyDelegate.""" + """Tests for InMemoryDatastore, using ArrowTableDelegate with + AstropyTable. + """ configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml") @@ -1179,7 +1171,7 @@ def testBadAstropyColumnParquet(self): def testBadInput(self): tab1 = _makeSimpleAstropyTable() - delegate = ArrowAstropyDelegate("ArrowAstropy") + delegate = ArrowTableDelegate("ArrowAstropy") with self.assertRaises(ValueError): delegate.handleParameters(inMemoryDataset="not_an_astropy_table") @@ -1221,7 +1213,7 @@ def testNumpyTable(self): self.butler.put(tab1, self.datasetType, dataId={}) # Read the whole Table. tab2 = self.butler.get(self.datasetType, dataId={}) - self._checkNumpyTableEquality(tab1, tab2) + _checkNumpyTableEquality(tab1, tab2) # Read the columns. columns2 = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={}) self.assertEqual(len(columns2), len(tab1.dtype.names)) @@ -1235,9 +1227,9 @@ def testNumpyTable(self): self.assertEqual(schema, ArrowNumpySchema(tab1.dtype)) # Read just some columns a few different ways. tab3 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "c"]}) - self._checkNumpyTableEquality(tab1[["a", "c"]], tab3) + _checkNumpyTableEquality(tab1[["a", "c"]], tab3) tab4 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "a"}) - self._checkNumpyTableEquality( + _checkNumpyTableEquality( tab1[ [ "a", @@ -1246,9 +1238,9 @@ def testNumpyTable(self): tab4, ) tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["index", "a"]}) - self._checkNumpyTableEquality(tab1[["index", "a"]], tab5) + _checkNumpyTableEquality(tab1[["index", "a"]], tab5) tab6 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "ddd"}) - self._checkNumpyTableEquality( + _checkNumpyTableEquality( tab1[ [ "ddd", @@ -1257,7 +1249,7 @@ def testNumpyTable(self): tab6, ) tab7 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "a"]}) - self._checkNumpyTableEquality( + _checkNumpyTableEquality( tab1[ [ "a", @@ -1275,7 +1267,7 @@ def testNumpyTableBigEndian(self): self.butler.put(tab1, self.datasetType, dataId={}) # Read the whole Table. tab2 = self.butler.get(self.datasetType, dataId={}) - self._checkNumpyTableEquality(tab1, tab2, has_bigendian=True) + _checkNumpyTableEquality(tab1, tab2, has_bigendian=True) def testArrowNumpySchema(self): tab1 = _makeSimpleNumpyTable(include_multidim=True) @@ -1317,7 +1309,7 @@ def testWriteNumpyTableReadAsArrowTable(self): tab2_numpy = arrow_to_numpy(tab2) - self._checkNumpyTableEquality(tab1, tab2_numpy) + _checkNumpyTableEquality(tab1, tab2_numpy) # Check reading the columns. columns = tab2.schema.names @@ -1372,7 +1364,7 @@ def testWriteNumpyTableReadAsAstropyTable(self): tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") tab2_numpy = tab2.as_array() - self._checkNumpyTableEquality(tab1, tab2_numpy) + _checkNumpyTableEquality(tab1, tab2_numpy) # Check reading the columns. columns = list(tab2.columns.keys()) @@ -1397,7 +1389,7 @@ def testWriteNumpyTableReadAsNumpyDict(self): tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpyDict") tab2_numpy = _numpy_dict_to_numpy(tab2) - self._checkNumpyTableEquality(tab1, tab2_numpy) + _checkNumpyTableEquality(tab1, tab2_numpy) def testBadNumpyColumnParquet(self): tab1 = _makeSimpleAstropyTable() @@ -1426,28 +1418,22 @@ def testBadNumpyColumnParquet(self): with self.assertRaises(RuntimeError): self.butler.put(bad_tab_np, self.datasetType, dataId={}) - def _checkNumpyTableEquality(self, table1, table2, has_bigendian=False): - """Check if two numpy tables have the same columns/values + @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.") + def testWriteReadAstropyTableLossless(self): + tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True) - Parameters - ---------- - table1 : `numpy.ndarray` - table2 : `numpy.ndarray` - has_bigendian : `bool` - """ - self.assertEqual(table1.dtype.names, table2.dtype.names) - for name in table1.dtype.names: - if not has_bigendian: - self.assertEqual(table1.dtype[name], table2.dtype[name]) - else: - # Only check type matches, force to little-endian. - self.assertEqual(table1.dtype[name].newbyteorder(">"), table2.dtype[name].newbyteorder(">")) - self.assertTrue(np.all(table1 == table2)) + self.butler.put(tab1, self.datasetType, dataId={}) + tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") -@unittest.skipUnless(np is not None, "Cannot test ParquetFormatterArrowNumpy without numpy.") + _checkAstropyTableEquality(tab1, tab2) + + +@unittest.skipUnless(np is not None, "Cannot test ImMemoryDatastore with Numpy table without numpy.") class InMemoryArrowNumpyDelegateTestCase(ParquetFormatterArrowNumpyTestCase): - """Tests for InMemoryDatastore, using ArrowNumpyDelegate.""" + """Tests for InMemoryDatastore, using ArrowTableDelegate with + Numpy table. + """ configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml") @@ -1457,7 +1443,7 @@ def testBadNumpyColumnParquet(self): def testBadInput(self): tab1 = _makeSimpleNumpyTable() - delegate = ArrowNumpyDelegate("ArrowNumpy") + delegate = ArrowTableDelegate("ArrowNumpy") with self.assertRaises(ValueError): delegate.handleParameters(inMemoryDataset="not_a_numpy_table") @@ -1674,12 +1660,12 @@ def testWriteArrowTableReadAsAstropyTable(self): # Read back out as an astropy table. tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") - self._checkAstropyTableEquality(tab1, tab2) + _checkAstropyTableEquality(tab1, tab2) # Read back out as an arrow table, convert to astropy table. atab3 = self.butler.get(self.datasetType, dataId={}) tab3 = arrow_to_astropy(atab3) - self._checkAstropyTableEquality(tab1, tab3) + _checkAstropyTableEquality(tab1, tab3) # Check reading the columns. columns = list(tab2.columns.keys()) @@ -1717,12 +1703,12 @@ def testWriteArrowTableReadAsNumpyTable(self): # Read back out as a numpy table. tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpy") - self._checkNumpyTableEquality(tab1, tab2) + _checkNumpyTableEquality(tab1, tab2) # Read back out as an arrow table, convert to numpy table. atab3 = self.butler.get(self.datasetType, dataId={}) tab3 = arrow_to_numpy(atab3) - self._checkNumpyTableEquality(tab1, tab3) + _checkNumpyTableEquality(tab1, tab3) # Check reading the columns. columns = list(tab2.dtype.names) @@ -1746,55 +1732,20 @@ def testWriteArrowTableReadAsNumpyDict(self): tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpyDict") tab2_numpy = _numpy_dict_to_numpy(tab2) - self._checkNumpyTableEquality(tab1, tab2_numpy) + _checkNumpyTableEquality(tab1, tab2_numpy) - def _checkAstropyTableEquality(self, table1, table2): - """Check if two astropy tables have the same columns/values + @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.") + def testWriteReadAstropyTableLossless(self): + tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True) - Parameters - ---------- - table1 : `astropy.table.Table` - table2 : `astropy.table.Table` - """ - self.assertEqual(table1.dtype, table2.dtype) - for name in table1.columns: - self.assertEqual(table1[name].unit, table2[name].unit) - self.assertEqual(table1[name].description, table2[name].description) - self.assertEqual(table1[name].format, table2[name].format) - # We need to check masked/regular columns after filling. - has_masked = False - if isinstance(table1[name], atable.column.MaskedColumn): - c1 = table1[name].filled() - has_masked = True - else: - c1 = np.array(table1[name]) - if has_masked: - self.assertIsInstance(table2[name], atable.column.MaskedColumn) - c2 = table2[name].filled() - else: - self.assertFalse(isinstance(table2[name], atable.column.MaskedColumn)) - c2 = np.array(table2[name]) - np.testing.assert_array_equal(c1, c2) - # If we have a masked column then we test the underlying data. - if has_masked: - np.testing.assert_array_equal(np.array(c1), np.array(c2)) - np.testing.assert_array_equal(table1[name].mask, table2[name].mask) - - def _checkNumpyTableEquality(self, table1, table2): - """Check if two numpy tables have the same columns/values - - Parameters - ---------- - table1 : `numpy.ndarray` - table2 : `numpy.ndarray` - """ - self.assertEqual(table1.dtype.names, table2.dtype.names) - for name in table1.dtype.names: - self.assertEqual(table1.dtype[name], table2.dtype[name]) - self.assertTrue(np.all(table1 == table2)) + self.butler.put(tab1, self.datasetType, dataId={}) + + tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") + _checkAstropyTableEquality(tab1, tab2) -@unittest.skipUnless(pa is not None, "Cannot test InMemoryArrowTableDelegate without pyarrow.") + +@unittest.skipUnless(pa is not None, "Cannot test InMemoryDatastore with ArroWTable without pyarrow.") class InMemoryArrowTableDelegateTestCase(ParquetFormatterArrowTableTestCase): """Tests for InMemoryDatastore, using ArrowTableDelegate.""" @@ -1861,7 +1812,7 @@ def testNumpyDict(self): self.butler.put(dict1, self.datasetType, dataId={}) # Read the whole table. dict2 = self.butler.get(self.datasetType, dataId={}) - self._checkNumpyDictEquality(dict1, dict2) + _checkNumpyDictEquality(dict1, dict2) # Read the columns. columns2 = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={}) self.assertEqual(len(columns2), len(dict1.keys())) @@ -1876,19 +1827,19 @@ def testNumpyDict(self): # Read just some columns a few different ways. tab3 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "c"]}) subdict = {key: dict1[key] for key in ["a", "c"]} - self._checkNumpyDictEquality(subdict, tab3) + _checkNumpyDictEquality(subdict, tab3) tab4 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "a"}) subdict = {key: dict1[key] for key in ["a"]} - self._checkNumpyDictEquality(subdict, tab4) + _checkNumpyDictEquality(subdict, tab4) tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["index", "a"]}) subdict = {key: dict1[key] for key in ["index", "a"]} - self._checkNumpyDictEquality(subdict, tab5) + _checkNumpyDictEquality(subdict, tab5) tab6 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "ddd"}) subdict = {key: dict1[key] for key in ["ddd"]} - self._checkNumpyDictEquality(subdict, tab6) + _checkNumpyDictEquality(subdict, tab6) tab7 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "a"]}) subdict = {key: dict1[key] for key in ["a"]} - self._checkNumpyDictEquality(subdict, tab7) + _checkNumpyDictEquality(subdict, tab7) # Passing an unrecognized column should be a ValueError. with self.assertRaises(ValueError): self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["e"]}) @@ -1904,7 +1855,7 @@ def testWriteNumpyDictReadAsArrowTable(self): tab2_dict = arrow_to_numpy_dict(tab2) - self._checkNumpyDictEquality(dict1, tab2_dict) + _checkNumpyDictEquality(dict1, tab2_dict) @unittest.skipUnless(pd is not None, "Cannot test reading as a dataframe without pandas.") def testWriteNumpyDictReadAsDataFrame(self): @@ -1934,7 +1885,7 @@ def testWriteNumpyDictReadAsAstropyTable(self): tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") tab2_dict = _astropy_to_numpy_dict(tab2) - self._checkNumpyDictEquality(dict1, tab2_dict) + _checkNumpyDictEquality(dict1, tab2_dict) def testWriteNumpyDictReadAsNumpyTable(self): tab1 = _makeSimpleNumpyTable(include_multidim=True) @@ -1945,7 +1896,7 @@ def testWriteNumpyDictReadAsNumpyTable(self): tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpy") tab2_dict = _numpy_to_numpy_dict(tab2) - self._checkNumpyDictEquality(dict1, tab2_dict) + _checkNumpyDictEquality(dict1, tab2_dict) def testWriteNumpyDictBad(self): dict1 = {"a": 4, "b": np.ndarray([1])} @@ -1964,24 +1915,23 @@ def testWriteNumpyDictBad(self): with self.assertRaises(RuntimeError): self.butler.put(dict4, self.datasetType, dataId={}) - def _checkNumpyDictEquality(self, dict1, dict2): - """Check if two numpy dicts have the same columns/values. + @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.") + def testWriteReadAstropyTableLossless(self): + tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True) + + self.butler.put(tab1, self.datasetType, dataId={}) - Parameters - ---------- - dict1 : `dict` [`str`, `np.ndarray`] - dict2 : `dict` [`str`, `np.ndarray`] - """ - self.assertEqual(set(dict1.keys()), set(dict2.keys())) - for name in dict1: - self.assertEqual(dict1[name].dtype, dict2[name].dtype) - self.assertTrue(np.all(dict1[name] == dict2[name])) + tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy") + _checkAstropyTableEquality(tab1, tab2) -@unittest.skipUnless(np is not None, "Cannot test ParquetFormatterArrowNumpy without numpy.") -@unittest.skipUnless(pa is not None, "Cannot test ParquetFormatterArrowNumpy without pyarrow.") + +@unittest.skipUnless(np is not None, "Cannot test InMemoryDatastore with NumpyDict without numpy.") +@unittest.skipUnless(pa is not None, "Cannot test InMemoryDatastore with NumpyDict without pyarrow.") class InMemoryNumpyDictDelegateTestCase(ParquetFormatterArrowNumpyDictTestCase): - """Tests for InMemoryDatastore, using ArrowNumpyDictDelegate.""" + """Tests for InMemoryDatastore, using ArrowTableDelegate with + Numpy dict. + """ configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml") @@ -2152,7 +2102,7 @@ def testWriteArrowSchemaReadAsArrowNumpySchema(self): self.assertEqual(np_schema2, np_schema1) -@unittest.skipUnless(pa is not None, "Cannot test InMemoryArrowSchemaDelegate without pyarrow.") +@unittest.skipUnless(pa is not None, "Cannot test InMemoryDatastore with ArrowSchema without pyarrow.") class InMemoryArrowSchemaDelegateTestCase(ParquetFormatterArrowSchemaTestCase): """Tests for InMemoryDatastore and ArrowSchema.""" @@ -2212,5 +2162,83 @@ def testRowGroupSizeDataFrameWithLists(self): self.assertGreater(row_group_size, 1_000_000) +def _checkAstropyTableEquality(table1, table2, skip_units=False, has_bigendian=False): + """Check if two astropy tables have the same columns/values. + + Parameters + ---------- + table1 : `astropy.table.Table` + table2 : `astropy.table.Table` + skip_units : `bool` + has_bigendian : `bool` + """ + if not has_bigendian: + assert table1.dtype == table2.dtype + else: + for name in table1.dtype.names: + # Only check type matches, force to little-endian. + assert table1.dtype[name].newbyteorder(">") == table2.dtype[name].newbyteorder(">") + + assert table1.meta == table2.meta + if not skip_units: + for name in table1.columns: + assert table1[name].unit == table2[name].unit + assert table1[name].description == table2[name].description + assert table1[name].format == table2[name].format + + for name in table1.columns: + # We need to check masked/regular columns after filling. + has_masked = False + if isinstance(table1[name], atable.column.MaskedColumn): + c1 = table1[name].filled() + has_masked = True + else: + c1 = np.array(table1[name]) + if has_masked: + assert isinstance(table2[name], atable.column.MaskedColumn) + c2 = table2[name].filled() + else: + assert not isinstance(table2[name], atable.column.MaskedColumn) + c2 = np.array(table2[name]) + np.testing.assert_array_equal(c1, c2) + # If we have a masked column then we test the underlying data. + if has_masked: + np.testing.assert_array_equal(np.array(c1), np.array(c2)) + np.testing.assert_array_equal(table1[name].mask, table2[name].mask) + + +def _checkNumpyTableEquality(table1, table2, has_bigendian=False): + """Check if two numpy tables have the same columns/values + + Parameters + ---------- + table1 : `numpy.ndarray` + table2 : `numpy.ndarray` + has_bigendian : `bool` + """ + assert table1.dtype.names == table2.dtype.names + for name in table1.dtype.names: + if not has_bigendian: + assert table1.dtype[name] == table2.dtype[name] + else: + # Only check type matches, force to little-endian. + assert table1.dtype[name].newbyteorder(">") == table2.dtype[name].newbyteorder(">") + assert np.all(table1 == table2) + + +def _checkNumpyDictEquality(dict1, dict2): + """Check if two numpy dicts have the same columns/values. + + Parameters + ---------- + dict1 : `dict` [`str`, `np.ndarray`] + dict2 : `dict` [`str`, `np.ndarray`] + """ + assert set(dict1.keys()) == set(dict2.keys()) + for name in dict1: + assert dict1[name].dtype == dict2[name].dtype + assert np.all(dict1[name] == dict2[name]) + + if __name__ == "__main__": unittest.main()