Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45726: Pass Butler to RepoExportContext rather than Registry #1052

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/lsst/daf/butler/direct_butler/_direct_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,9 +1607,7 @@ def export(
with open(filename, "w") as stream:
backend = BackendClass(stream, universe=self.dimensions)
try:
helper = RepoExportContext(
self._registry, self._datastore, backend=backend, directory=directory, transfer=transfer
)
helper = RepoExportContext(self, backend=backend, directory=directory, transfer=transfer)
with self._caching_context():
yield helper
except BaseException:
Expand Down
35 changes: 15 additions & 20 deletions python/lsst/daf/butler/transfers/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@
from .._dataset_ref import DatasetId, DatasetRef
from .._dataset_type import DatasetType
from .._file_dataset import FileDataset
from ..datastore import Datastore
from ..dimensions import DataCoordinate, DimensionElement, DimensionRecord
from ..registry import CollectionType
from ..registry.interfaces import ChainedCollectionRecord, CollectionRecord

if TYPE_CHECKING:
from lsst.resources import ResourcePathExpression

from ..registry.sql_registry import SqlRegistry
from ..direct_butler import DirectButler
from ._interfaces import RepoExportBackend


Expand All @@ -61,10 +60,8 @@ class RepoExportContext:

Parameters
----------
registry : `SqlRegistry`
Registry to export from.
datastore : `Datastore`
Datastore to export from.
butler : `lsst.daf.butler.direct_butler.DirectButler`
Butler to export from.
backend : `RepoExportBackend`
Implementation class for a particular export file format.
directory : `~lsst.resources.ResourcePathExpression`, optional
Expand All @@ -76,15 +73,13 @@ class RepoExportContext:

def __init__(
self,
registry: SqlRegistry,
datastore: Datastore,
butler: DirectButler, # Requires butler._registry to work for now.
backend: RepoExportBackend,
*,
directory: ResourcePathExpression | None = None,
transfer: str | None = None,
):
self._registry = registry
self._datastore = datastore
self._butler = butler
self._backend = backend
self._directory = directory
self._transfer = transfer
Expand Down Expand Up @@ -118,7 +113,7 @@ def saveCollection(self, name: str) -> None:
export its child collections; these must be explicitly exported or
already be present in the repository they are being imported into.
"""
self._collections[name] = self._registry.get_collection_record(name)
self._collections[name] = self._butler._registry.get_collection_record(name)

def saveDimensionData(
self, element: str | DimensionElement, records: Iterable[dict | DimensionRecord]
Expand All @@ -136,7 +131,7 @@ def saveDimensionData(
`dict` instances.
"""
if not isinstance(element, DimensionElement):
element = self._registry.dimensions[element]
element = self._butler.dimensions[element]
for record in records:
if not isinstance(record, DimensionRecord):
record = element.RecordClass(**record)
Expand Down Expand Up @@ -170,13 +165,13 @@ def saveDataIds(
standardized_elements: Set[DimensionElement]
if elements is None:
standardized_elements = frozenset(
element for element in self._registry.dimensions.elements if element.has_own_table
element for element in self._butler.dimensions.elements if element.has_own_table
)
else:
standardized_elements = set()
for element in elements:
if not isinstance(element, DimensionElement):
element = self._registry.dimensions[element]
element = self._butler.dimensions[element]
if element.has_own_table:
standardized_elements.add(element)
for dataId in dataIds:
Expand All @@ -185,7 +180,7 @@ def saveDataIds(
# if the data ID is already expanded, and DM-26692 will add (or at
# least start to add / unblock) query functionality that should
# let us speed this up internally as well.
dataId = self._registry.expandDataId(dataId)
dataId = self._butler.registry.expandDataId(dataId)
for element_name in dataId.dimensions.elements:
record = dataId.records[element_name]
if record is not None and record.definition in standardized_elements:
Expand Down Expand Up @@ -243,7 +238,7 @@ def saveDatasets(
refs_to_export[dataset_id] = ref
# Do a vectorized datastore export, which might be a lot faster than
# one-by-one.
exports = self._datastore.export(
exports = self._butler._datastore.export(
refs_to_export.values(),
directory=self._directory,
transfer=self._transfer,
Expand All @@ -267,15 +262,15 @@ def _finish(self) -> None:

For use by `Butler.export` only.
"""
for element in self._registry.dimensions.sorted(self._records.keys()):
for element in self._butler.dimensions.sorted(self._records.keys()):
# To make export deterministic sort the DataCoordinate instances.
r = self._records[element]
self._backend.saveDimensionData(element, *[r[dataId] for dataId in sorted(r.keys())])
for datasetsByRun in self._datasets.values():
for run in datasetsByRun:
self._collections[run] = self._registry.get_collection_record(run)
self._collections[run] = self._butler._registry.get_collection_record(run)
for collectionName in self._computeSortedCollections():
doc = self._registry.getCollectionDocumentation(collectionName)
doc = self._butler.registry.getCollectionDocumentation(collectionName)
self._backend.saveCollection(self._collections[collectionName], doc)
# Sort the dataset types and runs before exporting to ensure
# reproducible order in export file.
Expand Down Expand Up @@ -363,7 +358,7 @@ def _computeDatasetAssociations(self) -> dict[str, list[DatasetAssociation]]:
collectionTypes = {CollectionType.TAGGED}
if datasetType.isCalibration():
collectionTypes.add(CollectionType.CALIBRATION)
associationIter = self._registry.queryDatasetAssociations(
associationIter = self._butler.registry.queryDatasetAssociations(
datasetType,
collections=self._collections.keys(),
collectionTypes=collectionTypes,
Expand Down
Loading