diff --git a/doc/changes/DM-45431.feature.md b/doc/changes/DM-45431.feature.md
new file mode 100644
index 0000000000..495ca87ef2
--- /dev/null
+++ b/doc/changes/DM-45431.feature.md
@@ -0,0 +1,3 @@
+The ParquetFormatter now declares it can_accept Arrow tables, Astropy tables, Numpy tables, and pandas DataFraemes.
+This means that we have complete lossless storage of any parquet-compatible type into a datastore that has declared a different type; e.g. an astropy table with units can be persisted into a DataFrame storage class without those units being stripped.
+This ticket also adds can_accept to the InMemoryDatastore delegates, and now one ArrowTableDelegate handles all the parquet-compatible datasets.
diff --git a/python/lsst/daf/butler/_storage_class_delegate.py b/python/lsst/daf/butler/_storage_class_delegate.py
index 344c299f2e..1392ba583c 100644
--- a/python/lsst/daf/butler/_storage_class_delegate.py
+++ b/python/lsst/daf/butler/_storage_class_delegate.py
@@ -86,6 +86,31 @@ def __init__(self, storageClass: StorageClass):
assert storageClass is not None
self.storageClass = storageClass
+ def can_accept(self, inMemoryDataset: Any) -> bool:
+ """Indicate whether this delegate can accept the specified
+ storage class directly.
+
+ Parameters
+ ----------
+ inMemoryDataset : `object`
+ The dataset that is to be stored.
+
+ Returns
+ -------
+ accepts : `bool`
+ If `True` the delegate can handle data of this type without
+ requiring datastore to convert it. If `False` the datastore
+ will attempt to convert before storage.
+
+ Notes
+ -----
+ The base class always returns `False` even if the given type is an
+ instance of the delegate type. This will result in a storage class
+ conversion no-op but also allows mocks with mocked storage classes
+ to work properly.
+ """
+ return False
+
@staticmethod
def _attrNames(componentName: str, getter: bool = True) -> tuple[str, ...]:
"""Return list of suitable attribute names to attempt to use.
diff --git a/python/lsst/daf/butler/configs/storageClasses.yaml b/python/lsst/daf/butler/configs/storageClasses.yaml
index 08a1287c50..e56a2e111f 100644
--- a/python/lsst/daf/butler/configs/storageClasses.yaml
+++ b/python/lsst/daf/butler/configs/storageClasses.yaml
@@ -127,7 +127,7 @@ storageClasses:
astropy.table.Table: lsst.daf.butler.formatters.parquet.astropy_to_pandas
numpy.ndarray: pandas.DataFrame.from_records
dict: pandas.DataFrame.from_records
- delegate: lsst.daf.butler.delegates.dataframe.DataFrameDelegate
+ delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate
derivedComponents:
columns: DataFrameIndex
rowcount: int
@@ -179,7 +179,7 @@ storageClasses:
pandas.core.frame.DataFrame: lsst.daf.butler.formatters.parquet.pandas_to_astropy
numpy.ndarray: lsst.daf.butler.formatters.parquet.numpy_to_astropy
dict: astropy.table.Table
- delegate: lsst.daf.butler.delegates.arrowastropy.ArrowAstropyDelegate
+ delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate
derivedComponents:
columns: ArrowColumnList
rowcount: int
@@ -200,7 +200,7 @@ storageClasses:
pandas.core.frame.DataFrame: pandas.DataFrame.to_records
astropy.table.Table: astropy.table.Table.as_array
dict: lsst.daf.butler.formatters.parquet._numpy_dict_to_numpy
- delegate: lsst.daf.butler.delegates.arrownumpy.ArrowNumpyDelegate
+ delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate
derivedComponents:
columns: ArrowColumnList
rowcount: int
@@ -221,7 +221,7 @@ storageClasses:
pandas.core.frame.DataFrame: lsst.daf.butler.formatters.parquet._pandas_to_numpy_dict
astropy.table.Table: lsst.daf.butler.formatters.parquet._astropy_to_numpy_dict
numpy.ndarray: lsst.daf.butler.formatters.parquet._numpy_to_numpy_dict
- delegate: lsst.daf.butler.delegates.arrownumpydict.ArrowNumpyDictDelegate
+ delegate: lsst.daf.butler.delegates.arrowtable.ArrowTableDelegate
derivedComponents:
columns: ArrowColumnList
rowcount: int
diff --git a/python/lsst/daf/butler/datastore/generic_base.py b/python/lsst/daf/butler/datastore/generic_base.py
index 2921edd987..75415fea29 100644
--- a/python/lsst/daf/butler/datastore/generic_base.py
+++ b/python/lsst/daf/butler/datastore/generic_base.py
@@ -35,7 +35,6 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Generic, TypeVar
-from .._exceptions import DatasetTypeNotSupportedError
from ..datastore._datastore import Datastore
from .stored_file_info import StoredDatastoreItemInfo
@@ -54,34 +53,6 @@ class GenericBaseDatastore(Datastore, Generic[_InfoType]):
Should always be sub-classed since key abstract methods are missing.
"""
- def _validate_put_parameters(self, inMemoryDataset: object, ref: DatasetRef) -> None:
- """Validate the supplied arguments for put.
-
- Parameters
- ----------
- inMemoryDataset : `object`
- The dataset to store.
- ref : `DatasetRef`
- Reference to the associated Dataset.
- """
- storageClass = ref.datasetType.storageClass
-
- # Sanity check
- if not isinstance(inMemoryDataset, storageClass.pytype):
- raise TypeError(
- f"Inconsistency between supplied object ({type(inMemoryDataset)}) "
- f"and storage class type ({storageClass.pytype})"
- )
-
- # Confirm that we can accept this dataset
- if not self.constraints.isAcceptable(ref):
- # Raise rather than use boolean return value.
- raise DatasetTypeNotSupportedError(
- f"Dataset {ref} has been rejected by this datastore via configuration."
- )
-
- return
-
def remove(self, ref: DatasetRef) -> None:
"""Indicate to the Datastore that a dataset can be removed.
diff --git a/python/lsst/daf/butler/datastores/inMemoryDatastore.py b/python/lsst/daf/butler/datastores/inMemoryDatastore.py
index 25b2b1099f..d40659b0b8 100644
--- a/python/lsst/daf/butler/datastores/inMemoryDatastore.py
+++ b/python/lsst/daf/butler/datastores/inMemoryDatastore.py
@@ -39,6 +39,7 @@
from urllib.parse import urlencode
from lsst.daf.butler import DatasetId, DatasetRef, StorageClass
+from lsst.daf.butler._exceptions import DatasetTypeNotSupportedError
from lsst.daf.butler.datastore import DatasetRefURIs, DatastoreConfig
from lsst.daf.butler.datastore.generic_base import GenericBaseDatastore, post_process_get
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
@@ -397,11 +398,21 @@ def put(self, inMemoryDataset: Any, ref: DatasetRef) -> None:
allow `ChainedDatastore` to put to multiple datastores without
requiring that every datastore accepts the dataset.
"""
+ if not self.constraints.isAcceptable(ref):
+ # Raise rather than use boolean return value.
+ raise DatasetTypeNotSupportedError(
+ f"Dataset {ref} has been rejected by this datastore via configuration."
+ )
+
# May need to coerce the in memory dataset to the correct
# python type, otherwise parameters may not work.
- inMemoryDataset = ref.datasetType.storageClass.coerce_type(inMemoryDataset)
-
- self._validate_put_parameters(inMemoryDataset, ref)
+ try:
+ delegate = ref.datasetType.storageClass.delegate()
+ except TypeError:
+ # TypeError is raised when a storage class doesn't have a delegate.
+ delegate = None
+ if not delegate or not delegate.can_accept(inMemoryDataset):
+ inMemoryDataset = ref.datasetType.storageClass.coerce_type(inMemoryDataset)
self.datasets[ref.id] = inMemoryDataset
log.debug("Store %s in %s", ref, self.name)
diff --git a/python/lsst/daf/butler/delegates/arrowastropy.py b/python/lsst/daf/butler/delegates/arrowastropy.py
deleted file mode 100644
index de9de17e93..0000000000
--- a/python/lsst/daf/butler/delegates/arrowastropy.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# This file is part of daf_butler.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This software is dual licensed under the GNU General Public License and also
-# under a 3-clause BSD license. Recipients may choose which of these licenses
-# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
-# respectively. If you choose the GPL option then the following text applies
-# (but note that there is still no warranty even if you opt for BSD instead):
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-"""Support for reading Astropy tables with the Arrow formatter."""
-from __future__ import annotations
-
-from typing import Any
-
-import astropy.table as atable
-from lsst.daf.butler.formatters.parquet import ArrowAstropySchema
-from lsst.utils.introspection import get_full_type_name
-
-from .arrowtable import ArrowTableDelegate
-
-__all__ = ["ArrowAstropyDelegate"]
-
-
-class ArrowAstropyDelegate(ArrowTableDelegate):
- """Delegate that understands the ``ArrowAstropy`` storage class."""
-
- _datasetType = atable.Table
-
- def getComponent(self, composite: atable.Table, componentName: str) -> Any:
- """Get a component from an astropy table stored via ArrowAstropy.
-
- Parameters
- ----------
- composite : `~astropy.table.Table`
- Astropy table to access component.
- componentName : `str`
- Name of component to retrieve.
-
- Returns
- -------
- component : `object`
- The component.
-
- Raises
- ------
- AttributeError
- The component can not be found.
- """
- match componentName:
- case "columns":
- return list(composite.columns.keys())
- case "schema":
- return ArrowAstropySchema(composite)
- case "rowcount":
- return len(composite)
-
- raise AttributeError(
- f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}"
- )
-
- def _getColumns(self, inMemoryDataset: atable.Table) -> list[str]:
- return inMemoryDataset.columns.keys()
-
- def _selectColumns(self, inMemoryDataset: atable.Table, columns: list[str]) -> atable.Table:
- return inMemoryDataset[columns]
diff --git a/python/lsst/daf/butler/delegates/arrownumpy.py b/python/lsst/daf/butler/delegates/arrownumpy.py
deleted file mode 100644
index 4fcd34b604..0000000000
--- a/python/lsst/daf/butler/delegates/arrownumpy.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# This file is part of daf_butler.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This software is dual licensed under the GNU General Public License and also
-# under a 3-clause BSD license. Recipients may choose which of these licenses
-# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
-# respectively. If you choose the GPL option then the following text applies
-# (but note that there is still no warranty even if you opt for BSD instead):
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-"""Support for reading numpy tables (structured arrays) with the Arrow
-formatter.
-"""
-from __future__ import annotations
-
-from typing import Any
-
-import numpy as np
-from lsst.daf.butler.formatters.parquet import ArrowNumpySchema
-from lsst.utils.introspection import get_full_type_name
-
-from .arrowtable import ArrowTableDelegate
-
-__all__ = ["ArrowNumpyDelegate"]
-
-
-class ArrowNumpyDelegate(ArrowTableDelegate):
- """Delegate that understands the ``ArrowNumpy`` storage class."""
-
- _datasetType = np.ndarray
-
- def getComponent(self, composite: np.ndarray, componentName: str) -> Any:
- """Get a component from a numpy table stored via ArrowNumpy.
-
- Parameters
- ----------
- composite : `~numpy.ndarray`
- Numpy table to access component.
- componentName : `str`
- Name of component to retrieve.
-
- Returns
- -------
- component : `object`
- The component.
-
- Raises
- ------
- AttributeError
- The component can not be found.
- """
- match componentName:
- case "columns":
- return list(composite.dtype.names)
- case "schema":
- return ArrowNumpySchema(composite.dtype)
- case "rowcount":
- return len(composite)
-
- raise AttributeError(
- f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}"
- )
-
- def _getColumns(self, inMemoryDataset: np.ndarray) -> list[str]:
- return inMemoryDataset.dtype.names
-
- def _selectColumns(self, inMemoryDataset: np.ndarray, columns: list[str]) -> np.ndarray:
- return inMemoryDataset[columns]
diff --git a/python/lsst/daf/butler/delegates/arrownumpydict.py b/python/lsst/daf/butler/delegates/arrownumpydict.py
deleted file mode 100644
index b94d0f0ee3..0000000000
--- a/python/lsst/daf/butler/delegates/arrownumpydict.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# This file is part of daf_butler.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This software is dual licensed under the GNU General Public License and also
-# under a 3-clause BSD license. Recipients may choose which of these licenses
-# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
-# respectively. If you choose the GPL option then the following text applies
-# (but note that there is still no warranty even if you opt for BSD instead):
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-"""Support for reading dictionaries of numpy arrays with the Arrow
-formatter.
-"""
-from __future__ import annotations
-
-from typing import Any
-
-import numpy as np
-from lsst.daf.butler.formatters.parquet import ArrowNumpySchema, _numpy_dict_to_dtype
-from lsst.utils.introspection import get_full_type_name
-
-from .arrowtable import ArrowTableDelegate
-
-__all__ = ["ArrowNumpyDictDelegate"]
-
-
-class ArrowNumpyDictDelegate(ArrowTableDelegate):
- """Delegate that understands the ``ArrowNumpyDict`` storage class."""
-
- _datasetType = dict
-
- def getComponent(self, composite: dict[str, np.ndarray], componentName: str) -> Any:
- """Get a component from a dict of numpy arrays stored via
- ArrowNumpyDict.
-
- Parameters
- ----------
- composite : `~numpy.ndarray`
- Numpy table to access component.
- componentName : `str`
- Name of component to retrieve.
-
- Returns
- -------
- component : `object`
- The component.
-
- Raises
- ------
- AttributeError
- The component can not be found.
- """
- match componentName:
- case "columns":
- return list(composite.keys())
- case "schema":
- dtype, _ = _numpy_dict_to_dtype(composite)
- return ArrowNumpySchema(dtype)
- case "rowcount":
- return len(composite[list(composite.keys())[0]])
-
- raise AttributeError(
- f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}"
- )
-
- def _getColumns(self, inMemoryDataset: dict[str, np.ndarray]) -> list[str]:
- return list(inMemoryDataset.keys())
-
- def _selectColumns(
- self, inMemoryDataset: dict[str, np.ndarray], columns: list[str]
- ) -> dict[str, np.ndarray]:
- return {column: inMemoryDataset[column] for column in columns}
diff --git a/python/lsst/daf/butler/delegates/arrowtable.py b/python/lsst/daf/butler/delegates/arrowtable.py
index 88e552922b..1d4c728245 100644
--- a/python/lsst/daf/butler/delegates/arrowtable.py
+++ b/python/lsst/daf/butler/delegates/arrowtable.py
@@ -28,24 +28,29 @@
"""Support for reading Arrow tables."""
from __future__ import annotations
+__all__ = ["ArrowTableDelegate"]
+
from collections.abc import Mapping
-from typing import Any
+from typing import TYPE_CHECKING, Any
import pyarrow as pa
from lsst.daf.butler import StorageClassDelegate
from lsst.utils.introspection import get_full_type_name
from lsst.utils.iteration import ensure_iterable
-__all__ = ["ArrowTableDelegate"]
+if TYPE_CHECKING:
+ import pandas
class ArrowTableDelegate(StorageClassDelegate):
- """Delegate that understands the ``ArrowTable`` storage class."""
+ """Delegate that understands ArrowTable and related storage classes."""
- _datasetType = pa.Table
+ def can_accept(self, inMemoryDataset: Any) -> bool:
+ # Docstring inherited.
+ return _checkArrowCompatibleType(inMemoryDataset) is not None
- def getComponent(self, composite: pa.Table, componentName: str) -> Any:
- """Get a component from an Arrow table.
+ def getComponent(self, composite: Any, componentName: str) -> Any:
+ """Get a component from an Arrow table or equivalent.
Parameters
----------
@@ -64,72 +69,179 @@ def getComponent(self, composite: pa.Table, componentName: str) -> Any:
AttributeError
The component can not be found.
"""
- if componentName in ("columns", "schema"):
- # The schema will be translated to column format
- # depending on the input type.
- return composite.schema
+ typeString = _checkArrowCompatibleType(composite)
+
+ if typeString is None:
+ raise ValueError(f"Unsupported composite type {get_full_type_name(composite)}")
+
+ if componentName == "columns":
+ if typeString == "arrow":
+ return composite.schema
+ elif typeString == "astropy":
+ return list(composite.columns.keys())
+ elif typeString == "numpy":
+ return list(composite.dtype.names)
+ elif typeString == "numpydict":
+ return list(composite.keys())
+ elif typeString == "pandas":
+ import pandas
+
+ if isinstance(composite.columns, pandas.MultiIndex):
+ return composite.columns
+ else:
+ return pandas.Index(self._getAllDataframeColumns(composite))
+
+ elif componentName == "schema":
+ if typeString == "arrow":
+ return composite.schema
+ elif typeString == "astropy":
+ from lsst.daf.butler.formatters.parquet import ArrowAstropySchema
+
+ return ArrowAstropySchema(composite)
+ elif typeString == "numpy":
+ from lsst.daf.butler.formatters.parquet import ArrowNumpySchema
+
+ return ArrowNumpySchema(composite.dtype)
+ elif typeString == "numpydict":
+ from lsst.daf.butler.formatters.parquet import ArrowNumpySchema, _numpy_dict_to_dtype
+
+ dtype, _ = _numpy_dict_to_dtype(composite)
+ return ArrowNumpySchema(dtype)
+ elif typeString == "pandas":
+ from lsst.daf.butler.formatters.parquet import DataFrameSchema
+
+ return DataFrameSchema(composite.iloc[:0])
elif componentName == "rowcount":
- return len(composite[composite.schema.names[0]])
+ if typeString == "arrow":
+ return len(composite[composite.schema.names[0]])
+ elif typeString in ["astropy", "numpy", "pandas"]:
+ return len(composite)
+ elif typeString == "numpydict":
+ return len(composite[list(composite.keys())[0]])
raise AttributeError(
f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}"
)
def handleParameters(self, inMemoryDataset: Any, parameters: Mapping[str, Any] | None = None) -> Any:
- if not isinstance(inMemoryDataset, self._datasetType):
- raise ValueError(
- f"inMemoryDataset must be a {get_full_type_name(self._datasetType)} and "
- f"not {get_full_type_name(inMemoryDataset)}."
- )
+ typeString = _checkArrowCompatibleType(inMemoryDataset)
+
+ if typeString is None:
+ raise ValueError(f"Unsupported inMemoryDataset type {get_full_type_name(inMemoryDataset)}")
if parameters is None:
return inMemoryDataset
if "columns" in parameters:
- read_columns = list(ensure_iterable(parameters["columns"]))
- for column in read_columns:
- if not isinstance(column, str):
- raise NotImplementedError(
- "InMemoryDataset of an Arrow Table only supports string column names."
- )
- if column not in self._getColumns(inMemoryDataset):
- raise ValueError(f"Unrecognized column name {column!r}.")
+ readColumns = list(ensure_iterable(parameters["columns"]))
+
+ if typeString == "arrow":
+ allColumns = inMemoryDataset.schema.names
+ elif typeString == "astropy":
+ allColumns = inMemoryDataset.columns.keys()
+ elif typeString == "numpy":
+ allColumns = inMemoryDataset.dtype.names
+ elif typeString == "numpydict":
+ allColumns = list(inMemoryDataset.keys())
+ elif typeString == "pandas":
+ import pandas
+
+ allColumns = self._getAllDataframeColumns(inMemoryDataset)
+
+ if typeString == "pandas" and isinstance(inMemoryDataset.columns, pandas.MultiIndex):
+ from ..formatters.parquet import _standardize_multi_index_columns
+
+ # We have a multi-index dataframe which needs special
+ # handling.
+ readColumns = _standardize_multi_index_columns(
+ inMemoryDataset.columns,
+ parameters["columns"],
+ stringify=False,
+ )
+ else:
+ readColumns = list(ensure_iterable(parameters["columns"]))
+
+ for column in readColumns:
+ if not isinstance(column, str):
+ raise NotImplementedError(
+ f"InMemoryDataset of a {get_full_type_name(inMemoryDataset)} only "
+ "supports string column names."
+ )
+ if column not in allColumns:
+ raise ValueError(f"Unrecognized column name {column!r}.")
+
+ if typeString == "pandas":
+ # Exclude index columns from the subset.
+ readColumns = [
+ name
+ for name in ensure_iterable(parameters["columns"])
+ if name not in inMemoryDataset.index.names
+ ]
# Ensure uniqueness, keeping order.
- read_columns = list(dict.fromkeys(read_columns))
-
- return self._selectColumns(inMemoryDataset, read_columns)
+ readColumns = list(dict.fromkeys(readColumns))
+
+ if typeString == "arrow":
+ return inMemoryDataset.select(readColumns)
+ elif typeString in ("astropy", "numpy", "pandas"):
+ return inMemoryDataset[readColumns]
+ elif typeString == "numpydict":
+ return {column: inMemoryDataset[column] for column in readColumns}
else:
return inMemoryDataset
- def _getColumns(self, inMemoryDataset: pa.Table) -> list[str]:
- """Get the column names from the inMemoryDataset.
-
- Parameters
- ----------
- inMemoryDataset : `object`
- Dataset to extract columns.
+ def _getAllDataframeColumns(self, dataset: pandas.DataFrame) -> list[str]:
+ """Get all columns, including index columns.
Returns
-------
columns : `list` [`str`]
- List of columns.
- """
- return inMemoryDataset.schema.names
-
- def _selectColumns(self, inMemoryDataset: pa.Table, columns: list[str]) -> pa.Table:
- """Select a subset of columns from the inMemoryDataset.
-
- Parameters
- ----------
- inMemoryDataset : `object`
- Dataset to extract columns.
- columns : `list` [`str`]
- List of columns to extract.
-
- Returns
- -------
- subDataset : `object`
- Subselection of inMemoryDataset.
+ List of all columns.
"""
- return inMemoryDataset.select(columns)
+ allColumns = list(dataset.columns)
+ if dataset.index.names[0] is not None:
+ allColumns.extend(dataset.index.names)
+
+ return allColumns
+
+
+def _checkArrowCompatibleType(dataset: Any) -> str | None:
+ """Check a dataset for arrow compatiblity and return type string.
+
+ Parameters
+ ----------
+ dataset : `object`
+ Dataset object.
+
+ Returns
+ -------
+ typeString : `str`
+ Type string will be ``arrow`` or ``astropy`` or ``numpy`` or ``pandas``
+ or "numpydict".
+ """
+ import numpy as np
+ from astropy.table import Table as astropyTable
+
+ if isinstance(dataset, pa.Table):
+ return "arrow"
+ elif isinstance(dataset, astropyTable):
+ return "astropy"
+ elif isinstance(dataset, np.ndarray):
+ return "numpy"
+ elif isinstance(dataset, dict):
+ for key, item in dataset.items():
+ if not isinstance(item, np.ndarray):
+ # This is some other sort of dictionary.
+ return None
+ return "numpydict"
+ elif hasattr(dataset, "to_parquet"):
+ # This may be a pandas DataFrame
+ try:
+ import pandas
+ except ImportError:
+ pandas = None
+
+ if pandas is not None and isinstance(dataset, pandas.DataFrame):
+ return "pandas"
+
+ return None
diff --git a/python/lsst/daf/butler/delegates/dataframe.py b/python/lsst/daf/butler/delegates/dataframe.py
deleted file mode 100644
index 342cb3408e..0000000000
--- a/python/lsst/daf/butler/delegates/dataframe.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# This file is part of daf_butler.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This software is dual licensed under the GNU General Public License and also
-# under a 3-clause BSD license. Recipients may choose which of these licenses
-# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
-# respectively. If you choose the GPL option then the following text applies
-# (but note that there is still no warranty even if you opt for BSD instead):
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-"""Support for reading DataFrames."""
-from __future__ import annotations
-
-import collections.abc
-from collections.abc import Mapping
-from typing import Any
-
-import pandas
-from lsst.daf.butler import StorageClassDelegate
-from lsst.daf.butler.formatters.parquet import DataFrameSchema
-from lsst.utils.introspection import get_full_type_name
-from lsst.utils.iteration import ensure_iterable
-
-from ..formatters.parquet import _standardize_multi_index_columns
-
-__all__ = ["DataFrameDelegate"]
-
-
-class DataFrameDelegate(StorageClassDelegate):
- """Delegate that understands the ``DataFrame`` storage class."""
-
- def getComponent(self, composite: pandas.DataFrame, componentName: str) -> Any:
- """Get a component from a DataFrame.
-
- Parameters
- ----------
- composite : `~pandas.DataFrame`
- ``DataFrame`` to access component.
- componentName : `str`
- Name of component to retrieve.
-
- Returns
- -------
- component : `object`
- The component.
-
- Raises
- ------
- AttributeError
- The component can not be found.
- """
- if componentName == "columns":
- if isinstance(composite.columns, pandas.MultiIndex):
- return composite.columns
- else:
- return pandas.Index(self._getAllColumns(composite))
- elif componentName == "rowcount":
- return len(composite)
- elif componentName == "schema":
- return DataFrameSchema(composite.iloc[:0])
- else:
- raise AttributeError(
- f"Do not know how to retrieve component {componentName} from {get_full_type_name(composite)}"
- )
-
- def handleParameters(
- self, inMemoryDataset: pandas.DataFrame, parameters: Mapping[str, Any] | None = None
- ) -> Any:
- """Return possibly new in-memory dataset using the supplied parameters.
-
- Parameters
- ----------
- inMemoryDataset : `object`
- Object to modify based on the parameters.
- parameters : `dict`, optional
- Parameters to apply. Values are specific to the parameter.
- Supported parameters are defined in the associated
- `StorageClass`. If no relevant parameters are specified the
- ``inMemoryDataset`` will be return unchanged.
-
- Returns
- -------
- inMemoryDataset : `object`
- Original in-memory dataset, or updated form after parameters
- have been used.
- """
- if not isinstance(inMemoryDataset, pandas.DataFrame):
- raise ValueError(
- "handleParameters for a DataFrame must get a DataFrame, "
- f"not {get_full_type_name(inMemoryDataset)}."
- )
-
- if parameters is None:
- return inMemoryDataset
-
- if "columns" in parameters:
- allColumns = self._getAllColumns(inMemoryDataset)
-
- if not isinstance(parameters["columns"], collections.abc.Iterable):
- raise NotImplementedError(
- "InMemoryDataset of a DataFrame only supports list/tuple of string column names"
- )
-
- if isinstance(inMemoryDataset.columns, pandas.MultiIndex):
- # We have a multi-index dataframe which needs special handling.
- readColumns = _standardize_multi_index_columns(
- inMemoryDataset.columns,
- parameters["columns"],
- stringify=False,
- )
- else:
- for column in ensure_iterable(parameters["columns"]):
- if not isinstance(column, str):
- raise NotImplementedError(
- "InMemoryDataset of a DataFrame only supports string column names."
- )
- if column not in allColumns:
- raise ValueError(f"Unrecognized column name {column!r}.")
-
- # Exclude index columns from the subset.
- readColumns = [
- name
- for name in ensure_iterable(parameters["columns"])
- if name not in inMemoryDataset.index.names
- ]
-
- # Ensure uniqueness, keeping order.
- readColumns = list(dict.fromkeys(readColumns))
-
- return inMemoryDataset[readColumns]
- else:
- return inMemoryDataset
-
- def _getAllColumns(self, inMemoryDataset: pandas.DataFrame) -> list[str]:
- """Get all columns, including index columns.
-
- Returns
- -------
- columns : `list` [`str`]
- List of all columns.
- """
- allColumns = list(inMemoryDataset.columns)
- if inMemoryDataset.index.names[0] is not None:
- allColumns.extend(inMemoryDataset.index.names)
-
- return allColumns
diff --git a/python/lsst/daf/butler/formatters/parquet.py b/python/lsst/daf/butler/formatters/parquet.py
index 97e800d139..5c9c952bf4 100644
--- a/python/lsst/daf/butler/formatters/parquet.py
+++ b/python/lsst/daf/butler/formatters/parquet.py
@@ -57,6 +57,7 @@
import pyarrow as pa
import pyarrow.parquet as pq
from lsst.daf.butler import FormatterV2
+from lsst.daf.butler.delegates.arrowtable import _checkArrowCompatibleType
from lsst.resources import ResourcePath
from lsst.utils.introspection import get_full_type_name
from lsst.utils.iteration import ensure_iterable
@@ -77,6 +78,10 @@ class ParquetFormatter(FormatterV2):
default_extension = ".parq"
can_read_from_local_file = True
+ def can_accept(self, in_memory_dataset: Any) -> bool:
+ # Docstring inherited.
+ return _checkArrowCompatibleType(in_memory_dataset) is not None
+
def read_from_local_file(self, path: str, component: str | None = None, expected_size: int = -1) -> Any:
# Docstring inherited from Formatter.read.
schema = pq.read_schema(path)
@@ -143,44 +148,30 @@ def read_from_local_file(self, path: str, component: str | None = None, expected
return arrow_table
def write_local_file(self, in_memory_dataset: Any, uri: ResourcePath) -> None:
- import numpy as np
- from astropy.table import Table as astropyTable
- arrow_table = None
- if isinstance(in_memory_dataset, pa.Table):
- # This will be the most likely match.
- arrow_table = in_memory_dataset
- elif isinstance(in_memory_dataset, astropyTable):
- arrow_table = astropy_to_arrow(in_memory_dataset)
- elif isinstance(in_memory_dataset, np.ndarray):
- arrow_table = numpy_to_arrow(in_memory_dataset)
- elif isinstance(in_memory_dataset, dict):
- try:
- arrow_table = numpy_dict_to_arrow(in_memory_dataset)
- except (TypeError, AttributeError) as e:
- raise ValueError(
- "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
- ) from e
- elif isinstance(in_memory_dataset, pa.Schema):
+ if isinstance(in_memory_dataset, pa.Schema):
pq.write_metadata(in_memory_dataset, uri.ospath)
return
- else:
- if hasattr(in_memory_dataset, "to_parquet"):
- # This may be a pandas DataFrame
- try:
- import pandas as pd
- except ImportError:
- pd = None
- if pd is not None and isinstance(in_memory_dataset, pd.DataFrame):
- arrow_table = pandas_to_arrow(in_memory_dataset)
+ type_string = _checkArrowCompatibleType(in_memory_dataset)
- if arrow_table is None:
+ if type_string is None:
raise ValueError(
f"Unsupported type {get_full_type_name(in_memory_dataset)} of "
"inMemoryDataset for ParquetFormatter."
)
+ if type_string == "arrow":
+ arrow_table = in_memory_dataset
+ elif type_string == "astropy":
+ arrow_table = astropy_to_arrow(in_memory_dataset)
+ elif type_string == "numpy":
+ arrow_table = numpy_to_arrow(in_memory_dataset)
+ elif type_string == "numpydict":
+ arrow_table = numpy_dict_to_arrow(in_memory_dataset)
+ else:
+ arrow_table = pandas_to_arrow(in_memory_dataset)
+
row_group_size = compute_row_group_size(arrow_table.schema)
pq.write_table(arrow_table, uri.ospath, row_group_size=row_group_size)
diff --git a/tests/test_parquet.py b/tests/test_parquet.py
index 351b90f94f..97485d46e0 100644
--- a/tests/test_parquet.py
+++ b/tests/test_parquet.py
@@ -62,24 +62,11 @@
StorageClassFactory,
)
-try:
- from lsst.daf.butler.delegates.arrowastropy import ArrowAstropyDelegate
-except ImportError:
- atable = None
- pa = None
-try:
- from lsst.daf.butler.delegates.arrownumpy import ArrowNumpyDelegate
-except ImportError:
- np = None
- pa = None
try:
from lsst.daf.butler.delegates.arrowtable import ArrowTableDelegate
except ImportError:
pa = None
-try:
- from lsst.daf.butler.delegates.dataframe import DataFrameDelegate
-except ImportError:
- pd = None
+
try:
from lsst.daf.butler.formatters.parquet import (
ArrowAstropySchema,
@@ -751,10 +738,55 @@ def testBadDataFrameColumnParquet(self):
with self.assertRaises(RuntimeError):
self.butler.put(bad_df, self.datasetType, dataId={})
+ @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.")
+ def testWriteReadAstropyTableLossless(self):
+ tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True)
+
+ self.butler.put(tab1, self.datasetType, dataId={})
+
+ tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
+
+ _checkAstropyTableEquality(tab1, tab2)
+
+ @unittest.skipUnless(np is not None, "Cannot test reading as numpy without numpy.")
+ def testWriteReadNumpyTableLossless(self):
+ tab1 = _makeSimpleNumpyTable(include_multidim=True)
+
+ self.butler.put(tab1, self.datasetType, dataId={})
+
+ tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpy")
+
+ _checkNumpyTableEquality(tab1, tab2)
+
+ @unittest.skipUnless(pa is not None, "Cannot test reading as arrow without pyarrow.")
+ def testWriteReadArrowTableLossless(self):
+ tab1 = _makeSimpleArrowTable(include_multidim=False, include_masked=True)
+
+ self.butler.put(tab1, self.datasetType, dataId={})
+
+ tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowTable")
+
+ self.assertEqual(tab1.schema, tab2.schema)
+ tab1_np = arrow_to_numpy(tab1)
+ tab2_np = arrow_to_numpy(tab2)
+ for col in tab1.column_names:
+ np.testing.assert_array_equal(tab2_np[col], tab1_np[col])
+
+ @unittest.skipUnless(np is not None, "Cannot test reading as numpy dict without numpy.")
+ def testWriteReadNumpyDictLossless(self):
+ tab1 = _makeSimpleNumpyTable(include_multidim=True)
+ dict1 = _numpy_to_numpy_dict(tab1)
+
+ self.butler.put(tab1, self.datasetType, dataId={})
+
+ dict2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpyDict")
-@unittest.skipUnless(pd is not None, "Cannot test InMemoryDataFrameDelegate without pandas.")
+ _checkNumpyDictEquality(dict1, dict2)
+
+
+@unittest.skipUnless(pd is not None, "Cannot test InMemoryDatastore with DataFrames without pandas.")
class InMemoryDataFrameDelegateTestCase(ParquetFormatterDataFrameTestCase):
- """Tests for InMemoryDatastore, using DataFrameDelegate."""
+ """Tests for InMemoryDatastore, using ArrowTableDelegate with Dataframe."""
configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml")
@@ -776,7 +808,7 @@ def testLegacyDataFrame(self):
def testBadInput(self):
df1, _ = _makeSingleIndexDataFrame()
- delegate = DataFrameDelegate("DataFrame")
+ delegate = ArrowTableDelegate("DataFrame")
with self.assertRaises(ValueError):
delegate.handleParameters(inMemoryDataset="not_a_dataframe")
@@ -832,7 +864,7 @@ def testAstropyTable(self):
self.butler.put(tab1, self.datasetType, dataId={})
# Read the whole Table.
tab2 = self.butler.get(self.datasetType, dataId={})
- self._checkAstropyTableEquality(tab1, tab2)
+ _checkAstropyTableEquality(tab1, tab2)
# Read the columns.
columns2 = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={})
self.assertEqual(len(columns2), len(tab1.dtype.names))
@@ -846,15 +878,15 @@ def testAstropyTable(self):
self.assertEqual(schema, ArrowAstropySchema(tab1))
# Read just some columns a few different ways.
tab3 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "c"]})
- self._checkAstropyTableEquality(tab1[("a", "c")], tab3)
+ _checkAstropyTableEquality(tab1[("a", "c")], tab3)
tab4 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "a"})
- self._checkAstropyTableEquality(tab1[("a",)], tab4)
+ _checkAstropyTableEquality(tab1[("a",)], tab4)
tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["index", "a"]})
- self._checkAstropyTableEquality(tab1[("index", "a")], tab5)
+ _checkAstropyTableEquality(tab1[("index", "a")], tab5)
tab6 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "ddd"})
- self._checkAstropyTableEquality(tab1[("ddd",)], tab6)
+ _checkAstropyTableEquality(tab1[("ddd",)], tab6)
tab7 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "a"]})
- self._checkAstropyTableEquality(tab1[("a",)], tab7)
+ _checkAstropyTableEquality(tab1[("a",)], tab7)
# Passing an unrecognized column should be a ValueError.
with self.assertRaises(ValueError):
self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["e"]})
@@ -865,7 +897,7 @@ def testAstropyTableBigEndian(self):
self.butler.put(tab1, self.datasetType, dataId={})
# Read the whole Table.
tab2 = self.butler.get(self.datasetType, dataId={})
- self._checkAstropyTableEquality(tab1, tab2, has_bigendian=True)
+ _checkAstropyTableEquality(tab1, tab2, has_bigendian=True)
def testAstropyTableWithMetadata(self):
tab1 = _makeSimpleAstropyTable(include_multidim=True)
@@ -884,7 +916,7 @@ def testAstropyTableWithMetadata(self):
# Read the whole Table.
tab2 = self.butler.get(self.datasetType, dataId={})
# This will check that the metadata is equivalent as well.
- self._checkAstropyTableEquality(tab1, tab2)
+ _checkAstropyTableEquality(tab1, tab2)
def testArrowAstropySchema(self):
tab1 = _makeSimpleAstropyTable()
@@ -945,7 +977,7 @@ def testAstropyParquet(self):
tab2a = self.butler.get(self.datasetType, dataId={})
tab2b = self.butler.get("astropy_parquet", dataId={})
- self._checkAstropyTableEquality(tab2a, tab2b)
+ _checkAstropyTableEquality(tab2a, tab2b)
columns2a = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={})
columns2b = self.butler.get("astropy_parquet.columns", dataId={})
@@ -971,7 +1003,7 @@ def testWriteAstropyReadAsArrowTable(self):
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowTable")
tab2_astropy = arrow_to_astropy(tab2)
- self._checkAstropyTableEquality(tab1, tab2_astropy)
+ _checkAstropyTableEquality(tab1, tab2_astropy)
# Check reading the columns.
columns = tab2.schema.names
@@ -1056,7 +1088,7 @@ def testWriteSingleIndexDataFrameWithMaskedColsReadAsAstropyTable(self):
df1_tab = pandas_to_astropy(df1)
- self._checkAstropyTableEquality(df1_tab, tab2)
+ _checkAstropyTableEquality(df1_tab, tab2)
@unittest.skipUnless(np is not None, "Cannot test reading as numpy without numpy.")
def testWriteAstropyReadAsNumpyTable(self):
@@ -1068,7 +1100,7 @@ def testWriteAstropyReadAsNumpyTable(self):
# This is tricky because it loses the units.
tab2_astropy = atable.Table(tab2)
- self._checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True)
+ _checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True)
# Check reading the columns.
columns = list(tab2.dtype.names)
@@ -1095,7 +1127,7 @@ def testWriteAstropyReadAsNumpyDict(self):
# This is tricky because it loses the units.
tab2_astropy = atable.Table(tab2)
- self._checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True)
+ _checkAstropyTableEquality(tab1, tab2_astropy, skip_units=True)
def testBadAstropyColumnParquet(self):
tab1 = _makeSimpleAstropyTable()
@@ -1120,52 +1152,12 @@ def testBadAstropyColumnParquet(self):
with self.assertRaises(RuntimeError):
self.butler.put(bad_tab, self.datasetType, dataId={})
- def _checkAstropyTableEquality(self, table1, table2, skip_units=False, has_bigendian=False):
- """Check if two astropy tables have the same columns/values.
- Parameters
- ----------
- table1 : `astropy.table.Table`
- table2 : `astropy.table.Table`
- skip_units : `bool`
- has_bigendian : `bool`
- """
- if not has_bigendian:
- self.assertEqual(table1.dtype, table2.dtype)
- else:
- for name in table1.dtype.names:
- # Only check type matches, force to little-endian.
- self.assertEqual(table1.dtype[name].newbyteorder(">"), table2.dtype[name].newbyteorder(">"))
-
- self.assertEqual(table1.meta, table2.meta)
- if not skip_units:
- for name in table1.columns:
- self.assertEqual(table1[name].unit, table2[name].unit)
- self.assertEqual(table1[name].description, table2[name].description)
- self.assertEqual(table1[name].format, table2[name].format)
- # We need to check masked/regular columns after filling.
- has_masked = False
- if isinstance(table1[name], atable.column.MaskedColumn):
- c1 = table1[name].filled()
- has_masked = True
- else:
- c1 = np.array(table1[name])
- if has_masked:
- self.assertIsInstance(table2[name], atable.column.MaskedColumn)
- c2 = table2[name].filled()
- else:
- self.assertFalse(isinstance(table2[name], atable.column.MaskedColumn))
- c2 = np.array(table2[name])
- np.testing.assert_array_equal(c1, c2)
- # If we have a masked column then we test the underlying data.
- if has_masked:
- np.testing.assert_array_equal(np.array(c1), np.array(c2))
- np.testing.assert_array_equal(table1[name].mask, table2[name].mask)
-
-
-@unittest.skipUnless(atable is not None, "Cannot test InMemoryArrowAstropyDelegate without astropy.")
+@unittest.skipUnless(atable is not None, "Cannot test InMemoryDatastore with AstropyTable without astropy.")
class InMemoryArrowAstropyDelegateTestCase(ParquetFormatterArrowAstropyTestCase):
- """Tests for InMemoryDatastore, using ArrowAstropyDelegate."""
+ """Tests for InMemoryDatastore, using ArrowTableDelegate with
+ AstropyTable.
+ """
configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml")
@@ -1179,7 +1171,7 @@ def testBadAstropyColumnParquet(self):
def testBadInput(self):
tab1 = _makeSimpleAstropyTable()
- delegate = ArrowAstropyDelegate("ArrowAstropy")
+ delegate = ArrowTableDelegate("ArrowAstropy")
with self.assertRaises(ValueError):
delegate.handleParameters(inMemoryDataset="not_an_astropy_table")
@@ -1221,7 +1213,7 @@ def testNumpyTable(self):
self.butler.put(tab1, self.datasetType, dataId={})
# Read the whole Table.
tab2 = self.butler.get(self.datasetType, dataId={})
- self._checkNumpyTableEquality(tab1, tab2)
+ _checkNumpyTableEquality(tab1, tab2)
# Read the columns.
columns2 = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={})
self.assertEqual(len(columns2), len(tab1.dtype.names))
@@ -1235,9 +1227,9 @@ def testNumpyTable(self):
self.assertEqual(schema, ArrowNumpySchema(tab1.dtype))
# Read just some columns a few different ways.
tab3 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "c"]})
- self._checkNumpyTableEquality(tab1[["a", "c"]], tab3)
+ _checkNumpyTableEquality(tab1[["a", "c"]], tab3)
tab4 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "a"})
- self._checkNumpyTableEquality(
+ _checkNumpyTableEquality(
tab1[
[
"a",
@@ -1246,9 +1238,9 @@ def testNumpyTable(self):
tab4,
)
tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["index", "a"]})
- self._checkNumpyTableEquality(tab1[["index", "a"]], tab5)
+ _checkNumpyTableEquality(tab1[["index", "a"]], tab5)
tab6 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "ddd"})
- self._checkNumpyTableEquality(
+ _checkNumpyTableEquality(
tab1[
[
"ddd",
@@ -1257,7 +1249,7 @@ def testNumpyTable(self):
tab6,
)
tab7 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "a"]})
- self._checkNumpyTableEquality(
+ _checkNumpyTableEquality(
tab1[
[
"a",
@@ -1275,7 +1267,7 @@ def testNumpyTableBigEndian(self):
self.butler.put(tab1, self.datasetType, dataId={})
# Read the whole Table.
tab2 = self.butler.get(self.datasetType, dataId={})
- self._checkNumpyTableEquality(tab1, tab2, has_bigendian=True)
+ _checkNumpyTableEquality(tab1, tab2, has_bigendian=True)
def testArrowNumpySchema(self):
tab1 = _makeSimpleNumpyTable(include_multidim=True)
@@ -1317,7 +1309,7 @@ def testWriteNumpyTableReadAsArrowTable(self):
tab2_numpy = arrow_to_numpy(tab2)
- self._checkNumpyTableEquality(tab1, tab2_numpy)
+ _checkNumpyTableEquality(tab1, tab2_numpy)
# Check reading the columns.
columns = tab2.schema.names
@@ -1372,7 +1364,7 @@ def testWriteNumpyTableReadAsAstropyTable(self):
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
tab2_numpy = tab2.as_array()
- self._checkNumpyTableEquality(tab1, tab2_numpy)
+ _checkNumpyTableEquality(tab1, tab2_numpy)
# Check reading the columns.
columns = list(tab2.columns.keys())
@@ -1397,7 +1389,7 @@ def testWriteNumpyTableReadAsNumpyDict(self):
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpyDict")
tab2_numpy = _numpy_dict_to_numpy(tab2)
- self._checkNumpyTableEquality(tab1, tab2_numpy)
+ _checkNumpyTableEquality(tab1, tab2_numpy)
def testBadNumpyColumnParquet(self):
tab1 = _makeSimpleAstropyTable()
@@ -1426,28 +1418,22 @@ def testBadNumpyColumnParquet(self):
with self.assertRaises(RuntimeError):
self.butler.put(bad_tab_np, self.datasetType, dataId={})
- def _checkNumpyTableEquality(self, table1, table2, has_bigendian=False):
- """Check if two numpy tables have the same columns/values
+ @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.")
+ def testWriteReadAstropyTableLossless(self):
+ tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True)
- Parameters
- ----------
- table1 : `numpy.ndarray`
- table2 : `numpy.ndarray`
- has_bigendian : `bool`
- """
- self.assertEqual(table1.dtype.names, table2.dtype.names)
- for name in table1.dtype.names:
- if not has_bigendian:
- self.assertEqual(table1.dtype[name], table2.dtype[name])
- else:
- # Only check type matches, force to little-endian.
- self.assertEqual(table1.dtype[name].newbyteorder(">"), table2.dtype[name].newbyteorder(">"))
- self.assertTrue(np.all(table1 == table2))
+ self.butler.put(tab1, self.datasetType, dataId={})
+ tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
-@unittest.skipUnless(np is not None, "Cannot test ParquetFormatterArrowNumpy without numpy.")
+ _checkAstropyTableEquality(tab1, tab2)
+
+
+@unittest.skipUnless(np is not None, "Cannot test ImMemoryDatastore with Numpy table without numpy.")
class InMemoryArrowNumpyDelegateTestCase(ParquetFormatterArrowNumpyTestCase):
- """Tests for InMemoryDatastore, using ArrowNumpyDelegate."""
+ """Tests for InMemoryDatastore, using ArrowTableDelegate with
+ Numpy table.
+ """
configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml")
@@ -1457,7 +1443,7 @@ def testBadNumpyColumnParquet(self):
def testBadInput(self):
tab1 = _makeSimpleNumpyTable()
- delegate = ArrowNumpyDelegate("ArrowNumpy")
+ delegate = ArrowTableDelegate("ArrowNumpy")
with self.assertRaises(ValueError):
delegate.handleParameters(inMemoryDataset="not_a_numpy_table")
@@ -1674,12 +1660,12 @@ def testWriteArrowTableReadAsAstropyTable(self):
# Read back out as an astropy table.
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
- self._checkAstropyTableEquality(tab1, tab2)
+ _checkAstropyTableEquality(tab1, tab2)
# Read back out as an arrow table, convert to astropy table.
atab3 = self.butler.get(self.datasetType, dataId={})
tab3 = arrow_to_astropy(atab3)
- self._checkAstropyTableEquality(tab1, tab3)
+ _checkAstropyTableEquality(tab1, tab3)
# Check reading the columns.
columns = list(tab2.columns.keys())
@@ -1717,12 +1703,12 @@ def testWriteArrowTableReadAsNumpyTable(self):
# Read back out as a numpy table.
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpy")
- self._checkNumpyTableEquality(tab1, tab2)
+ _checkNumpyTableEquality(tab1, tab2)
# Read back out as an arrow table, convert to numpy table.
atab3 = self.butler.get(self.datasetType, dataId={})
tab3 = arrow_to_numpy(atab3)
- self._checkNumpyTableEquality(tab1, tab3)
+ _checkNumpyTableEquality(tab1, tab3)
# Check reading the columns.
columns = list(tab2.dtype.names)
@@ -1746,55 +1732,20 @@ def testWriteArrowTableReadAsNumpyDict(self):
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpyDict")
tab2_numpy = _numpy_dict_to_numpy(tab2)
- self._checkNumpyTableEquality(tab1, tab2_numpy)
+ _checkNumpyTableEquality(tab1, tab2_numpy)
- def _checkAstropyTableEquality(self, table1, table2):
- """Check if two astropy tables have the same columns/values
+ @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.")
+ def testWriteReadAstropyTableLossless(self):
+ tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True)
- Parameters
- ----------
- table1 : `astropy.table.Table`
- table2 : `astropy.table.Table`
- """
- self.assertEqual(table1.dtype, table2.dtype)
- for name in table1.columns:
- self.assertEqual(table1[name].unit, table2[name].unit)
- self.assertEqual(table1[name].description, table2[name].description)
- self.assertEqual(table1[name].format, table2[name].format)
- # We need to check masked/regular columns after filling.
- has_masked = False
- if isinstance(table1[name], atable.column.MaskedColumn):
- c1 = table1[name].filled()
- has_masked = True
- else:
- c1 = np.array(table1[name])
- if has_masked:
- self.assertIsInstance(table2[name], atable.column.MaskedColumn)
- c2 = table2[name].filled()
- else:
- self.assertFalse(isinstance(table2[name], atable.column.MaskedColumn))
- c2 = np.array(table2[name])
- np.testing.assert_array_equal(c1, c2)
- # If we have a masked column then we test the underlying data.
- if has_masked:
- np.testing.assert_array_equal(np.array(c1), np.array(c2))
- np.testing.assert_array_equal(table1[name].mask, table2[name].mask)
-
- def _checkNumpyTableEquality(self, table1, table2):
- """Check if two numpy tables have the same columns/values
-
- Parameters
- ----------
- table1 : `numpy.ndarray`
- table2 : `numpy.ndarray`
- """
- self.assertEqual(table1.dtype.names, table2.dtype.names)
- for name in table1.dtype.names:
- self.assertEqual(table1.dtype[name], table2.dtype[name])
- self.assertTrue(np.all(table1 == table2))
+ self.butler.put(tab1, self.datasetType, dataId={})
+
+ tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
+ _checkAstropyTableEquality(tab1, tab2)
-@unittest.skipUnless(pa is not None, "Cannot test InMemoryArrowTableDelegate without pyarrow.")
+
+@unittest.skipUnless(pa is not None, "Cannot test InMemoryDatastore with ArroWTable without pyarrow.")
class InMemoryArrowTableDelegateTestCase(ParquetFormatterArrowTableTestCase):
"""Tests for InMemoryDatastore, using ArrowTableDelegate."""
@@ -1861,7 +1812,7 @@ def testNumpyDict(self):
self.butler.put(dict1, self.datasetType, dataId={})
# Read the whole table.
dict2 = self.butler.get(self.datasetType, dataId={})
- self._checkNumpyDictEquality(dict1, dict2)
+ _checkNumpyDictEquality(dict1, dict2)
# Read the columns.
columns2 = self.butler.get(self.datasetType.componentTypeName("columns"), dataId={})
self.assertEqual(len(columns2), len(dict1.keys()))
@@ -1876,19 +1827,19 @@ def testNumpyDict(self):
# Read just some columns a few different ways.
tab3 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "c"]})
subdict = {key: dict1[key] for key in ["a", "c"]}
- self._checkNumpyDictEquality(subdict, tab3)
+ _checkNumpyDictEquality(subdict, tab3)
tab4 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "a"})
subdict = {key: dict1[key] for key in ["a"]}
- self._checkNumpyDictEquality(subdict, tab4)
+ _checkNumpyDictEquality(subdict, tab4)
tab5 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["index", "a"]})
subdict = {key: dict1[key] for key in ["index", "a"]}
- self._checkNumpyDictEquality(subdict, tab5)
+ _checkNumpyDictEquality(subdict, tab5)
tab6 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": "ddd"})
subdict = {key: dict1[key] for key in ["ddd"]}
- self._checkNumpyDictEquality(subdict, tab6)
+ _checkNumpyDictEquality(subdict, tab6)
tab7 = self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["a", "a"]})
subdict = {key: dict1[key] for key in ["a"]}
- self._checkNumpyDictEquality(subdict, tab7)
+ _checkNumpyDictEquality(subdict, tab7)
# Passing an unrecognized column should be a ValueError.
with self.assertRaises(ValueError):
self.butler.get(self.datasetType, dataId={}, parameters={"columns": ["e"]})
@@ -1904,7 +1855,7 @@ def testWriteNumpyDictReadAsArrowTable(self):
tab2_dict = arrow_to_numpy_dict(tab2)
- self._checkNumpyDictEquality(dict1, tab2_dict)
+ _checkNumpyDictEquality(dict1, tab2_dict)
@unittest.skipUnless(pd is not None, "Cannot test reading as a dataframe without pandas.")
def testWriteNumpyDictReadAsDataFrame(self):
@@ -1934,7 +1885,7 @@ def testWriteNumpyDictReadAsAstropyTable(self):
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
tab2_dict = _astropy_to_numpy_dict(tab2)
- self._checkNumpyDictEquality(dict1, tab2_dict)
+ _checkNumpyDictEquality(dict1, tab2_dict)
def testWriteNumpyDictReadAsNumpyTable(self):
tab1 = _makeSimpleNumpyTable(include_multidim=True)
@@ -1945,7 +1896,7 @@ def testWriteNumpyDictReadAsNumpyTable(self):
tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowNumpy")
tab2_dict = _numpy_to_numpy_dict(tab2)
- self._checkNumpyDictEquality(dict1, tab2_dict)
+ _checkNumpyDictEquality(dict1, tab2_dict)
def testWriteNumpyDictBad(self):
dict1 = {"a": 4, "b": np.ndarray([1])}
@@ -1964,24 +1915,23 @@ def testWriteNumpyDictBad(self):
with self.assertRaises(RuntimeError):
self.butler.put(dict4, self.datasetType, dataId={})
- def _checkNumpyDictEquality(self, dict1, dict2):
- """Check if two numpy dicts have the same columns/values.
+ @unittest.skipUnless(atable is not None, "Cannot test reading as astropy without astropy.")
+ def testWriteReadAstropyTableLossless(self):
+ tab1 = _makeSimpleAstropyTable(include_multidim=True, include_masked=True)
+
+ self.butler.put(tab1, self.datasetType, dataId={})
- Parameters
- ----------
- dict1 : `dict` [`str`, `np.ndarray`]
- dict2 : `dict` [`str`, `np.ndarray`]
- """
- self.assertEqual(set(dict1.keys()), set(dict2.keys()))
- for name in dict1:
- self.assertEqual(dict1[name].dtype, dict2[name].dtype)
- self.assertTrue(np.all(dict1[name] == dict2[name]))
+ tab2 = self.butler.get(self.datasetType, dataId={}, storageClass="ArrowAstropy")
+ _checkAstropyTableEquality(tab1, tab2)
-@unittest.skipUnless(np is not None, "Cannot test ParquetFormatterArrowNumpy without numpy.")
-@unittest.skipUnless(pa is not None, "Cannot test ParquetFormatterArrowNumpy without pyarrow.")
+
+@unittest.skipUnless(np is not None, "Cannot test InMemoryDatastore with NumpyDict without numpy.")
+@unittest.skipUnless(pa is not None, "Cannot test InMemoryDatastore with NumpyDict without pyarrow.")
class InMemoryNumpyDictDelegateTestCase(ParquetFormatterArrowNumpyDictTestCase):
- """Tests for InMemoryDatastore, using ArrowNumpyDictDelegate."""
+ """Tests for InMemoryDatastore, using ArrowTableDelegate with
+ Numpy dict.
+ """
configFile = os.path.join(TESTDIR, "config/basic/butler-inmemory.yaml")
@@ -2152,7 +2102,7 @@ def testWriteArrowSchemaReadAsArrowNumpySchema(self):
self.assertEqual(np_schema2, np_schema1)
-@unittest.skipUnless(pa is not None, "Cannot test InMemoryArrowSchemaDelegate without pyarrow.")
+@unittest.skipUnless(pa is not None, "Cannot test InMemoryDatastore with ArrowSchema without pyarrow.")
class InMemoryArrowSchemaDelegateTestCase(ParquetFormatterArrowSchemaTestCase):
"""Tests for InMemoryDatastore and ArrowSchema."""
@@ -2212,5 +2162,83 @@ def testRowGroupSizeDataFrameWithLists(self):
self.assertGreater(row_group_size, 1_000_000)
+def _checkAstropyTableEquality(table1, table2, skip_units=False, has_bigendian=False):
+ """Check if two astropy tables have the same columns/values.
+
+ Parameters
+ ----------
+ table1 : `astropy.table.Table`
+ table2 : `astropy.table.Table`
+ skip_units : `bool`
+ has_bigendian : `bool`
+ """
+ if not has_bigendian:
+ assert table1.dtype == table2.dtype
+ else:
+ for name in table1.dtype.names:
+ # Only check type matches, force to little-endian.
+ assert table1.dtype[name].newbyteorder(">") == table2.dtype[name].newbyteorder(">")
+
+ assert table1.meta == table2.meta
+ if not skip_units:
+ for name in table1.columns:
+ assert table1[name].unit == table2[name].unit
+ assert table1[name].description == table2[name].description
+ assert table1[name].format == table2[name].format
+
+ for name in table1.columns:
+ # We need to check masked/regular columns after filling.
+ has_masked = False
+ if isinstance(table1[name], atable.column.MaskedColumn):
+ c1 = table1[name].filled()
+ has_masked = True
+ else:
+ c1 = np.array(table1[name])
+ if has_masked:
+ assert isinstance(table2[name], atable.column.MaskedColumn)
+ c2 = table2[name].filled()
+ else:
+ assert not isinstance(table2[name], atable.column.MaskedColumn)
+ c2 = np.array(table2[name])
+ np.testing.assert_array_equal(c1, c2)
+ # If we have a masked column then we test the underlying data.
+ if has_masked:
+ np.testing.assert_array_equal(np.array(c1), np.array(c2))
+ np.testing.assert_array_equal(table1[name].mask, table2[name].mask)
+
+
+def _checkNumpyTableEquality(table1, table2, has_bigendian=False):
+ """Check if two numpy tables have the same columns/values
+
+ Parameters
+ ----------
+ table1 : `numpy.ndarray`
+ table2 : `numpy.ndarray`
+ has_bigendian : `bool`
+ """
+ assert table1.dtype.names == table2.dtype.names
+ for name in table1.dtype.names:
+ if not has_bigendian:
+ assert table1.dtype[name] == table2.dtype[name]
+ else:
+ # Only check type matches, force to little-endian.
+ assert table1.dtype[name].newbyteorder(">") == table2.dtype[name].newbyteorder(">")
+ assert np.all(table1 == table2)
+
+
+def _checkNumpyDictEquality(dict1, dict2):
+ """Check if two numpy dicts have the same columns/values.
+
+ Parameters
+ ----------
+ dict1 : `dict` [`str`, `np.ndarray`]
+ dict2 : `dict` [`str`, `np.ndarray`]
+ """
+ assert set(dict1.keys()) == set(dict2.keys())
+ for name in dict1:
+ assert dict1[name].dtype == dict2[name].dtype
+ assert np.all(dict1[name] == dict2[name])
+
+
if __name__ == "__main__":
unittest.main()