Created
July 14, 2023 21:54
-
-
Save msdrigg/02c7716d6e2a0cb4e5ef08d14f180119 to your computer and use it in GitHub Desktop.
Cancel fastapi route handler on client disconnect
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from contextlib import asynccontextmanager | |
| import functools | |
| from typing import Annotated, Any, AsyncContextManager, Awaitable, Callable | |
| from anyio import create_task_group | |
| from anyio.abc import TaskGroup | |
| from fastapi import Request, FastAPI | |
| from fastapi.params import Depends | |
| # Creating this gist for discussion here https://github.com/tiangolo/fastapi/discussions/8805 | |
| # Note: using these methods will likely discard request body (because they call receive and discard the result) | |
| app = FastAPI() | |
| async def wait_for_disconnect(request: Request) -> True: | |
| _receive = request._receive | |
| while (await _receive())["type"] != "http.disconnect": | |
| pass | |
| return True | |
| async def cancellation( | |
| request: Request, | |
| ): | |
| event = asyncio.Event() | |
| async def set_event_on_disconnect(): | |
| await wait_for_disconnect(request) | |
| event.set() | |
| async with asyncio.TaskGroup() as tg: | |
| disconnect_task = tg.create_task(set_event_on_disconnect()) | |
| yield event | |
| disconnect_task.cancel() | |
| CancellationEvent = Annotated[asyncio.Event, Depends(cancellation)] | |
| @asynccontextmanager | |
| async def create_request_task_group(request: Request): | |
| async def cancel_on_disconnect(): | |
| await wait_for_disconnect(request) | |
| raise asyncio.CancelledError() | |
| async with create_task_group() as outer_tg: | |
| outer_tg.start_soon(cancel_on_disconnect) | |
| async with create_task_group() as tg: | |
| yield tg | |
| outer_tg.cancel_scope.cancel() | |
| async def request_task_group(request: Request): | |
| return functools.partial(create_request_task_group, request) | |
| RequestTaskGroup = Annotated[ | |
| Callable[[], AsyncContextManager[TaskGroup]], Depends(request_task_group) | |
| ] | |
| async def huge_work(): | |
| print("Starting work") | |
| await asyncio.sleep(5) | |
| print("Work completed!") | |
| @app.get("/cancel_tg_dependency") | |
| async def cancel_tg_dependency( | |
| get_task_group: RequestTaskGroup, | |
| ): | |
| # This task group will be cancelled if the client disconnects | |
| # before it exits | |
| async with get_task_group() as tg: | |
| tg.start_soon(huge_work) | |
| return {"message": "Done"} | |
| @app.get("/cancel_tg") | |
| async def cancel_tg(request: Request): | |
| async with create_request_task_group(request) as tg: | |
| tg.start_soon(huge_work) | |
| return {"message": "Done"} | |
| @app.get("/cancel_event") | |
| async def cancel_event(event: CancellationEvent): | |
| async with create_task_group() as tg: | |
| async def cancel_after_completion(func: Awaitable[Any]): | |
| await func() | |
| tg.cancel_scope.cancel() | |
| tg.start_soon(cancel_after_completion, huge_work) | |
| tg.start_soon(cancel_after_completion, event.wait) | |
| return {"message": "Done"} | |
Author
This didn't work unfortunately - same code, virtual environment replicated with the above requirements file.
- No errors btw
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Glad it will help you!