Skip to content

Commit

Permalink
Add some simple query interface tests
Browse files Browse the repository at this point in the history
These are the advanced tests with the simple interface where
possible.
  • Loading branch information
timj committed Sep 4, 2024
1 parent 999b737 commit 536b3ff
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 18 deletions.
23 changes: 19 additions & 4 deletions python/lsst/daf/butler/_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,8 @@ def query_data_ids(
"""
if data_id is None:
data_id = DataCoordinate.make_empty(self.dimensions)
if order_by is None:
order_by = []
with self.query() as query:
result = (
query.where(data_id, where, bind=bind, **kwargs)
Expand All @@ -1573,6 +1575,8 @@ def query_datasets(
where: str = "",
bind: Mapping[str, Any] | None = None,
with_dimension_records: bool = False,
order_by: Iterable[str] | str | None = None,
limit: int = 20_000,
explain: bool = True,
**kwargs: Any,
) -> list[DatasetRef]:
Expand Down Expand Up @@ -1609,6 +1613,12 @@ def query_datasets(
with_dimension_records : `bool`, optional
If `True` (default is `False`) then returned data IDs will have
dimension records.
order_by : `~collections.abc.Iterable` [`str`] or `str`, optional
Names of the columns/dimensions to use for ordering returned data
IDs. Column name can be prefixed with minus (``-``) to use
descending ordering.
limit : `int`, optional
Upper limit on the number of returned records.
explain : `bool`, optional
If `True` (default) then `EmptyQueryResultError` exception is
raised when resulting list is empty. The exception contains
Expand Down Expand Up @@ -1654,11 +1664,14 @@ def query_datasets(
"""
if data_id is None:
data_id = DataCoordinate.make_empty(self.dimensions)
if order_by is None:
order_by = []
with self.query() as query:
result = query.where(data_id, where, bind=bind, **kwargs).datasets(
dataset_type,
collections=collections,
find_first=find_first,
result = (
query.where(data_id, where, bind=bind, **kwargs)
.datasets(dataset_type, collections=collections, find_first=find_first)
.order_by(*ensure_iterable(order_by))
.limit(limit)
)
if with_dimension_records:
result = result.with_dimension_records()
Expand Down Expand Up @@ -1738,6 +1751,8 @@ def query_dimension_records(
"""
if data_id is None:
data_id = DataCoordinate.make_empty(self.dimensions)
if order_by is None:
order_by = []
with self.query() as query:
result = (
query.where(data_id, where, bind=bind, **kwargs)
Expand Down
188 changes: 174 additions & 14 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def check_detector_records(
self.assertFalse(results.any(exact=False, execute=False))
self.assertFalse(results.any(exact=True, execute=False))
self.assertCountEqual(results.explain_no_results(), list(messages))
self.check_detector_records_returned(list(results), ids=ids, ordered=ordered)

def check_detector_records_returned(
self,
results: list[DimensionRecord],
ids: Sequence[int] = (1, 2, 3, 4),
ordered: bool = False,
) -> None:
expected = [DETECTOR_TUPLES[i] for i in ids]
queried = list(make_detector_tuples(results).values())
if ordered:
Expand All @@ -173,33 +181,63 @@ def test_simple_record_query(self) -> None:
_x = query.expression_factory
results = query.dimension_records("detector")
self.check_detector_records(results)
self.check_detector_records_returned(butler.query_dimension_records("detector"))
self.check_detector_records(results.order_by("detector"), ordered=True)
self.check_detector_records_returned(
butler.query_dimension_records("detector", order_by="detector"), ordered=True
)
self.check_detector_records(
results.order_by(_x.detector.full_name.desc), [4, 3, 2, 1], ordered=True
)
self.check_detector_records_returned(
butler.query_dimension_records("detector", order_by="-full_name"),
ids=[4, 3, 2, 1],
ordered=True,
)
self.check_detector_records(results.order_by("detector").limit(2), [1, 2], ordered=True)
self.check_detector_records_returned(
butler.query_dimension_records("detector", limit=2, order_by="detector"),
ids=[1, 2],
ordered=True,
)
self.check_detector_records(results.where(_x.detector.raft == "B", instrument="Cam1"), [3, 4])
self.check_detector_records_returned(
butler.query_dimension_records(
"detector", where="detector.raft = R", bind={"R": "B"}, instrument="Cam1"
),
ids=[3, 4],
)

def test_simple_data_coordinate_query(self) -> None:
butler = self.make_butler("base.yaml")

expected_detectors = [1, 2, 3, 4]
universe = butler.dimensions
expected_coordinates = [
DataCoordinate.standardize({"instrument": "Cam1", "detector": x}, universe=universe)
for x in expected_detectors
]

with butler.query() as query:
# Test empty query
self.assertCountEqual(query.data_ids([]), [DataCoordinate.makeEmpty(butler.dimensions)])
empty = DataCoordinate.make_empty(butler.dimensions)
self.assertCountEqual(list(query.data_ids([])), [empty])
self.assertCountEqual(butler.query_data_ids([]), [empty])

# Test query for a single dimension
results = query.data_ids(["detector"])
expected_detectors = [1, 2, 3, 4]
universe = butler.dimensions
expected_coordinates = [
DataCoordinate.standardize({"instrument": "Cam1", "detector": x}, universe=universe)
for x in expected_detectors
]
self.assertCountEqual(list(results), expected_coordinates)

data_ids = butler.query_data_ids("detector")
self.assertCountEqual(data_ids, expected_coordinates)

def test_simple_dataset_query(self) -> None:
butler = self.make_butler("base.yaml", "datasets.yaml")
with butler.query() as query:
refs = list(query.datasets("bias", "imported_g").order_by("detector"))
refs_q = list(query.datasets("bias", "imported_g").order_by("detector"))
refs_simple = butler.query_datasets("bias", "imported_g", order_by="detector")

for refs in (refs_q, refs_simple):
self.assertEqual(len(refs), 3)
self.assertEqual(refs[0].id, UUID("e15ab039-bc8b-4135-87c5-90902a7c0b22"))
self.assertEqual(refs[1].id, UUID("51352db4-a47a-447c-b12d-a50b206b17cd"))
Expand Down Expand Up @@ -425,6 +463,14 @@ def test_implied_union_record_query(self) -> None:
list(query.where(physical_filter="Cam1-R1", instrument="Cam1").dimension_records("band")),
[band.RecordClass(name="r")],
)
self.assertCountEqual(
butler.query_dimension_records("band"),
[band.RecordClass(name="g"), band.RecordClass(name="r")],
)
self.assertCountEqual(
butler.query_dimension_records("band", physical_filter="Cam1-R1", instrument="Cam1"),
[band.RecordClass(name="r")],
)

def test_dataset_constrained_record_query(self) -> None:
"""Test a query for dimension records constrained by the existence of
Expand Down Expand Up @@ -531,6 +577,15 @@ def test_spatial_overlaps(self) -> None:
[1, 3, 4],
has_postprocessing=True,
)
self.check_detector_records_returned(
butler.query_dimension_records(
"detector",
where="visit_detector_region.region OVERLAPS region",
bind={"region": htm7.pixelization.pixel(253954)},
visit=1,
),
ids=[1, 3, 4],
)
# Query for detectors from a particular visit that overlap an htm7
# ID. This is basically the same query as the last one, but
# expressed as a spatial join, and we can recognize that
Expand All @@ -556,6 +611,14 @@ def test_spatial_overlaps(self) -> None:
[1, 3, 4],
has_postprocessing=False,
)
self.check_detector_records_returned(
butler.query_dimension_records(
"detector",
visit=1,
htm7=253954,
),
ids=[1, 3, 4],
)
# Query for the detectors from any visit that overlap a region:
# this gets contributions from multiple visits, and would have
# duplicates if we didn't get rid of them via GROUP BY.
Expand All @@ -566,6 +629,14 @@ def test_spatial_overlaps(self) -> None:
[1, 2, 3, 4],
has_postprocessing=True,
)
self.check_detector_records_returned(
butler.query_dimension_records(
"detector",
where="visit_detector_region.region OVERLAPS region",
bind={"region": htm7.pixelization.pixel(253954)},
),
ids=[1, 2, 3, 4],
)
# Once again we rewrite the region-constraint query as a spatial
# join, which drops the postprocessing. This join has to be
# explicit because `visit` no longer gets into the query dimensions
Expand Down Expand Up @@ -604,6 +675,14 @@ def test_spatial_overlaps(self) -> None:
[1, 2, 3],
has_postprocessing=True,
)
self.check_detector_records_returned(
butler.query_dimension_records(
"detector",
where="visit_detector_region.region OVERLAPS region",
bind={"region": patch_record.region},
),
ids=[1, 2, 3],
)
# Combine postprocessing with order_by and limit.
self.check_detector_records(
query.where(
Expand All @@ -615,6 +694,16 @@ def test_spatial_overlaps(self) -> None:
[3, 2],
has_postprocessing=True,
)
self.check_detector_records_returned(
butler.query_dimension_records(
"detector",
where="visit_detector_region.region OVERLAPS region",
bind={"region": patch_record.region},
order_by="-detector",
limit=2,
),
ids=[3, 2],
)
# Try a case where there are some records before postprocessing but
# none afterwards.
self.check_detector_records(
Expand All @@ -625,7 +714,16 @@ def test_spatial_overlaps(self) -> None:
[],
has_postprocessing=True,
)

self.check_detector_records_returned(
butler.query_dimension_records(
"detector",
where="visit_detector_region.region OVERLAPS region",
bind={"region": patch_record.region},
detector=4,
explain=False,
),
ids=[],
)
# Check spatial queries using points instead of regions.
# This (ra, dec) is a point in the center of the region for visit
# 1, detector 3.
Expand Down Expand Up @@ -744,6 +842,13 @@ def test_common_skypix_overlaps(self) -> None:
],
[253954, 253955],
)
self.assertCountEqual(
[
record.id
for record in butler.query_dimension_records("htm7", skymap="SkyMap1", tract=0, patch=4)
],
[253954, 253955],
)
# Constraint on the patch region (with the query not knowing it
# corresponds to that patch).
(patch,) = query.where(skymap="SkyMap1", tract=0, patch=4).dimension_records("patch")
Expand Down Expand Up @@ -774,6 +879,13 @@ def test_spatial_constraint_queries(self) -> None:
for data_id in query.data_ids(["patch"]).where({"instrument": "HSC", "visit": 318})
],
)
self.assertEqual(
[(9813, 72)],
[
(data_id["tract"], data_id["patch"])
for data_id in butler.query_data_ids(["patch"], instrument="HSC", visit=318)
],
)

# This tests the case where the 'patch' region is needed in
# postprocessing AND is also returned in the result rows.
Expand All @@ -789,6 +901,13 @@ def test_spatial_constraint_queries(self) -> None:
for record in query.dimension_records("patch").where({"instrument": "HSC", "visit": 318})
],
)
self.assertEqual(
[(9813, 72, region_hex)],
[
(record.tract, record.id, record.region.encode().hex())
for record in butler.query_dimension_records("patch", instrument="HSC", visit=318)
],
)

def test_data_coordinate_upload(self) -> None:
"""Test queries for dimension records with a data coordinate upload."""
Expand Down Expand Up @@ -1005,11 +1124,11 @@ def test_timespan_results(self) -> None:
"""Test returning dimension records that include timespans."""
butler = self.make_butler("base.yaml", "spatial.yaml")
with butler.query() as query:
query_results = list(query.dimension_records("visit"))
simple_results = butler.query_dimension_records("visit")
for results in (query_results, simple_results):
self.assertCountEqual(
[
(record.id, record.timespan.begin, record.timespan.end)
for record in query.dimension_records("visit")
],
[(record.id, record.timespan.begin, record.timespan.end) for record in results],
[
(
1,
Expand Down Expand Up @@ -1060,6 +1179,10 @@ def test_column_expressions(self) -> None:
query.where(_x.not_(_x.detector != 2)).dimension_records("detector"),
[2],
)
self.check_detector_records_returned(
butler.query_dimension_records("detector", where="NOT (detector != 2)"),
[2],
)
self.check_detector_records(
# Empty string expression should evaluate to True.
query.where(_x.detector == 2, "").dimension_records("detector"),
Expand Down Expand Up @@ -1121,6 +1244,18 @@ def test_column_expressions(self) -> None:
],
[2],
)
self.assertCountEqual(
[
record.id
for record in butler.query_dimension_records(
# In the middle of the timespan.
"visit",
where="visit.timespan OVERLAPS(ts)",
bind={"ts": astropy.time.Time("2021-09-09T03:02:30", format="isot", scale="tai")},
)
],
[2],
)
self.assertCountEqual(
[
record.id
Expand Down Expand Up @@ -1199,6 +1334,12 @@ def test_column_expressions(self) -> None:
query.where(_x.detector.in_iterable([1, 3, 4])).dimension_records("detector"),
[1, 3, 4],
)
self.check_detector_records_returned(
butler.query_dimension_records(
"detector", where="detector IN (det)", bind={"det": [1, 3, 4]}
),
[1, 3, 4],
)
self.check_detector_records(
query.where(_x.detector.in_range(start=2, stop=None)).dimension_records("detector"),
[2, 3, 4],
Expand Down Expand Up @@ -1259,14 +1400,23 @@ def _run_registry_query(where: str) -> list[int]:
butler.registry.queryDimensionRecords("exposure", where=where, instrument="HSC")
)

def _run_simple_query(where: str) -> list[int]:
return _get_exposure_ids_from_dimension_records(
butler.query_dimension_records("exposure", where=where, instrument="HSC")
)

def _run_query(where: str) -> list[int]:
with butler.query() as query:
return _get_exposure_ids_from_dimension_records(
query.dimension_records("exposure").where(where, instrument="HSC")
)

# Test boolean columns in the `where` string syntax.
for test, query_func in [("registry", _run_registry_query), ("new-query", _run_query)]:
for test, query_func in [
("registry", _run_registry_query),
("new-query", _run_query),
("simple", _run_simple_query),
]:
with self.subTest(test):
# Boolean columns should be usable standalone as an expression.
self.assertCountEqual(query_func("exposure.can_see_sky"), [TRUE_ID])
Expand Down Expand Up @@ -1392,6 +1542,16 @@ def test_dataset_region_queries(self) -> None:
refs = list(results)
self.assertEqual(len(refs), count, f"POS={pos} REFS={refs}")

simple_refs = butler.query_datasets(
"calexp",
collections=run,
instrument="HSC",
where="visit_detector_region.region OVERLAPS(POS)",
bind={"POS": Region.from_ivoa_pos(pos)},
explain=False,
)
self.assertCountEqual(refs, simple_refs)

def test_dataset_time_queries(self) -> None:
"""Test region queries for datasets."""
# Import data to play with.
Expand Down

0 comments on commit 536b3ff

Please sign in to comment.