diff --git a/tiled/_tests/test_client_cache.py b/tiled/_tests/test_client_cache.py index 8aa0f351f..4233254b5 100644 --- a/tiled/_tests/test_client_cache.py +++ b/tiled/_tests/test_client_cache.py @@ -11,7 +11,7 @@ from ..adapters.array import ArrayAdapter from ..adapters.mapping import MapAdapter from ..client import Context, from_context, record_history -from ..client.cache import Cache, CachedResponse, with_thread_lock +from ..client.cache import Cache, CachedResponse, ThreadingMode, with_thread_lock from ..server.app import build_app tree = MapAdapter( @@ -177,7 +177,7 @@ def test_clear_cache(client): def test_not_thread_safe(client, monkeypatch): # Check that writes fail if thread safety is disabled - monkeypatch.setattr(sqlite3, "threadsafety", 0) + monkeypatch.setattr(sqlite3, "threadsafety", ThreadingMode.SINGLE_THREAD) cache = client.context.cache # Clear the cache in another thread with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: @@ -187,7 +187,8 @@ def test_not_thread_safe(client, monkeypatch): @pytest.mark.skipif( - sqlite3.threadsafety < 2, reason="sqlite not built with thread safe support" + sqlite3.threadsafety != ThreadingMode.SERIALIZED, + reason="sqlite not built with thread safe support", ) def test_thread_safety(client): cache = client.context.cache diff --git a/tiled/client/cache.py b/tiled/client/cache.py index 85781c4f1..44d85be20 100644 --- a/tiled/client/cache.py +++ b/tiled/client/cache.py @@ -1,3 +1,4 @@ +import enum import json import os import sqlite3 @@ -162,6 +163,18 @@ def wrapper(obj, *args, **kwargs): return wrapper +class ThreadingMode(enum.IntEnum): + """Threading mode used in the sqlite3 package. + + https://docs.python.org/3/library/sqlite3.html#sqlite3.threadsafety + + """ + + SINGLE_THREAD = 0 + MULTI_THREAD = 1 + SERIALIZED = 3 + + class Cache: def __init__( self, @@ -201,9 +214,8 @@ def write_safe(self): lock (``@with_thread_lock``) to prevent parallel writes. """ - SERIALIZED = 3 # Could be defined in an enum elsewhere is_main_thread = threading.current_thread().ident == self._owner_thread - sqlite_is_safe = sqlite3.threadsafety == SERIALIZED + sqlite_is_safe = sqlite3.threadsafety == ThreadingMode.SERIALIZED return is_main_thread or sqlite_is_safe def __getstate__(self):