diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index cf39ea92bc..33d5fd524d 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -40,7 +40,7 @@ from .._butler_config import ButlerConfig from .._config import Config from .._dataset_existence import DatasetExistence -from .._dataset_ref import DatasetIdGenEnum, DatasetRef +from .._dataset_ref import DatasetIdGenEnum, DatasetRef, SerializedDatasetRef from .._dataset_type import DatasetType, SerializedDatasetType from .._deferredDatasetHandle import DeferredDatasetHandle from .._file_dataset import FileDataset @@ -48,10 +48,12 @@ from .._storage_class import StorageClass from .._timespan import Timespan from ..datastore import DatasetRefURIs -from ..dimensions import DataId, DimensionConfig, DimensionUniverse -from ..registry import CollectionArgType, Registry, RegistryDefaults +from ..dimensions import DataCoordinate, DataId, DimensionConfig, DimensionUniverse, SerializedDataCoordinate +from ..registry import CollectionArgType, NoDefaultCollectionError, Registry, RegistryDefaults +from ..registry.wildcards import CollectionWildcard from ..transfers import RepoExportContext from ._config import RemoteButlerConfigModel +from .server import FindDatasetModel class RemoteButler(Butler): @@ -101,6 +103,39 @@ def dimensions(self) -> DimensionUniverse: self._dimensions = DimensionUniverse(config) return self._dimensions + def _simplify_dataId( + self, dataId: DataId | None, **kwargs: dict[str, int | str] + ) -> SerializedDataCoordinate | None: + """Take a generic Data ID and convert it to a serializable form. + + Parameters + ---------- + dataId : `dict`, `None`, `DataCoordinate` + The data ID to serialize. + **kwargs : `dict` + Additional values that should be included if this is not + a `DataCoordinate`. + + Returns + ------- + data_id : `SerializedDataCoordinate` or `None` + A serializable form. + """ + if dataId is None and not kwargs: + return None + if isinstance(dataId, DataCoordinate): + return dataId.to_simple() + + if dataId is None: + data_id = kwargs + elif kwargs: + # Change variable because DataId is immutable and mypy complains. + data_id = dict(dataId) + data_id.update(kwargs) + + # Assume we can treat it as a dict. + return SerializedDataCoordinate(dataId=data_id) + def getDatasetType(self, name: str) -> DatasetType: # Docstring inherited. raise NotImplementedError() @@ -198,7 +233,31 @@ def find_dataset( datastore_records: bool = False, **kwargs: Any, ) -> DatasetRef | None: - raise NotImplementedError() + if collections is None: + if not self.collections: + raise NoDefaultCollectionError( + "No collections provided to find_dataset, and no defaults from butler construction." + ) + collections = self.collections + # Temporary hack. Assume strings for collections. In future + # want to construct CollectionWildcard and filter it through collection + # cache to generate list of collection names. + wildcards = CollectionWildcard.from_expression(collections) + + if isinstance(datasetType, DatasetType): + datasetType = datasetType.name + + query = FindDatasetModel( + dataId=self._simplify_dataId(dataId, **kwargs), collections=wildcards.strings + ) + + path = f"find_dataset/{datasetType}" + response = self._client.post( + self._get_url(path), json=query.model_dump(mode="json", exclude_unset=True) + ) + response.raise_for_status() + + return DatasetRef.from_simple(SerializedDatasetRef(**response.json()), universe=self.dimensions) def retrieveArtifacts( self, diff --git a/python/lsst/daf/butler/remote_butler/server/__init__.py b/python/lsst/daf/butler/remote_butler/server/__init__.py index d63badaf11..93c9018bc4 100644 --- a/python/lsst/daf/butler/remote_butler/server/__init__.py +++ b/python/lsst/daf/butler/remote_butler/server/__init__.py @@ -27,3 +27,4 @@ from ._factory import * from ._server import * +from ._server_models import * diff --git a/python/lsst/daf/butler/remote_butler/server/_server.py b/python/lsst/daf/butler/remote_butler/server/_server.py index 51bf01a4e6..b791651c47 100644 --- a/python/lsst/daf/butler/remote_butler/server/_server.py +++ b/python/lsst/daf/butler/remote_butler/server/_server.py @@ -35,9 +35,16 @@ from fastapi import Depends, FastAPI from fastapi.middleware.gzip import GZipMiddleware -from lsst.daf.butler import Butler, SerializedDatasetType +from lsst.daf.butler import ( + Butler, + DataCoordinate, + SerializedDataCoordinate, + SerializedDatasetRef, + SerializedDatasetType, +) from ._factory import Factory +from ._server_models import FindDatasetModel BUTLER_ROOT = "ci_hsc_gen3/DATA" @@ -56,6 +63,26 @@ def factory_dependency() -> Factory: return Factory(butler=_make_global_butler()) +def unpack_dataId(butler: Butler, data_id: SerializedDataCoordinate | None) -> DataCoordinate | None: + """Convert the serialized dataId back to full DataCoordinate. + + Parameters + ---------- + butler : `lsst.daf.butler.Butler` + The butler to use for registry and universe. + data_id : `SerializedDataCoordinate` or `None` + The serialized form. + + Returns + ------- + dataId : `DataCoordinate` or `None` + The DataId usable by registry. + """ + if data_id is None: + return None + return DataCoordinate.from_simple(data_id, registry=butler.registry) + + @app.get("/butler/v1/universe", response_model=dict[str, Any]) def get_dimension_universe(factory: Factory = Depends(factory_dependency)) -> dict[str, Any]: """Allow remote client to get dimensions definition.""" @@ -78,3 +105,27 @@ def get_dataset_type( butler = factory.create_butler() datasetType = butler.get_dataset_type(dataset_type_name) return datasetType.to_simple() + + +# Not yet supported: TimeSpan is not yet a pydantic model. +# collections parameter assumes client-side has resolved regexes. +@app.post( + "/butler/v1/find_dataset/{dataset_type}", + summary="Retrieve this dataset definition from collection, dataset type, and dataId", + response_model=SerializedDatasetRef, + response_model_exclude_unset=True, + response_model_exclude_defaults=True, + response_model_exclude_none=True, +) +def find_dataset( + dataset_type: str, + query: FindDatasetModel, + factory: Factory = Depends(factory_dependency), +) -> SerializedDatasetRef | None: + collection_query = query.collections if query.collections else None + + butler = factory.create_butler() + ref = butler.find_dataset( + dataset_type, dataId=unpack_dataId(butler, query.dataId), collections=collection_query + ) + return ref.to_simple() if ref else None diff --git a/python/lsst/daf/butler/remote_butler/server/_server_models.py b/python/lsst/daf/butler/remote_butler/server/_server_models.py index 1c34747e33..686a9ad571 100644 --- a/python/lsst/daf/butler/remote_butler/server/_server_models.py +++ b/python/lsst/daf/butler/remote_butler/server/_server_models.py @@ -26,3 +26,14 @@ # along with this program. If not, see . """Models used for client/server communication.""" + +__all__ = ["FindDatasetModel"] + +from lsst.daf.butler import SerializedDataCoordinate + +from ..._compat import _BaseModelCompat + + +class FindDatasetModel(_BaseModelCompat): + dataId: SerializedDataCoordinate + collections: list[str] diff --git a/tests/test_server.py b/tests/test_server.py index 394ceae728..3d3c3742c7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -27,6 +27,7 @@ import os.path import unittest +import uuid try: # Failing to import any of these should disable the tests. @@ -37,7 +38,8 @@ TestClient = None app = None -from lsst.daf.butler import Butler +from lsst.daf.butler import Butler, DatasetRef +from lsst.daf.butler.tests import DatastoreMock from lsst.daf.butler.tests.utils import MetricTestRepo, makeTestTempDir, removeTestTempDir TESTDIR = os.path.abspath(os.path.dirname(__file__)) @@ -68,6 +70,9 @@ def setUpClass(cls): # Override the server's Butler initialization to point at our test repo server_butler = Butler.from_config(cls.root, writeable=True) + # Not yet testing butler.get() + DatastoreMock.apply(server_butler) + def create_factory_dependency(): return Factory(butler=server_butler) @@ -79,6 +84,7 @@ def create_factory_dependency(): # Populate the test server. server_butler.import_(filename=os.path.join(TESTDIR, "data", "registry", "base.yaml")) + server_butler.import_(filename=os.path.join(TESTDIR, "data", "registry", "datasets-uuid.yaml")) @classmethod def tearDownClass(cls): @@ -98,6 +104,11 @@ def test_get_dataset_type(self): bias_type = self.butler.get_dataset_type("bias") self.assertEqual(bias_type.name, "bias") + def test_find_dataset(self): + ref = self.butler.find_dataset("bias", collections="imported_g", detector=1, instrument="Cam1") + self.assertIsInstance(ref, DatasetRef) + self.assertEqual(ref.id, uuid.UUID("e15ab039-bc8b-4135-87c5-90902a7c0b22")) + if __name__ == "__main__": unittest.main()