-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow nesting sync_to_async via asyncio.wait_for (#367)
Change the order of fallbacks used by SyncToAsync to find the appropriate executor for sync code, so it prefers to use AsyncToSync.loop_thread_executors rather than thread_sensitive_context. Add test case to demonstrate problem.
- Loading branch information
1 parent
d1ee1fa
commit d920c3c
Showing
5 changed files
with
206 additions
and
209 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,120 +1,128 @@ | ||
import random | ||
import string | ||
import sys | ||
import asyncio | ||
import contextlib | ||
import contextvars | ||
import threading | ||
import weakref | ||
from typing import Any, Dict, Union | ||
|
||
|
||
class _CVar: | ||
"""Storage utility for Local.""" | ||
|
||
def __init__(self) -> None: | ||
self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar( | ||
"asgiref.local" | ||
) | ||
|
||
def __getattr__(self, key): | ||
storage_object = self._data.get({}) | ||
try: | ||
return storage_object[key] | ||
except KeyError: | ||
raise AttributeError(f"{self!r} object has no attribute {key!r}") | ||
|
||
def __setattr__(self, key: str, value: Any) -> None: | ||
if key == "_data": | ||
return super().__setattr__(key, value) | ||
|
||
storage_object = self._data.get({}) | ||
storage_object[key] = value | ||
self._data.set(storage_object) | ||
|
||
def __delattr__(self, key: str) -> None: | ||
storage_object = self._data.get({}) | ||
if key in storage_object: | ||
del storage_object[key] | ||
self._data.set(storage_object) | ||
else: | ||
raise AttributeError(f"{self!r} object has no attribute {key!r}") | ||
|
||
|
||
class Local: | ||
""" | ||
A drop-in replacement for threading.locals that also works with asyncio | ||
Tasks (via the current_task asyncio method), and passes locals through | ||
sync_to_async and async_to_sync. | ||
Specifically: | ||
- Locals work per-coroutine on any thread not spawned using asgiref | ||
- Locals work per-thread on any thread not spawned using asgiref | ||
- Locals are shared with the parent coroutine when using sync_to_async | ||
- Locals are shared with the parent thread when using async_to_sync | ||
(and if that thread was launched using sync_to_async, with its parent | ||
coroutine as well, with this working for indefinite levels of nesting) | ||
Set thread_critical to True to not allow locals to pass from an async Task | ||
to a thread it spawns. This is needed for code that truly needs | ||
thread-safety, as opposed to things used for helpful context (e.g. sqlite | ||
does not like being called from a different thread to the one it is from). | ||
Thread-critical code will still be differentiated per-Task within a thread | ||
as it is expected it does not like concurrent access. | ||
This doesn't use contextvars as it needs to support 3.6. Once it can support | ||
3.7 only, we can then reimplement the storage more nicely. | ||
"""Local storage for async tasks. | ||
This is a namespace object (similar to `threading.local`) where data is | ||
also local to the current async task (if there is one). | ||
In async threads, local means in the same sense as the `contextvars` | ||
module - i.e. a value set in an async frame will be visible: | ||
- to other async code `await`-ed from this frame. | ||
- to tasks spawned using `asyncio` utilities (`create_task`, `wait_for`, | ||
`gather` and probably others). | ||
- to code scheduled in a sync thread using `sync_to_async` | ||
In "sync" threads (a thread with no async event loop running), the | ||
data is thread-local, but additionally shared with async code executed | ||
via the `async_to_sync` utility, which schedules async code in a new thread | ||
and copies context across to that thread. | ||
If `thread_critical` is True, then the local will only be visible per-thread, | ||
behaving exactly like `threading.local` if the thread is sync, and as | ||
`contextvars` if the thread is async. This allows genuinely thread-sensitive | ||
code (such as DB handles) to be kept stricly to their initial thread and | ||
disable the sharing across `sync_to_async` and `async_to_sync` wrapped calls. | ||
Unlike plain `contextvars` objects, this utility is threadsafe. | ||
""" | ||
|
||
def __init__(self, thread_critical: bool = False) -> None: | ||
self._thread_critical = thread_critical | ||
self._thread_lock = threading.RLock() | ||
self._context_refs: "weakref.WeakSet[object]" = weakref.WeakSet() | ||
# Random suffixes stop accidental reuse between different Locals, | ||
# though we try to force deletion as well. | ||
self._attr_name = "_asgiref_local_impl_{}_{}".format( | ||
id(self), | ||
"".join(random.choice(string.ascii_letters) for i in range(8)), | ||
) | ||
|
||
def _get_context_id(self): | ||
""" | ||
Get the ID we should use for looking up variables | ||
""" | ||
# Prevent a circular reference | ||
from .sync import AsyncToSync, SyncToAsync | ||
|
||
# First, pull the current task if we can | ||
context_id = SyncToAsync.get_current_task() | ||
context_is_async = True | ||
# OK, let's try for a thread ID | ||
if context_id is None: | ||
context_id = threading.current_thread() | ||
context_is_async = False | ||
# If we're thread-critical, we stop here, as we can't share contexts. | ||
self._storage: "Union[threading.local, _CVar]" | ||
|
||
if thread_critical: | ||
# Thread-local storage | ||
self._storage = threading.local() | ||
else: | ||
# Contextvar storage | ||
self._storage = _CVar() | ||
|
||
@contextlib.contextmanager | ||
def _lock_storage(self): | ||
# Thread safe access to storage | ||
if self._thread_critical: | ||
return context_id | ||
# Now, take those and see if we can resolve them through the launch maps | ||
for i in range(sys.getrecursionlimit()): | ||
try: | ||
if context_is_async: | ||
# Tasks have a source thread in AsyncToSync | ||
context_id = AsyncToSync.launch_map[context_id] | ||
context_is_async = False | ||
else: | ||
# Threads have a source task in SyncToAsync | ||
context_id = SyncToAsync.launch_map[context_id] | ||
context_is_async = True | ||
except KeyError: | ||
break | ||
# this is a test for are we in a async or sync | ||
# thread - will raise RuntimeError if there is | ||
# no current loop | ||
asyncio.get_running_loop() | ||
except RuntimeError: | ||
# We are in a sync thread, the storage is | ||
# just the plain thread local (i.e, "global within | ||
# this thread" - it doesn't matter where you are | ||
# in a call stack you see the same storage) | ||
yield self._storage | ||
else: | ||
# We are in an async thread - storage is still | ||
# local to this thread, but additionally should | ||
# behave like a context var (is only visible with | ||
# the same async call stack) | ||
|
||
# Ensure context exists in the current thread | ||
if not hasattr(self._storage, "cvar"): | ||
self._storage.cvar = _CVar() | ||
|
||
# self._storage is a thread local, so the members | ||
# can't be accessed in another thread (we don't | ||
# need any locks) | ||
yield self._storage.cvar | ||
else: | ||
# Catch infinite loops (they happen if you are screwing around | ||
# with AsyncToSync implementations) | ||
raise RuntimeError("Infinite launch_map loops") | ||
return context_id | ||
|
||
def _get_storage(self): | ||
context_obj = self._get_context_id() | ||
if not hasattr(context_obj, self._attr_name): | ||
setattr(context_obj, self._attr_name, {}) | ||
self._context_refs.add(context_obj) | ||
return getattr(context_obj, self._attr_name) | ||
|
||
def __del__(self): | ||
try: | ||
for context_obj in self._context_refs: | ||
try: | ||
delattr(context_obj, self._attr_name) | ||
except AttributeError: | ||
pass | ||
except TypeError: | ||
# WeakSet.__iter__ can crash when interpreter is shutting down due | ||
# to _IterationGuard being None. | ||
pass | ||
# Lock for thread_critical=False as other threads | ||
# can access the exact same storage object | ||
with self._thread_lock: | ||
yield self._storage | ||
|
||
def __getattr__(self, key): | ||
with self._thread_lock: | ||
storage = self._get_storage() | ||
if key in storage: | ||
return storage[key] | ||
else: | ||
raise AttributeError(f"{self!r} object has no attribute {key!r}") | ||
with self._lock_storage() as storage: | ||
return getattr(storage, key) | ||
|
||
def __setattr__(self, key, value): | ||
if key in ("_context_refs", "_thread_critical", "_thread_lock", "_attr_name"): | ||
if key in ("_local", "_storage", "_thread_critical", "_thread_lock"): | ||
return super().__setattr__(key, value) | ||
with self._thread_lock: | ||
storage = self._get_storage() | ||
storage[key] = value | ||
with self._lock_storage() as storage: | ||
setattr(storage, key, value) | ||
|
||
def __delattr__(self, key): | ||
with self._thread_lock: | ||
storage = self._get_storage() | ||
if key in storage: | ||
del storage[key] | ||
else: | ||
raise AttributeError(f"{self!r} object has no attribute {key!r}") | ||
with self._lock_storage() as storage: | ||
delattr(storage, key) |
Oops, something went wrong.