Last active
August 20, 2025 12:50
-
-
Save richardhundt/17dfccb5c1e253f798999fc2b2417d7e to your computer and use it in GitHub Desktop.
Patch httpx ASGITransport to stream responses for testing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from contextlib import asynccontextmanager | |
| import typing | |
| import asyncio | |
| from httpx._models import Request, Response | |
| from httpx._transports.asgi import ASGITransport | |
| from httpx._types import AsyncByteStream | |
| class ASGIResponseByteStream(AsyncByteStream): | |
| def __init__( | |
| self, stream: typing.AsyncGenerator[bytes, None] | |
| ) -> None: | |
| self._stream = stream | |
| def __aiter__(self) -> typing.AsyncIterator[bytes]: | |
| return self._stream.__aiter__() | |
| async def aclose(self) -> None: | |
| await self._stream.aclose() | |
| async def patch_handle_async_request( | |
| self: ASGITransport, | |
| request: Request, | |
| ) -> Response: | |
| assert isinstance(request.stream, AsyncByteStream) | |
| # ASGI scope. | |
| scope = { | |
| "type": "http", | |
| "asgi": {"version": "3.0"}, | |
| "http_version": "1.1", | |
| "method": request.method, | |
| "headers": [(k.lower(), v) for (k, v) in request.headers.raw], | |
| "scheme": request.url.scheme, | |
| "path": request.url.path, | |
| "raw_path": request.url.raw_path, | |
| "query_string": request.url.query, | |
| "server": (request.url.host, request.url.port), | |
| "client": self.client, | |
| "root_path": self.root_path, | |
| } | |
| # Request. | |
| request_body_chunks = request.stream.__aiter__() | |
| request_complete = False | |
| # Response. | |
| status_code = None | |
| response_headers = None | |
| sentinel = object() | |
| body_queue = asyncio.Queue() | |
| response_started = asyncio.Event() | |
| response_complete = asyncio.Event() | |
| # ASGI callables. | |
| async def receive() -> typing.Dict[str, typing.Any]: | |
| nonlocal request_complete | |
| if request_complete: | |
| await response_complete.wait() | |
| return {"type": "http.disconnect"} | |
| try: | |
| body = await request_body_chunks.__anext__() | |
| except StopAsyncIteration: | |
| request_complete = True | |
| return {"type": "http.request", "body": b"", "more_body": False} | |
| return {"type": "http.request", "body": body, "more_body": True} | |
| async def send(message: typing.Dict[str, typing.Any]) -> None: | |
| nonlocal status_code, response_headers, response_started | |
| if message["type"] == "http.response.start": | |
| assert not response_started.is_set() | |
| status_code = message["status"] | |
| response_headers = message.get("headers", []) | |
| response_started.set() | |
| elif message["type"] == "http.response.body": | |
| assert response_started.is_set() | |
| assert not response_complete.is_set() | |
| body = message.get("body", b"") | |
| more_body = message.get("more_body", False) | |
| if body and request.method != "HEAD": | |
| await body_queue.put(body) | |
| if not more_body: | |
| await body_queue.put(sentinel) | |
| response_complete.set() | |
| async def run_app() -> None: | |
| try: | |
| await self.app(scope, receive, send) | |
| except Exception: # noqa: PIE-786 | |
| if self.raise_app_exceptions or not response_complete.is_set(): | |
| raise | |
| async def body_stream() -> typing.AsyncGenerator[bytes, None]: | |
| while True: | |
| body = await body_queue.get() | |
| if body != sentinel: | |
| yield body | |
| else: | |
| return | |
| asyncio.create_task(run_app()) | |
| await response_started.wait() | |
| assert status_code is not None | |
| assert response_headers is not None | |
| stream = ASGIResponseByteStream(body_stream()) | |
| return Response(status_code, headers=response_headers, stream=stream) | |
| @asynccontextmanager | |
| async def patch_asgi_transport(): | |
| restore = ASGITransport.handle_async_request | |
| ASGITransport.handle_async_request = patch_handle_async_request | |
| yield | |
| ASGITransport.handle_async_request = restore | |
Author
Thanks. This is awesome
if i may post a minimal working fastapi + pytest example using your code here:
"""
Read: - https://github.com/encode/httpx/issues/2186
- https://gist.github.com/richardhundt/17dfccb5c1e253f798999fc2b2417d7e
- https://stackoverflow.com/a/75760884
"""
import httpx
import pytest
import asyncio
from contextlib import asynccontextmanager
import typing
from httpx._models import Request, Response
from httpx._transports.asgi import ASGITransport
from httpx._types import AsyncByteStream
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
app = FastAPI()
async def fake_data_streamer():
for i in range(10):
yield b'data: some fake data\n\n'
print("write to kk")
await asyncio.sleep(0.5)
@app.get('/')
async def main():
return StreamingResponse(fake_data_streamer(), media_type='text/event-stream')
@pytest.fixture
def anyio_backend():
return 'asyncio'
class ASGIResponseByteStream(AsyncByteStream):
def __init__(
self, stream: typing.AsyncGenerator[bytes, None]
) -> None:
self._stream = stream
def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self._stream.__aiter__()
async def aclose(self) -> None:
await self._stream.aclose()
async def patch_handle_async_request(
self: ASGITransport,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}
# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False
# Response.
status_code = None
response_headers = None
sentinel = object()
body_queue = asyncio.Queue()
response_started = asyncio.Event()
response_complete = asyncio.Event()
# ASGI callables.
async def receive() -> typing.Dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: typing.Dict[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":
assert not response_started.is_set()
status_code = message["status"]
response_headers = message.get("headers", [])
response_started.set()
elif message["type"] == "http.response.body":
assert response_started.is_set()
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and request.method != "HEAD":
await body_queue.put(body)
if not more_body:
await body_queue.put(sentinel)
response_complete.set()
async def run_app() -> None:
try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions or not response_complete.is_set():
raise
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
while True:
body = await body_queue.get()
if body != sentinel:
yield body
else:
return
asyncio.create_task(run_app())
await response_started.wait()
assert status_code is not None
assert response_headers is not None
stream = ASGIResponseByteStream(body_stream())
return Response(status_code, headers=response_headers, stream=stream)
@asynccontextmanager
async def patch_asgi_transport():
restore = ASGITransport.handle_async_request
ASGITransport.handle_async_request = patch_handle_async_request
yield
ASGITransport.handle_async_request = restore
@pytest.fixture
async def test_client():
async with patch_asgi_transport():
async with httpx.AsyncClient(base_url="http://testserver", transport=ASGITransport(app=app)) as async_client:
yield async_client
@pytest.mark.anyio
async def test_stream_reads_incrementally(test_client: 'httpx.AsyncClient'):
async with test_client.stream("GET", "/") as resp:
assert resp.status_code == 200
async for line in resp.aiter_lines():
assert line.startswith("data:")
payload = line.lstrip("data:").strip()
assert payload == "some fake data"
break
# assert got is not None, "Did not receive an SSE data line"
# assert "some fake data" in got
# agen = resp.aiter_bytes()
# first = await agen.__anext__()
# assert first == b"some fake data\n\n"
# await resp.aclose() or chatGPT-5 suggested using a uvicorn server:
import uvicorn
import pytest
import asyncio
import httpx
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
app = FastAPI()
async def fake_data_streamer():
for i in range(10):
yield b'data: some fake data\n\n'
print("write to kk")
await asyncio.sleep(0.5)
@app.get('/')
async def main():
return StreamingResponse(fake_data_streamer(), media_type='text/event-stream')
@pytest.fixture
def anyio_backend():
return 'asyncio'
@pytest.fixture
async def live_server():
host = "127.0.0.1"
port = 8000
config = uvicorn.Config(app, host=host, port=port, log_level="error", lifespan="on")
server = uvicorn.Server(config)
task = asyncio.create_task(server.serve())
while not server.started:
await asyncio.sleep(0.1)
try:
yield f"http://{host}:{port}"
finally:
server.should_exit = True
await task
@pytest.mark.anyio
async def test_true_streaming(live_server):
async with httpx.AsyncClient(base_url=live_server, timeout=httpx.Timeout(5, read=None)) as client:
async with client.stream("GET", "/") as resp:
assert resp.status_code == 200 # available immediately
agen = resp.aiter_lines()
first = await agen.__anext__() # returns after first yield
assert first.strip() == "some fake data"
await resp.aclose()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example usage: