Skip to content

Commit

Permalink
anndata- and df-based dataframe round-trip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Aug 5, 2024
1 parent 20d769e commit b967eda
Showing 1 changed file with 80 additions and 25 deletions.
105 changes: 80 additions & 25 deletions apis/python/tests/test_dataframe_io_roundtrips.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from dataclasses import asdict, dataclass, fields
from inspect import getfullargspec
from os.path import join
from pathlib import Path
from typing import List, Optional, Tuple
Expand All @@ -13,8 +14,9 @@

from tiledbsoma import SOMA_JOINID, DataFrame, Experiment
from tiledbsoma.io._common import _DATAFRAME_ORIGINAL_INDEX_NAME_JSON
from tiledbsoma.io.ingest import from_anndata
from tiledbsoma.io.outgest import to_anndata
from tiledbsoma.io._registration import AxisIDMapping
from tiledbsoma.io.ingest import IngestionParams, _write_dataframe, from_anndata
from tiledbsoma.io.outgest import _read_dataframe, to_anndata


def parse_col(col_str: str) -> Tuple[Optional[str], List[str]]:
Expand Down Expand Up @@ -62,26 +64,37 @@ class RoundTrip:
persisted_metadata: Optional[str] = None
# Argument passed to `_write_dataframe` on ingest (default here matches `from_anndata`'s "obs" path)
ingest_id_column_name: Optional[str] = "obs_id"
# Argument passed to `_extract_pdf` on outgest (default here matches `to_anndata`'s "obs_id" path)
# Argument passed to `_read_dataframe` on outgest (default here matches `to_anndata`'s "obs_id" path)
outgest_default_index_name: Optional[str] = None
# Argument passed to `_read_dataframe` on outgest (default here matches `to_anndata`'s "obs_id" path)
outgest_fallback_index_name: Optional[str] = "obs_id"


def parametrize_roundtrips(*roundtrips: RoundTrip):
def parametrize_roundtrips(roundtrips: List[RoundTrip]):
def wrapper(fn):
names = [f.name for f in fields(RoundTrip)[1:]]
values = [list(asdict(rt).values()) for rt in roundtrips]
ids, values = zip(*([(vs[0], vs[1:]) for vs in values]))
# Test-case IDs
ids = [rt.name for rt in roundtrips]
# Convert `RoundTrip`s to "values" arrays, filtered and reordered to match kwargs expected by the wrapped
# function
fields_names = [f.name for f in fields(RoundTrip)]
spec = getfullargspec(fn)
names = [arg for arg in spec.args if arg in fields_names]
values = [
{name: rt_dict[name] for name in names}.values()
for rt_dict in [asdict(rt) for rt in roundtrips]
]
# Delegate to PyTest `parametrize`
return pytest.mark.parametrize(
names,
values,
ids=ids,
names, # arg names
values, # arg value lists
ids=ids, # test-case names
)(fn)

return wrapper


# fmt: off
@parametrize_roundtrips(
ROUND_TRIPS = [
RoundTrip(
'1. `df.index` named "index"',
make_df("index=xx,yy,zz", col0="aa,bb,cc", col1="AA,BB,CC"),
Expand Down Expand Up @@ -138,9 +151,30 @@ def wrapper(fn):
[ "idx", "obs_id", ],
"idx",
),
)
]
# fmt: on
def test_io_roundtrips(


def verify_metadata(
sdf: DataFrame, persisted_column_names: List[str], persisted_metadata: Optional[str]
):
# Verify column names and types
schema = sdf.schema
assert schema.names == [SOMA_JOINID, *persisted_column_names]
[soma_joinid_type, *string_col_types] = schema.types
assert soma_joinid_type == pa.int64() and schema.field(0).nullable is False
for string_col_type in string_col_types:
assert string_col_type == pa.large_string()

# Verify "original index metadata"
actual_index_metadata = json.loads(
sdf.metadata[_DATAFRAME_ORIGINAL_INDEX_NAME_JSON]
)
assert actual_index_metadata == persisted_metadata


@parametrize_roundtrips(ROUND_TRIPS)
def test_adata_io_roundtrips(
tmp_path: Path,
original_df: pd.DataFrame,
persisted_column_names: List[str],
Expand All @@ -161,21 +195,42 @@ def test_io_roundtrips(
# Verify column names and types
obs_uri = join(uri, "obs")
obs = DataFrame.open(obs_uri)
schema = obs.schema
assert schema.names == [SOMA_JOINID, *persisted_column_names]
[soma_joinid_type, *string_col_types] = schema.types
assert soma_joinid_type == pa.int64() and schema.field(0).nullable is False
for string_col_type in string_col_types:
assert string_col_type == pa.large_string()

# Verify "original index metadata"
actual_index_metadata = json.loads(
obs.metadata[_DATAFRAME_ORIGINAL_INDEX_NAME_JSON]
)
assert actual_index_metadata == persisted_metadata
verify_metadata(obs, persisted_column_names, persisted_metadata)

# Verify outgested pd.DataFrame
with Experiment.open(ingested_uri) as exp:
adata1 = to_anndata(exp, "meas", obs_id_name=outgest_default_index_name)
outgested_obs = adata1.obs
assert_frame_equal(outgested_obs, outgested_df)


@parametrize_roundtrips(ROUND_TRIPS)
def test_df_io_roundtrips(
tmp_path: Path,
original_df: pd.DataFrame,
persisted_column_names: List[str],
persisted_metadata: Optional[str],
ingest_id_column_name: Optional[str],
outgest_default_index_name: Optional[str],
outgest_fallback_index_name: Optional[str],
outgested_df: pd.DataFrame,
):
uri = str(tmp_path)
_write_dataframe(
uri,
original_df,
id_column_name=ingest_id_column_name,
axis_mapping=AxisIDMapping(data=tuple(range(len(original_df)))),
ingestion_params=IngestionParams("write", None),
).close()

sdf = DataFrame.open(uri)
verify_metadata(sdf, persisted_column_names, persisted_metadata)

# Verify outgested pd.DataFrame
actual_outgested_df = _read_dataframe(
sdf,
default_index_name=outgest_default_index_name,
fallback_index_name=outgest_fallback_index_name,
)
assert_frame_equal(actual_outgested_df, outgested_df)

0 comments on commit b967eda

Please sign in to comment.