Created
September 30, 2025 02:48
-
-
Save etrotta/ee957810cebcf8b5c004394b472b8380 to your computer and use it in GitHub Desktop.
Demonstration of Polars streaming data via 1.34 Sink Batches + Input Plugins
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "httpx-ws==0.8.0", | |
| # "httpx==0.28.1", | |
| # "httpcore==1.0.9", | |
| # "polars>=1.34.0b4", | |
| # ] | |
| # [tool.uv] | |
| # prerelease = "allow" | |
| # /// | |
| import json | |
| from typing import Any, Iterator | |
| from urllib.parse import urlencode | |
| import polars as pl | |
| from httpx_ws import connect_ws | |
| from polars.io.plugins import register_io_source | |
| # BSKY_JETSTREAM = "wss://jetstream1.us-west.bsky.network/subscribe" | |
| BSKY_JETSTREAM = "wss://jetstream2.us-east.bsky.network/subscribe" | |
| SCHEMA_BLUEPRINT = { | |
| "did": "did:plc:xyz", | |
| "time_us": 123, | |
| "kind": "commit", | |
| "commit": { | |
| "rev": "abc", | |
| "operation": "create", | |
| "collection": "app.bsky.feed.post", | |
| "rkey": "idk", | |
| "record": { | |
| "$type": "app.bsky.feed.post", | |
| "createdAt": "2025-01-01T00:00:00.000Z", | |
| "embed": { | |
| "$type": "app.bsky.embed.record", | |
| "record": { | |
| "cid": "abcxyz", | |
| "uri": "at://did:plc:abc/app.bsky.feed.post/xyz", | |
| }, | |
| }, | |
| "langs": ["en"], | |
| "text": "Well, well, well.", | |
| }, | |
| "cid": "abcxyz", | |
| }, | |
| } | |
| def _parse_obj(obj: Any) -> pl.DataType: | |
| if isinstance(obj, dict): | |
| fields = {key: _parse_obj(value) for key, value in obj.items()} | |
| return pl.Struct(fields) | |
| elif isinstance(obj, list): | |
| return pl.List(_parse_obj(obj[0])) | |
| elif isinstance(obj, str): | |
| return pl.String() | |
| elif isinstance(obj, int): | |
| return pl.Int64() | |
| else: | |
| raise NotImplementedError() | |
| def _parse_schema(obj: dict[str, Any]) -> pl.Schema: | |
| columns = {key: _parse_obj(value) for key, value in obj.items()} | |
| return pl.Schema(columns) | |
| SCHEMA = _parse_schema(SCHEMA_BLUEPRINT) | |
| def scan_jetstream_plugin() -> pl.LazyFrame: | |
| query: list[tuple[str, str]] = [ | |
| ("wantedCollections", "app.bsky.feed.post"), | |
| # ("wantedDids", ""), | |
| # ("cursor", str(cursor_timestamp)), | |
| # ("compress", "true"), | |
| ("maxMessageSizeBytes", str(2**16)), | |
| ] | |
| encoded_query = urlencode(query) | |
| url = f"{BSKY_JETSTREAM}?{encoded_query}" | |
| def source_generator( | |
| with_columns: list[str] | None, | |
| predicate: pl.Expr | None, | |
| n_rows: int | None, | |
| batch_size: int | None, | |
| ) -> Iterator[pl.DataFrame]: | |
| """ | |
| Generator function that creates the source. | |
| This function will be registered as IO source. | |
| """ | |
| if batch_size is None: | |
| batch_size = 10 | |
| if n_rows is not None: | |
| batch_size = min(batch_size, n_rows) | |
| with connect_ws(url) as ws: | |
| while n_rows is None or n_rows > 0: | |
| batch = [] | |
| for _ in range(batch_size): | |
| batch.append(json.loads(ws.receive_text())) | |
| df = pl.DataFrame(batch, schema=SCHEMA) | |
| if with_columns is not None: | |
| df = df.select(with_columns) | |
| if predicate is not None: | |
| df = df.filter(predicate) | |
| if n_rows is not None: | |
| n_rows -= df.height | |
| batch_size = min(batch_size, n_rows) | |
| yield df | |
| return register_io_source(io_source=source_generator, schema=SCHEMA) | |
| def sink_batch(df: pl.DataFrame) -> None: | |
| print(df) | |
| def main(): | |
| lz = scan_jetstream_plugin() | |
| lz.head(200).sink_batches(sink_batch, chunk_size=100, lazy=False) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment