Last active
December 4, 2025 19:29
-
-
Save deanm0000/71d2158ee490493688b5782a4f9d3ff4 to your computer and use it in GitHub Desktop.
advent of code 2025 day 2 using polars
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import polars as pl | |
| from polars import col as c | |
| from polars import lit as l | |
| with open("./inputs/day2.txt") as f: | |
| example_input = f.read() | |
| df = ( | |
| pl.select(range=pl.lit(example_input).str.split(",").explode()) | |
| .select(c.range.str.split("-").list.to_struct(fields=["begin", "end"])) | |
| .unnest("range") | |
| .with_columns(pl.all().cast(pl.UInt64)) | |
| .select(num=pl.int_ranges(c.begin, c.end + 1).explode()) | |
| .lazy() | |
| ) | |
| DIGITSTR = c.num.cast(pl.String).alias("digitstr") | |
| DIGITS = DIGITSTR.str.len_chars().alias("n_digits") | |
| def part1(return_sum=True): | |
| _df = ( | |
| df.filter(DIGITS % 2 == 0) | |
| .with_columns( | |
| lhs=DIGITSTR.str.slice(0, (DIGITS / 2).cast(pl.Int8)), | |
| rhs=DIGITSTR.str.slice((DIGITS / 2).cast(pl.Int8)), | |
| ) | |
| .filter(c.lhs == c.rhs) | |
| ) | |
| if return_sum: | |
| return _df.select(c.num.sum()).collect().item() | |
| else: | |
| return _df.select(c.num).collect()["num"] | |
| def part2(): | |
| # don't redo part1 | |
| known_invalid = part1(return_sum=False) | |
| generator = df.group_by(DIGITS).agg(pl.all()).collect_batches() | |
| invalids = known_invalid.implode() | |
| # to use arrays, we must know the shape upfront so use a streaming generator | |
| # that was grouped by the number of digits | |
| for subdf in generator: | |
| digits = subdf["n_digits"][0] | |
| subdf = ( | |
| subdf.drop("n_digits") | |
| .explode("num") # since it was a group_by this will be a list | |
| .filter(c.num.is_in(invalids).not_()) | |
| .with_columns( | |
| digits=DIGITSTR.str.split("") | |
| .list.eval(pl.element().cast(pl.UInt8)) | |
| .list.to_array(digits) | |
| ) | |
| ) | |
| gcds = pl_gcd(digits) | |
| for gcd in gcds: | |
| if gcd > 1: | |
| subdf = subdf.with_columns( | |
| c.digits.reshape((-1, gcd, int(digits / gcd))) | |
| ) | |
| filter_size = gcd | |
| else: | |
| filter_size = digits | |
| invalids = invalids.list.set_union( | |
| subdf.filter( | |
| # using n_unique is very slow so use pl.all_horizontal instead since we know width anyway | |
| pl.all_horizontal( | |
| c.digits.arr.first() == c.digits.arr.get(x) | |
| for x in range(1, filter_size) | |
| ) | |
| )["num"].implode() | |
| ) | |
| subdf = subdf.filter(c.num.is_in(invalids).not_()) | |
| return invalids.explode().sum() | |
| def pl_gcd(number: int): | |
| """Greatest common divisor in polars""" | |
| return ( | |
| pl.select(l(number).alias("num")) | |
| .with_columns(gcd=pl.int_ranges(1, int(number / 2) + 1)) | |
| .explode("gcd") | |
| .filter(((c.num / c.gcd) - (c.num / c.gcd).round()).abs() < 0.000001) | |
| .get_column("gcd") | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment