Skip to content

Instantly share code, notes, and snippets.

@Magnus167
Last active September 22, 2025 16:43
Show Gist options
  • Select an option

  • Save Magnus167/60d4663aa579a956246245742d0746dc to your computer and use it in GitHub Desktop.

Select an option

Save Magnus167/60d4663aa579a956246245742d0746dc to your computer and use it in GitHub Desktop.
from __future__ import annotations
import re
from typing import Iterable, Optional, Sequence
import polars as pl
from tqdm import tqdm
# ==== Constants ====
KEY_COLS: list[str] = ["ticker", "real_date"]
DEFAULT_COMPARE_COLS: list[str] = ["value", "grading", "eop_lag", "last_updated"]
# ==== Internal: pick latest/earliest per (ticker, real_date) ====
def _pick_update(
lf: pl.LazyFrame, *, method: str, tie_break: bool = True
) -> pl.LazyFrame:
"""
Return exactly one row per (ticker, real_date) by max/min(last_updated).
If multiple rows share the chosen timestamp, break ties deterministically
with a hash over non-key, non-timestamp columns (when enabled).
"""
if method not in {"max", "min"}:
raise ValueError("method must be 'max' or 'min'")
schema_names = lf.collect_schema().names()
value_cols = [c for c in schema_names if c not in KEY_COLS]
x = lf
if tie_break:
tie_inputs = [c for c in value_cols if c != "last_updated"]
x = x.with_columns(
(pl.struct(tie_inputs).hash() if tie_inputs else pl.lit(0)).alias("_tie")
)
ts_pick = (
pl.col("last_updated").max()
if method == "max"
else pl.col("last_updated").min()
).alias("_ts_pick")
ts_per_key = x.group_by(KEY_COLS).agg(ts_pick)
picked = (
x.join(ts_per_key, on=KEY_COLS, how="inner")
.filter(pl.col("last_updated") == pl.col("_ts_pick"))
.drop("_ts_pick")
)
if not tie_break:
return picked
tmax = picked.group_by(KEY_COLS).agg(pl.col("_tie").max().alias("_tie_pick"))
return (
picked.join(tmax, on=KEY_COLS, how="inner")
.filter(pl.col("_tie") == pl.col("_tie_pick"))
.drop("_tie_pick", "_tie")
)
def select_latest_updates(lf: pl.LazyFrame, *, tie_break: bool = True) -> pl.LazyFrame:
return _pick_update(lf, method="max", tie_break=tie_break)
def select_earliest_updates(
lf: pl.LazyFrame, *, tie_break: bool = True
) -> pl.LazyFrame:
return _pick_update(lf, method="min", tie_break=tie_break)
# ==== Filtering helpers ====
def _ticker_filter_expr(
*,
cids: Optional[Sequence[str]],
xcats: Optional[Sequence[str]],
and_semantics: bool,
) -> Optional[pl.Expr]:
"""
Build a boolean expression for JPMAQS tickers of the form CID_..._XCAT.
- If cids provided: anchor at start (`^CID|^CID2|...`)
- If xcats provided: anchor at end (`XCAT$|XCAT2$|...`)
- Combine with AND (default) or OR semantics.
"""
exprs: list[pl.Expr] = []
if cids:
cid_regex = r"^(" + "|".join(re.escape(c) for c in cids) + r")"
exprs.append(pl.col("ticker").str.contains(cid_regex))
if xcats:
xcat_regex = r"(" + "|".join(re.escape(x) for x in xcats) + r")$"
exprs.append(pl.col("ticker").str.contains(xcat_regex))
if not exprs:
return None
if len(exprs) == 1:
return exprs[0]
return pl.all_horizontal(exprs) if and_semantics else pl.any_horizontal(exprs)
def filter_jpmaqs(
lf: pl.LazyFrame,
*,
cids: Optional[Sequence[str]] = None,
xcats: Optional[Sequence[str]] = None,
tickers: Optional[Sequence[str]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
start_last_updated: Optional[str] = None,
end_last_updated: Optional[str] = None,
latest_release: Optional[bool] = None,
first_release: Optional[bool] = None,
and_cid_xcat: bool = True,
only_columns: Optional[Sequence[str]] = None,
) -> pl.LazyFrame:
"""
Build a fully-lazy query with predicate/projection pushdown for JPMAQS parquet data.
"""
# Optional projection pruning up front
if only_columns:
need = list(dict.fromkeys([*KEY_COLS, *only_columns]))
present = lf.collect_schema().names()
keep = [c for c in need if c in present]
lf = lf.select(keep)
filters: list[pl.Expr] = []
if tickers:
filters.append(pl.col("ticker").is_in(list(tickers)))
if start_date:
filters.append(
pl.col("real_date")
>= pl.lit(start_date).str.strptime(pl.Date, strict=False)
)
if end_date:
filters.append(
pl.col("real_date") <= pl.lit(end_date).str.strptime(pl.Date, strict=False)
)
if start_last_updated:
filters.append(
pl.col("last_updated")
>= pl.lit(start_last_updated).str.strptime(pl.Datetime, strict=False)
)
if end_last_updated:
filters.append(
pl.col("last_updated")
<= pl.lit(end_last_updated).str.strptime(pl.Datetime, strict=False)
)
ticker_expr = _ticker_filter_expr(
cids=cids, xcats=xcats, and_semantics=and_cid_xcat
)
if ticker_expr is not None:
filters.append(ticker_expr)
if filters:
lf = lf.filter(pl.all_horizontal(filters))
# Release selection: exactly one of latest_release / first_release must be True.
if latest_release is None and first_release is None:
latest_release = True # default
if (bool(latest_release) + bool(first_release)) != 1:
raise ValueError("Exactly one of latest_release or first_release must be True.")
return select_latest_updates(lf) if latest_release else select_earliest_updates(lf)
def load_lazy(
path: str,
*,
cid: Optional[str] = None,
columns: Optional[Sequence[str]] = None,
**kwargs,
) -> pl.LazyFrame:
"""
Lazily scan parquet dataset(s) and apply JPMAQS filters.
"""
lf = pl.scan_parquet(path)
cids = [cid] if cid else None
return filter_jpmaqs(lf, cids=cids, only_columns=columns, **kwargs)
# ==== Diffing ====
def eq_missing_all(pairs: Iterable[tuple[str, str]]) -> pl.Expr:
"""Null-safe equality across many column pairs."""
return pl.all_horizontal([pl.col(a).eq_missing(pl.col(b)) for a, b in pairs])
def diff_lazy(
a: pl.LazyFrame,
b: pl.LazyFrame,
*,
a_name: str = "PROD",
b_name: str = "UAT",
compare_cols: Optional[Sequence[str]] = None,
) -> pl.LazyFrame:
"""
Full-outer join on KEY_COLS; flag rows only in A/B or with updated values.
Keeps side-by-side columns for inspection. Purely lazy.
"""
compare_cols = list(compare_cols or DEFAULT_COMPARE_COLS)
a_schema = a.collect_schema().names()
b_schema = b.collect_schema().names()
a_sel = a.select([c for c in [*KEY_COLS, *compare_cols] if c in a_schema])
b_sel = b.select([c for c in [*KEY_COLS, *compare_cols] if c in b_schema])
suffix = f"_{b_name.lower()}"
b_sel = b_sel.rename(
{
c: (c if c in KEY_COLS else f"{c}{suffix}")
for c in b_sel.collect_schema().names()
}
)
joined = a_sel.join(b_sel, on=KEY_COLS, how="full", coalesce=True)
in_a = pl.any_horizontal([pl.col(c).is_not_null() for c in compare_cols]).alias(
"_in_a"
)
in_b = pl.any_horizontal(
[pl.col(f"{c}{suffix}").is_not_null() for c in compare_cols]
).alias("_in_b")
same = eq_missing_all([(c, f"{c}{suffix}") for c in compare_cols]).alias("_same")
out = (
joined.with_columns(in_a, in_b, same)
.with_columns(
pl.when(pl.col("_in_a") & ~pl.col("_in_b"))
.then(pl.lit(f"only_in_{a_name.lower()}"))
.when(~pl.col("_in_a") & pl.col("_in_b"))
.then(pl.lit(f"only_in_{b_name.lower()}"))
.when(pl.col("_in_a") & pl.col("_in_b") & ~pl.col("_same"))
.then(pl.lit("updated"))
.otherwise(pl.lit("same"))
.alias("change_type")
)
.filter(pl.col("change_type") != "same")
)
# Keep keys, change_type, and available compare columns from both sides
keep: list[str] = [*KEY_COLS, "change_type"]
for c in compare_cols:
if c in out.collect_schema().names():
keep.append(c)
rhs = f"{c}{suffix}"
if rhs in out.collect_schema().names():
keep.append(rhs)
return out.select(keep)
def list_tickers(path: str) -> list[str]:
return (
pl.scan_parquet(path)
.select("ticker")
.unique()
.sort("ticker")
.collect()["ticker"]
.to_list()
)
# ==== Example main ====
if __name__ == "__main__":
# Assuming that all parquet files (all including delta files) are in the given folders
# the files dont need to be read one-by-one, or anything, they just need to be in the folder
from pathlib import Path
prod_path = "./data/jpmaqs-data/PROD/*.parquet"
uat_path = "./data/jpmaqs-data/UAT/*.parquet"
out_path = "./data/jpmaqs-data/diff-output/"
Path(out_path).mkdir(parents=True, exist_ok=True)
# Minimal discovery
tickers = list_tickers(prod_path)
cids = sorted(set(t.split("_", 1)[0] for t in tickers))
print(f"Found {len(cids)} CIDs, {len(tickers)} tickers in PROD.")
for cid in tqdm(cids, "Processing"):
needed = list(dict.fromkeys([*KEY_COLS, *DEFAULT_COMPARE_COLS]))
# prod = load_lazy(older_path, cid="AUD", columns=needed, latest_release=True)
# uat = load_lazy(new_path, cid="AUD", columns=needed, latest_release=True)
prod_lf = load_lazy(prod_path, cid=cid, columns=needed, latest_release=True)
uat_lf = load_lazy(uat_path, cid=cid, columns=needed, latest_release=True)
diff_lf = diff_lazy(prod_lf, uat_lf, a_name="PROD", b_name="UAT")
# Prefer streaming engine when supported by your Polars version
result = diff_lf.collect(engine="streaming")
print(result)
# save the result as a CSV as CID-diff.CSV
csv_path = out_path + f"{cid}-diff.csv"
result.write_csv(csv_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment