Skip to content

Instantly share code, notes, and snippets.

@DarinM223
Last active August 14, 2025 06:21
Show Gist options
  • Select an option

  • Save DarinM223/4f8bcc45f526dbbb3fe46bd2eed5433c to your computer and use it in GitHub Desktop.

Select an option

Save DarinM223/4f8bcc45f526dbbb3fe46bd2eed5433c to your computer and use it in GitHub Desktop.
ANF Conversion where all calls should be in tail position (fully CPS)
structure Lam =
struct
datatype bop = Add | Sub | Mul
val showBop = fn Add => "Add" | Sub => "Sub" | Mul => "Mul"
datatype exp = Lit of int | Bop of bop * exp * exp | Call of string * exp list
type var = string
val showVar = fn t0 => "\"" ^ t0 ^ "\""
datatype value = Var of var | Int of int
val showValue =
fn Var t0 => "Var " ^ "(" ^ showVar t0 ^ ")"
| Int t1 => "Int " ^ "(" ^ Int.toString t1 ^ ")"
datatype lexp =
Halt of value
| LetBop of var * bop * value * value * lexp
| LetCall of var * string * value list * lexp
local
val rec lexp = fn lexp_0 =>
fn Halt t0 => "Halt " ^ "(" ^ showValue t0 ^ ")"
| LetBop (t1, t2, t3, t4, t5) =>
"LetBop " ^ "("
^
String.concatWith ", "
[showVar t1, showBop t2, showValue t3, showValue t4, lexp_0 t5] ^ ")"
| LetCall (t6, t7, t8, t9) =>
"LetCall " ^ "("
^
String.concatWith ", "
[ showVar t6
, "\"" ^ t7 ^ "\""
, "[" ^ String.concatWith ", " (List.map showValue t8) ^ "]"
, lexp_0 t9
] ^ ")"
val lexp = fn () => let val rec lexp_0 = fn ? => lexp lexp_0 ? in lexp_0 end
in val showLexp = lexp ()
end
local val c = ref 0
in val fresh = fn p => p ^ Int.toString (!c) before c := !c + 1
end
infix @@
fun f @@ a = f a
local
fun go (exp: exp) (k': lexp -> lexp) (k: value * (lexp -> lexp) -> lexp) :
lexp =
case exp of
Lit i => k (Int i, k')
| Bop (bop, l, r) =>
go l k' @@ (fn (l, k') =>
go r k' @@ (fn (r, k') =>
let val tmp = fresh "tmp"
in k (Var tmp, fn rest => k' (LetBop (tmp, bop, l, r, rest)))
end))
| Call (f, exps) => goChildren f [] exps k' k
and goChildren f vs [] k' k =
let val tmp = fresh "tmp"
in k (Var tmp, fn rest => k' (LetCall (tmp, f, List.rev vs, rest)))
end
| goChildren f vs (e :: es) k' k =
go e k' (fn (v, k') => goChildren f (v :: vs) es k' k)
in val convert = fn exp => go exp (fn a => a) (fn (v, f) => f (Halt v))
end
end
local open Lam
in
val result = convert (Bop (Add, Bop (Sub, Lit 2, Lit 1), Call
("f", [Bop (Mul, Lit 3, Call ("h", [Lit 4, Lit 6])), Call ("g", [Lit 5])])))
val () = print (showLexp result ^ "\n")
end
structure Lam =
struct
datatype bop = Plus | Minus | Times
val showBop = fn Plus => "Plus" | Minus => "Minus" | Times => "Times"
type var = string
datatype exp =
Int of int
| Var of var
| Lam of var * exp
| App of exp * exp
| Bop of bop * exp * exp
| If of exp * exp * exp
end
structure Anf =
struct
structure L = Lam
type var = string
val showVar = fn t0 => "\"" ^ t0 ^ "\""
datatype value = Int of int | Var of var | Glob of var
val showValue =
fn Int t0 => "Int " ^ "(" ^ Int.toString t0 ^ ")"
| Var t1 => "Var " ^ "(" ^ showVar t1 ^ ")"
| Glob t2 => "Glob " ^ "(" ^ showVar t2 ^ ")"
datatype exp =
Halt of value
| Fun of var * var list * exp * exp
| Join of var * var option * exp * exp
| Jump of var * value option
| App of var * var * value list * exp
| Bop of var * L.bop * value * value * exp
| If of value * exp * exp
| Tuple of var * value list * exp
| Proj of var * var * int * exp
(* Generated by smlgen *)
local
fun showOption f (SOME s) = "SOME " ^ f s
| showOption _ NONE = "NONE"
val rec exp = fn exp_0 =>
fn Halt t0 => "Halt " ^ "(" ^ showValue t0 ^ ")"
| Fun (t1, t2, t3, t4) =>
"Fun " ^ "("
^
String.concatWith ", "
[ showVar t1
, "[" ^ String.concatWith ", " (List.map showVar t2) ^ "]"
, exp_0 t3
, exp_0 t4
] ^ ")"
| Join (t5, t6, t7, t8) =>
"Join " ^ "("
^
String.concatWith ", "
[showVar t5, showOption showVar t6, exp_0 t7, exp_0 t8] ^ ")"
| Jump (t9, t10) =>
"Jump " ^ "("
^ String.concatWith ", " [showVar t9, showOption showValue t10] ^ ")"
| App (t11, t12, t13, t14) =>
"App " ^ "("
^
String.concatWith ", "
[ showVar t11
, showVar t12
, "[" ^ String.concatWith ", " (List.map showValue t13) ^ "]"
, exp_0 t14
] ^ ")"
| Bop (t15, t16, t17, t18, t19) =>
"Bop " ^ "("
^
String.concatWith ", "
[showVar t15, L.showBop t16, showValue t17, showValue t18, exp_0 t19]
^ ")"
| If (t20, t21, t22) =>
"If " ^ "("
^ String.concatWith ", " [showValue t20, exp_0 t21, exp_0 t22] ^ ")"
| Tuple (t23, t24, t25) =>
"Tuple " ^ "("
^
String.concatWith ", "
[ showVar t23
, "[" ^ String.concatWith ", " (List.map showValue t24) ^ "]"
, exp_0 t25
] ^ ")"
| Proj (t26, t27, t28, t29) =>
"Proj " ^ "("
^
String.concatWith ", "
[showVar t26, showVar t27, Int.toString t28, exp_0 t29] ^ ")"
val exp = fn () => let val rec exp_0 = fn ? => exp exp_0 ? in exp_0 end
in val showExp = exp ()
end
local val c = ref 0
in
val fresh = fn p => p ^ Int.toString (!c) before c := !c + 1
val reset = fn () => c := 0
end
infix @@
fun f @@ a = f a
(* Start with easiest to write solution using continuations,
but not in full CPS form (k is not called only in tail form) *)
local
fun go exp k =
case exp of
L.Int i => k (Int i)
| L.Var v => k (Var v)
| L.Lam (v, body) =>
let
val body = go body Halt
val f = fresh "f"
in
Fun (f, [v], body, k (Var f))
end
| L.App (f, x) =>
go f @@ (fn f =>
go x @@ (fn x =>
case f of
Var f => let val r = fresh "r" in App (r, f, [x], k (Var r)) end
| _ => raise Fail "must apply named value"))
| L.Bop (bop, x, y) =>
go x @@ (fn x =>
go y @@ (fn y =>
let val r = fresh "r"
in Bop (r, bop, x, y, k (Var r))
end))
| L.If (c, t, f) =>
go c (fn c =>
let
val (j, p) = (fresh "j", fresh "p")
val jump = fn p => Jump (j, SOME p)
in
Join (j, SOME p, k (Var p), If (c, go t jump, go f jump))
end)
in val convert = fn exp => go exp Halt
end
(* Step 1: write in CPS form. In order for k to be called in tail form, k
itself needs to have a continuation (exp -> exp) passed in for what to do
next with the resulting expression. Then go can thread around its own
(exp -> exp) to pass into k that determines what to do next. *)
local
fun go (exp: L.exp) (k': exp -> exp) (k: value * (exp -> exp) -> exp) : exp =
case exp of
L.Int i => k (Int i, k')
| L.Var v => k (Var v, k')
| L.Lam (v, body) =>
let
val k' = fn body =>
let val f = fresh "f"
in k (Var f, fn rest => k' (Fun (f, [v], body, rest)))
end
in
go body k' (fn (value, k') => k' (Halt value))
end
| L.App (f, x) =>
go f k' @@ (fn (f, k') =>
go x k' @@ (fn (x, k') =>
case f of
Var f =>
let val r = fresh "r"
in k (Var r, fn rest => k' (App (r, f, [x], rest)))
end
| _ => raise Fail "must apply named value"))
| L.Bop (bop, x, y) =>
go x k' @@ (fn (x, k') =>
go y k' @@ (fn (y, k') =>
let val r = fresh "r"
in k (Var r, fn rest => k' (Bop (r, bop, x, y, rest)))
end))
| L.If (c, t, f) =>
go c k' (fn (c, k') =>
let
val (j, p) = (fresh "j", fresh "p")
val jump = fn (v, k') => k' (Jump (j, SOME v))
val go' = fn e => fn f => go e f jump
in
k (Var p, fn rest =>
go' t @@ (fn t =>
go' f @@ (fn f =>
k' (Join (j, SOME p, rest, If (c, t, f))))))
end)
in
val convertCPS: L.exp -> exp = fn e =>
go e (fn a => a) (fn (v, k) => k (Halt v))
end
local
(* Step 2: type definitions
Types to eliminate in the original program:
k' : exp -> exp
k : value * (exp -> exp) -> exp
Create ADTs for these types. *)
(* The data constructors hold the free variables for each anonymous function
with the given type. So every closure passed in that is a (exp -> exp)
type will have its own data constructor in K'. *)
(* k' has free variables:
`f`, `v`, `body`, `k`, and `k'` from Lam case
`r`, `f`, `x`, `k'` from App case
`r`, `bop`, `x`, `y`, `k'` from Bop case
`t`, `f`, `t` (shadowed with type Anf.exp),
`k'`, `j`, `p`, `c` (Anf.value), rest (Anf.exp) from If case *)
datatype K' =
K'_Convert (* Initial fn a => a passed into go *)
| K'_Lam1 of {k': K', k: K, v: string}
| K'_Lam2 of {k': K', f: string, v: string, body: exp}
| K'_App1 of {r: string, f: string, x: value, k': K'}
| K'_Bop1 of {r: string, bop: L.bop, x: value, y: value, k': K'}
| K'_If1 of {t: L.exp, f: L.exp, k': K', j: string, p: string, c: value}
| K'_If2 of {f: L.exp, k': K', j: string, p: string, c: value, rest: exp}
| K'_If3 of {t: exp, k': K', j: string, p: string, c: value, rest: exp}
and K =
K_Lam1
| K_App1 of {x: L.exp, k: K}
| K_App2 of {f: value, k: K}
| K_Bop1 of {y: L.exp, bop: L.bop, k: K}
| K_Bop2 of {x: value, bop: L.bop, k: K}
| K_If1 of {t: L.exp, f: L.exp, k: K}
| K_If2 of {j: string}
(* Step 3, replace calls to k and k' with applyK and applyK'
and anonymous functions with the relevant datatypes.
The parameters passed into the apply functions hold the
parameters when tail calling. *)
fun go (exp: L.exp) (k': K') (k: K) : exp =
case exp of
L.Int i => applyK k (Int i) k'
| L.Var v => applyK k (Var v) k'
| L.Lam (v, body) => go body (K'_Lam1 {k' = k', k = k, v = v}) K_Lam1
| L.App (f, x) => go f k' (K_App1 {x = x, k = k})
| L.Bop (bop, x, y) => go x k' (K_Bop1 {y = y, bop = bop, k = k})
| L.If (c, t, f) => go c k' (K_If1 {t = t, f = f, k = k})
(* Step 4: fill out the apply functions for each closure, calling
`go`, `applyK`, and `applyK'` recursively as needed *)
and applyK' K'_Convert exp = exp
| applyK' (K'_Lam1 {k', k, v}) body =
let val f = fresh "f"
in applyK k (Var f) (K'_Lam2 {k' = k', f = f, v = v, body = body})
end
| applyK' (K'_Lam2 {k', f, v, body}) rest =
applyK' k' (Fun (f, [v], body, rest))
| applyK' (K'_App1 {r, f, x, k'}) rest =
applyK' k' (App (r, f, [x], rest))
| applyK' (K'_Bop1 {r, bop, x, y, k'}) rest =
applyK' k' (Bop (r, bop, x, y, rest))
| applyK' (K'_If1 {t, f, k', j, p, c}) rest =
go t (K'_If2 {f = f, k' = k', j = j, p = p, c = c, rest = rest})
(K_If2 {j = j})
| applyK' (K'_If2 {f, k', j, p, c, rest}) t =
go f (K'_If3 {t = t, k' = k', j = j, p = p, c = c, rest = rest})
(K_If2 {j = j})
| applyK' (K'_If3 {t, k', j, p, c, rest}) f =
applyK' k' (Join (j, SOME p, rest, If (c, t, f)))
and applyK K_Lam1 value k' =
applyK' k' (Halt value)
| applyK (K_App1 {x, k}) f k' =
go x k' (K_App2 {f = f, k = k})
| applyK (K_App2 {f, k}) x k' =
(case f of
Var f =>
let val r = fresh "r"
in applyK k (Var r) (K'_App1 {r = r, f = f, x = x, k' = k'})
end
| _ => raise Fail "must apply named value")
| applyK (K_Bop1 {y, bop, k}) x k' =
go y k' (K_Bop2 {x = x, bop = bop, k = k})
| applyK (K_Bop2 {x, bop, k}) y k' =
let
val r = fresh "r"
in
applyK k (Var r) (K'_Bop1 {r = r, bop = bop, x = x, y = y, k' = k'})
end
| applyK (K_If1 {t, f, k}) c k' =
let
val (j, p) = (fresh "j", fresh "p")
in
applyK k (Var p) (K'_If1
{t = t, f = f, k' = k', j = j, p = p, c = c})
end
| applyK (K_If2 {j}) v k' =
applyK' k' (Jump (j, SOME v))
in val convertDefunc: L.exp -> exp = fn e => go e K'_Convert K_Lam1
end
local
(* Step 5: Identify that K' contains K' inside itself as a free variable
in every data constructor except the first one (K'_Convert).
And K contains K inside itself for every data constructor except
K_Lam1 and K_If2. Because of this we can treat K and K' as a list of
frames where each frame contains the other free variables. If K is
an empty list, then we run the case for K_Lam1, and if K' is empty,
we run the case for K'_Convert.
The reason why we do this is because when we lower to C++, we can
treat this stack of frames as a contiguous std::vector,
which is more efficient than a linked list of pointers. *)
datatype K'Frame =
K'_Lam1 of {k: K, v: string}
| K'_Lam2 of {f: string, v: string, body: exp}
| K'_App1 of {r: string, f: string, x: value}
| K'_Bop1 of {r: string, bop: L.bop, x: value, y: value}
| K'_If1 of {t: L.exp, f: L.exp, j: string, p: string, c: value}
| K'_If2 of {f: L.exp, j: string, p: string, c: value, rest: exp}
| K'_If3 of {t: exp, j: string, p: string, c: value, rest: exp}
and KFrame =
K_App1 of {x: L.exp}
| K_App2 of {f: value}
| K_Bop1 of {y: L.exp, bop: L.bop}
| K_Bop2 of {x: value, bop: L.bop}
| K_If1 of {t: L.exp, f: L.exp}
| K_If2 of {j: string}
withtype K' = K'Frame list
and K = KFrame list
(* If we want to add k' as a free variable to a K' closure, we instead
push the frame with the other free variables to the top of the stack.
If we want to pass an empty K' closure, we mutably clear the K' stack. *)
fun go (exp: L.exp) (k': K') (k: K) : exp =
case exp of
L.Int i => applyK k (Int i) k'
| L.Var v => applyK k (Var v) k'
| L.Lam (v, body) => go body (K'_Lam1 {k = k, v = v} :: k') []
| L.App (f, x) => go f k' (K_App1 {x = x} :: k)
| L.Bop (bop, x, y) => go x k' (K_Bop1 {y = y, bop = bop} :: k)
| L.If (c, t, f) => go c k' (K_If1 {t = t, f = f} :: k)
(* Similarly, instead of unpacking k' as a free variable, we refer to it as the
rest of the stack after popping the topmost element. *)
and applyK' [] (exp: exp) = exp
| applyK' (K'_Lam1 {k, v} :: k') body =
let val f = fresh "f"
in applyK k (Var f) (K'_Lam2 {f = f, v = v, body = body} :: k')
end
| applyK' (K'_Lam2 {f, v, body} :: k') rest =
applyK' k' (Fun (f, [v], body, rest))
| applyK' (K'_App1 {r, f, x} :: k') rest =
applyK' k' (App (r, f, [x], rest))
| applyK' (K'_Bop1 {r, bop, x, y} :: k') rest =
applyK' k' (Bop (r, bop, x, y, rest))
| applyK' (K'_If1 {t, f, j, p, c} :: k') rest =
go t (K'_If2 {f = f, j = j, p = p, c = c, rest = rest} :: k')
[K_If2 {j = j}]
| applyK' (K'_If2 {f, j, p, c, rest} :: k') t =
go f (K'_If3 {t = t, j = j, p = p, c = c, rest = rest} :: k')
[K_If2 {j = j}]
| applyK' (K'_If3 {t, j, p, c, rest} :: k') f =
applyK' k' (Join (j, SOME p, rest, If (c, t, f)))
and applyK [] value k' =
applyK' k' (Halt value)
| applyK (K_App1 {x} :: k) f k' =
go x k' (K_App2 {f = f} :: k)
| applyK (K_App2 {f} :: k) x k' =
(case f of
Var f =>
let val r = fresh "r"
in applyK k (Var r) (K'_App1 {r = r, f = f, x = x} :: k')
end
| _ => raise Fail "must apply named value")
| applyK (K_Bop1 {y, bop} :: k) x k' =
go y k' (K_Bop2 {x = x, bop = bop} :: k)
| applyK (K_Bop2 {x, bop} :: k) y k' =
let val r = fresh "r"
in applyK k (Var r) (K'_Bop1 {r = r, bop = bop, x = x, y = y} :: k')
end
| applyK (K_If1 {t, f} :: k) c k' =
let val (j, p) = (fresh "j", fresh "p")
in applyK k (Var p) (K'_If1 {t = t, f = f, j = j, p = p, c = c} :: k')
end
| applyK (K_If2 {j} :: _) v k' =
applyK' k' (Jump (j, SOME v))
in val convertDefunc': L.exp -> exp = fn e => go e [] []
end
end
local
open Lam
val exp = App
( Lam ("v", Bop (Times, Var "v", If
( Bop (Plus, Var "v", Int 2)
, App (Lam ("x", Bop (Minus, Var "x", Int 1)), Var "v")
, If (Var "v", Int 2, Int 3)
)))
, Int 5
)
in
val anf = Anf.convert exp before Anf.reset ()
val () = print (Anf.showExp anf ^ "\n\n")
val anf = Anf.convertCPS exp before Anf.reset ()
val () = print (Anf.showExp anf ^ "\n\n")
val anf = Anf.convertDefunc exp before Anf.reset ()
val () = print (Anf.showExp anf ^ "\n\n")
val anf = Anf.convertDefunc' exp
val () = print (Anf.showExp anf ^ "\n\n")
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment