Skip to content

Commit

Permalink
Fix task cancellation propagate to subtasks when using sync middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
ttys0dev committed Jan 16, 2024
1 parent 19e14e7 commit 20fb4ce
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
42 changes: 30 additions & 12 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ async def main_wrap(
if context is not None:
_restore_context(context[0])

current_task = asyncio.current_task()
if current_task is not None:
task_context = SyncToAsync.task_context.get()
task_context.append(current_task)

try:
# If we have an exception, run the function inside the except block
# after raising it so exc_info is correctly populated.
Expand All @@ -324,6 +329,8 @@ async def main_wrap(
else:
call_result.set_result(result)
finally:
if current_task is not None:
task_context.remove(current_task)
context[0] = contextvars.copy_context()


Expand Down Expand Up @@ -355,6 +362,10 @@ class SyncToAsync(Generic[_P, _R]):
# Single-thread executor for thread-sensitive code
single_thread_executor = ThreadPoolExecutor(max_workers=1)

task_context: "contextvars.ContextVar[list[Future[Any]]]" = contextvars.ContextVar(
"task_context", default=[]
)

# Maintain a contextvar for the current execution context. Optionally used
# for thread sensitive mode.
thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = (
Expand Down Expand Up @@ -438,19 +449,26 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
child = functools.partial(self.func, *args, **kwargs)
func = context.run

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
func,
child,
),
)
ret: _R
try:
# Run the code in the right thread
ret: _R = await loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
func,
child,
),
)

ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
tasks = self.task_context.get()
for task in tasks:
task.cancel()
await task
ret = await exec_coro
finally:
_restore_context(context)
self.deadlock_context.set(False)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ async def server_entry():


@pytest.mark.asyncio
@pytest.mark.xfail
@pytest.mark.skip(reason="deadlocks")
async def test_sync_to_async_with_blocker_thread_sensitive():
"""
Tests sync_to_async running on a long-time blocker in a thread_sensitive context.
Expand Down Expand Up @@ -852,7 +852,6 @@ def sync_task():


@pytest.mark.asyncio
@pytest.mark.skip(reason="deadlocks")
async def test_inner_shield_sync_middleware():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
Expand Down

0 comments on commit 20fb4ce

Please sign in to comment.