From 76675137361a7d5f1c6fe2409595553996e19211 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 21 Dec 2023 17:21:57 -0700 Subject: [PATCH] Make ButlerRepoIndex threadsafe Python does not guarantee that dict mutations are threadsafe, and in particular the current CPython implementation is not if __eq__ is defined for the object used as the key. (Almost every Butler-related object defines a custom __eq__.) --- python/lsst/daf/butler/_butler_repo_index.py | 10 ++- python/lsst/daf/butler/_thread_safe_cache.py | 78 ++++++++++++++++++++ tests/test_thread_safe_cache.py | 47 ++++++++++++ 3 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 python/lsst/daf/butler/_thread_safe_cache.py create mode 100644 tests/test_thread_safe_cache.py diff --git a/python/lsst/daf/butler/_butler_repo_index.py b/python/lsst/daf/butler/_butler_repo_index.py index 184e1f2e3d..bff4e77c95 100644 --- a/python/lsst/daf/butler/_butler_repo_index.py +++ b/python/lsst/daf/butler/_butler_repo_index.py @@ -35,6 +35,7 @@ from lsst.resources import ResourcePath from ._config import Config +from ._thread_safe_cache import ThreadSafeCache class ButlerRepoIndex: @@ -57,7 +58,7 @@ class ButlerRepoIndex: index_env_var: ClassVar[str] = "DAF_BUTLER_REPOSITORY_INDEX" """The name of the environment variable to read to locate the index.""" - _cache: ClassVar[dict[ResourcePath, Config]] = {} + _cache: ClassVar[ThreadSafeCache[ResourcePath, Config]] = ThreadSafeCache() """Cache of indexes. In most scenarios only one index will be found and the environment will not change. In tests this may not be true.""" @@ -88,8 +89,9 @@ def _read_repository_index(cls, index_uri: ResourcePath) -> Config: ----- Does check the cache before reading the file. """ - if index_uri in cls._cache: - return cls._cache[index_uri] + config = cls._cache.get(index_uri) + if config is not None: + return config try: repo_index = Config(index_uri) @@ -100,7 +102,7 @@ def _read_repository_index(cls, index_uri: ResourcePath) -> Config: raise RuntimeError( f"Butler repository index file at {index_uri} could not be read: {type(e).__qualname__} {e}" ) from e - cls._cache[index_uri] = repo_index + repo_index = cls._cache.set_or_get(index_uri, repo_index) return repo_index diff --git a/python/lsst/daf/butler/_thread_safe_cache.py b/python/lsst/daf/butler/_thread_safe_cache.py new file mode 100644 index 0000000000..d1dda9157c --- /dev/null +++ b/python/lsst/daf/butler/_thread_safe_cache.py @@ -0,0 +1,78 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import threading +from typing import Generic, TypeVar + +TKey = TypeVar("TKey") +TValue = TypeVar("TValue") + + +class ThreadSafeCache(Generic[TKey, TValue]): + """A simple thread-safe cache. Ensures that once a value is stored for + a key, it does not change. + """ + + def __init__(self) -> None: + self._mutex = threading.Lock() + self._cache: dict[TKey, TValue] = dict() + + def get(self, key: TKey) -> TValue | None: + """Return the value associated with the given key, or ``None`` if no + value has been assigned to that key. + + Parameters + ---------- + key : ``TKey`` + Key used to look up the value. + """ + with self._mutex: + return self._cache.get(key) + + def set_or_get(self, key: TKey, value: TValue) -> TValue: + """Set a value for a key if the key does not already have a value. + + Parameters + ---------- + key : ``TKey`` + Key used to look up the value. + value : ``TValue`` + Value to store in the cache. + + Returns + ------- + value : ``TValue`` + The existing value stored for the key if it was present, or + ``value`` if this was a new key. + """ + with self._mutex: + existing_value = self._cache.get(key) + if existing_value is None: + self._cache[key] = value + return value + else: + return existing_value diff --git a/tests/test_thread_safe_cache.py b/tests/test_thread_safe_cache.py new file mode 100644 index 0000000000..bc2bc03404 --- /dev/null +++ b/tests/test_thread_safe_cache.py @@ -0,0 +1,47 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +from lsst.daf.butler._thread_safe_cache import ThreadSafeCache + + +class ThreadSafeCacheTestCase(unittest.TestCase): + """Test ThreadSafeCache.""" + + def test_cache(self): + cache = ThreadSafeCache() + self.assertIsNone(cache.get("unknown")) + self.assertEqual(cache.set_or_get("key", "a"), "a") + self.assertEqual(cache.get("key"), "a") + self.assertEqual(cache.set_or_get("key", "b"), "a") + self.assertEqual(cache.get("key"), "a") + self.assertIsNone(cache.get("other")) + + +if __name__ == "__main__": + unittest.main()