Skip to content

Commit

Permalink
Inject SqlQueryContext into ObsCoreManager at top level
Browse files Browse the repository at this point in the history
Modified the ExposureRegionFactory interface to no longer expose SqlQueryContext outside daf_butler.

The ExposureRegionFactory interface has one internal implementation and one external implementation in the dax_obscore package.  The internal implementation needs privileged access to SqlRegistry internals, but the external one does not.  The external one now needs to support RemoteButler and can no longer provide a SqlRegistry object.

In order to make this change, the SqlQueryContext is now created once when ObscoreLiveTableManager is created and passed to the internal ExposureRegionFactory's constructor, instead of being part of the ExposureRegionFactory method call interface.
  • Loading branch information
dhirving committed Sep 16, 2024
1 parent 910bec6 commit 74ff087
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 45 deletions.
16 changes: 7 additions & 9 deletions python/lsst/daf/butler/registry/interfaces/_obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
if TYPE_CHECKING:
from lsst.sphgeom import Region

from ..._column_type_info import ColumnTypeInfo
from ..._dataset_ref import DatasetRef
from ...dimensions import DimensionUniverse
from ..queries import SqlQueryContext
from ._collections import CollectionRecord
from ._database import Database, StaticTablesContext
from ._datasets import DatasetRecordStorageManager
Expand Down Expand Up @@ -103,6 +103,7 @@ def initialize(
datasets: type[DatasetRecordStorageManager],
dimensions: DimensionRecordStorageManager,
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
) -> ObsCoreTableManager:
"""Construct an instance of the manager.
Expand All @@ -124,6 +125,9 @@ def initialize(
Manager for Registry dimensions.
registry_schema_version : `VersionTuple` or `None`
Schema version of this extension as defined in registry.
column_type_info : `ColumnTypeInfo`
Information about column types that can differ between data
repositories and registry instances.
Returns
-------
Expand All @@ -144,7 +148,7 @@ def config_json(self) -> str:
raise NotImplementedError()

@abstractmethod
def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
"""Possibly add datasets to the obscore table.
This method should be called when new datasets are added to a RUN
Expand All @@ -156,8 +160,6 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
Dataset refs to add. Dataset refs have to be completely expanded.
If a record with the same dataset ID is already in obscore table,
the dataset is ignored.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.
Returns
-------
Expand All @@ -180,9 +182,7 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
raise NotImplementedError()

@abstractmethod
def associate(
self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
) -> int:
def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
"""Possibly add datasets to the obscore table.
This method should be called when existing datasets are associated with
Expand All @@ -196,8 +196,6 @@ def associate(
the dataset is ignored.
collection : `CollectionRecord`
Collection record for a TAGGED collection.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.
Returns
-------
Expand Down
23 changes: 12 additions & 11 deletions python/lsst/daf/butler/registry/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,17 @@ def initialize(
universe=universe,
registry_schema_version=schema_versions.get("datastores"),
)
kwargs["column_types"] = ColumnTypeInfo(
database.getTimespanRepresentation(),
universe,
dataset_id_spec=types.datasets.addDatasetForeignKey(
dummy_table,
primaryKey=False,
nullable=False,
),
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
if types.obscore is not None and "obscore" in types.manager_configs:
kwargs["obscore"] = types.obscore.initialize(
database,
Expand All @@ -453,20 +464,10 @@ def initialize(
datasets=types.datasets,
dimensions=kwargs["dimensions"],
registry_schema_version=schema_versions.get("obscore"),
column_type_info=kwargs["column_types"],
)
else:
kwargs["obscore"] = None
kwargs["column_types"] = ColumnTypeInfo(
database.getTimespanRepresentation(),
universe,
dataset_id_spec=types.datasets.addDatasetForeignKey(
dummy_table,
primaryKey=False,
nullable=False,
),
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
kwargs["caching_context"] = caching_context
return cls(**kwargs)

Expand Down
34 changes: 23 additions & 11 deletions python/lsst/daf/butler/registry/obscore/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from lsst.utils.introspection import find_outside_stacklevel
from lsst.utils.iteration import chunk_iterable

from ..._column_type_info import ColumnTypeInfo
from ..interfaces import ObsCoreTableManager, VersionTuple
from ._config import ConfigCollectionType, ObsCoreManagerConfig
from ._records import ExposureRegionFactory, Record, RecordFactory
Expand Down Expand Up @@ -71,14 +72,16 @@ class _ExposureRegionFactory(ExposureRegionFactory):
The dimension records storage manager.
"""

def __init__(self, dimensions: DimensionRecordStorageManager):
def __init__(self, dimensions: DimensionRecordStorageManager, context: SqlQueryContext):
self.dimensions = dimensions
self.universe = dimensions.universe
self.exposure_dimensions = self.universe["exposure"].minimal_group
self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
self._context = context

def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
# Docstring is inherited from a base class.
context = self._context
# Make a relation that starts with visit_definition (mapping between
# exposure and visit).
relation = context.make_initial_relation()
Expand Down Expand Up @@ -134,6 +137,9 @@ class ObsCoreLiveTableManager(ObsCoreTableManager):
Spatial plugins.
registry_schema_version : `VersionTuple` or `None`, optional
Version of registry schema.
column_type_info : `ColumnTypeInfo`
Information about column types that can differ between data
repositories and registry instances.
"""

def __init__(
Expand All @@ -147,6 +153,7 @@ def __init__(
dimensions: DimensionRecordStorageManager,
spatial_plugins: Collection[SpatialObsCorePlugin],
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
):
super().__init__(registry_schema_version=registry_schema_version)
self.db = db
Expand All @@ -155,7 +162,11 @@ def __init__(
self.universe = universe
self.config = config
self.spatial_plugins = spatial_plugins
exposure_region_factory = _ExposureRegionFactory(dimensions)
self._column_type_info = column_type_info
exposure_region_factory = _ExposureRegionFactory(
dimensions,
SqlQueryContext(self.db, column_type_info),
)
self.record_factory = RecordFactory(
config, schema, universe, spatial_plugins, exposure_region_factory
)
Expand Down Expand Up @@ -189,6 +200,7 @@ def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> O
# 'initialize'.
spatial_plugins=self.spatial_plugins,
registry_schema_version=self._registry_schema_version,
column_type_info=self._column_type_info,
)

@classmethod
Expand All @@ -202,6 +214,7 @@ def initialize(
datasets: type[DatasetRecordStorageManager],
dimensions: DimensionRecordStorageManager,
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
) -> ObsCoreTableManager:
# Docstring inherited from base class.
config_data = Config(config)
Expand All @@ -227,6 +240,7 @@ def initialize(
dimensions=dimensions,
spatial_plugins=spatial_plugins,
registry_schema_version=registry_schema_version,
column_type_info=column_type_info,
)

def config_json(self) -> str:
Expand All @@ -244,7 +258,7 @@ def currentVersions(cls) -> list[VersionTuple]:
# Docstring inherited from base class.
return [_VERSION]

def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
# Docstring inherited from base class.

# Only makes sense for RUN collection types
Expand Down Expand Up @@ -279,19 +293,17 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
# Take all refs, no collection check.
obscore_refs = refs

return self._populate(obscore_refs, context)
return self._populate(obscore_refs)

def associate(
self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
) -> int:
def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
# Docstring inherited from base class.

# Only works when collection type is TAGGED
if self.tagged_collection is None:
return 0

if collection.name == self.tagged_collection:
return self._populate(refs, context)
return self._populate(refs)
else:
return 0

Expand All @@ -315,11 +327,11 @@ def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord)
count += self.db.deleteWhere(self.table, where)
return count

def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def _populate(self, refs: Iterable[DatasetRef]) -> int:
"""Populate obscore table with the data from given datasets."""
records: list[Record] = []
for ref in refs:
record = self.record_factory(ref, context)
record = self.record_factory(ref)
if record is not None:
records.append(record)

Expand Down
11 changes: 3 additions & 8 deletions python/lsst/daf/butler/registry/obscore/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
from ._schema import ObsCoreSchema
from ._spatial import SpatialObsCorePlugin

if TYPE_CHECKING:
from ..queries import SqlQueryContext

_LOG = logging.getLogger(__name__)

# Map extra column type to a conversion method that takes string.
Expand All @@ -67,15 +64,13 @@ class ExposureRegionFactory:
"""Abstract interface for a class that returns a Region for an exposure."""

@abstractmethod
def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
"""Return a region for a given DataId that corresponds to an exposure.
Parameters
----------
dataId : `DataCoordinate`
Data ID for an exposure dataset.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.
Returns
-------
Expand Down Expand Up @@ -125,7 +120,7 @@ def __init__(
self.visit = universe["visit"]
self.physical_filter = cast(Dimension, universe["physical_filter"])

def __call__(self, ref: DatasetRef, context: SqlQueryContext) -> Record | None:
def __call__(self, ref: DatasetRef) -> Record | None:
"""Make an ObsCore record from a dataset.
Parameters
Expand Down Expand Up @@ -194,7 +189,7 @@ def __call__(self, ref: DatasetRef, context: SqlQueryContext) -> Record | None:
if (dimension_record := dataId.records[self.exposure.name]) is not None:
self._exposure_records(dimension_record, record)
if self.exposure_region_factory is not None:
region = self.exposure_region_factory.exposure_region(dataId, context)
region = self.exposure_region_factory.exposure_region(dataId)
elif self.visit.name in dataId and (dimension_record := dataId.records[self.visit.name]) is not None:
self._visit_records(dimension_record, record)

Expand Down
9 changes: 3 additions & 6 deletions python/lsst/daf/butler/registry/sql_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,7 @@ def insertDatasets(
try:
refs = list(storage.insert(runRecord, expandedDataIds, idGenerationMode))
if self._managers.obscore:
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.add_datasets(refs, context)
self._managers.obscore.add_datasets(refs)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
"A database constraint failure was triggered by inserting "
Expand Down Expand Up @@ -1193,8 +1192,7 @@ def _importDatasets(
try:
refs = list(storage.import_(runRecord, expandedDatasets))
if self._managers.obscore:
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.add_datasets(refs, context)
self._managers.obscore.add_datasets(refs)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
"A database constraint failure was triggered by inserting "
Expand Down Expand Up @@ -1307,8 +1305,7 @@ def associate(self, collection: str, refs: Iterable[DatasetRef]) -> None:
if self._managers.obscore:
# If a TAGGED collection is being monitored by ObsCore
# manager then we may need to save the dataset.
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.associate(refsForType, collectionRecord, context)
self._managers.obscore.associate(refsForType, collectionRecord)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
f"Constraint violation while associating dataset of type {datasetType.name} with "
Expand Down

0 comments on commit 74ff087

Please sign in to comment.