Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45993: Optimize DirectButlerCollections.query_info to avoid too many queries #1075

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions python/lsst/daf/butler/_butler_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@
__all__ = ("ButlerCollections", "CollectionInfo")

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence, Set
from typing import Any, overload
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any, overload

from pydantic import BaseModel

from ._collection_type import CollectionType

if TYPE_CHECKING:
from ._dataset_type import DatasetType


class CollectionInfo(BaseModel):
"""Information about a single Butler collection."""
Expand Down Expand Up @@ -275,6 +279,8 @@
include_chains: bool | None = None,
include_parents: bool = False,
include_summary: bool = False,
include_doc: bool = False,
summary_datasets: Iterable[DatasetType] | None = None,
) -> Sequence[CollectionInfo]:
"""Query the butler for collections matching an expression and
return detailed information about those collections.
Expand All @@ -298,6 +304,14 @@
include_summary : `bool`, optional
Whether the returned information includes dataset type and
governor information for the collections.
include_doc : `bool`, optional
Whether the returned information includes collection documentation
string.
summary_datasets : `~collections.abc.Iterable` [ `DatasetType` ], \
optional
Dataset types to include in returned summaries. Only used if
``include_summary`` is `True`. If not specified then all dataset
types will be included.

Returns
-------
Expand Down Expand Up @@ -411,3 +425,30 @@
collection_dataset_types.update(info.dataset_types)
dataset_types_set = dataset_types_set.intersection(collection_dataset_types)
return dataset_types_set

def _group_by_dataset_type(
self, dataset_types: Set[str], collection_infos: Iterable[CollectionInfo]
) -> Mapping[str, list[str]]:
"""Filter dataset types and collections names based on summary in
collecion infos.

Parameters
----------
dataset_types : `~collections.abc.Set` [`str`]
Set of dataset type names to extract.
collection_infos : `~collections.abc.Iterable` [`CollectionInfo`]
Collection infos, must contain dataset type summary.

Returns
-------
filtered : `~collections.abc.Mapping` [`str`, `list`[`str`]]
Mapping of the dataset type name to its corresponding list of
collection names.
"""
dataset_type_collections: dict[str, list[str]] = defaultdict(list)
for info in collection_infos:
if info.dataset_types is None:
raise RuntimeError("Can only filter by collections if include_summary was True")

Check warning on line 451 in python/lsst/daf/butler/_butler_collections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/_butler_collections.py#L451

Added line #L451 was not covered by tests
for dataset_type in info.dataset_types & dataset_types:
dataset_type_collections[dataset_type].append(info.name)
return dataset_type_collections
59 changes: 51 additions & 8 deletions python/lsst/daf/butler/direct_butler/_direct_butler_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

__all__ = ("DirectButlerCollections",)

from collections.abc import Iterable, Sequence, Set
from collections.abc import Iterable, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any

import sqlalchemy
from lsst.utils.iteration import ensure_iterable
Expand All @@ -39,6 +40,11 @@
from ..registry._exceptions import OrphanedRecordError
from ..registry.interfaces import ChainedCollectionRecord
from ..registry.sql_registry import SqlRegistry
from ..registry.wildcards import CollectionWildcard

if TYPE_CHECKING:
from .._dataset_type import DatasetType
from ..registry._collection_summary import CollectionSummary


class DirectButlerCollections(ButlerCollections):
Expand Down Expand Up @@ -107,20 +113,57 @@ def query_info(
include_chains: bool | None = None,
include_parents: bool = False,
include_summary: bool = False,
include_doc: bool = False,
summary_datasets: Iterable[DatasetType] | None = None,
) -> Sequence[CollectionInfo]:
info = []
with self._registry.caching_context():
if collection_types is None:
collection_types = CollectionType.all()
for name in self._registry.queryCollections(
expression,
collectionTypes=collection_types,
flattenChains=flatten_chains,
includeChains=include_chains,
):
elif isinstance(collection_types, CollectionType):
collection_types = {collection_types}

records = self._registry._managers.collections.resolve_wildcard(
CollectionWildcard.from_expression(expression),
collection_types=collection_types,
flatten_chains=flatten_chains,
include_chains=include_chains,
)

summaries: Mapping[Any, CollectionSummary] = {}
if include_summary:
summaries = self._registry._managers.datasets.fetch_summaries(records, summary_datasets)

docs: Mapping[Any, str] = {}
if include_doc:
docs = self._registry._managers.collections.get_docs(record.key for record in records)

for record in records:
doc = docs.get(record.key, "")
children: tuple[str, ...] = tuple()
if record.type == CollectionType.CHAINED:
assert isinstance(record, ChainedCollectionRecord)
children = tuple(record.children)
parents: frozenset[str] | None = None
if include_parents:
# TODO: This is non-vectorized, so expensive to do in a
# loop.
parents = frozenset(self._registry.getCollectionParentChains(record.name))
dataset_types: Set[str] | None = None
if summary := summaries.get(record.key):
dataset_types = frozenset([dt.name for dt in summary.dataset_types])

info.append(
self.get_info(name, include_parents=include_parents, include_summary=include_summary)
CollectionInfo(
name=record.name,
type=record.type,
doc=doc,
parents=parents,
children=children,
dataset_types=dataset_types,
)
)

return info

def get_info(
Expand Down
26 changes: 19 additions & 7 deletions python/lsst/daf/butler/registry/collections/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, TypeVar, cast

import sqlalchemy
from lsst.utils.iteration import chunk_iterable

from ..._collection_type import CollectionType
from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError
Expand Down Expand Up @@ -450,13 +451,24 @@ def filter_types(records: Iterable[CollectionRecord[K]]) -> Iterator[CollectionR

def getDocumentation(self, key: K) -> str | None:
# Docstring inherited from CollectionManager.
sql = (
sqlalchemy.sql.select(self._tables.collection.columns.doc)
.select_from(self._tables.collection)
.where(self._tables.collection.columns[self._collectionIdName] == key)
)
with self._db.query(sql) as sql_result:
return sql_result.scalar()
docs = self.get_docs([key])
return docs.get(key)

def get_docs(self, keys: Iterable[K]) -> Mapping[K, str]:
# Docstring inherited from CollectionManager.
docs: dict[K, str] = {}
id_column = self._tables.collection.columns[self._collectionIdName]
doc_column = self._tables.collection.columns.doc
for chunk in chunk_iterable(keys):
sql = (
sqlalchemy.sql.select(id_column, doc_column)
.select_from(self._tables.collection)
.where(sqlalchemy.sql.and_(id_column.in_(chunk), doc_column != sqlalchemy.literal("")))
)
with self._db.query(sql) as sql_result:
for row in sql_result:
docs[row[0]] = row[1]
return docs

def setDocumentation(self, key: K, doc: str | None) -> None:
# Docstring inherited from CollectionManager.
Expand Down
19 changes: 18 additions & 1 deletion python/lsst/daf/butler/registry/interfaces/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]

from abc import abstractmethod
from collections.abc import Iterable, Set
from collections.abc import Iterable, Mapping, Set
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar

import sqlalchemy
Expand Down Expand Up @@ -570,6 +570,23 @@ def getDocumentation(self, key: _Key) -> str | None:
"""
raise NotImplementedError()

@abstractmethod
def get_docs(self, key: Iterable[_Key]) -> Mapping[_Key, str]:
"""Retrieve the documentation string for multiple collections.

Parameters
----------
key : `~collections.abc.Iterable` [ _Key ]
Internal primary key value for the collection.

Returns
-------
docs : `~collections.abc.Mapping` [ _Key, `str`]
Documentation strings indexed by collection key. Only collections
with non-empty documentation strings are returned.
"""
raise NotImplementedError()

@abstractmethod
def setDocumentation(self, key: _Key, doc: str | None) -> None:
"""Set the documentation string for a collection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .._collection_type import CollectionType

if TYPE_CHECKING:
from .._dataset_type import DatasetType
from ._registry import RemoteButlerRegistry


Expand Down Expand Up @@ -79,6 +80,8 @@ def query_info(
include_chains: bool | None = None,
include_parents: bool = False,
include_summary: bool = False,
include_doc: bool = False,
summary_datasets: Iterable[DatasetType] | None = None,
) -> Sequence[CollectionInfo]:
# This should become a single call on the server in the future.
if collection_types is None:
Expand Down
1 change: 1 addition & 0 deletions python/lsst/daf/butler/script/exportCalibs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def exportCalibs(
collections_query,
flatten_chains=True,
include_chains=True,
include_doc=True,
collection_types={CollectionType.CALIBRATION, CollectionType.CHAINED},
):
log.info("Checking collection: %s", collection.name)
Expand Down
21 changes: 9 additions & 12 deletions python/lsst/daf/butler/script/queryDataIds.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,27 +177,24 @@ def queryDataIds(
if datasets:
# Need to constrain results based on dataset type and collection.
query_collections = collections or "*"
collections_info = butler.collections.query_info(query_collections, include_summary=True)
collections_info = butler.collections.query_info(
query_collections, include_summary=True, summary_datasets=dataset_types
)
expanded_collections = [info.name for info in collections_info]
filtered_dataset_types = list(
butler.collections._filter_dataset_types([dt.name for dt in dataset_types], collections_info)
dataset_type_collections = butler.collections._group_by_dataset_type(
{dt.name for dt in dataset_types}, collections_info
)
if not filtered_dataset_types:
if not dataset_type_collections:
return (
None,
f"No datasets of type {datasets!r} existed in the specified "
f"collections {','.join(expanded_collections)}.",
)

sub_query = query.join_dataset_search(
filtered_dataset_types.pop(0), collections=expanded_collections
)
for dt in filtered_dataset_types:
sub_query = sub_query.join_dataset_search(dt, collections=expanded_collections)
for dt, dt_collections in dataset_type_collections.items():
query = query.join_dataset_search(dt, collections=dt_collections)

results = sub_query.data_ids(dimensions)
else:
results = query.data_ids(dimensions)
results = query.data_ids(dimensions)

if where:
results = results.where(where)
Expand Down
28 changes: 18 additions & 10 deletions python/lsst/daf/butler/script/queryDatasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,31 +240,39 @@
Dataset references matching the given query criteria grouped
by dataset type.
"""
datasetTypes = self._dataset_type_glob or ...
datasetTypes = self._dataset_type_glob
query_collections: Iterable[str] = self._collections_wildcard or ["*"]

# Currently need to use old interface to get all the matching
# dataset types and loop over the dataset types executing a new
# query each time.
dataset_types: set[str] = {d.name for d in self.butler.registry.queryDatasetTypes(datasetTypes)}
dataset_types = set(self.butler.registry.queryDatasetTypes(datasetTypes or ...))
n_dataset_types = len(dataset_types)
if n_dataset_types == 0:
_LOG.info("The given dataset type, %s, is not known to this butler.", datasetTypes)
return

Check warning on line 253 in python/lsst/daf/butler/script/queryDatasets.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/script/queryDatasets.py#L252-L253

Added lines #L252 - L253 were not covered by tests

# Expand the collections query and include summary information.
query_collections_info = self.butler.collections.query_info(query_collections, include_summary=True)
query_collections_info = self.butler.collections.query_info(
query_collections,
include_summary=True,
flatten_chains=True,
include_chains=False,
summary_datasets=dataset_types,
)
expanded_query_collections = [c.name for c in query_collections_info]
if self._find_first and set(query_collections) != set(expanded_query_collections):
raise RuntimeError("Can not use wildcards in collections when find_first=True")
query_collections = expanded_query_collections

# Only iterate over dataset types that are relevant for the query.
dataset_types = set(
self.butler.collections._filter_dataset_types(dataset_types, query_collections_info)
dataset_type_names = {dataset_type.name for dataset_type in dataset_types}
dataset_type_collections = self.butler.collections._group_by_dataset_type(
dataset_type_names, query_collections_info
)

if (n_filtered := len(dataset_types)) != n_dataset_types:
if (n_filtered := len(dataset_type_collections)) != n_dataset_types:
_LOG.info("Filtered %d dataset types down to %d", n_dataset_types, n_filtered)
elif n_dataset_types == 0:
_LOG.info("The given dataset type, %s, is not known to this butler.", datasetTypes)
else:
_LOG.info("Processing %d dataset type%s", n_dataset_types, "" if n_dataset_types == 1 else "s")

Expand All @@ -278,7 +286,7 @@
# possible dataset types to query.
warn_limit = True
limit = abs(limit) + 1 # +1 to tell us we hit the limit.
for dt in sorted(dataset_types):
for dt, collections in sorted(dataset_type_collections.items()):
kwargs: dict[str, Any] = {}
if self._where:
kwargs["where"] = self._where
Expand All @@ -288,7 +296,7 @@
_LOG.debug("Querying dataset type %s with %s", dt, kwargs)
results = self.butler.query_datasets(
dt,
collections=query_collections,
collections=collections,
find_first=self._find_first,
with_dimension_records=True,
order_by=self._order_by,
Expand Down
24 changes: 12 additions & 12 deletions python/lsst/daf/butler/script/queryDimensionRecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ def queryDimensionRecords(

if datasets:
query_collections = collections or "*"
collections_info = butler.collections.query_info(query_collections, include_summary=True)
expanded_collections = [info.name for info in collections_info]
dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(datasets)]
dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections_info))

if not dataset_types:
dataset_types = butler.registry.queryDatasetTypes(datasets)
collections_info = butler.collections.query_info(
query_collections, include_summary=True, summary_datasets=dataset_types
)
dataset_type_collections = butler.collections._group_by_dataset_type(
{dt.name for dt in dataset_types}, collections_info
)

if not dataset_type_collections:
return None

sub_query = query.join_dataset_search(dataset_types.pop(0), collections=expanded_collections)
for dt in dataset_types:
sub_query = sub_query.join_dataset_search(dt, collections=expanded_collections)
for dt, dt_collections in dataset_type_collections.items():
query = query.join_dataset_search(dt, collections=dt_collections)

query_results = sub_query.dimension_records(element)
else:
query_results = query.dimension_records(element)
query_results = query.dimension_records(element)

if where:
query_results = query_results.where(where)
Expand Down
Loading
Loading