From 5b4498ca1974bdffa509693f4488c93caa90e5da Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 8 Aug 2024 16:51:03 -0700 Subject: [PATCH] Allow boolean columns to be used in where strings Fix an issue where boolean metadata columns (like exposure.can_see_sky and exposure.has_simulated) were not usable in "where" clauses for Registry query functions. These column names can now be used as a boolean expression, for example `where="exposure.can_see_sky"` or `where="NOT exposure.can_see_sky"`. --- doc/changes/DM-45680.bugfix.md | 4 ++ .../_sql_column_visitor.py | 5 +++ .../daf/butler/queries/_expression_strings.py | 9 +++- .../daf/butler/queries/tree/_predicate.py | 42 ++++++++++++++++++- python/lsst/daf/butler/queries/visitors.py | 24 +++++++++++ .../queries/expressions/_predicate.py | 7 +++- .../daf/butler/registry/tests/_registry.py | 42 +++++++++++++++++++ 7 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 doc/changes/DM-45680.bugfix.md diff --git a/doc/changes/DM-45680.bugfix.md b/doc/changes/DM-45680.bugfix.md new file mode 100644 index 0000000000..a6471b9d4a --- /dev/null +++ b/doc/changes/DM-45680.bugfix.md @@ -0,0 +1,4 @@ +Fix an issue where boolean metadata columns (like `exposure.can_see_sky` and +`exposure.has_simulated`) were not usable in `where` clauses for Registry query +functions. These column names can now be used as a boolean expression, for +example `where="exposure.can_see_sky` or `where="NOT exposure.can_see_sky"`. diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py index 23a2485f64..86dd8f93e6 100644 --- a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py +++ b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py @@ -132,6 +132,11 @@ def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[An # Docstring inherited. return self.expect_scalar(expression.operand).desc() + def visit_boolean_wrapper( + self, value: qt.ColumnExpression, flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + return self.expect_scalar(value) + def visit_comparison( self, a: qt.ColumnExpression, diff --git a/python/lsst/daf/butler/queries/_expression_strings.py b/python/lsst/daf/butler/queries/_expression_strings.py index dbc4694763..cc0cc755cf 100644 --- a/python/lsst/daf/butler/queries/_expression_strings.py +++ b/python/lsst/daf/butler/queries/_expression_strings.py @@ -198,7 +198,14 @@ def visitIdentifier(self, name: str, node: Node) -> _VisitorResult: if categorizeConstant(name) == ExpressionConstant.NULL: return _Null() - return _ColExpr(interpret_identifier(self.context, name)) + column_expression = interpret_identifier(self.context, name) + if column_expression.column_type == "bool": + # Expression-handling code (in this file and elsewhere) expects + # boolean-valued expressions to be represented as Predicate, not a + # column expression. + return Predicate.from_bool_expression(column_expression) + else: + return _ColExpr(column_expression) def visitNumericLiteral(self, value: str, node: Node) -> _VisitorResult: numeric: int | float diff --git a/python/lsst/daf/butler/queries/tree/_predicate.py b/python/lsst/daf/butler/queries/tree/_predicate.py index 9525411e8a..8335c81e3a 100644 --- a/python/lsst/daf/butler/queries/tree/_predicate.py +++ b/python/lsst/daf/butler/queries/tree/_predicate.py @@ -155,6 +155,26 @@ def from_bool(cls, value: bool) -> Predicate: # return cls.model_construct(operands=() if value else ((),)) + @classmethod + def from_bool_expression(cls, value: ColumnExpression) -> Predicate: + """Construct a predicate that wraps a boolean ColumnExpression, taking + on the value of the underlying ColumnExpression. + + Parameters + ---------- + value : `ColumnExpression` + Boolean-valued expression to convert to Predicate. + + Returns + ------- + predicate : `Predicate` + Predicate representing the expression. + """ + if value.column_type != "bool": + raise ValueError(f"ColumnExpression must have column type 'bool', not '{value.column_type}'") + + return cls._from_leaf(BooleanWrapper(operand=value)) + @classmethod def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: """Construct a predicate representing a binary comparison between @@ -412,6 +432,26 @@ def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlag return visitor._visit_logical_not(self.operand, flags) +class BooleanWrapper(PredicateLeafBase): + """Pass-through to a pre-existing boolean column expression.""" + + predicate_type: Literal["boolean_wrapper"] = "boolean_wrapper" + + operand: ColumnExpression + """Wrapped expression that will be used as the value for this predicate.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + def __str__(self) -> str: + return f"{self.operand}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_boolean_wrapper(self.operand, flags) + + @final class IsNull(PredicateLeafBase): """A boolean column expression that tests whether its operand is NULL.""" @@ -639,7 +679,7 @@ def _validate_column_types(self) -> InQuery: return self -LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery +LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery | BooleanWrapper PredicateLeaf: TypeAlias = Annotated[ LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type") ] diff --git a/python/lsst/daf/butler/queries/visitors.py b/python/lsst/daf/butler/queries/visitors.py index 8340389e19..5e77161d15 100644 --- a/python/lsst/daf/butler/queries/visitors.py +++ b/python/lsst/daf/butler/queries/visitors.py @@ -197,6 +197,25 @@ class PredicateVisitor(Generic[_A, _O, _L]): visit method arguments. """ + @abstractmethod + def visit_boolean_wrapper(self, value: tree.ColumnExpression, flags: PredicateVisitFlags) -> _L: + """Visit a boolean-valued column expression. + + Parameters + ---------- + value : `tree.ColumnExpression` + Column expression, guaranteed to have `column_type == "bool"`. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + @abstractmethod def visit_comparison( self, @@ -448,6 +467,11 @@ class SimplePredicateVisitor( return a replacement `Predicate` to construct a new tree. """ + def visit_boolean_wrapper( + self, value: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + return None + def visit_comparison( self, a: tree.ColumnExpression, diff --git a/python/lsst/daf/butler/registry/queries/expressions/_predicate.py b/python/lsst/daf/butler/registry/queries/expressions/_predicate.py index bd38773e8f..9c5d611cfb 100644 --- a/python/lsst/daf/butler/registry/queries/expressions/_predicate.py +++ b/python/lsst/daf/butler/registry/queries/expressions/_predicate.py @@ -390,7 +390,12 @@ def visitIdentifier(self, name: str, node: Node) -> VisitorResult: if column == timespan_database_representation.TimespanDatabaseRepresentation.NAME else element.RecordClass.fields.standard[column].getPythonType() ) - return ColumnExpression.reference(tag, dtype) + if dtype is bool: + # ColumnExpression is for non-boolean columns only. Booleans + # are represented as Predicate. + return Predicate.reference(tag) + else: + return ColumnExpression.reference(tag, dtype) else: tag = DimensionKeyColumnTag(element.name) assert isinstance(element, Dimension) diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index f60e4fc151..8b9c9444af 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -4133,3 +4133,45 @@ def test_collection_summary(self) -> None: # Note that instrument governor resurrects here, even though there are # no datasets left with that governor. self.assertEqual(summary.governors, {"instrument": {"Cam1"}, "skymap": {"SkyMap1"}}) + + def test_query_where_string_boolean_expressions(self) -> None: + """Test that 'where' clauses for queries return the expected results + for boolean columns used as expressions. + """ + registry = self.makeRegistry() + # Exposure is the only dimension that has boolean columns, and this set + # of data has all the pre-requisites for exposure set up. + self.loadData(registry, "hsc-rc2-subset.yaml") + base_data = {"instrument": "HSC", "physical_filter": "HSC-R", "group": "903342", "day_obs": 20130617} + + TRUE_ID_1 = 1001 + TRUE_ID_2 = 2001 + FALSE_ID_1 = 1002 + FALSE_ID_2 = 2002 + records = [ + {"id": TRUE_ID_1, "obs_id": "true-1", "can_see_sky": True}, + {"id": TRUE_ID_2, "obs_id": "true-2", "can_see_sky": True}, + {"id": FALSE_ID_1, "obs_id": "false-1", "can_see_sky": False}, + {"id": FALSE_ID_2, "obs_id": "false-2", "can_see_sky": False}, + # There is also a record ID 903342 from the YAML file with a NULL + # value for can_see_sky. + ] + for record in records: + registry.insertDimensionData("exposure", base_data | record) + + def _run_query(where: str) -> list[str]: + result = list(registry.queryDimensionRecords("exposure", where=where, instrument="HSC")) + return [x.dataId["exposure"] for x in result] + + # Boolean columns should be usable standalone as an expression. + self.assertCountEqual(_run_query("exposure.can_see_sky"), [TRUE_ID_1, TRUE_ID_2]) + + # You can find false values in the column with NOT. The NOT of NULL + # is NULL, consistent with SQL semantics -- so records with NULL + # can_see_sky are not included here. + self.assertCountEqual(_run_query("NOT exposure.can_see_sky"), [FALSE_ID_1, FALSE_ID_2]) + + # Make sure the bare column composes with other expressions correctly. + self.assertCountEqual( + _run_query("exposure.can_see_sky OR exposure = 1002"), [TRUE_ID_1, TRUE_ID_2, FALSE_ID_1] + )