diff --git a/python/lsst/daf/butler/_utilities/locked_object.py b/python/lsst/daf/butler/_utilities/locked_object.py new file mode 100644 index 0000000000..4b8432c016 --- /dev/null +++ b/python/lsst/daf/butler/_utilities/locked_object.py @@ -0,0 +1,52 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from contextlib import contextmanager +from threading import Lock +from typing import Generic, Iterator, TypeVar + +_T = TypeVar("_T") + + +class LockedObject(Generic[_T]): + """Wraps an object to enforce that all accesses to the object are performed + while holding a mutex lock. + + Parameters + ---------- + obj : `object` + The object that will be returned from the ``access()`` method. + """ + + def __init__(self, obj: _T): + self._lock = Lock() + self._obj = obj + + @contextmanager + def access(self) -> Iterator[_T]: + with self._lock: + yield self._obj diff --git a/python/lsst/daf/butler/remote_butler/_factory.py b/python/lsst/daf/butler/remote_butler/_factory.py index 624c4afd5e..64ff112f6c 100644 --- a/python/lsst/daf/butler/remote_butler/_factory.py +++ b/python/lsst/daf/butler/remote_butler/_factory.py @@ -36,7 +36,7 @@ from .._butler_instance_options import ButlerInstanceOptions from ._authentication import get_authentication_token_from_environment from ._config import RemoteButlerConfigModel -from ._remote_butler import RemoteButler +from ._remote_butler import RemoteButler, RemoteButlerCache class RemoteButlerFactory: @@ -67,6 +67,7 @@ def __init__(self, server_url: str, http_client: httpx.Client | None = None): self.http_client = http_client else: self.http_client = httpx.Client() + self._cache = RemoteButlerCache() @staticmethod def create_factory_from_config(config: ButlerConfig) -> RemoteButlerFactory: @@ -95,6 +96,7 @@ def create_butler_for_access_token( access_token=access_token, options=butler_options, server_url=self.server_url, + cache=self._cache, ) def create_butler_with_credentials_from_environment( diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index 5d620771a4..393a643aa0 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -31,6 +31,7 @@ from collections.abc import Collection, Iterable, Mapping, Sequence from contextlib import AbstractContextManager +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TextIO, Type, TypeVar, cast import httpx @@ -46,6 +47,7 @@ from .._dataset_ref import DatasetId, DatasetIdGenEnum, DatasetRef, SerializedDatasetRef from .._dataset_type import DatasetType, SerializedDatasetType from .._storage_class import StorageClass +from .._utilities.locked_object import LockedObject from ..datastore import DatasetRefURIs from ..dimensions import DataCoordinate, DataIdValue, DimensionConfig, DimensionUniverse, SerializedDataId from ..registry import MissingDatasetTypeError, NoDefaultCollectionError, RegistryDefaults @@ -89,9 +91,12 @@ class RemoteButler(Butler): # numpydoc ignore=PR02 Default values and other settings for the Butler instance. http_client : `httpx.Client` HTTP connection pool we will use to connect to the server. - access_token : `str` or `None`, optional + access_token : `str` Rubin Science Platform Gafaelfawr access token that will be used to authenticate with the server. + cache : RemoteButlerCache + Cache of data shared between multiple RemoteButler instances connected + to the same server. Notes ----- @@ -99,11 +104,11 @@ class RemoteButler(Butler): # numpydoc ignore=PR02 `Butler.from_config` or `RemoteButlerFactory`. """ - _dimensions: DimensionUniverse | None _registry_defaults: RegistryDefaults _client: httpx.Client _server_url: str _headers: dict[str, str] + _cache: RemoteButlerCache # This is __new__ instead of __init__ because we have to support # instantiation via the legacy constructor Butler.__new__(), which @@ -114,13 +119,19 @@ class RemoteButler(Butler): # numpydoc ignore=PR02 # a second time with the original arguments to Butler() when the instance # is returned from Butler.__new__() def __new__( - cls, *, server_url: str, options: ButlerInstanceOptions, http_client: httpx.Client, access_token: str + cls, + *, + server_url: str, + options: ButlerInstanceOptions, + http_client: httpx.Client, + access_token: str, + cache: RemoteButlerCache, ) -> RemoteButler: self = cast(RemoteButler, super().__new__(cls)) self._client = http_client self._server_url = server_url - self._dimensions = None + self._cache = cache # TODO: RegistryDefaults should have finish() called on it, but this # requires getCollectionSummary() which is not yet implemented @@ -142,15 +153,19 @@ def isWriteable(self) -> bool: @property def dimensions(self) -> DimensionUniverse: # Docstring inherited. - if self._dimensions is not None: - return self._dimensions + with self._cache.access() as cache: + if cache.dimensions is not None: + return cache.dimensions response = self._client.get(self._get_url("universe")) response.raise_for_status() config = DimensionConfig.fromString(response.text, format="json") - self._dimensions = DimensionUniverse(config) - return self._dimensions + universe = DimensionUniverse(config) + with self._cache.access() as cache: + if cache.dimensions is None: + cache.dimensions = universe + return cache.dimensions def _simplify_dataId(self, dataId: DataId | None, kwargs: dict[str, DataIdValue]) -> SerializedDataId: """Take a generic Data ID and convert it to a serializable form. @@ -649,3 +664,13 @@ def _extract_dataset_type(datasetRefOrType: DatasetRef | DatasetType | str) -> D return datasetRefOrType.datasetType else: return None + + +@dataclass +class _RemoteButlerCacheData: + dimensions: DimensionUniverse | None = None + + +class RemoteButlerCache(LockedObject[_RemoteButlerCacheData]): + def __init__(self) -> None: + super().__init__(_RemoteButlerCacheData()) diff --git a/tests/test_locked_object.py b/tests/test_locked_object.py new file mode 100644 index 0000000000..4ce97a8af5 --- /dev/null +++ b/tests/test_locked_object.py @@ -0,0 +1,47 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +from lsst.daf.butler._utilities.locked_object import LockedObject + + +class LockedObjectTestCase(unittest.TestCase): + """Test LockedObject.""" + + def test_named_locks(self): + data = object() + locked_obj = LockedObject(data) + self.assertFalse(locked_obj._lock.locked()) + with locked_obj.access() as accessed: + self.assertTrue(locked_obj._lock.locked()) + self.assertIs(data, accessed) + self.assertFalse(locked_obj._lock.locked()) + + +if __name__ == "__main__": + unittest.main()