Last active
January 18, 2026 19:04
-
-
Save LeeeeT/4d95151fdd6afbda016b44a44a69115a to your computer and use it in GitHub Desktop.
Token-passing Optimal Reduction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import itertools | |
| from dataclasses import dataclass | |
| from typing import NewType | |
| type Name = int | |
| name = itertools.count() | |
| PosNam = NewType("PosNam", Name) | |
| NegNam = NewType("NegNam", Name) | |
| type Label = int | |
| label = itertools.count() | |
| type Pos = Var | Nil | Lam | Sup | Usp | Cob | Wait | Call | |
| type Neg = Sub | Era | App | Dup | Udp | Amb | Hold | Decide | Eval | |
| @dataclass(frozen=True) | |
| class Var: | |
| nam: PosNam | |
| @dataclass(frozen=True) | |
| class Sub: | |
| nam: NegNam | |
| @dataclass(frozen=True) | |
| class Nil: | |
| pass | |
| @dataclass(frozen=True) | |
| class Era: | |
| pass | |
| @dataclass(frozen=True) | |
| class Lam: | |
| bnd: NegNam | |
| bod: PosNam | |
| @dataclass(frozen=True) | |
| class App: | |
| arg: PosNam | |
| ret: NegNam | |
| @dataclass(frozen=True) | |
| class Dup: | |
| dpl: Label | |
| dp0: NegNam | |
| dp1: NegNam | |
| @dataclass(frozen=True) | |
| class Sup: | |
| spl: Label | |
| sp0: PosNam | |
| sp1: PosNam | |
| @dataclass(frozen=True) | |
| class Udp: | |
| ud0: NegNam | |
| ud1: NegNam | |
| @dataclass(frozen=True) | |
| class Usp: | |
| us0: PosNam | |
| us1: PosNam | |
| @dataclass(frozen=True) | |
| class Amb: | |
| am0: NegNam | |
| am1: NegNam | |
| @dataclass(frozen=True) | |
| class Cob: | |
| cb0: PosNam | |
| cb1: PosNam | |
| @dataclass(frozen=True) | |
| class Wait: | |
| wt0: PosNam | |
| wt1: NegNam | |
| @dataclass(frozen=True) | |
| class Call: | |
| pass | |
| @dataclass(frozen=True) | |
| class Hold: | |
| hd0: NegNam | |
| hd1: PosNam | |
| @dataclass(frozen=True) | |
| class Decide: | |
| dc0: NegNam | |
| dc1: PosNam | |
| @dataclass(frozen=True) | |
| class Eval: | |
| eva: NegNam | |
| type Rdx = tuple[NegNam, PosNam] | |
| @dataclass(frozen=True) | |
| class Net: | |
| book: set[Rdx] | |
| vars: dict[PosNam, Pos] | |
| subs: dict[NegNam, Neg] | |
| def empty_net() -> Net: | |
| return Net(set(), {}, {}) | |
| def var(net: Net, nam: PosNam) -> PosNam: | |
| var = PosNam(next(name)) | |
| net.vars[var] = Var(nam) | |
| return var | |
| def sub(net: Net, nam: NegNam) -> NegNam: | |
| sub = NegNam(next(name)) | |
| net.subs[sub] = Sub(nam) | |
| return sub | |
| def nil(net: Net) -> PosNam: | |
| nil = PosNam(next(name)) | |
| net.vars[nil] = Nil() | |
| return nil | |
| def era(net: Net) -> NegNam: | |
| era = NegNam(next(name)) | |
| net.subs[era] = Era() | |
| return era | |
| def lam(net: Net, bnd: NegNam, bod: PosNam) -> PosNam: | |
| lam = PosNam(next(name)) | |
| net.vars[lam] = Lam(bnd, bod) | |
| return lam | |
| def app(net: Net, arg: PosNam, ret: NegNam) -> NegNam: | |
| app = NegNam(next(name)) | |
| net.subs[app] = App(arg, ret) | |
| return app | |
| def dup(net: Net, dpl: Label, dp0: NegNam, dp1: NegNam) -> NegNam: | |
| dup = NegNam(next(name)) | |
| net.subs[dup] = Dup(dpl, dp0, dp1) | |
| return dup | |
| def sup(net: Net, spl: Label, sp0: PosNam, sp1: PosNam) -> PosNam: | |
| sup = PosNam(next(name)) | |
| net.vars[sup] = Sup(spl, sp0, sp1) | |
| return sup | |
| def udp(net: Net, ud0: NegNam, ud1: NegNam) -> NegNam: | |
| udp = NegNam(next(name)) | |
| net.subs[udp] = Udp(ud0, ud1) | |
| return udp | |
| def usp(net: Net, us0: PosNam, us1: PosNam) -> PosNam: | |
| usp = PosNam(next(name)) | |
| net.vars[usp] = Usp(us0, us1) | |
| return usp | |
| def amb(net: Net, am0: NegNam, am1: NegNam) -> NegNam: | |
| amb = NegNam(next(name)) | |
| net.subs[amb] = Amb(am0, am1) | |
| return amb | |
| def cob(net: Net, cb0: PosNam, cb1: PosNam) -> PosNam: | |
| cob = PosNam(next(name)) | |
| net.vars[cob] = Cob(cb0, cb1) | |
| return cob | |
| def wait(net: Net, wt0: PosNam, wt1: NegNam) -> PosNam: | |
| wait = PosNam(next(name)) | |
| net.vars[wait] = Wait(wt0, wt1) | |
| return wait | |
| def call(net: Net) -> PosNam: | |
| call = PosNam(next(name)) | |
| net.vars[call] = Call() | |
| return call | |
| def hold(net: Net, hd0: NegNam, hd1: PosNam) -> NegNam: | |
| hold = NegNam(next(name)) | |
| net.subs[hold] = Hold(hd0, hd1) | |
| return hold | |
| def decide(net: Net, dc0: NegNam, dc1: PosNam) -> NegNam: | |
| decide = NegNam(next(name)) | |
| net.subs[decide] = Decide(dc0, dc1) | |
| return decide | |
| def eval(net: Net, eva: NegNam) -> NegNam: | |
| eval = NegNam(next(name)) | |
| net.subs[eval] = Eval(eva) | |
| return eval | |
| def wire(net: Net) -> tuple[PosNam, NegNam]: | |
| nam = next(name) | |
| return var(net, PosNam(nam)), sub(net, NegNam(nam)) | |
| def show_pos(net: Net, pos: PosNam, visited: set[Name] | None = None) -> str: | |
| if visited is None: | |
| visited = set() | |
| if pos in visited: | |
| return "+....." | |
| visited.add(pos) | |
| match net.vars[pos]: | |
| case Var(nam) if nam in net.vars: | |
| return show_pos(net, nam, visited) | |
| case Var(nam): | |
| return f"+{nam}" | |
| case Nil(): | |
| return "+_" | |
| case Lam(bnd, bod): | |
| return f"+({show_neg(net, bnd, visited)} {show_pos(net, bod, visited)})" | |
| case Sup(spl, sp0, sp1): | |
| return f"+&{spl}{{{show_pos(net, sp0, visited)} {show_pos(net, sp1, visited)}}}" | |
| case Usp(us0, us1): | |
| return f"+{{{show_pos(net, us0, visited)} {show_pos(net, us1, visited)}}}" | |
| case Cob(cb0, cb1): | |
| return f"-cob({show_pos(net, cb0, visited) if cb0 in net.vars else "stolen"} {show_pos(net, cb1, visited) if cb1 in net.vars else "stolen"})" | |
| case Wait(wt0, wt1): | |
| return f"+wait({show_pos(net, wt0, visited)} {show_neg(net, wt1, visited)})" | |
| case Call(): | |
| return "+call" | |
| def show_neg(net: Net, neg: NegNam, visited: set[Name] | None = None) -> str: | |
| if visited is None: | |
| visited = set() | |
| if neg in visited: | |
| return "+....." | |
| visited.add(neg) | |
| match net.subs[neg]: | |
| case Sub(nam) if nam in net.subs: | |
| return show_neg(net, nam, visited) | |
| case Sub(nam): | |
| return f"-{nam}" | |
| case Era(): | |
| return "-_" | |
| case App(arg, ret): | |
| return f"-({show_pos(net, arg, visited)} {show_neg(net, ret, visited)})" | |
| case Dup(dpl, dp0, dp1): | |
| return f"-&{dpl}{{{show_neg(net, dp0, visited)} {show_neg(net, dp1, visited)}}}" | |
| case Udp(ud0, ud1): | |
| return f"-{{{show_neg(net, ud0, visited)} {show_neg(net, ud1, visited)}}}" | |
| case Amb(am0, am1): | |
| return f"-amb({show_neg(net, am0, visited) if am0 in net.subs else "stolen"} {show_neg(net, am1, visited) if am1 in net.subs else "stolen"})" | |
| case Hold(hd0, hd1): | |
| return f"-hold({show_neg(net, hd0, visited)} {show_pos(net, hd1, visited)})" | |
| case Decide(dc0, dc1): | |
| return f"-decide({show_neg(net, dc0, visited)} {show_pos(net, dc1, visited)})" | |
| case Eval(eva): | |
| # return f"-eval({show_neg(net, eva)})" | |
| return show_neg(net, eva, visited) | |
| @dataclass(kw_only=True) | |
| class Stats: | |
| itrs: dict[str, int] | |
| max_size: int | |
| parallelism: int | |
| def bump(itrs: dict[str, int], key: str) -> None: | |
| itrs[key] = itrs.get(key, 0) + 1 | |
| def reduce(net: Net, itrs: dict[str, int]) -> None: | |
| lhs, rhs = net.book.pop() | |
| lhsc = net.subs.pop(lhs) | |
| rhsc = net.vars.pop(rhs) | |
| match lhsc, rhsc: | |
| case lhsc, Var(nam) if nam in net.vars: | |
| net.subs[lhs] = lhsc | |
| net.book.add((lhs, nam)) | |
| case lhsc, Var(nam): | |
| net.subs[NegNam(nam)] = lhsc | |
| case Sub(nam), rhsc if nam in net.subs: | |
| net.vars[rhs] = rhsc | |
| net.book.add((nam, rhs)) | |
| case Sub(nam), rhsc: | |
| net.vars[PosNam(nam)] = rhsc | |
| case lhsc, Cob(m, a) if m in net.vars: | |
| net.subs[lhs] = lhsc | |
| x = next(name) | |
| net.vars[PosNam(x)] = net.vars.pop(m) | |
| net.book.add((lhs, PosNam(x))) | |
| case lhsc, Cob(m, a): | |
| net.subs[lhs] = lhsc | |
| net.book.add((lhs, a)) | |
| case Amb(m, a), rhsc if m in net.subs: | |
| net.vars[rhs] = rhsc | |
| x = next(name) | |
| net.subs[NegNam(x)] = net.subs.pop(m) | |
| net.book.add((NegNam(x), rhs)) | |
| case Amb(m, a), rhsc: | |
| net.vars[rhs] = rhsc | |
| net.book.add((a, rhs)) | |
| case Era(), Nil(): | |
| pass | |
| case Era(), Lam(bnd, bod): | |
| net.book.add((bnd, nil(net))) | |
| net.book.add((era(net), bod)) | |
| case Era(), Sup(spl, sp0, sp1): | |
| net.book.add((era(net), sp0)) | |
| net.book.add((era(net), sp1)) | |
| case Era(), Usp(us0, us1): | |
| net.book.add((era(net), us0)) | |
| net.book.add((era(net), us1)) | |
| case Era(), Wait(wt0, wt1): | |
| net.book.add((era(net), wt0)) | |
| net.book.add((wt1, nil(net))) | |
| case Era(), Call(): | |
| pass | |
| case App(arg, ret), Nil(): | |
| net.book.add((era(net), arg)) | |
| net.book.add((ret, nil(net))) | |
| case App(arg, ret), Lam(bnd, bod): | |
| ap, an = wire(net) | |
| net.book.add((bnd, wait(net, ap, hold(net, an, arg)))) | |
| net.book.add((ret, bod)) | |
| case App(arg, ret), Sup(spl, sp0, sp1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| net.book.add((dup(net, spl, an, bn), arg)) | |
| net.book.add((ret, sup(net, spl, cp, dp))) | |
| net.book.add((app(net, ap, cn), sp0)) | |
| net.book.add((app(net, bp, dn), sp1)) | |
| case App(arg, ret), Usp(us0, us1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| lab = next(label) | |
| net.book.add((dup(net, lab, an, bn), arg)) | |
| net.book.add((ret, sup(net, lab, cp, dp))) | |
| net.book.add((app(net, ap, cn), us0)) | |
| net.book.add((app(net, bp, dn), us1)) | |
| case App(arg, ret), Wait(wt0, wt1): | |
| ap, an = wire(net) | |
| net.book.add((ret, wait(net, ap, hold(net, app(net, arg, an), wait(net, wt0, wt1))))) | |
| case Dup(dpl, dp0, dp1), Nil(): | |
| net.book.add((dp0, nil(net))) | |
| net.book.add((dp1, nil(net))) | |
| case Dup(dpl, dp0, dp1), Lam(bnd, bod): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| net.book.add((dp0, lam(net, an, bp))) | |
| net.book.add((dp1, lam(net, cn, dp))) | |
| net.book.add((bnd, sup(net, dpl, ap, cp))) | |
| net.book.add((dup(net, dpl, bn, dn), bod)) | |
| case Dup(dpl, dp0, dp1), Sup(spl, sp0, sp1) if dpl == spl: | |
| net.book.add((dp0, sp0)) | |
| net.book.add((dp1, sp1)) | |
| case Dup(dpl, dp0, dp1), Sup(spl, sp0, sp1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| net.book.add((dp0, sup(net, spl, ap, bp))) | |
| net.book.add((dp1, sup(net, spl, cp, dp))) | |
| net.book.add((dup(net, dpl, an, cn), sp0)) | |
| net.book.add((dup(net, dpl, bn, dn), sp1)) | |
| case Dup(dpl, dp0, dp1), Usp(us0, us1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| net.book.add((dp0, usp(net, ap, bp))) | |
| net.book.add((dp1, usp(net, cp, dp))) | |
| net.book.add((dup(net, dpl, an, cn), us0)) | |
| net.book.add((dup(net, dpl, bn, dn), us1)) | |
| case Dup(dpl, dp0, dp1), Wait(wt0, wt1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dc = decide(net, wt1, bp) | |
| net.book.add((dp0, wait(net, ap, amb(net, dc, bn)))) | |
| net.book.add((dp1, wait(net, cp, amb(net, dc, bn)))) | |
| net.book.add((dup(net, dpl, an, cn), wt0)) | |
| case Udp(ud0, ud1), Nil(): | |
| net.book.add((ud0, nil(net))) | |
| net.book.add((ud1, nil(net))) | |
| case Udp(ud0, ud1), Lam(bnd, bod): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| lab = next(label) | |
| net.book.add((ud0, lam(net, an, bp))) | |
| net.book.add((ud1, lam(net, cn, dp))) | |
| net.book.add((bnd, sup(net, lab, ap, cp))) | |
| net.book.add((dup(net, lab, bn, dn), bod)) | |
| case Udp(ud0, ud1), Sup(spl, sp0, sp1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| net.book.add((ud0, sup(net, spl, ap, bp))) | |
| net.book.add((ud1, sup(net, spl, cp, dp))) | |
| net.book.add((udp(net, an, cn), sp0)) | |
| net.book.add((udp(net, bn, dn), sp1)) | |
| case Udp(ud0, ud1), Usp(us0, us1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dp, dn = wire(net) | |
| net.book.add((ud0, usp(net, ap, bp))) | |
| net.book.add((ud1, usp(net, cp, dp))) | |
| net.book.add((udp(net, an, cn), us0)) | |
| net.book.add((udp(net, bn, dn), us1)) | |
| case Udp(ud0, ud1), Wait(wt0, wt1): | |
| ap, an = wire(net) | |
| bp, bn = wire(net) | |
| cp, cn = wire(net) | |
| dc = decide(net, wt1, bp) | |
| net.book.add((ud0, wait(net, ap, amb(net, dc, bn)))) | |
| net.book.add((ud1, wait(net, cp, amb(net, dc, bn)))) | |
| net.book.add((udp(net, an, cn), wt0)) | |
| case Hold(hd0, hd1), Nil(): | |
| net.book.add((hd0, nil(net))) | |
| net.book.add((era(net), hd1)) | |
| case Hold(hd0, hd1), Call(): | |
| net.book.add((eval(net, hd0), hd1)) | |
| case Decide(dc0, dc1), Nil(): | |
| net.book.add((dc0, dc1)) | |
| case Decide(dc0, dc1), Call(): | |
| net.book.add((dc0, call(net))) | |
| net.book.add((era(net), dc1)) | |
| case Eval(eva), Nil(): | |
| net.book.add((eva, nil(net))) | |
| case Eval(eva), Lam(bnd, bod): | |
| ap, an = wire(net) | |
| net.book.add((eva, lam(net, bnd, ap))) | |
| net.book.add((eval(net, an), bod)) | |
| case Eval(eva), Sup(spl, sp0, sp1): | |
| net.book.add((eva, sup(net, spl, sp0, sp1))) | |
| case Eval(eva), Usp(us0, us1): | |
| net.book.add((eva, usp(net, us0, us1))) | |
| case Eval(eva), Wait(wt0, wt1): | |
| net.book.add((eval(net, eva), wt0)) | |
| net.book.add((wt1, call(net))) | |
| case lhsc, rhsc: | |
| raise RuntimeError(f"{lhsc.__class__.__name__.upper()}-{rhsc.__class__.__name__.upper()} undefined") | |
| bump(itrs, f"{lhsc.__class__.__name__}-{rhsc.__class__.__name__}") | |
| def normalize(net: Net) -> Stats: | |
| stats = Stats(itrs={}, max_size=len(net.vars)+len(net.subs), parallelism=len(net.book)) | |
| while net.book: | |
| reduce(net, stats.itrs) | |
| stats.max_size = max(stats.max_size, len(net.vars) + len(net.subs)) | |
| stats.parallelism = max(stats.parallelism, len(net.book)) | |
| return stats | |
| def print_state(net: Net, root: PosNam, *, heap: bool = False) -> None: | |
| print("ROOT:") | |
| print(f" {show_pos(net, root)}") | |
| print() | |
| print("BOOK:") | |
| for lhs, rhs in net.book: | |
| print(f" {show_neg(net, lhs)} ⋈ {show_pos(net, rhs)}") | |
| if heap: | |
| print() | |
| print("VARS:") | |
| for nam in net.vars: | |
| print(f" {nam} = {show_pos(net, nam)}") | |
| print() | |
| print("SUBS:") | |
| for nam in net.subs: | |
| print(f" {nam} = {show_neg(net, nam)}") | |
| def print_itrs(itrs: dict[str, int]) -> None: | |
| counts = itrs | |
| print(f" ITRS: {sum(counts.values())}") | |
| for k, v in sorted(counts.items(), key=lambda x: x[1], reverse=True): | |
| if v > 0: | |
| print(f" {k}: {v}") | |
| def print_stats(stats: Stats) -> None: | |
| print("STATS:") | |
| print_itrs(stats.itrs) | |
| print(f" MAX SIZE: {stats.max_size}") | |
| print(f" PARALLELISM: {stats.parallelism}") | |
| def node_type_counts(net: Net) -> dict[str, int]: | |
| counts: dict[str, int] = {} | |
| for node in net.vars.values(): | |
| name = node.__class__.__name__ | |
| counts[name] = counts.get(name, 0) + 1 | |
| for node in net.subs.values(): | |
| name = node.__class__.__name__ | |
| counts[name] = counts.get(name, 0) + 1 | |
| return counts | |
| def print_node_counts(net: Net) -> None: | |
| counts = node_type_counts(net) | |
| print("NODE COUNTS:") | |
| for name, count in sorted(counts.items(), key=lambda x: (-x[1], x[0])): | |
| print(f" {name}: {count}") | |
| def print_reduction(net: Net, root: PosNam, *, heap: bool = False) -> None: | |
| print("=" * 30) | |
| print("=", "INITIAL".center(26), "=") | |
| print("=" * 30) | |
| print() | |
| print_state(net, root, heap=heap) | |
| print() | |
| print("=" * 30) | |
| print("=", "NORMALIZED".center(26), "=") | |
| print("=" * 30) | |
| print() | |
| stats = normalize(net) | |
| print_state(net, root, heap=heap) | |
| print() | |
| print_stats(stats) | |
| print() | |
| print_node_counts(net) | |
| def mk_app(net: Net, fun: PosNam, arg: PosNam) -> PosNam: | |
| rp, rn = wire(net) | |
| net.book.add((app(net, arg, rn), fun)) | |
| return rp | |
| def mk_udp(net: Net, x: PosNam) -> tuple[PosNam, PosNam]: | |
| x0p, x0n = wire(net) | |
| x1p, x1n = wire(net) | |
| net.book.add((udp(net, x0n, x1n), x)) | |
| return x0p, x1p | |
| def mk_eval(net: Net, x: PosNam) -> PosNam: | |
| rp, rn = wire(net) | |
| net.book.add((eval(net, rn), x)) | |
| return rp | |
| # λx.x | |
| def mk_id(net: Net) -> PosNam: | |
| xp, xn = wire(net) | |
| return lam(net, xn, xp) | |
| def mk_nat(net: Net, n: int) -> PosNam: | |
| xp, xn = wire(net) | |
| res = xp | |
| apps: list[NegNam] = [] | |
| for _ in range(n): | |
| res_p, res_n = wire(net) | |
| apps.append(app(net, res, res_n)) | |
| res = res_p | |
| def build_tree(leaves: list[NegNam]) -> NegNam: | |
| if not leaves: | |
| return era(net) | |
| if len(leaves) == 1: | |
| return leaves[0] | |
| mid = len(leaves) // 2 | |
| return udp(net, build_tree(leaves[:mid]), build_tree(leaves[mid:])) | |
| return lam(net, build_tree(apps), lam(net, xn, res)) | |
| def main() -> None: | |
| net = empty_net() | |
| # root = mk_nat(net, 2) | |
| # root = mk_app(net, root, mk_nat(net, 2)) | |
| # root = mk_app(net, root, mk_nat(net, 2)) | |
| # root = mk_app(net, root, mk_nat(net, 3)) | |
| # root = mk_app(net, root, mk_id(net)) | |
| # print_reduction(net, mk_eval(net, root)) | |
| root = mk_nat(net, 2) | |
| root = mk_app(net, root, mk_nat(net, 2)) | |
| root = mk_app(net, root, mk_nat(net, 2)) | |
| root = mk_app(net, root, mk_nat(net, 3)) | |
| root = mk_eval(net, root) | |
| stats = normalize(net) | |
| print_state(net, root) | |
| print() | |
| print_stats(stats) | |
| print() | |
| print_node_counts(net) | |
| # print() | |
| # root = mk_app(net, root, mk_id(net)) | |
| # root = mk_eval(net, root) | |
| # stats = normalize(net) | |
| # print_state(net, root) | |
| # print() | |
| # print_stats(stats) | |
| # print() | |
| # print_node_counts(net) | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment