Skip to content

Instantly share code, notes, and snippets.

@msdrigg
Created July 14, 2023 21:54
Show Gist options
  • Select an option

  • Save msdrigg/02c7716d6e2a0cb4e5ef08d14f180119 to your computer and use it in GitHub Desktop.

Select an option

Save msdrigg/02c7716d6e2a0cb4e5ef08d14f180119 to your computer and use it in GitHub Desktop.
Cancel fastapi route handler on client disconnect
import asyncio
from contextlib import asynccontextmanager
import functools
from typing import Annotated, Any, AsyncContextManager, Awaitable, Callable
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import Request, FastAPI
from fastapi.params import Depends
# Creating this gist for discussion here https://github.com/tiangolo/fastapi/discussions/8805
# Note: using these methods will likely discard request body (because they call receive and discard the result)
app = FastAPI()
async def wait_for_disconnect(request: Request) -> True:
_receive = request._receive
while (await _receive())["type"] != "http.disconnect":
pass
return True
async def cancellation(
request: Request,
):
event = asyncio.Event()
async def set_event_on_disconnect():
await wait_for_disconnect(request)
event.set()
async with asyncio.TaskGroup() as tg:
disconnect_task = tg.create_task(set_event_on_disconnect())
yield event
disconnect_task.cancel()
CancellationEvent = Annotated[asyncio.Event, Depends(cancellation)]
@asynccontextmanager
async def create_request_task_group(request: Request):
async def cancel_on_disconnect():
await wait_for_disconnect(request)
raise asyncio.CancelledError()
async with create_task_group() as outer_tg:
outer_tg.start_soon(cancel_on_disconnect)
async with create_task_group() as tg:
yield tg
outer_tg.cancel_scope.cancel()
async def request_task_group(request: Request):
return functools.partial(create_request_task_group, request)
RequestTaskGroup = Annotated[
Callable[[], AsyncContextManager[TaskGroup]], Depends(request_task_group)
]
async def huge_work():
print("Starting work")
await asyncio.sleep(5)
print("Work completed!")
@app.get("/cancel_tg_dependency")
async def cancel_tg_dependency(
get_task_group: RequestTaskGroup,
):
# This task group will be cancelled if the client disconnects
# before it exits
async with get_task_group() as tg:
tg.start_soon(huge_work)
return {"message": "Done"}
@app.get("/cancel_tg")
async def cancel_tg(request: Request):
async with create_request_task_group(request) as tg:
tg.start_soon(huge_work)
return {"message": "Done"}
@app.get("/cancel_event")
async def cancel_event(event: CancellationEvent):
async with create_task_group() as tg:
async def cancel_after_completion(func: Awaitable[Any]):
await func()
tg.cancel_scope.cancel()
tg.start_soon(cancel_after_completion, huge_work)
tg.start_soon(cancel_after_completion, event.wait)
return {"message": "Done"}
@msdrigg
Copy link
Author

msdrigg commented Sep 8, 2023

Glad it will help you!

@ha-sante
Copy link

ha-sante commented Nov 4, 2023

This didn't work unfortunately - same code, virtual environment replicated with the above requirements file.

  • No errors btw

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment