Skip to content

Commit

Permalink
Use partition keys as datatree node names
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Sep 10, 2024
1 parent f201cd3 commit 2cd496c
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def initialise_default_args(
epoch: str | None,
table_factory: TableFactory | None,
partition_columns: List[str] | None,
partition_key: PartitionKeyT | None,
structure_factory: MSv2StructureFactory | None,
) -> Tuple[str, TableFactory, List[str], PartitionKeyT, MSv2StructureFactory]:
) -> Tuple[str, TableFactory, List[str], MSv2StructureFactory]:
"""
Ensures consistency when initialising default arguments from multiple locations
"""
Expand All @@ -98,16 +97,7 @@ def initialise_default_args(
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, partition_columns, auto_corrs=auto_corrs
)
structure = structure_factory()
if partition_key is None:
partition_key = next(iter(structure.keys()))
warnings.warn(
f"No partition_key was supplied. Selected first partition {partition_key}"
)
elif partition_key not in structure:
raise ValueError(f"{partition_key} not in {list(structure.keys())}")

return epoch, table_factory, partition_columns, partition_key, structure_factory
return epoch, table_factory, partition_columns, structure_factory


class MSv2Store(AbstractWritableDataStore):
Expand Down Expand Up @@ -164,19 +154,28 @@ def open(
if not isinstance(ms, str):
raise ValueError("Measurement Sets paths must be strings")

epoch, table_factory, partition_columns, partition_key, structure_factory = (
epoch, table_factory, partition_columns, structure_factory = (
initialise_default_args(
ms,
ninstances,
auto_corrs,
epoch,
None,
partition_columns,
partition_key,
structure_factory,
)
)

structure = structure_factory()

if partition_key is None:
partition_key = next(iter(structure.keys()))
warnings.warn(
f"No partition_key was supplied. Selected first partition {partition_key}"
)
elif partition_key not in structure:
raise ValueError(f"{partition_key} not in {list(structure.keys())}")

return cls(
table_factory,
structure_factory,
Expand Down Expand Up @@ -332,16 +331,16 @@ def open_datatree(
else:
raise ValueError("Measurement Set paths must be strings")

epoch, _, partition_columns, _, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None, None
epoch, _, partition_columns, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None
)

structure = structure_factory()
datasets = {}
chunks = kwargs.pop("chunks", None)
pchunks = promote_chunks(structure, chunks)

for i, partition_key in enumerate(structure):
for partition_key in structure:
ds = xarray.open_dataset(
ms,
drop_variables=drop_variables,
Expand All @@ -354,6 +353,8 @@ def open_datatree(
chunks=None if pchunks is None else pchunks[partition_key],
**kwargs,
)
datasets[str(i)] = ds

key = ",".join(f"{k}={v}" for k, v in sorted(partition_key))
datasets[key] = ds

return DataTree.from_dict(datasets)

0 comments on commit 2cd496c

Please sign in to comment.