Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active January 18, 2026 19:04
Show Gist options
  • Select an option

  • Save LeeeeT/4d95151fdd6afbda016b44a44a69115a to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/4d95151fdd6afbda016b44a44a69115a to your computer and use it in GitHub Desktop.
Token-passing Optimal Reduction
import itertools
from dataclasses import dataclass
from typing import NewType
type Name = int
name = itertools.count()
PosNam = NewType("PosNam", Name)
NegNam = NewType("NegNam", Name)
type Label = int
label = itertools.count()
type Pos = Var | Nil | Lam | Sup | Usp | Cob | Wait | Call
type Neg = Sub | Era | App | Dup | Udp | Amb | Hold | Decide | Eval
@dataclass(frozen=True)
class Var:
nam: PosNam
@dataclass(frozen=True)
class Sub:
nam: NegNam
@dataclass(frozen=True)
class Nil:
pass
@dataclass(frozen=True)
class Era:
pass
@dataclass(frozen=True)
class Lam:
bnd: NegNam
bod: PosNam
@dataclass(frozen=True)
class App:
arg: PosNam
ret: NegNam
@dataclass(frozen=True)
class Dup:
dpl: Label
dp0: NegNam
dp1: NegNam
@dataclass(frozen=True)
class Sup:
spl: Label
sp0: PosNam
sp1: PosNam
@dataclass(frozen=True)
class Udp:
ud0: NegNam
ud1: NegNam
@dataclass(frozen=True)
class Usp:
us0: PosNam
us1: PosNam
@dataclass(frozen=True)
class Amb:
am0: NegNam
am1: NegNam
@dataclass(frozen=True)
class Cob:
cb0: PosNam
cb1: PosNam
@dataclass(frozen=True)
class Wait:
wt0: PosNam
wt1: NegNam
@dataclass(frozen=True)
class Call:
pass
@dataclass(frozen=True)
class Hold:
hd0: NegNam
hd1: PosNam
@dataclass(frozen=True)
class Decide:
dc0: NegNam
dc1: PosNam
@dataclass(frozen=True)
class Eval:
eva: NegNam
type Rdx = tuple[NegNam, PosNam]
@dataclass(frozen=True)
class Net:
book: set[Rdx]
vars: dict[PosNam, Pos]
subs: dict[NegNam, Neg]
def empty_net() -> Net:
return Net(set(), {}, {})
def var(net: Net, nam: PosNam) -> PosNam:
var = PosNam(next(name))
net.vars[var] = Var(nam)
return var
def sub(net: Net, nam: NegNam) -> NegNam:
sub = NegNam(next(name))
net.subs[sub] = Sub(nam)
return sub
def nil(net: Net) -> PosNam:
nil = PosNam(next(name))
net.vars[nil] = Nil()
return nil
def era(net: Net) -> NegNam:
era = NegNam(next(name))
net.subs[era] = Era()
return era
def lam(net: Net, bnd: NegNam, bod: PosNam) -> PosNam:
lam = PosNam(next(name))
net.vars[lam] = Lam(bnd, bod)
return lam
def app(net: Net, arg: PosNam, ret: NegNam) -> NegNam:
app = NegNam(next(name))
net.subs[app] = App(arg, ret)
return app
def dup(net: Net, dpl: Label, dp0: NegNam, dp1: NegNam) -> NegNam:
dup = NegNam(next(name))
net.subs[dup] = Dup(dpl, dp0, dp1)
return dup
def sup(net: Net, spl: Label, sp0: PosNam, sp1: PosNam) -> PosNam:
sup = PosNam(next(name))
net.vars[sup] = Sup(spl, sp0, sp1)
return sup
def udp(net: Net, ud0: NegNam, ud1: NegNam) -> NegNam:
udp = NegNam(next(name))
net.subs[udp] = Udp(ud0, ud1)
return udp
def usp(net: Net, us0: PosNam, us1: PosNam) -> PosNam:
usp = PosNam(next(name))
net.vars[usp] = Usp(us0, us1)
return usp
def amb(net: Net, am0: NegNam, am1: NegNam) -> NegNam:
amb = NegNam(next(name))
net.subs[amb] = Amb(am0, am1)
return amb
def cob(net: Net, cb0: PosNam, cb1: PosNam) -> PosNam:
cob = PosNam(next(name))
net.vars[cob] = Cob(cb0, cb1)
return cob
def wait(net: Net, wt0: PosNam, wt1: NegNam) -> PosNam:
wait = PosNam(next(name))
net.vars[wait] = Wait(wt0, wt1)
return wait
def call(net: Net) -> PosNam:
call = PosNam(next(name))
net.vars[call] = Call()
return call
def hold(net: Net, hd0: NegNam, hd1: PosNam) -> NegNam:
hold = NegNam(next(name))
net.subs[hold] = Hold(hd0, hd1)
return hold
def decide(net: Net, dc0: NegNam, dc1: PosNam) -> NegNam:
decide = NegNam(next(name))
net.subs[decide] = Decide(dc0, dc1)
return decide
def eval(net: Net, eva: NegNam) -> NegNam:
eval = NegNam(next(name))
net.subs[eval] = Eval(eva)
return eval
def wire(net: Net) -> tuple[PosNam, NegNam]:
nam = next(name)
return var(net, PosNam(nam)), sub(net, NegNam(nam))
def show_pos(net: Net, pos: PosNam, visited: set[Name] | None = None) -> str:
if visited is None:
visited = set()
if pos in visited:
return "+....."
visited.add(pos)
match net.vars[pos]:
case Var(nam) if nam in net.vars:
return show_pos(net, nam, visited)
case Var(nam):
return f"+{nam}"
case Nil():
return "+_"
case Lam(bnd, bod):
return f"+({show_neg(net, bnd, visited)} {show_pos(net, bod, visited)})"
case Sup(spl, sp0, sp1):
return f"+&{spl}{{{show_pos(net, sp0, visited)} {show_pos(net, sp1, visited)}}}"
case Usp(us0, us1):
return f"+{{{show_pos(net, us0, visited)} {show_pos(net, us1, visited)}}}"
case Cob(cb0, cb1):
return f"-cob({show_pos(net, cb0, visited) if cb0 in net.vars else "stolen"} {show_pos(net, cb1, visited) if cb1 in net.vars else "stolen"})"
case Wait(wt0, wt1):
return f"+wait({show_pos(net, wt0, visited)} {show_neg(net, wt1, visited)})"
case Call():
return "+call"
def show_neg(net: Net, neg: NegNam, visited: set[Name] | None = None) -> str:
if visited is None:
visited = set()
if neg in visited:
return "+....."
visited.add(neg)
match net.subs[neg]:
case Sub(nam) if nam in net.subs:
return show_neg(net, nam, visited)
case Sub(nam):
return f"-{nam}"
case Era():
return "-_"
case App(arg, ret):
return f"-({show_pos(net, arg, visited)} {show_neg(net, ret, visited)})"
case Dup(dpl, dp0, dp1):
return f"-&{dpl}{{{show_neg(net, dp0, visited)} {show_neg(net, dp1, visited)}}}"
case Udp(ud0, ud1):
return f"-{{{show_neg(net, ud0, visited)} {show_neg(net, ud1, visited)}}}"
case Amb(am0, am1):
return f"-amb({show_neg(net, am0, visited) if am0 in net.subs else "stolen"} {show_neg(net, am1, visited) if am1 in net.subs else "stolen"})"
case Hold(hd0, hd1):
return f"-hold({show_neg(net, hd0, visited)} {show_pos(net, hd1, visited)})"
case Decide(dc0, dc1):
return f"-decide({show_neg(net, dc0, visited)} {show_pos(net, dc1, visited)})"
case Eval(eva):
# return f"-eval({show_neg(net, eva)})"
return show_neg(net, eva, visited)
@dataclass(kw_only=True)
class Stats:
itrs: dict[str, int]
max_size: int
parallelism: int
def bump(itrs: dict[str, int], key: str) -> None:
itrs[key] = itrs.get(key, 0) + 1
def reduce(net: Net, itrs: dict[str, int]) -> None:
lhs, rhs = net.book.pop()
lhsc = net.subs.pop(lhs)
rhsc = net.vars.pop(rhs)
match lhsc, rhsc:
case lhsc, Var(nam) if nam in net.vars:
net.subs[lhs] = lhsc
net.book.add((lhs, nam))
case lhsc, Var(nam):
net.subs[NegNam(nam)] = lhsc
case Sub(nam), rhsc if nam in net.subs:
net.vars[rhs] = rhsc
net.book.add((nam, rhs))
case Sub(nam), rhsc:
net.vars[PosNam(nam)] = rhsc
case lhsc, Cob(m, a) if m in net.vars:
net.subs[lhs] = lhsc
x = next(name)
net.vars[PosNam(x)] = net.vars.pop(m)
net.book.add((lhs, PosNam(x)))
case lhsc, Cob(m, a):
net.subs[lhs] = lhsc
net.book.add((lhs, a))
case Amb(m, a), rhsc if m in net.subs:
net.vars[rhs] = rhsc
x = next(name)
net.subs[NegNam(x)] = net.subs.pop(m)
net.book.add((NegNam(x), rhs))
case Amb(m, a), rhsc:
net.vars[rhs] = rhsc
net.book.add((a, rhs))
case Era(), Nil():
pass
case Era(), Lam(bnd, bod):
net.book.add((bnd, nil(net)))
net.book.add((era(net), bod))
case Era(), Sup(spl, sp0, sp1):
net.book.add((era(net), sp0))
net.book.add((era(net), sp1))
case Era(), Usp(us0, us1):
net.book.add((era(net), us0))
net.book.add((era(net), us1))
case Era(), Wait(wt0, wt1):
net.book.add((era(net), wt0))
net.book.add((wt1, nil(net)))
case Era(), Call():
pass
case App(arg, ret), Nil():
net.book.add((era(net), arg))
net.book.add((ret, nil(net)))
case App(arg, ret), Lam(bnd, bod):
ap, an = wire(net)
net.book.add((bnd, wait(net, ap, hold(net, an, arg))))
net.book.add((ret, bod))
case App(arg, ret), Sup(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dup(net, spl, an, bn), arg))
net.book.add((ret, sup(net, spl, cp, dp)))
net.book.add((app(net, ap, cn), sp0))
net.book.add((app(net, bp, dn), sp1))
case App(arg, ret), Usp(us0, us1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
lab = next(label)
net.book.add((dup(net, lab, an, bn), arg))
net.book.add((ret, sup(net, lab, cp, dp)))
net.book.add((app(net, ap, cn), us0))
net.book.add((app(net, bp, dn), us1))
case App(arg, ret), Wait(wt0, wt1):
ap, an = wire(net)
net.book.add((ret, wait(net, ap, hold(net, app(net, arg, an), wait(net, wt0, wt1)))))
case Dup(dpl, dp0, dp1), Nil():
net.book.add((dp0, nil(net)))
net.book.add((dp1, nil(net)))
case Dup(dpl, dp0, dp1), Lam(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, lam(net, an, bp)))
net.book.add((dp1, lam(net, cn, dp)))
net.book.add((bnd, sup(net, dpl, ap, cp)))
net.book.add((dup(net, dpl, bn, dn), bod))
case Dup(dpl, dp0, dp1), Sup(spl, sp0, sp1) if dpl == spl:
net.book.add((dp0, sp0))
net.book.add((dp1, sp1))
case Dup(dpl, dp0, dp1), Sup(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, sup(net, spl, ap, bp)))
net.book.add((dp1, sup(net, spl, cp, dp)))
net.book.add((dup(net, dpl, an, cn), sp0))
net.book.add((dup(net, dpl, bn, dn), sp1))
case Dup(dpl, dp0, dp1), Usp(us0, us1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, usp(net, ap, bp)))
net.book.add((dp1, usp(net, cp, dp)))
net.book.add((dup(net, dpl, an, cn), us0))
net.book.add((dup(net, dpl, bn, dn), us1))
case Dup(dpl, dp0, dp1), Wait(wt0, wt1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dc = decide(net, wt1, bp)
net.book.add((dp0, wait(net, ap, amb(net, dc, bn))))
net.book.add((dp1, wait(net, cp, amb(net, dc, bn))))
net.book.add((dup(net, dpl, an, cn), wt0))
case Udp(ud0, ud1), Nil():
net.book.add((ud0, nil(net)))
net.book.add((ud1, nil(net)))
case Udp(ud0, ud1), Lam(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
lab = next(label)
net.book.add((ud0, lam(net, an, bp)))
net.book.add((ud1, lam(net, cn, dp)))
net.book.add((bnd, sup(net, lab, ap, cp)))
net.book.add((dup(net, lab, bn, dn), bod))
case Udp(ud0, ud1), Sup(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((ud0, sup(net, spl, ap, bp)))
net.book.add((ud1, sup(net, spl, cp, dp)))
net.book.add((udp(net, an, cn), sp0))
net.book.add((udp(net, bn, dn), sp1))
case Udp(ud0, ud1), Usp(us0, us1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((ud0, usp(net, ap, bp)))
net.book.add((ud1, usp(net, cp, dp)))
net.book.add((udp(net, an, cn), us0))
net.book.add((udp(net, bn, dn), us1))
case Udp(ud0, ud1), Wait(wt0, wt1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dc = decide(net, wt1, bp)
net.book.add((ud0, wait(net, ap, amb(net, dc, bn))))
net.book.add((ud1, wait(net, cp, amb(net, dc, bn))))
net.book.add((udp(net, an, cn), wt0))
case Hold(hd0, hd1), Nil():
net.book.add((hd0, nil(net)))
net.book.add((era(net), hd1))
case Hold(hd0, hd1), Call():
net.book.add((eval(net, hd0), hd1))
case Decide(dc0, dc1), Nil():
net.book.add((dc0, dc1))
case Decide(dc0, dc1), Call():
net.book.add((dc0, call(net)))
net.book.add((era(net), dc1))
case Eval(eva), Nil():
net.book.add((eva, nil(net)))
case Eval(eva), Lam(bnd, bod):
ap, an = wire(net)
net.book.add((eva, lam(net, bnd, ap)))
net.book.add((eval(net, an), bod))
case Eval(eva), Sup(spl, sp0, sp1):
net.book.add((eva, sup(net, spl, sp0, sp1)))
case Eval(eva), Usp(us0, us1):
net.book.add((eva, usp(net, us0, us1)))
case Eval(eva), Wait(wt0, wt1):
net.book.add((eval(net, eva), wt0))
net.book.add((wt1, call(net)))
case lhsc, rhsc:
raise RuntimeError(f"{lhsc.__class__.__name__.upper()}-{rhsc.__class__.__name__.upper()} undefined")
bump(itrs, f"{lhsc.__class__.__name__}-{rhsc.__class__.__name__}")
def normalize(net: Net) -> Stats:
stats = Stats(itrs={}, max_size=len(net.vars)+len(net.subs), parallelism=len(net.book))
while net.book:
reduce(net, stats.itrs)
stats.max_size = max(stats.max_size, len(net.vars) + len(net.subs))
stats.parallelism = max(stats.parallelism, len(net.book))
return stats
def print_state(net: Net, root: PosNam, *, heap: bool = False) -> None:
print("ROOT:")
print(f" {show_pos(net, root)}")
print()
print("BOOK:")
for lhs, rhs in net.book:
print(f" {show_neg(net, lhs)} ⋈ {show_pos(net, rhs)}")
if heap:
print()
print("VARS:")
for nam in net.vars:
print(f" {nam} = {show_pos(net, nam)}")
print()
print("SUBS:")
for nam in net.subs:
print(f" {nam} = {show_neg(net, nam)}")
def print_itrs(itrs: dict[str, int]) -> None:
counts = itrs
print(f" ITRS: {sum(counts.values())}")
for k, v in sorted(counts.items(), key=lambda x: x[1], reverse=True):
if v > 0:
print(f" {k}: {v}")
def print_stats(stats: Stats) -> None:
print("STATS:")
print_itrs(stats.itrs)
print(f" MAX SIZE: {stats.max_size}")
print(f" PARALLELISM: {stats.parallelism}")
def node_type_counts(net: Net) -> dict[str, int]:
counts: dict[str, int] = {}
for node in net.vars.values():
name = node.__class__.__name__
counts[name] = counts.get(name, 0) + 1
for node in net.subs.values():
name = node.__class__.__name__
counts[name] = counts.get(name, 0) + 1
return counts
def print_node_counts(net: Net) -> None:
counts = node_type_counts(net)
print("NODE COUNTS:")
for name, count in sorted(counts.items(), key=lambda x: (-x[1], x[0])):
print(f" {name}: {count}")
def print_reduction(net: Net, root: PosNam, *, heap: bool = False) -> None:
print("=" * 30)
print("=", "INITIAL".center(26), "=")
print("=" * 30)
print()
print_state(net, root, heap=heap)
print()
print("=" * 30)
print("=", "NORMALIZED".center(26), "=")
print("=" * 30)
print()
stats = normalize(net)
print_state(net, root, heap=heap)
print()
print_stats(stats)
print()
print_node_counts(net)
def mk_app(net: Net, fun: PosNam, arg: PosNam) -> PosNam:
rp, rn = wire(net)
net.book.add((app(net, arg, rn), fun))
return rp
def mk_udp(net: Net, x: PosNam) -> tuple[PosNam, PosNam]:
x0p, x0n = wire(net)
x1p, x1n = wire(net)
net.book.add((udp(net, x0n, x1n), x))
return x0p, x1p
def mk_eval(net: Net, x: PosNam) -> PosNam:
rp, rn = wire(net)
net.book.add((eval(net, rn), x))
return rp
# λx.x
def mk_id(net: Net) -> PosNam:
xp, xn = wire(net)
return lam(net, xn, xp)
def mk_nat(net: Net, n: int) -> PosNam:
xp, xn = wire(net)
res = xp
apps: list[NegNam] = []
for _ in range(n):
res_p, res_n = wire(net)
apps.append(app(net, res, res_n))
res = res_p
def build_tree(leaves: list[NegNam]) -> NegNam:
if not leaves:
return era(net)
if len(leaves) == 1:
return leaves[0]
mid = len(leaves) // 2
return udp(net, build_tree(leaves[:mid]), build_tree(leaves[mid:]))
return lam(net, build_tree(apps), lam(net, xn, res))
def main() -> None:
net = empty_net()
# root = mk_nat(net, 2)
# root = mk_app(net, root, mk_nat(net, 2))
# root = mk_app(net, root, mk_nat(net, 2))
# root = mk_app(net, root, mk_nat(net, 3))
# root = mk_app(net, root, mk_id(net))
# print_reduction(net, mk_eval(net, root))
root = mk_nat(net, 2)
root = mk_app(net, root, mk_nat(net, 2))
root = mk_app(net, root, mk_nat(net, 2))
root = mk_app(net, root, mk_nat(net, 3))
root = mk_eval(net, root)
stats = normalize(net)
print_state(net, root)
print()
print_stats(stats)
print()
print_node_counts(net)
# print()
# root = mk_app(net, root, mk_id(net))
# root = mk_eval(net, root)
# stats = normalize(net)
# print_state(net, root)
# print()
# print_stats(stats)
# print()
# print_node_counts(net)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment