Skip to content

Commit

Permalink
Add limited support for non-common skypix in expressions (DM-44362)
Browse files Browse the repository at this point in the history
First attempt to fix skypix constraint in DataId. Non-common skypix failed
in the new query system, adding a special visitor that rewrites query
predicate in terms of common skypix. The fix is far from perfect though,
a general case is not very optimal, and htm-specific optimization
depends too much on htm internals. This fixes test_skypix_constraint_queries
unit test but not completely. The test with non-common skypix now works, but
the test using three-way join with common htm7 still fails.
  • Loading branch information
andy-slac committed Jul 8, 2024
1 parent f00c9fd commit fd3d28a
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 2 deletions.
28 changes: 26 additions & 2 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .. import ddl
from .._dataset_type import DatasetType
from .._exceptions import InvalidQueryError
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse, SkyPixDimension
from ..dimensions.record_cache import DimensionRecordCache
from ..queries import tree as qt
from ..queries.driver import (
Expand Down Expand Up @@ -80,6 +80,7 @@
ResultPageConverter,
ResultPageConverterContext,
)
from ._skypix_visitor import SkyPixRewriteVisitor
from ._sql_column_visitor import SqlColumnVisitor

if TYPE_CHECKING:
Expand Down Expand Up @@ -484,7 +485,7 @@ def analyze_query(
Column expressions to sort by.
find_first_dataset : `str` or `None`, optional
Name of a dataset type for which only one result row for each data
ID should be returned, with the colletions searched in order.
ID should be returned, with the collections searched in order.
Returns
-------
Expand Down Expand Up @@ -874,6 +875,18 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query
potentially included, with the remainder still present in
`QueryJoinPlans.predicate`.
"""
skypix_visitor = SkyPixRewriteVisitor(tree.dimensions.universe)
if predicate := tree.predicate.visit(skypix_visitor):
# Rewritten predicate, we also want to update tree dimensions to
# remove non-common skypix dimensions.
dimensions = tree.dimensions
common_skypix = tree.dimensions.universe.commonSkyPix
if dimensions.skypix - {common_skypix.name}:
names = dimensions.names - tree.dimensions.skypix
names |= {common_skypix.name}
dimensions = DimensionGroup(dimensions.universe, names)
tree = tree.model_copy(update=dict(predicate=predicate, dimensions=dimensions))

# Delegate to the dimensions manager to rewrite the predicate and start
# a QueryBuilder to cover any spatial overlap joins or constraints.
# We'll return that QueryBuilder at the end.
Expand All @@ -886,6 +899,17 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query
tree.get_joined_dimension_groups(),
)
result = QueryJoinsPlan(predicate=predicate, columns=builder.columns)

# Add spatial constraints from SkyPix visitor for every spatial
# dimension.
if skypix_visitor.region_constraints:
for element_name in tree.dimensions.elements:
element = tree.dimensions.universe[element_name]
if element.spatial and not isinstance(element, SkyPixDimension):
builder.postprocessing.spatial_where_filtering.extend(
(element, region) for region in skypix_visitor.region_constraints
)

# Add columns required by postprocessing.
builder.postprocessing.gather_columns_required(result.columns)
# We also check that the predicate doesn't reference any dimensions
Expand Down
172 changes: 172 additions & 0 deletions python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# This file is part of daf_butler.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

__all__ = ["SkyPixRewriteVisitor"]

from typing import Any

from lsst.sphgeom import Region

from ..dimensions import DimensionUniverse, SkyPixDimension
from ..queries import tree as qt
from ..queries.tree._column_literal import IntColumnLiteral
from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor


class SkyPixRewriteVisitor(
SimplePredicateVisitor,
ColumnExpressionVisitor[tuple[SkyPixDimension, None] | tuple[None, Any] | tuple[None, None]],
):
"""A predicate visitor that rewrites skypix constraints that use non-common
skypix.
Parameters
----------
universe : `DimensionUniverse`
Dimension universe.
"""

def __init__(self, universe: DimensionUniverse):
self.universe = universe
self._common_skypix = universe.commonSkyPix
self.region_constraints: list[Region] = []

def visit_comparison(
self,
a: qt.ColumnExpression,
operator: qt.ComparisonOperator,
b: qt.ColumnExpression,
flags: PredicateVisitFlags,
) -> qt.Predicate | None:
if flags & PredicateVisitFlags.HAS_OR_SIBLINGS:
return None
if flags & PredicateVisitFlags.INVERTED:
if operator == "!=":
operator = "=="
else:
return None
if operator == "==":
k_a, v_a = a.visit(self)
k_b, v_b = b.visit(self)
if k_a is not None and v_b is not None:
skypix_dimension = k_a
value = v_b
elif k_b is not None and v_a is not None:
skypix_dimension = k_b
value = v_a

Check warning on line 82 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L81-L82

Added lines #L81 - L82 were not covered by tests
else:
return None

if skypix_dimension == self._common_skypix:
# Common skypix should be handled properly, no need to rewrite.
return None

predicate: qt.Predicate | None = None
region: Region | None = None
if skypix_dimension.system.name == "htm" and self._common_skypix.system.name == "htm":
# In case of HTM we can do some things in more optimal way.
# TODO: This depends on HTM index mapping, maybe we should add
# this facility to sphgeom classes.
if skypix_dimension.level < self._common_skypix.level:
# In case of more coarse skypix we can just replace
# equality with a range constraint on a common skypix.
level_shift = (self._common_skypix.level - skypix_dimension.level) * 2
begin, end = (value << level_shift, ((value + 1) << level_shift) - 1)
predicate = qt.Predicate.in_range(

Check warning on line 101 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L99-L101

Added lines #L99 - L101 were not covered by tests
qt.DimensionKeyReference.model_construct(dimension=self._common_skypix), begin, end
)
else:
# In case of a finer HTM we want to constraint on a common
# skypix and add post-processing filter for its region.
level_shift = (skypix_dimension.level - self._common_skypix.level) * 2
common_index = value >> level_shift
predicate = qt.Predicate.compare(
qt.DimensionKeyReference.model_construct(dimension=self._common_skypix),
"==",
IntColumnLiteral.model_construct(value=common_index),
)
region = skypix_dimension.pixelization.pixel(value)
else:
# More general case will use an envelope around the pixel
# region, not super efficient.
region = skypix_dimension.pixelization.pixel(value)

Check warning on line 118 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L118

Added line #L118 was not covered by tests
# Try to limit the number of ranges, as it probably does not
# help to have super-precise envelope.
envelope = self._common_skypix.pixelization.envelope(region, 64)
predicates: list[qt.Predicate] = []

Check warning on line 122 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L121-L122

Added lines #L121 - L122 were not covered by tests
for begin, end in envelope:
if begin == end:
predicates.append(

Check warning on line 125 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L125

Added line #L125 was not covered by tests
qt.Predicate.compare(
qt.DimensionKeyReference.model_construct(dimension=self._common_skypix),
"==",
IntColumnLiteral.model_construct(value=begin),
)
)
else:
predicates.append(

Check warning on line 133 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L133

Added line #L133 was not covered by tests
qt.Predicate.in_range(
qt.DimensionKeyReference.model_construct(dimension=self._common_skypix),
begin,
end,
)
)
predicate = qt.Predicate.from_bool(False).logical_or(*predicates)

Check warning on line 140 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L140

Added line #L140 was not covered by tests

if region is not None:
self.region_constraints.append(region)
return predicate

return None

def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]:
return None, None

def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]:
return None, None

def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]:
return None, expression.get_literal_value()

def visit_dimension_key_reference(
self, expression: qt.DimensionKeyReference
) -> tuple[SkyPixDimension, None] | tuple[None, None]:
if isinstance(expression.dimension, SkyPixDimension):
return expression.dimension, None
else:
return None, None

def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]:
return None, None

def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]:
return None, None

Check warning on line 169 in python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_skypix_visitor.py#L169

Added line #L169 was not covered by tests

def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]:
raise AssertionError("No Reversed expressions in predicates.")

0 comments on commit fd3d28a

Please sign in to comment.