Created
November 30, 2025 19:21
-
-
Save CypherpunkSamurai/ab7abf82727f5c554c62c49adfc2a2e0 to your computer and use it in GitHub Desktop.
FastAPI Middlwares
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # tracer.py | |
| # Request Caching Middleware | |
| import time | |
| import uuid | |
| # types | |
| from typing import List, Optional | |
| from cachetools import TTLCache | |
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
| from starlette.requests import Request | |
| from starlette.responses import Response | |
| # logging | |
| from kairo.utils.log import logger | |
| class CacheMiddleware(BaseHTTPMiddleware): | |
| """A Middleware Class to Cache Requests""" | |
| def __init__(self, app, maxsize=512, ttl=10): | |
| """ | |
| Initialize the CacheMiddleware with a TTLCache. | |
| Args: | |
| app (FastAPI): The FastAPI application instance. | |
| maxsize (int): Maximum size of the cache in kilobytes. | |
| ttl (int): Time-to-live for cache entries in seconds. | |
| """ | |
| super().__init__(app) | |
| self._cache = TTLCache(maxsize=maxsize * 1024, ttl=ttl) | |
| self._ttl = ttl | |
| async def _trace_start(self, request: Request): | |
| """Start tracing the request""" | |
| request_id = str(uuid.uuid4().hex) | |
| logger.info(f"<- Caching Request {request_id} - {request.method} {request.url}") | |
| # cache the request id | |
| self._cache[request_id] = { | |
| "id": request_id, | |
| "method": request.method, | |
| "url": request.url, | |
| "timestamp": time.time(), | |
| "ttl": self._ttl, | |
| "status": "pending", | |
| "request": request, | |
| "response": None, | |
| } | |
| return request_id | |
| async def _trace_stop( | |
| self, request_id: str, response: Response, status: str = "completed" | |
| ): | |
| """Stop tracing the request""" | |
| # check request exists in cache | |
| if request_id in self._cache: | |
| self._cache[request_id]["response"] = response | |
| self._cache[request_id]["status"] = status | |
| self._cache[request_id]["duration"] = ( | |
| time.time() - self._cache[request_id]["timestamp"] | |
| ) | |
| logger.info(f"-> Caching Response: {request_id} - {response.status_code}") | |
| async def get_trace(self, request_id: str) -> Optional[dict]: | |
| """Get the trace of a request""" | |
| return self._cache.get(request_id) | |
| async def get_traces(self) -> List[dict]: | |
| """Get all traces""" | |
| return list(self._cache.values()) | |
| async def dispatch( | |
| self, request: Request, call_next: RequestResponseEndpoint | |
| ) -> Response: | |
| """A Classic Middleware Class to Trace Requests""" | |
| # refer to https://stackoverflow.com/questions/71525132/how-to-write-a-custom-fastapi-middleware-class | |
| # https://davidmuraya.com/blog/adding-middleware-to-fastapi-applications/ | |
| # trace request | |
| request_id = await self._trace_start(request) | |
| # Call the next middleware or route handler | |
| # **This Forwards the Request** | |
| response = await call_next(request) | |
| # trace response | |
| await self._trace_stop(request_id, response) | |
| # add header | |
| response.headers["X-Request-Id"] = request_id | |
| return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment