Skip to content

Instantly share code, notes, and snippets.

@beatzxbt
Created December 20, 2024 13:05
Show Gist options
  • Select an option

  • Save beatzxbt/98d36b652e163e2a1dd0ee113eb0066b to your computer and use it in GitHub Desktop.

Select an option

Save beatzxbt/98d36b652e163e2a1dd0ee113eb0066b to your computer and use it in GitHub Desktop.
import xxhash
import asyncio
import platform
from collections import deque
from warnings import warn as warning
from typing import List, Dict, Optional
from mm_toolbox.time import time_ns, time_s
from .conn.raw cimport WsConnection
from libc.stdint cimport uint8_t, int64_t, uint64_t, UINT64_MAX
cdef class PoolQueue:
def __init__(self, uint64_t max_size=UINT64_MAX):
"""
Initializes the PoolQueue with a specified maximum size.
"""
if max_size <= 0:
raise ValueError("max_size must be greater than 0.")
self.max_size = max_size
self.current_size = 0
self._queue = deque()
self._latest_hashes = set()
cpdef uint64_t generate_hash(self, bytes msg):
"""
Generates a 64-bit hash for the given message using xxhash.
"""
return xxhash.xxh3_64_intdigest(msg)
cpdef bint is_unique(self, uint64_t hash_value):
"""
Checks if the hash is unique (not already in the queue).
"""
return hash_value not in self._latest_hashes
cpdef bint is_empty(self):
"""
Checks if the queue is empty.
"""
return self.current_size == 0
cpdef void put_item(self, bytes item, uint64_t hash_value):
"""
Adds an item to the queue if there is space. Raises IndexError
if the queue is full.
"""
if self.current_size < self.max_size:
self._queue.append((item, hash_value))
self._latest_hashes.add(hash_value)
self.current_size += 1
else:
raise IndexError("PoolQueue is full. Cannot add new item.")
cpdef void put_item_with_overwrite(self, bytes item, uint64_t hash_value):
"""
Adds an item to the queue, overwriting the oldest item if the queue
is full.
"""
cdef:
bytes old_item
uint64_t old_hash_value
if self.current_size >= self.max_size:
old_item, old_hash_value = self._queue.popleft()
self._latest_hashes.remove(old_hash_value)
self.current_size -= 1
self.put_item(item, hash_value)
cpdef bytes take_item(self):
"""
Removes and returns the oldest item from the queue.
Raises IndexError if the queue is empty.
"""
if self.current_size <= 0:
raise IndexError("PoolQueue is empty. Cannot take an item.")
cdef:
bytes item
uint64_t hash_value
item, hash_value = self._queue.popleft()
self._latest_hashes.remove(hash_value)
self.current_size -= 1
return item
cpdef list[bytes] take_all(self):
"""
Empties the queue into a list and restarts all
internal states.
"""
if self.current_size <= 0:
return []
cdef:
uint64_t i
bytes item
uint64_t hash_value
list[bytes] all_items = [None] * <Py_ssize_t>self.current_size
for i in range(0, self.current_size):
all_items[i] = self._queue.popleft()
self.current_size = 0
self._latest_hashes.clear()
return all_items
cdef class WsPoolConfig:
"""
Configuration for WebSocket connection eviction and latency measurement policies.
Attributes
----------
evict_interval_s : int
The interval in seconds at which eviction conditions are checked.
evict_events_threshold : int
The number of events after which eviction conditions are checked.
latency_interval_s : int
The interval in seconds at which latency is measured.
latency_events_threshold : int
The number of events after which latency is measured.
"""
def __init__(
self,
uint64_t evict_interval_s=60,
uint64_t evict_events_threshold=3000,
uint64_t latency_interval_s=1,
uint64_t latency_events_threshold=500,
):
self.evict_interval_s = evict_interval_s
self.evict_events_threshold = evict_events_threshold
self.latency_interval_s = latency_interval_s
self.latency_events_threshold = latency_events_threshold
if self.latency_interval_s > self.evict_interval_s:
warning(
"Latency interval cannot be greater than eviction interval. "
"Adjusting latency interval to eviction interval.",
RuntimeWarning
)
self.latency_interval_s = self.evict_interval_s
if self.latency_events_threshold > self.evict_events_threshold:
warning(
"Latency events threshold cannot be greater than eviction events threshold. "
"Adjusting latency events threshold to eviction events threshold.",
RuntimeWarning
)
self.latency_events_threshold = self.evict_events_threshold
cdef class WsPool:
"""
Manages a pool of fast WebSocket connections.
"""
def __init__(
self,
uint8_t size,
function ws_handler,
object logger,
WsPoolConfig config=None,
) -> None:
self._size = size
if self._size <= 1:
warning("Pool size cannot be <2, defaulting to 2.", RuntimeWarning)
self._size = 2
self._user_ws_handler = ws_handler
self._logger = logger
self._config = config if config is not None else WsPoolConfig()
self._queue = PoolQueue()
self._conns: dict[uint64_t, WsConnection] = {}
self._fast_conns: set[uint64_t] = set()
self._last_conn_eviction_time = 0.0
self._msg_ingress_task: asyncio.Task = None
self._conn_eviction_task: asyncio.Task = None
self._is_running: bool = False
# Require high performance event loop (OS spec)
if platform.system() == "Windows":
try:
import winloop
asyncio.set_event_loop_policy(winloop.EventLoopPolicy())
except ImportError:
raise ImportError(
"Requires 'winloop', install with 'pip install winloop'"
)
else:
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
raise ImportError(
"Requires 'uvloop', install with 'pip install uvloop'"
)
cdef inline uint64_t _generate_conn_id(self):
"""
Connection IDs generated with nanosecond timestamps are guaranteed
to be unique. Simple and works!
"""
return time_ns()
cdef inline void _process_ws_frame(self, uint64_t seq_id, double time, memoryview frame):
"""
Feeds data from ws connections into the queue if it is unique.
"""
cdef:
bytes msg = frame.tobytes()
uint64_t msg_hash = self._queue.generate_hash(msg)
# For now, the seq_id and time fields are ignored.
# These are implemented for the future in more
# advanced use cases, but for simplicity, are
# disabled at the moment.
if self._queue.is_unique(msg_hash):
self._queue.put_item(msg, msg_hash)
async def _msg_ingress(self) -> None:
"""
Ingests queue data and feeds it to a user-provided
processing function.
"""
while self._is_running:
try:
# Avoid IndexError by checking size first.
if not self._queue.is_empty():
msg = self._queue.take_item()
self._user_ws_handler(msg)
# Yield control back to the event loop.
else:
await asyncio.sleep(0)
continue
except asyncio.CancelledError:
return
except Exception:
pass
async def _conn_eviction(self) -> None:
"""
Enforces the config to evict slow connections.
"""
# Initial sleep for connections to warm up with data.
await asyncio.sleep(self._config.interval)
next_eviction_time = time_s() + self._config.interval
while True:
over_time_limit = time_s() >= next_eviction_time
over_event_limit = self._queue.current_size >= self._config.events
# If neither conditions are met, sleep for a small duration.
# Checks are very cheap, so we can do them often. Hence the sleep
# is set to a low 10ms. Can be increased in the future if deemed
# uneccesarily fast.
if not (over_time_limit or over_event_limit):
await asyncio.sleep(0.01)
continue
# A list of [(conn_id, conn_latency)] for each connection.
# We sort this later by latency, and can terminate the bottom 25%
# of connections, whilst keeping the top 25% in self._fast_conns
# as a set of quick connections to send data through.
conn_latency_ranking = [
(conn_id, conn.get_mean_latency())
for conn_id, conn in self._conns.items()
]
conn_latency_ranking.sort(key=lambda x: x[1])
slow_conns = conn_latency_ranking[:-max(1, self._size // 4)]
fast_conns = conn_latency_ranking[max(1, self._size // 4):]
for conn_id, conn in slow_conns:
conn.close()
if conn_id in self._fast_conns:
self._fast_conns.remove(conn_id)
for conn_id, conn in fast_conns:
self._fast_conns.add(conn_id)
# Reset eviction timer and active queue size. We manually
# filter through the queue and process each message before
# clearing to prevent data loss. This purposely blocks,
# letting the websockets back up data before resuming.
#
# If you get backpressure issues, speed up your processing!
next_eviction_time = time_s() + self._config.interval
unprocessed_msgs = self._queue.take_all()
for msg in unprocessed_msgs:
self._user_ws_handler(msg)
async def _open_new_conn(
self,
str url,
list[dict] on_connect=None,
):
"""
Establishes a new WebSocket connection, adds it to the connection pool,
and begins data ingestion.
Parameters
----------
url : str
The WebSocket URL to connect to.
on_connect : Optional[List[Dict]], optional
List of payloads to send upon connecting (default is None).
"""
new_conn_id = self._generate_conn_id()
new_conn = WsConnection(logger=self._logger, conn_id=new_conn_id)
await self._conns[new_conn_id].start(
url=url,
on_connect=on_connect,
)
self._conns[new_conn_id] = new_conn
async def start(
self,
str url,
list[dict] on_connect=None,
):
"""
Starts all WebSocket connections in the pool.
Parameters
----------
url : str
The WebSocket URL to connect to.
on_connect : Optional[List[Dict]], optional
List of payloads to send upon connecting (default is None).
"""
if self._is_running:
raise RuntimeError("Connection already running.")
self._is_running = True
await asyncio.gather(
*[
self._open_new_conn(url, on_connect)
for _ in range(self._size)
]
)
self._eviction_task = asyncio.create_task(self._enforce_eviction_policy())
cpdef void send_data(self, msg):
"""
Sends a payload through all WebSocket connections in the pool.
"""
if len(self._fast_conns) > 0:
for conn_id in self._fast_conns:
self._conns[conn_id].send_data(msg)
else:
for conn in self._conns.values():
conn.send_data(msg)
cpdef void shutdown(self):
"""
Shuts down all WebSocket connections and stops the eviction task.
"""
self._is_running = False
# Should auto stop with _is_running set
# to False, but doesnt hurt to be sure.
if not self._eviction_task.done():
self._eviction_task.cancel()
if not self._data_ingress_task.done():
self._data_ingress_task.cancel()
for conn in self._conns.items():
conn.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment