Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task cancellation propagation to subtasks when using sync middleware #435

Merged
merged 1 commit into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
# `main_wrap`.
context = [contextvars.copy_context()]

# Get task context so that parent task knows which task to propagate
# an asyncio.CancelledError to.
task_context = getattr(SyncToAsync.threadlocal, "task_context", None)

loop = None
# Use call_soon_threadsafe to schedule a synchronous callback on the
# main event loop's thread if it's there, otherwise make a new loop
Expand All @@ -211,6 +215,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
awaitable = self.main_wrap(
call_result,
sys.exc_info(),
task_context,
context,
*args,
**kwargs,
Expand Down Expand Up @@ -295,6 +300,7 @@ async def main_wrap(
self,
call_result: "Future[_R]",
exc_info: "OptExcInfo",
task_context: "Optional[List[asyncio.Task[Any]]]",
context: List[contextvars.Context],
*args: _P.args,
**kwargs: _P.kwargs,
Expand All @@ -309,6 +315,10 @@ async def main_wrap(
if context is not None:
_restore_context(context[0])

current_task = asyncio.current_task()
if current_task is not None and task_context is not None:
task_context.append(current_task)

try:
# If we have an exception, run the function inside the except block
# after raising it so exc_info is correctly populated.
Expand All @@ -324,6 +334,8 @@ async def main_wrap(
else:
call_result.set_result(result)
finally:
if current_task is not None and task_context is not None:
task_context.remove(current_task)
context[0] = contextvars.copy_context()


Expand Down Expand Up @@ -437,20 +449,38 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
context = contextvars.copy_context()
child = functools.partial(self.func, *args, **kwargs)
func = context.run

task_context: List[asyncio.Task[Any]] = []

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
ret: _R
try:
# Run the code in the right thread
ret: _R = await loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
func,
child,
),
)

ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
cancel_parent = True
try:
task = task_context[0]
task.cancel()
try:
await task
cancel_parent = False
except asyncio.CancelledError:
pass
except IndexError:
pass
if cancel_parent:
exec_coro.cancel()
ret = await exec_coro
finally:
_restore_context(context)
self.deadlock_context.set(False)
Expand All @@ -466,7 +496,7 @@ def __get__(
func = functools.partial(self.__call__, parent)
return functools.update_wrapper(func, self.func)

def thread_handler(self, loop, exc_info, func, *args, **kwargs):
def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs):
"""
Wraps the sync application with exception handling.
"""
Expand All @@ -476,6 +506,7 @@ def thread_handler(self, loop, exc_info, func, *args, **kwargs):
# Set the threadlocal for AsyncToSync
self.threadlocal.main_event_loop = loop
self.threadlocal.main_event_loop_pid = os.getpid()
self.threadlocal.task_context = task_context

# Run the function
# If we have an exception, run the function inside the except block
Expand Down
159 changes: 156 additions & 3 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,13 +852,10 @@ def sync_task():


@pytest.mark.asyncio
@pytest.mark.skip(reason="deadlocks")
async def test_inner_shield_sync_middleware():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
cancelling a django request task when using sync middleware.

Currently this tests is skipped as it causes a deadlock.
"""

# Hypothetical Django scenario - middleware function is sync
Expand Down Expand Up @@ -968,3 +965,159 @@ async def async_task():
assert task_complete

assert task_executed


@pytest.mark.asyncio
async def test_inner_shield_sync_and_async_middleware():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
cancelling a django request task when using sync and middleware chained
together.
"""

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_1():
async_to_sync(async_middleware_2)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_2():
await sync_to_async(sync_middleware_3)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_3():
async_to_sync(async_middleware_4)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_4():
await sync_to_async(sync_middleware_5)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_5():
async_to_sync(async_view)()

task_complete = False
task_cancel_caught = False

# Future that completes when subtask cancellation attempt is caught
task_blocker = asyncio.Future()

async def async_view():
"""Async view with a task that is shielded from cancellation."""
nonlocal task_complete, task_cancel_caught, task_blocker
task = asyncio.create_task(async_task())
try:
await asyncio.shield(task)
except asyncio.CancelledError:
task_cancel_caught = True
task_blocker.set_result(True)
await task
task_complete = True

task_executed = False

# Future that completes after subtask is created
task_started_future = asyncio.Future()

async def async_task():
"""Async subtask that should not be canceled when parent is canceled."""
nonlocal task_started_future, task_executed, task_blocker
task_started_future.set_result(True)
await task_blocker
task_executed = True

task_cancel_propagated = False

async with ThreadSensitiveContext():
task = asyncio.create_task(sync_to_async(sync_middleware_1)())
await task_started_future
task.cancel()
try:
await task
except asyncio.CancelledError:
task_cancel_propagated = True
assert not task_cancel_propagated
assert task_cancel_caught
assert task_complete

assert task_executed


@pytest.mark.asyncio
async def test_inner_shield_sync_and_async_middleware_sync_task():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
cancelling a django request task when using sync and middleware chained
together with an async view calling a sync function calling an async task.

This test ensures that a parent initiated task cancellation will not
propagate to a shielded subtask.
"""

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_1():
async_to_sync(async_middleware_2)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_2():
await sync_to_async(sync_middleware_3)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_3():
async_to_sync(async_middleware_4)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_4():
await sync_to_async(sync_middleware_5)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_5():
async_to_sync(async_view)()

task_complete = False
task_cancel_caught = False

# Future that completes when subtask cancellation attempt is caught
task_blocker = asyncio.Future()

async def async_view():
"""Async view with a task that is shielded from cancellation."""
nonlocal task_complete, task_cancel_caught, task_blocker
task = asyncio.create_task(sync_to_async(sync_parent)())
try:
await asyncio.shield(task)
except asyncio.CancelledError:
task_cancel_caught = True
task_blocker.set_result(True)
await task
task_complete = True

task_executed = False

# Future that completes after subtask is created
task_started_future = asyncio.Future()

def sync_parent():
async_to_sync(async_task)()

async def async_task():
"""Async subtask that should not be canceled when parent is canceled."""
nonlocal task_started_future, task_executed, task_blocker
task_started_future.set_result(True)
await task_blocker
task_executed = True

task_cancel_propagated = False

async with ThreadSensitiveContext():
task = asyncio.create_task(sync_to_async(sync_middleware_1)())
await task_started_future
task.cancel()
try:
await task
except asyncio.CancelledError:
task_cancel_propagated = True
assert not task_cancel_propagated
assert task_cancel_caught
assert task_complete

assert task_executed