Created
December 20, 2024 13:05
-
-
Save beatzxbt/98d36b652e163e2a1dd0ee113eb0066b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import xxhash | |
| import asyncio | |
| import platform | |
| from collections import deque | |
| from warnings import warn as warning | |
| from typing import List, Dict, Optional | |
| from mm_toolbox.time import time_ns, time_s | |
| from .conn.raw cimport WsConnection | |
| from libc.stdint cimport uint8_t, int64_t, uint64_t, UINT64_MAX | |
| cdef class PoolQueue: | |
| def __init__(self, uint64_t max_size=UINT64_MAX): | |
| """ | |
| Initializes the PoolQueue with a specified maximum size. | |
| """ | |
| if max_size <= 0: | |
| raise ValueError("max_size must be greater than 0.") | |
| self.max_size = max_size | |
| self.current_size = 0 | |
| self._queue = deque() | |
| self._latest_hashes = set() | |
| cpdef uint64_t generate_hash(self, bytes msg): | |
| """ | |
| Generates a 64-bit hash for the given message using xxhash. | |
| """ | |
| return xxhash.xxh3_64_intdigest(msg) | |
| cpdef bint is_unique(self, uint64_t hash_value): | |
| """ | |
| Checks if the hash is unique (not already in the queue). | |
| """ | |
| return hash_value not in self._latest_hashes | |
| cpdef bint is_empty(self): | |
| """ | |
| Checks if the queue is empty. | |
| """ | |
| return self.current_size == 0 | |
| cpdef void put_item(self, bytes item, uint64_t hash_value): | |
| """ | |
| Adds an item to the queue if there is space. Raises IndexError | |
| if the queue is full. | |
| """ | |
| if self.current_size < self.max_size: | |
| self._queue.append((item, hash_value)) | |
| self._latest_hashes.add(hash_value) | |
| self.current_size += 1 | |
| else: | |
| raise IndexError("PoolQueue is full. Cannot add new item.") | |
| cpdef void put_item_with_overwrite(self, bytes item, uint64_t hash_value): | |
| """ | |
| Adds an item to the queue, overwriting the oldest item if the queue | |
| is full. | |
| """ | |
| cdef: | |
| bytes old_item | |
| uint64_t old_hash_value | |
| if self.current_size >= self.max_size: | |
| old_item, old_hash_value = self._queue.popleft() | |
| self._latest_hashes.remove(old_hash_value) | |
| self.current_size -= 1 | |
| self.put_item(item, hash_value) | |
| cpdef bytes take_item(self): | |
| """ | |
| Removes and returns the oldest item from the queue. | |
| Raises IndexError if the queue is empty. | |
| """ | |
| if self.current_size <= 0: | |
| raise IndexError("PoolQueue is empty. Cannot take an item.") | |
| cdef: | |
| bytes item | |
| uint64_t hash_value | |
| item, hash_value = self._queue.popleft() | |
| self._latest_hashes.remove(hash_value) | |
| self.current_size -= 1 | |
| return item | |
| cpdef list[bytes] take_all(self): | |
| """ | |
| Empties the queue into a list and restarts all | |
| internal states. | |
| """ | |
| if self.current_size <= 0: | |
| return [] | |
| cdef: | |
| uint64_t i | |
| bytes item | |
| uint64_t hash_value | |
| list[bytes] all_items = [None] * <Py_ssize_t>self.current_size | |
| for i in range(0, self.current_size): | |
| all_items[i] = self._queue.popleft() | |
| self.current_size = 0 | |
| self._latest_hashes.clear() | |
| return all_items | |
| cdef class WsPoolConfig: | |
| """ | |
| Configuration for WebSocket connection eviction and latency measurement policies. | |
| Attributes | |
| ---------- | |
| evict_interval_s : int | |
| The interval in seconds at which eviction conditions are checked. | |
| evict_events_threshold : int | |
| The number of events after which eviction conditions are checked. | |
| latency_interval_s : int | |
| The interval in seconds at which latency is measured. | |
| latency_events_threshold : int | |
| The number of events after which latency is measured. | |
| """ | |
| def __init__( | |
| self, | |
| uint64_t evict_interval_s=60, | |
| uint64_t evict_events_threshold=3000, | |
| uint64_t latency_interval_s=1, | |
| uint64_t latency_events_threshold=500, | |
| ): | |
| self.evict_interval_s = evict_interval_s | |
| self.evict_events_threshold = evict_events_threshold | |
| self.latency_interval_s = latency_interval_s | |
| self.latency_events_threshold = latency_events_threshold | |
| if self.latency_interval_s > self.evict_interval_s: | |
| warning( | |
| "Latency interval cannot be greater than eviction interval. " | |
| "Adjusting latency interval to eviction interval.", | |
| RuntimeWarning | |
| ) | |
| self.latency_interval_s = self.evict_interval_s | |
| if self.latency_events_threshold > self.evict_events_threshold: | |
| warning( | |
| "Latency events threshold cannot be greater than eviction events threshold. " | |
| "Adjusting latency events threshold to eviction events threshold.", | |
| RuntimeWarning | |
| ) | |
| self.latency_events_threshold = self.evict_events_threshold | |
| cdef class WsPool: | |
| """ | |
| Manages a pool of fast WebSocket connections. | |
| """ | |
| def __init__( | |
| self, | |
| uint8_t size, | |
| function ws_handler, | |
| object logger, | |
| WsPoolConfig config=None, | |
| ) -> None: | |
| self._size = size | |
| if self._size <= 1: | |
| warning("Pool size cannot be <2, defaulting to 2.", RuntimeWarning) | |
| self._size = 2 | |
| self._user_ws_handler = ws_handler | |
| self._logger = logger | |
| self._config = config if config is not None else WsPoolConfig() | |
| self._queue = PoolQueue() | |
| self._conns: dict[uint64_t, WsConnection] = {} | |
| self._fast_conns: set[uint64_t] = set() | |
| self._last_conn_eviction_time = 0.0 | |
| self._msg_ingress_task: asyncio.Task = None | |
| self._conn_eviction_task: asyncio.Task = None | |
| self._is_running: bool = False | |
| # Require high performance event loop (OS spec) | |
| if platform.system() == "Windows": | |
| try: | |
| import winloop | |
| asyncio.set_event_loop_policy(winloop.EventLoopPolicy()) | |
| except ImportError: | |
| raise ImportError( | |
| "Requires 'winloop', install with 'pip install winloop'" | |
| ) | |
| else: | |
| try: | |
| import uvloop | |
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
| except ImportError: | |
| raise ImportError( | |
| "Requires 'uvloop', install with 'pip install uvloop'" | |
| ) | |
| cdef inline uint64_t _generate_conn_id(self): | |
| """ | |
| Connection IDs generated with nanosecond timestamps are guaranteed | |
| to be unique. Simple and works! | |
| """ | |
| return time_ns() | |
| cdef inline void _process_ws_frame(self, uint64_t seq_id, double time, memoryview frame): | |
| """ | |
| Feeds data from ws connections into the queue if it is unique. | |
| """ | |
| cdef: | |
| bytes msg = frame.tobytes() | |
| uint64_t msg_hash = self._queue.generate_hash(msg) | |
| # For now, the seq_id and time fields are ignored. | |
| # These are implemented for the future in more | |
| # advanced use cases, but for simplicity, are | |
| # disabled at the moment. | |
| if self._queue.is_unique(msg_hash): | |
| self._queue.put_item(msg, msg_hash) | |
| async def _msg_ingress(self) -> None: | |
| """ | |
| Ingests queue data and feeds it to a user-provided | |
| processing function. | |
| """ | |
| while self._is_running: | |
| try: | |
| # Avoid IndexError by checking size first. | |
| if not self._queue.is_empty(): | |
| msg = self._queue.take_item() | |
| self._user_ws_handler(msg) | |
| # Yield control back to the event loop. | |
| else: | |
| await asyncio.sleep(0) | |
| continue | |
| except asyncio.CancelledError: | |
| return | |
| except Exception: | |
| pass | |
| async def _conn_eviction(self) -> None: | |
| """ | |
| Enforces the config to evict slow connections. | |
| """ | |
| # Initial sleep for connections to warm up with data. | |
| await asyncio.sleep(self._config.interval) | |
| next_eviction_time = time_s() + self._config.interval | |
| while True: | |
| over_time_limit = time_s() >= next_eviction_time | |
| over_event_limit = self._queue.current_size >= self._config.events | |
| # If neither conditions are met, sleep for a small duration. | |
| # Checks are very cheap, so we can do them often. Hence the sleep | |
| # is set to a low 10ms. Can be increased in the future if deemed | |
| # uneccesarily fast. | |
| if not (over_time_limit or over_event_limit): | |
| await asyncio.sleep(0.01) | |
| continue | |
| # A list of [(conn_id, conn_latency)] for each connection. | |
| # We sort this later by latency, and can terminate the bottom 25% | |
| # of connections, whilst keeping the top 25% in self._fast_conns | |
| # as a set of quick connections to send data through. | |
| conn_latency_ranking = [ | |
| (conn_id, conn.get_mean_latency()) | |
| for conn_id, conn in self._conns.items() | |
| ] | |
| conn_latency_ranking.sort(key=lambda x: x[1]) | |
| slow_conns = conn_latency_ranking[:-max(1, self._size // 4)] | |
| fast_conns = conn_latency_ranking[max(1, self._size // 4):] | |
| for conn_id, conn in slow_conns: | |
| conn.close() | |
| if conn_id in self._fast_conns: | |
| self._fast_conns.remove(conn_id) | |
| for conn_id, conn in fast_conns: | |
| self._fast_conns.add(conn_id) | |
| # Reset eviction timer and active queue size. We manually | |
| # filter through the queue and process each message before | |
| # clearing to prevent data loss. This purposely blocks, | |
| # letting the websockets back up data before resuming. | |
| # | |
| # If you get backpressure issues, speed up your processing! | |
| next_eviction_time = time_s() + self._config.interval | |
| unprocessed_msgs = self._queue.take_all() | |
| for msg in unprocessed_msgs: | |
| self._user_ws_handler(msg) | |
| async def _open_new_conn( | |
| self, | |
| str url, | |
| list[dict] on_connect=None, | |
| ): | |
| """ | |
| Establishes a new WebSocket connection, adds it to the connection pool, | |
| and begins data ingestion. | |
| Parameters | |
| ---------- | |
| url : str | |
| The WebSocket URL to connect to. | |
| on_connect : Optional[List[Dict]], optional | |
| List of payloads to send upon connecting (default is None). | |
| """ | |
| new_conn_id = self._generate_conn_id() | |
| new_conn = WsConnection(logger=self._logger, conn_id=new_conn_id) | |
| await self._conns[new_conn_id].start( | |
| url=url, | |
| on_connect=on_connect, | |
| ) | |
| self._conns[new_conn_id] = new_conn | |
| async def start( | |
| self, | |
| str url, | |
| list[dict] on_connect=None, | |
| ): | |
| """ | |
| Starts all WebSocket connections in the pool. | |
| Parameters | |
| ---------- | |
| url : str | |
| The WebSocket URL to connect to. | |
| on_connect : Optional[List[Dict]], optional | |
| List of payloads to send upon connecting (default is None). | |
| """ | |
| if self._is_running: | |
| raise RuntimeError("Connection already running.") | |
| self._is_running = True | |
| await asyncio.gather( | |
| *[ | |
| self._open_new_conn(url, on_connect) | |
| for _ in range(self._size) | |
| ] | |
| ) | |
| self._eviction_task = asyncio.create_task(self._enforce_eviction_policy()) | |
| cpdef void send_data(self, msg): | |
| """ | |
| Sends a payload through all WebSocket connections in the pool. | |
| """ | |
| if len(self._fast_conns) > 0: | |
| for conn_id in self._fast_conns: | |
| self._conns[conn_id].send_data(msg) | |
| else: | |
| for conn in self._conns.values(): | |
| conn.send_data(msg) | |
| cpdef void shutdown(self): | |
| """ | |
| Shuts down all WebSocket connections and stops the eviction task. | |
| """ | |
| self._is_running = False | |
| # Should auto stop with _is_running set | |
| # to False, but doesnt hurt to be sure. | |
| if not self._eviction_task.done(): | |
| self._eviction_task.cancel() | |
| if not self._data_ingress_task.done(): | |
| self._data_ingress_task.cancel() | |
| for conn in self._conns.items(): | |
| conn.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment