Skip to content

Commit

Permalink
Make ButlerRepoIndex threadsafe
Browse files Browse the repository at this point in the history
Python does not guarantee that dict mutations are threadsafe, and in particular the current CPython implementation is not if __eq__ is defined for the object used as the key.  (Almost every Butler-related object defines a custom __eq__.)
  • Loading branch information
dhirving committed Dec 22, 2023
1 parent 610542d commit 7667513
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/lsst/daf/butler/_butler_repo_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lsst.resources import ResourcePath

from ._config import Config
from ._thread_safe_cache import ThreadSafeCache


class ButlerRepoIndex:
Expand All @@ -57,7 +58,7 @@ class ButlerRepoIndex:
index_env_var: ClassVar[str] = "DAF_BUTLER_REPOSITORY_INDEX"
"""The name of the environment variable to read to locate the index."""

_cache: ClassVar[dict[ResourcePath, Config]] = {}
_cache: ClassVar[ThreadSafeCache[ResourcePath, Config]] = ThreadSafeCache()
"""Cache of indexes. In most scenarios only one index will be found
and the environment will not change. In tests this may not be true."""

Expand Down Expand Up @@ -88,8 +89,9 @@ def _read_repository_index(cls, index_uri: ResourcePath) -> Config:
-----
Does check the cache before reading the file.
"""
if index_uri in cls._cache:
return cls._cache[index_uri]
config = cls._cache.get(index_uri)
if config is not None:
return config

try:
repo_index = Config(index_uri)
Expand All @@ -100,7 +102,7 @@ def _read_repository_index(cls, index_uri: ResourcePath) -> Config:
raise RuntimeError(
f"Butler repository index file at {index_uri} could not be read: {type(e).__qualname__} {e}"
) from e
cls._cache[index_uri] = repo_index
repo_index = cls._cache.set_or_get(index_uri, repo_index)

return repo_index

Expand Down
78 changes: 78 additions & 0 deletions python/lsst/daf/butler/_thread_safe_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# This file is part of daf_butler.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import threading
from typing import Generic, TypeVar

TKey = TypeVar("TKey")
TValue = TypeVar("TValue")


class ThreadSafeCache(Generic[TKey, TValue]):
"""A simple thread-safe cache. Ensures that once a value is stored for
a key, it does not change.
"""

def __init__(self) -> None:
self._mutex = threading.Lock()
self._cache: dict[TKey, TValue] = dict()

def get(self, key: TKey) -> TValue | None:
"""Return the value associated with the given key, or ``None`` if no
value has been assigned to that key.
Parameters
----------
key : ``TKey``
Key used to look up the value.
"""
with self._mutex:
return self._cache.get(key)

def set_or_get(self, key: TKey, value: TValue) -> TValue:
"""Set a value for a key if the key does not already have a value.
Parameters
----------
key : ``TKey``
Key used to look up the value.
value : ``TValue``
Value to store in the cache.
Returns
-------
value : ``TValue``
The existing value stored for the key if it was present, or
``value`` if this was a new key.
"""
with self._mutex:
existing_value = self._cache.get(key)
if existing_value is None:
self._cache[key] = value
return value
else:
return existing_value
47 changes: 47 additions & 0 deletions tests/test_thread_safe_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This file is part of daf_butler.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import unittest

from lsst.daf.butler._thread_safe_cache import ThreadSafeCache


class ThreadSafeCacheTestCase(unittest.TestCase):
"""Test ThreadSafeCache."""

def test_cache(self):
cache = ThreadSafeCache()
self.assertIsNone(cache.get("unknown"))
self.assertEqual(cache.set_or_get("key", "a"), "a")
self.assertEqual(cache.get("key"), "a")
self.assertEqual(cache.set_or_get("key", "b"), "a")
self.assertEqual(cache.get("key"), "a")
self.assertIsNone(cache.get("other"))


if __name__ == "__main__":
unittest.main()

0 comments on commit 7667513

Please sign in to comment.