Skip to content

Commit

Permalink
Thread-safety for the client Cache() (bluesky#662)
Browse files Browse the repository at this point in the history
* Added a locking mechanism and checks to make the client Cache() thread-safe.

* isort.

* Use a named constant for the sqlite thread safety mode.

Co-authored-by: Padraic Shafer <[email protected]>

* Used a proper IntEnum class for the sqlite threading mode using for caching.

* Use SerializableLock

* Test lock identity through roundtrip.

---------

Co-authored-by: Padraic Shafer <[email protected]>
Co-authored-by: Dan Allan <[email protected]>
  • Loading branch information
3 people authored Feb 23, 2024
1 parent 4a493b4 commit 5fa30b3
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 10 deletions.
56 changes: 55 additions & 1 deletion tiled/_tests/test_client_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import asyncio
import concurrent.futures
import sqlite3
import threading
import time
from contextlib import closing

import numpy
Expand All @@ -7,7 +11,7 @@
from ..adapters.array import ArrayAdapter
from ..adapters.mapping import MapAdapter
from ..client import Context, from_context, record_history
from ..client.cache import Cache, CachedResponse
from ..client.cache import Cache, CachedResponse, ThreadingMode, with_thread_lock
from ..server.app import build_app

tree = MapAdapter(
Expand Down Expand Up @@ -169,3 +173,53 @@ def test_clear_cache(client):
assert cache.count() > 0
cache.clear()
assert cache.size() == cache.count() == 0


def test_not_thread_safe(client, monkeypatch):
# Check that writes fail if thread safety is disabled
monkeypatch.setattr(sqlite3, "threadsafety", ThreadingMode.SINGLE_THREAD)
cache = client.context.cache
# Clear the cache in another thread
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
future = executor.submit(cache.clear)
with pytest.raises(RuntimeError):
future.result(timeout=1)


@pytest.mark.skipif(
sqlite3.threadsafety != ThreadingMode.SERIALIZED,
reason="sqlite not built with thread safe support",
)
def test_thread_safety(client):
cache = client.context.cache
# Clear the cache in another thread
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
future = executor.submit(cache.clear)
future.result(timeout=1)


@pytest.mark.asyncio
async def test_thread_lock():
"""Check that we can prevent concurrent thread writes."""

class Timer:
_lock = threading.Lock()
sleep_time = 0.01

@with_thread_lock
def sleep(self):
time.sleep(self.sleep_time)

timer = Timer()
# Run the timer twice concurrently
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
coros = [
loop.run_in_executor(executor, timer.sleep),
loop.run_in_executor(executor, timer.sleep),
]
t0 = time.perf_counter()
await asyncio.gather(*coros)
run_time = time.perf_counter() - t0
# Check that the threads didn't run in parallel
assert run_time >= (2.0 * timer.sleep_time), "Threads did not lock"
13 changes: 11 additions & 2 deletions tiled/_tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from packaging.version import parse

from ..client import from_context
from ..client.cache import Cache
from ..client.context import Context

MIN_VERSION = "0.1.0a104"
Expand All @@ -24,7 +25,7 @@ def test_pickle_context():


@pytest.mark.parametrize("structure_clients", ["numpy", "dask"])
def test_pickle_clients(structure_clients):
def test_pickle_clients(structure_clients, tmpdir):
try:
httpx.get(API_URL).raise_for_status()
except Exception:
Expand All @@ -34,7 +35,8 @@ def test_pickle_clients(structure_clients):
raise pytest.skip(
f"Server at {API_URL} is running too old a version to test against."
)
client = from_context(context, structure_clients)
cache = Cache(tmpdir / "http_response_cache.db")
client = from_context(context, structure_clients, cache)
pickle.loads(pickle.dumps(client))
for segements in [
["generated"],
Expand All @@ -46,3 +48,10 @@ def test_pickle_clients(structure_clients):
original = original[segment]
roundtripped = pickle.loads(pickle.dumps(original))
assert roundtripped.uri == original.uri


def test_lock_round_trip(tmpdir):
cache = Cache(tmpdir / "http_response_cache.db")
cache_round_tripped = pickle.loads(pickle.dumps(cache))
# implementation detail!
assert cache._lock.lock is cache_round_tripped._lock.lock
60 changes: 53 additions & 7 deletions tiled/client/cache.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import enum
import json
import os
import sqlite3
import threading
import typing as tp
from contextlib import closing
from datetime import datetime
from functools import wraps
from pathlib import Path

import appdirs
import httpx

from .utils import TiledResponse
from .utils import SerializableLock, TiledResponse

CACHE_DATABASE_SCHEMA_VERSION = 1

Expand Down Expand Up @@ -146,6 +148,33 @@ def _prepare_database(filepath, readonly):
return conn


def with_thread_lock(fn):
"""Makes sure the wrapper isn't accessed concurrently."""

@wraps(fn)
def wrapper(obj, *args, **kwargs):
obj._lock.acquire()
try:
result = fn(obj, *args, **kwargs)
finally:
obj._lock.release()
return result

return wrapper


class ThreadingMode(enum.IntEnum):
"""Threading mode used in the sqlite3 package.
https://docs.python.org/3/library/sqlite3.html#sqlite3.threadsafety
"""

SINGLE_THREAD = 0
MULTI_THREAD = 1
SERIALIZED = 3


class Cache:
def __init__(
self,
Expand All @@ -171,29 +200,42 @@ def __init__(
self._filepath = filepath
self._owner_thread = threading.current_thread().ident
self._conn = _prepare_database(filepath, readonly)
self._lock = SerializableLock()

def __repr__(self):
return f"<{type(self).__name__} {str(self._filepath)!r}>"

def write_safe(self):
"""
Check that it is safe to write.
"""Check that it is safe to write.
SQLite is not threadsafe for concurrent _writes_ unless the
underlying sqlite library was built with thread safety
enabled. Even still, it may be a good idea to use a thread
lock (``@with_thread_lock``) to prevent parallel writes.
SQLite is not threadsafe for concurrent _writes_.
"""
return threading.current_thread().ident == self._owner_thread
is_main_thread = threading.current_thread().ident == self._owner_thread
sqlite_is_safe = sqlite3.threadsafety == ThreadingMode.SERIALIZED
return is_main_thread or sqlite_is_safe

def __getstate__(self):
return (self.filepath, self.capacity, self.max_item_size, self._readonly)
return (
self.filepath,
self.capacity,
self.max_item_size,
self._readonly,
self._lock,
)

def __setstate__(self, state):
(filepath, capacity, max_item_size, readonly) = state
(filepath, capacity, max_item_size, readonly, lock) = state
self._capacity = capacity
self._max_item_size = max_item_size
self._readonly = readonly
self._filepath = filepath
self._owner_thread = threading.current_thread().ident
self._conn = _prepare_database(filepath, readonly)
self._lock = lock

@property
def readonly(self):
Expand Down Expand Up @@ -223,6 +265,7 @@ def max_item_size(self):
def max_item_size(self, max_item_size):
self._max_item_size = max_item_size

@with_thread_lock
def clear(self):
"""
Drop all entries from HTTP response cache.
Expand All @@ -237,6 +280,7 @@ def clear(self):
cur.execute("DELETE FROM responses")
self._conn.commit()

@with_thread_lock
def get(self, request: httpx.Request) -> tp.Optional[httpx.Response]:
"""Get cached response from Cache.
Expand Down Expand Up @@ -271,6 +315,7 @@ def get(self, request: httpx.Request) -> tp.Optional[httpx.Response]:

return load(row, request)

@with_thread_lock
def set(
self,
*,
Expand Down Expand Up @@ -321,6 +366,7 @@ def set(
self._conn.commit()
return True

@with_thread_lock
def delete(self, request: httpx.Request) -> None:
"""Delete an entry from cache.
Expand Down

0 comments on commit 5fa30b3

Please sign in to comment.