Last active
May 19, 2025 21:35
-
-
Save fabian-paul/e936f7e1a498e18eec152b68e7f9a73a to your computer and use it in GitHub Desktop.
publish-subscribe with Starlette
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import typing | |
| import asyncio | |
| import uvicorn | |
| import random | |
| import weakref | |
| from pydantic import BaseModel | |
| #from starlette.applications import Starlette | |
| #from starlette.routing import Route | |
| from fastapi import FastAPI, Query | |
| from sse_starlette.sse import EventSourceResponse | |
| #from pydantic import BaseModel | |
| #class Subscription(BaseModel): | |
| # file_ids: typing.List[int] | |
| # Use Bloom filter for subscribers? https://github.com/remram44/python-bloom-filter#readme | |
| class Subscriber: | |
| def __init__(self, file_ids, pusher_ref = None): | |
| self.file_ids = file_ids | |
| self.queue = asyncio.Queue() | |
| self.pusher_ref = pusher_ref | |
| self.counter = 0 | |
| def notify(self, file_id): | |
| if file_id in self.file_ids: | |
| self.queue.put_nowait(dict(data=f'{{"file_id": "{file_id}"}}', event="file_operation", id=self.counter)) | |
| self.counter += 1 | |
| def remove_from_pusher(self): | |
| if self.pusher_ref: | |
| pusher = self.pusher_ref() | |
| if pusher is not None: | |
| pusher.remove_subscriber(self) | |
| else: | |
| print("Pusher reference is already disconnected.") | |
| else: | |
| print("Subscriber wasn't registered with any Pusher.") | |
| async def event_generator(self): | |
| try: | |
| while True: | |
| event = await self.queue.get() | |
| yield event | |
| # see https://github.com/sysid/sse-starlette | |
| except asyncio.CancelledError as e: | |
| self.remove_from_pusher() | |
| print("Disconnected from client (via refresh/close)") | |
| # Do any other cleanup, if any | |
| raise e | |
| class Pusher: | |
| def __init__(self): | |
| self.subscribers = [] | |
| def do_work(self, file_id): | |
| print(f"Doing work for file with id {file_id}.") | |
| self.notify_subscribers(file_id) | |
| def notify_subscribers(self, file_id=1): | |
| for s in self.subscribers: | |
| s.notify(file_id) | |
| def add_subscriber(self, subscriber): | |
| self.subscribers.append(subscriber) | |
| subscriber.pusher_ref = weakref.ref(self) | |
| def remove_subscriber(self, subscriber): | |
| try: | |
| idx = self.subscribers.index(subscriber) | |
| except ValueError: | |
| print("failed deleting subscriber facade") | |
| return | |
| else: | |
| print("successfully deleted subscriber facade") | |
| del self.subscribers[idx] | |
| pusher = Pusher() | |
| async def mock_file_processing(): | |
| while True: | |
| pusher.do_work(random.randint(0, 20)) | |
| await asyncio.sleep(1) | |
| #routes = [ | |
| # Route('/pubsub/', sse), | |
| #] | |
| async def startup_event(): | |
| # see https://github.com/encode/starlette/issues/915 | |
| print("mock file processing not started yet, starting it now.") | |
| loop = asyncio.get_event_loop() | |
| loop.create_task(mock_file_processing()) | |
| app = FastAPI(debug=True, on_startup=[startup_event]) | |
| @app.get("/pubsub/", | |
| response_class=EventSourceResponse, | |
| responses={ | |
| 200: { | |
| "content": { | |
| "text/event-stream": { | |
| "schema": { | |
| "type" : "array", | |
| "format": "text/event-stream", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "id": {"type": "string"}, # TODO | |
| "event": {"type": "string"}, # TODO | |
| "data": {"type": "string"}, # TODO | |
| } | |
| } | |
| } | |
| } | |
| }, | |
| } | |
| }, | |
| ) | |
| async def sse(file_ids: typing.List[int] = Query(default=...)) -> EventSourceResponse: | |
| # see https://github.com/sysid/sse-starlette | |
| subscriber = Subscriber(file_ids=file_ids) | |
| pusher.add_subscriber(subscriber) | |
| generator = subscriber.event_generator() | |
| return EventSourceResponse(generator) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000, log_level='info') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment