diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index 75107e2..6f44edb 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -77,9 +77,8 @@ def initialise_default_args( epoch: str | None, table_factory: TableFactory | None, partition_columns: List[str] | None, - partition_key: PartitionKeyT | None, structure_factory: MSv2StructureFactory | None, -) -> Tuple[str, TableFactory, List[str], PartitionKeyT, MSv2StructureFactory]: +) -> Tuple[str, TableFactory, List[str], MSv2StructureFactory]: """ Ensures consistency when initialising default arguments from multiple locations """ @@ -98,16 +97,7 @@ def initialise_default_args( structure_factory = structure_factory or MSv2StructureFactory( table_factory, partition_columns, auto_corrs=auto_corrs ) - structure = structure_factory() - if partition_key is None: - partition_key = next(iter(structure.keys())) - warnings.warn( - f"No partition_key was supplied. Selected first partition {partition_key}" - ) - elif partition_key not in structure: - raise ValueError(f"{partition_key} not in {list(structure.keys())}") - - return epoch, table_factory, partition_columns, partition_key, structure_factory + return epoch, table_factory, partition_columns, structure_factory class MSv2Store(AbstractWritableDataStore): @@ -164,7 +154,7 @@ def open( if not isinstance(ms, str): raise ValueError("Measurement Sets paths must be strings") - epoch, table_factory, partition_columns, partition_key, structure_factory = ( + epoch, table_factory, partition_columns, structure_factory = ( initialise_default_args( ms, ninstances, @@ -172,11 +162,20 @@ def open( epoch, None, partition_columns, - partition_key, structure_factory, ) ) + structure = structure_factory() + + if partition_key is None: + partition_key = next(iter(structure.keys())) + warnings.warn( + f"No partition_key was supplied. Selected first partition {partition_key}" + ) + elif partition_key not in structure: + raise ValueError(f"{partition_key} not in {list(structure.keys())}") + return cls( table_factory, structure_factory, @@ -332,8 +331,8 @@ def open_datatree( else: raise ValueError("Measurement Set paths must be strings") - epoch, _, partition_columns, _, structure_factory = initialise_default_args( - ms, ninstances, auto_corrs, epoch, None, partition_columns, None, None + epoch, _, partition_columns, structure_factory = initialise_default_args( + ms, ninstances, auto_corrs, epoch, None, partition_columns, None ) structure = structure_factory() @@ -341,7 +340,7 @@ def open_datatree( chunks = kwargs.pop("chunks", None) pchunks = promote_chunks(structure, chunks) - for i, partition_key in enumerate(structure): + for partition_key in structure: ds = xarray.open_dataset( ms, drop_variables=drop_variables, @@ -354,6 +353,8 @@ def open_datatree( chunks=None if pchunks is None else pchunks[partition_key], **kwargs, ) - datasets[str(i)] = ds + + key = ",".join(f"{k}={v}" for k, v in sorted(partition_key)) + datasets[key] = ds return DataTree.from_dict(datasets)