diff --git a/python/lsst/daf/butler/registry/queries/_query.py b/python/lsst/daf/butler/registry/queries/_query.py index ea8483831f..069657afea 100644 --- a/python/lsst/daf/butler/registry/queries/_query.py +++ b/python/lsst/daf/butler/registry/queries/_query.py @@ -22,6 +22,7 @@ __all__ = () +import itertools from collections.abc import Iterable, Iterator, Mapping, Sequence, Set from contextlib import contextmanager from typing import Any, cast, final @@ -648,9 +649,6 @@ def find_datasets( lsst.daf.relation.ColumnError Raised if a dataset search is already present in this query and this is a find-first search. - ValueError - Raised if the given dataset type's dimensions are not a subset of - the current query's dimensions. """ if find_first and DatasetColumnTag.filter_from(self._relation.columns): raise ColumnError( @@ -680,14 +678,6 @@ def find_datasets( # where we materialize the initial data ID query into a temp table # and hence can't go back and "recover" those dataset columns anyway; # - if not (dataset_type.dimensions <= self._dimensions): - raise ValueError( - "Cannot find datasets from a query unless the dataset types's dimensions " - f"({dataset_type.dimensions}, for {dataset_type.name}) are a subset of the query's " - f"({self._dimensions})." - ) - columns = set(columns) - columns.add("dataset_id") collections = CollectionWildcard.from_expression(collections) if find_first: collections.require_ordered() @@ -699,27 +689,81 @@ def find_datasets( allow_calibration_collections=True, rejections=rejections, ) + # If the dataset type has dimensions not in the current query, or we + # need a temporal join for a calibration collection, either restore + # those columns or join them in. + full_dimensions = dataset_type.dimensions.union(self._dimensions) relation = self._relation + record_caches = self._record_caches + base_columns_required: set[ColumnTag] = { + DimensionKeyColumnTag(name) for name in full_dimensions.names + } + spatial_joins: list[tuple[str, str]] = [] + if not (dataset_type.dimensions <= self._dimensions): + if self._has_record_columns is True: + # This query is for expanded data IDs, so if we add new + # dimensions to the query we need to be able to get records for + # the new dimensions. + record_caches = dict(self._record_caches) + for element in self.dimensions.elements: + if element in record_caches: + continue + if ( + cache := self._backend.get_dimension_record_cache(element.name, self._context) + ) is not None: + record_caches[element] = cache + else: + base_columns_required.update(element.RecordClass.fields.columns.keys()) + # See if we need spatial joins between the current query and the + # dataset type's dimensions. The logic here is for multiple + # spatial joins in general, but in practice it'll be exceedingly + # rare for there to be more than one. We start by figuring out + # which spatial "families" (observations vs. skymaps, skypix + # systems) are present on only one side and not the other. + lhs_spatial_families = self._dimensions.spatial - dataset_type.dimensions.spatial + rhs_spatial_families = dataset_type.dimensions.spatial - self._dimensions.spatial + # Now we iterate over the Cartesian product of those, so e.g. + # if the query has {tract, patch, visit} and the dataset type + # has {htm7} dimensions, the iterations of this loop + # correspond to: (skymap, htm), (observations, htm). + for lhs_spatial_family, rhs_spatial_family in itertools.product( + lhs_spatial_families, rhs_spatial_families + ): + # For each pair we add a join between the most-precise element + # present in each family (e.g. patch beats tract). + spatial_joins.append( + ( + lhs_spatial_family.choose(full_dimensions.elements).name, + rhs_spatial_family.choose(full_dimensions.elements).name, + ) + ) + # Set up any temporal join between the query dimensions and CALIBRATION + # collection's validity ranges. temporal_join_on: set[ColumnTag] = set() if any(r.type is CollectionType.CALIBRATION for r in collection_records): for family in self._dimensions.temporal: - element = family.choose(self._dimensions.elements) - temporal_join_on.add(DimensionRecordColumnTag(element.name, "timespan")) - timespan_columns_required = set(temporal_join_on) - relation, columns_found = self._context.restore_columns(self._relation, timespan_columns_required) - timespan_columns_required.difference_update(columns_found) - if timespan_columns_required: - relation = self._backend.make_dimension_relation( - self._dimensions, - timespan_columns_required, - self._context, - initial_relation=relation, - # Don't permit joins to use any columns beyond those in the - # original relation, as that would change what this - # operation does. - initial_join_max_columns=frozenset(self._relation.columns), - governor_constraints=self._governor_constraints, - ) + endpoint = family.choose(self._dimensions.elements) + temporal_join_on.add(DimensionRecordColumnTag(endpoint.name, "timespan")) + base_columns_required.update(temporal_join_on) + # Note which of the many kinds of potentially-missing columns we have + # and add the rest. + base_columns_required.difference_update(relation.columns) + if base_columns_required: + relation = self._backend.make_dimension_relation( + full_dimensions, + base_columns_required, + self._context, + initial_relation=relation, + # Don't permit joins to use any columns beyond those in the + # original relation, as that would change what this + # operation does. + initial_join_max_columns=frozenset(self._relation.columns), + governor_constraints=self._governor_constraints, + spatial_joins=spatial_joins, + ) + # Finally we can join in the search for the dataset query. + columns = set(columns) + columns.add("dataset_id") if not collection_records: relation = relation.join( self._backend.make_doomed_dataset_relation(dataset_type, columns, rejections, self._context) @@ -742,7 +786,7 @@ def find_datasets( join_to=relation, temporal_join_on=temporal_join_on, ) - return self._chain(relation, defer=defer) + return self._chain(relation, dimensions=full_dimensions, record_caches=record_caches, defer=defer) def sliced( self, diff --git a/python/lsst/daf/butler/registry/queries/_results.py b/python/lsst/daf/butler/registry/queries/_results.py index 5e5f924b6c..1cc958bd09 100644 --- a/python/lsst/daf/butler/registry/queries/_results.py +++ b/python/lsst/daf/butler/registry/queries/_results.py @@ -254,8 +254,6 @@ def findDatasets( Raises ------ - ValueError - Raised if ``datasetType.dimensions.issubset(self.graph) is False``. MissingDatasetTypeError Raised if the given dataset type is not registered. """ @@ -314,12 +312,11 @@ def findRelatedDatasets( Raises ------ - ValueError - Raised if ``datasetType.dimensions.issubset(self.graph) is False`` - or ``dimensions.issubset(self.graph) is False``. MissingDatasetTypeError Raised if the given dataset type is not registered. """ + if dimensions is None: + dimensions = self.graph parent_dataset_type, _ = self._query.backend.resolve_single_dataset_type_wildcard( datasetType, components=False, explicit_only=True ) diff --git a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py index 6dea7b9c59..b3ea3ebe57 100644 --- a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py +++ b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py @@ -245,6 +245,10 @@ def make_dimension_relation( "it is part of a dataset subquery, spatial join, or other initial relation." ) + # Before joining in new tables to provide columns, attempt to restore + # them from the given relation by weakening projections applied to it. + relation, _ = context.restore_columns(relation, columns_required) + # Categorize columns not yet included in the relation to associate them # with dimension elements and detect bad inputs. missing_columns = ColumnCategorization.from_iterable(columns_required - relation.columns) diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index 5df6ca3630..35c7bc817e 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -1489,9 +1489,12 @@ def testQueryResults(self): expectedDeduplicatedBiases, ) - # Check dimensions match. - with self.assertRaises(ValueError): - subsetDataIds.findDatasets("flat", collections=["imported_r", "imported_g"], findFirst=True) + # Searching for a dataset with dimensions we had projected away + # restores those dimensions. + self.assertCountEqual( + list(subsetDataIds.findDatasets("flat", collections=["imported_r"], findFirst=True)), + expectedFlats, + ) # Use a component dataset type. self.assertCountEqual( @@ -3630,3 +3633,43 @@ def test_query_empty_collections(self) -> None: messages = list(result.explain_no_results()) self.assertTrue(messages) self.assertTrue(any("because collection list is empty" in message for message in messages)) + + def test_dataset_followup_spatial_joins(self) -> None: + """Test queryDataIds(...).findRelatedDatasets(...) where a spatial join + is involved. + """ + registry = self.makeRegistry() + self.loadData(registry, "base.yaml") + self.loadData(registry, "spatial.yaml") + pvi_dataset_type = DatasetType( + "pvi", {"visit", "detector"}, storageClass="StructuredDataDict", universe=registry.dimensions + ) + registry.registerDatasetType(pvi_dataset_type) + collection = "datasets" + registry.registerRun(collection) + (pvi1,) = registry.insertDatasets( + pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 1}], run=collection + ) + (pvi2,) = registry.insertDatasets( + pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 2}], run=collection + ) + (pvi3,) = registry.insertDatasets( + pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 3}], run=collection + ) + self.assertEqual( + set( + registry.queryDataIds(["patch"], skymap="SkyMap1", tract=0).findRelatedDatasets( + "pvi", [collection] + ) + ), + { + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=0), pvi1), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=0), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=1), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi1), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi3), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=3), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=4), pvi3), + }, + ) diff --git a/tests/data/registry/spatial.py b/tests/data/registry/spatial.py index d18ff224fb..393e63fdb5 100644 --- a/tests/data/registry/spatial.py +++ b/tests/data/registry/spatial.py @@ -252,7 +252,7 @@ def make_plots(detector_grid: bool, patch_grid: bool): index_labels(color="black", alpha=0.5), ) colors = iter(["red", "blue", "cyan", "green"]) - for (visit_id, visit_data), color in zip(VISIT_DATA.items(), colors, strict=True): + for (visit_id, visit_data), color in zip(VISIT_DATA.items(), colors, strict=False): for detector_id, pixel_indices in visit_data["detector_regions"].items(): label = f"visit={visit_id}" if label in labels_used: