Last active
August 14, 2025 06:21
-
-
Save DarinM223/4f8bcc45f526dbbb3fe46bd2eed5433c to your computer and use it in GitHub Desktop.
ANF Conversion where all calls should be in tail position (fully CPS)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| structure Lam = | |
| struct | |
| datatype bop = Add | Sub | Mul | |
| val showBop = fn Add => "Add" | Sub => "Sub" | Mul => "Mul" | |
| datatype exp = Lit of int | Bop of bop * exp * exp | Call of string * exp list | |
| type var = string | |
| val showVar = fn t0 => "\"" ^ t0 ^ "\"" | |
| datatype value = Var of var | Int of int | |
| val showValue = | |
| fn Var t0 => "Var " ^ "(" ^ showVar t0 ^ ")" | |
| | Int t1 => "Int " ^ "(" ^ Int.toString t1 ^ ")" | |
| datatype lexp = | |
| Halt of value | |
| | LetBop of var * bop * value * value * lexp | |
| | LetCall of var * string * value list * lexp | |
| local | |
| val rec lexp = fn lexp_0 => | |
| fn Halt t0 => "Halt " ^ "(" ^ showValue t0 ^ ")" | |
| | LetBop (t1, t2, t3, t4, t5) => | |
| "LetBop " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [showVar t1, showBop t2, showValue t3, showValue t4, lexp_0 t5] ^ ")" | |
| | LetCall (t6, t7, t8, t9) => | |
| "LetCall " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [ showVar t6 | |
| , "\"" ^ t7 ^ "\"" | |
| , "[" ^ String.concatWith ", " (List.map showValue t8) ^ "]" | |
| , lexp_0 t9 | |
| ] ^ ")" | |
| val lexp = fn () => let val rec lexp_0 = fn ? => lexp lexp_0 ? in lexp_0 end | |
| in val showLexp = lexp () | |
| end | |
| local val c = ref 0 | |
| in val fresh = fn p => p ^ Int.toString (!c) before c := !c + 1 | |
| end | |
| infix @@ | |
| fun f @@ a = f a | |
| local | |
| fun go (exp: exp) (k': lexp -> lexp) (k: value * (lexp -> lexp) -> lexp) : | |
| lexp = | |
| case exp of | |
| Lit i => k (Int i, k') | |
| | Bop (bop, l, r) => | |
| go l k' @@ (fn (l, k') => | |
| go r k' @@ (fn (r, k') => | |
| let val tmp = fresh "tmp" | |
| in k (Var tmp, fn rest => k' (LetBop (tmp, bop, l, r, rest))) | |
| end)) | |
| | Call (f, exps) => goChildren f [] exps k' k | |
| and goChildren f vs [] k' k = | |
| let val tmp = fresh "tmp" | |
| in k (Var tmp, fn rest => k' (LetCall (tmp, f, List.rev vs, rest))) | |
| end | |
| | goChildren f vs (e :: es) k' k = | |
| go e k' (fn (v, k') => goChildren f (v :: vs) es k' k) | |
| in val convert = fn exp => go exp (fn a => a) (fn (v, f) => f (Halt v)) | |
| end | |
| end | |
| local open Lam | |
| in | |
| val result = convert (Bop (Add, Bop (Sub, Lit 2, Lit 1), Call | |
| ("f", [Bop (Mul, Lit 3, Call ("h", [Lit 4, Lit 6])), Call ("g", [Lit 5])]))) | |
| val () = print (showLexp result ^ "\n") | |
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| structure Lam = | |
| struct | |
| datatype bop = Plus | Minus | Times | |
| val showBop = fn Plus => "Plus" | Minus => "Minus" | Times => "Times" | |
| type var = string | |
| datatype exp = | |
| Int of int | |
| | Var of var | |
| | Lam of var * exp | |
| | App of exp * exp | |
| | Bop of bop * exp * exp | |
| | If of exp * exp * exp | |
| end | |
| structure Anf = | |
| struct | |
| structure L = Lam | |
| type var = string | |
| val showVar = fn t0 => "\"" ^ t0 ^ "\"" | |
| datatype value = Int of int | Var of var | Glob of var | |
| val showValue = | |
| fn Int t0 => "Int " ^ "(" ^ Int.toString t0 ^ ")" | |
| | Var t1 => "Var " ^ "(" ^ showVar t1 ^ ")" | |
| | Glob t2 => "Glob " ^ "(" ^ showVar t2 ^ ")" | |
| datatype exp = | |
| Halt of value | |
| | Fun of var * var list * exp * exp | |
| | Join of var * var option * exp * exp | |
| | Jump of var * value option | |
| | App of var * var * value list * exp | |
| | Bop of var * L.bop * value * value * exp | |
| | If of value * exp * exp | |
| | Tuple of var * value list * exp | |
| | Proj of var * var * int * exp | |
| (* Generated by smlgen *) | |
| local | |
| fun showOption f (SOME s) = "SOME " ^ f s | |
| | showOption _ NONE = "NONE" | |
| val rec exp = fn exp_0 => | |
| fn Halt t0 => "Halt " ^ "(" ^ showValue t0 ^ ")" | |
| | Fun (t1, t2, t3, t4) => | |
| "Fun " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [ showVar t1 | |
| , "[" ^ String.concatWith ", " (List.map showVar t2) ^ "]" | |
| , exp_0 t3 | |
| , exp_0 t4 | |
| ] ^ ")" | |
| | Join (t5, t6, t7, t8) => | |
| "Join " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [showVar t5, showOption showVar t6, exp_0 t7, exp_0 t8] ^ ")" | |
| | Jump (t9, t10) => | |
| "Jump " ^ "(" | |
| ^ String.concatWith ", " [showVar t9, showOption showValue t10] ^ ")" | |
| | App (t11, t12, t13, t14) => | |
| "App " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [ showVar t11 | |
| , showVar t12 | |
| , "[" ^ String.concatWith ", " (List.map showValue t13) ^ "]" | |
| , exp_0 t14 | |
| ] ^ ")" | |
| | Bop (t15, t16, t17, t18, t19) => | |
| "Bop " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [showVar t15, L.showBop t16, showValue t17, showValue t18, exp_0 t19] | |
| ^ ")" | |
| | If (t20, t21, t22) => | |
| "If " ^ "(" | |
| ^ String.concatWith ", " [showValue t20, exp_0 t21, exp_0 t22] ^ ")" | |
| | Tuple (t23, t24, t25) => | |
| "Tuple " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [ showVar t23 | |
| , "[" ^ String.concatWith ", " (List.map showValue t24) ^ "]" | |
| , exp_0 t25 | |
| ] ^ ")" | |
| | Proj (t26, t27, t28, t29) => | |
| "Proj " ^ "(" | |
| ^ | |
| String.concatWith ", " | |
| [showVar t26, showVar t27, Int.toString t28, exp_0 t29] ^ ")" | |
| val exp = fn () => let val rec exp_0 = fn ? => exp exp_0 ? in exp_0 end | |
| in val showExp = exp () | |
| end | |
| local val c = ref 0 | |
| in | |
| val fresh = fn p => p ^ Int.toString (!c) before c := !c + 1 | |
| val reset = fn () => c := 0 | |
| end | |
| infix @@ | |
| fun f @@ a = f a | |
| (* Start with easiest to write solution using continuations, | |
| but not in full CPS form (k is not called only in tail form) *) | |
| local | |
| fun go exp k = | |
| case exp of | |
| L.Int i => k (Int i) | |
| | L.Var v => k (Var v) | |
| | L.Lam (v, body) => | |
| let | |
| val body = go body Halt | |
| val f = fresh "f" | |
| in | |
| Fun (f, [v], body, k (Var f)) | |
| end | |
| | L.App (f, x) => | |
| go f @@ (fn f => | |
| go x @@ (fn x => | |
| case f of | |
| Var f => let val r = fresh "r" in App (r, f, [x], k (Var r)) end | |
| | _ => raise Fail "must apply named value")) | |
| | L.Bop (bop, x, y) => | |
| go x @@ (fn x => | |
| go y @@ (fn y => | |
| let val r = fresh "r" | |
| in Bop (r, bop, x, y, k (Var r)) | |
| end)) | |
| | L.If (c, t, f) => | |
| go c (fn c => | |
| let | |
| val (j, p) = (fresh "j", fresh "p") | |
| val jump = fn p => Jump (j, SOME p) | |
| in | |
| Join (j, SOME p, k (Var p), If (c, go t jump, go f jump)) | |
| end) | |
| in val convert = fn exp => go exp Halt | |
| end | |
| (* Step 1: write in CPS form. In order for k to be called in tail form, k | |
| itself needs to have a continuation (exp -> exp) passed in for what to do | |
| next with the resulting expression. Then go can thread around its own | |
| (exp -> exp) to pass into k that determines what to do next. *) | |
| local | |
| fun go (exp: L.exp) (k': exp -> exp) (k: value * (exp -> exp) -> exp) : exp = | |
| case exp of | |
| L.Int i => k (Int i, k') | |
| | L.Var v => k (Var v, k') | |
| | L.Lam (v, body) => | |
| let | |
| val k' = fn body => | |
| let val f = fresh "f" | |
| in k (Var f, fn rest => k' (Fun (f, [v], body, rest))) | |
| end | |
| in | |
| go body k' (fn (value, k') => k' (Halt value)) | |
| end | |
| | L.App (f, x) => | |
| go f k' @@ (fn (f, k') => | |
| go x k' @@ (fn (x, k') => | |
| case f of | |
| Var f => | |
| let val r = fresh "r" | |
| in k (Var r, fn rest => k' (App (r, f, [x], rest))) | |
| end | |
| | _ => raise Fail "must apply named value")) | |
| | L.Bop (bop, x, y) => | |
| go x k' @@ (fn (x, k') => | |
| go y k' @@ (fn (y, k') => | |
| let val r = fresh "r" | |
| in k (Var r, fn rest => k' (Bop (r, bop, x, y, rest))) | |
| end)) | |
| | L.If (c, t, f) => | |
| go c k' (fn (c, k') => | |
| let | |
| val (j, p) = (fresh "j", fresh "p") | |
| val jump = fn (v, k') => k' (Jump (j, SOME v)) | |
| val go' = fn e => fn f => go e f jump | |
| in | |
| k (Var p, fn rest => | |
| go' t @@ (fn t => | |
| go' f @@ (fn f => | |
| k' (Join (j, SOME p, rest, If (c, t, f)))))) | |
| end) | |
| in | |
| val convertCPS: L.exp -> exp = fn e => | |
| go e (fn a => a) (fn (v, k) => k (Halt v)) | |
| end | |
| local | |
| (* Step 2: type definitions | |
| Types to eliminate in the original program: | |
| k' : exp -> exp | |
| k : value * (exp -> exp) -> exp | |
| Create ADTs for these types. *) | |
| (* The data constructors hold the free variables for each anonymous function | |
| with the given type. So every closure passed in that is a (exp -> exp) | |
| type will have its own data constructor in K'. *) | |
| (* k' has free variables: | |
| `f`, `v`, `body`, `k`, and `k'` from Lam case | |
| `r`, `f`, `x`, `k'` from App case | |
| `r`, `bop`, `x`, `y`, `k'` from Bop case | |
| `t`, `f`, `t` (shadowed with type Anf.exp), | |
| `k'`, `j`, `p`, `c` (Anf.value), rest (Anf.exp) from If case *) | |
| datatype K' = | |
| K'_Convert (* Initial fn a => a passed into go *) | |
| | K'_Lam1 of {k': K', k: K, v: string} | |
| | K'_Lam2 of {k': K', f: string, v: string, body: exp} | |
| | K'_App1 of {r: string, f: string, x: value, k': K'} | |
| | K'_Bop1 of {r: string, bop: L.bop, x: value, y: value, k': K'} | |
| | K'_If1 of {t: L.exp, f: L.exp, k': K', j: string, p: string, c: value} | |
| | K'_If2 of {f: L.exp, k': K', j: string, p: string, c: value, rest: exp} | |
| | K'_If3 of {t: exp, k': K', j: string, p: string, c: value, rest: exp} | |
| and K = | |
| K_Lam1 | |
| | K_App1 of {x: L.exp, k: K} | |
| | K_App2 of {f: value, k: K} | |
| | K_Bop1 of {y: L.exp, bop: L.bop, k: K} | |
| | K_Bop2 of {x: value, bop: L.bop, k: K} | |
| | K_If1 of {t: L.exp, f: L.exp, k: K} | |
| | K_If2 of {j: string} | |
| (* Step 3, replace calls to k and k' with applyK and applyK' | |
| and anonymous functions with the relevant datatypes. | |
| The parameters passed into the apply functions hold the | |
| parameters when tail calling. *) | |
| fun go (exp: L.exp) (k': K') (k: K) : exp = | |
| case exp of | |
| L.Int i => applyK k (Int i) k' | |
| | L.Var v => applyK k (Var v) k' | |
| | L.Lam (v, body) => go body (K'_Lam1 {k' = k', k = k, v = v}) K_Lam1 | |
| | L.App (f, x) => go f k' (K_App1 {x = x, k = k}) | |
| | L.Bop (bop, x, y) => go x k' (K_Bop1 {y = y, bop = bop, k = k}) | |
| | L.If (c, t, f) => go c k' (K_If1 {t = t, f = f, k = k}) | |
| (* Step 4: fill out the apply functions for each closure, calling | |
| `go`, `applyK`, and `applyK'` recursively as needed *) | |
| and applyK' K'_Convert exp = exp | |
| | applyK' (K'_Lam1 {k', k, v}) body = | |
| let val f = fresh "f" | |
| in applyK k (Var f) (K'_Lam2 {k' = k', f = f, v = v, body = body}) | |
| end | |
| | applyK' (K'_Lam2 {k', f, v, body}) rest = | |
| applyK' k' (Fun (f, [v], body, rest)) | |
| | applyK' (K'_App1 {r, f, x, k'}) rest = | |
| applyK' k' (App (r, f, [x], rest)) | |
| | applyK' (K'_Bop1 {r, bop, x, y, k'}) rest = | |
| applyK' k' (Bop (r, bop, x, y, rest)) | |
| | applyK' (K'_If1 {t, f, k', j, p, c}) rest = | |
| go t (K'_If2 {f = f, k' = k', j = j, p = p, c = c, rest = rest}) | |
| (K_If2 {j = j}) | |
| | applyK' (K'_If2 {f, k', j, p, c, rest}) t = | |
| go f (K'_If3 {t = t, k' = k', j = j, p = p, c = c, rest = rest}) | |
| (K_If2 {j = j}) | |
| | applyK' (K'_If3 {t, k', j, p, c, rest}) f = | |
| applyK' k' (Join (j, SOME p, rest, If (c, t, f))) | |
| and applyK K_Lam1 value k' = | |
| applyK' k' (Halt value) | |
| | applyK (K_App1 {x, k}) f k' = | |
| go x k' (K_App2 {f = f, k = k}) | |
| | applyK (K_App2 {f, k}) x k' = | |
| (case f of | |
| Var f => | |
| let val r = fresh "r" | |
| in applyK k (Var r) (K'_App1 {r = r, f = f, x = x, k' = k'}) | |
| end | |
| | _ => raise Fail "must apply named value") | |
| | applyK (K_Bop1 {y, bop, k}) x k' = | |
| go y k' (K_Bop2 {x = x, bop = bop, k = k}) | |
| | applyK (K_Bop2 {x, bop, k}) y k' = | |
| let | |
| val r = fresh "r" | |
| in | |
| applyK k (Var r) (K'_Bop1 {r = r, bop = bop, x = x, y = y, k' = k'}) | |
| end | |
| | applyK (K_If1 {t, f, k}) c k' = | |
| let | |
| val (j, p) = (fresh "j", fresh "p") | |
| in | |
| applyK k (Var p) (K'_If1 | |
| {t = t, f = f, k' = k', j = j, p = p, c = c}) | |
| end | |
| | applyK (K_If2 {j}) v k' = | |
| applyK' k' (Jump (j, SOME v)) | |
| in val convertDefunc: L.exp -> exp = fn e => go e K'_Convert K_Lam1 | |
| end | |
| local | |
| (* Step 5: Identify that K' contains K' inside itself as a free variable | |
| in every data constructor except the first one (K'_Convert). | |
| And K contains K inside itself for every data constructor except | |
| K_Lam1 and K_If2. Because of this we can treat K and K' as a list of | |
| frames where each frame contains the other free variables. If K is | |
| an empty list, then we run the case for K_Lam1, and if K' is empty, | |
| we run the case for K'_Convert. | |
| The reason why we do this is because when we lower to C++, we can | |
| treat this stack of frames as a contiguous std::vector, | |
| which is more efficient than a linked list of pointers. *) | |
| datatype K'Frame = | |
| K'_Lam1 of {k: K, v: string} | |
| | K'_Lam2 of {f: string, v: string, body: exp} | |
| | K'_App1 of {r: string, f: string, x: value} | |
| | K'_Bop1 of {r: string, bop: L.bop, x: value, y: value} | |
| | K'_If1 of {t: L.exp, f: L.exp, j: string, p: string, c: value} | |
| | K'_If2 of {f: L.exp, j: string, p: string, c: value, rest: exp} | |
| | K'_If3 of {t: exp, j: string, p: string, c: value, rest: exp} | |
| and KFrame = | |
| K_App1 of {x: L.exp} | |
| | K_App2 of {f: value} | |
| | K_Bop1 of {y: L.exp, bop: L.bop} | |
| | K_Bop2 of {x: value, bop: L.bop} | |
| | K_If1 of {t: L.exp, f: L.exp} | |
| | K_If2 of {j: string} | |
| withtype K' = K'Frame list | |
| and K = KFrame list | |
| (* If we want to add k' as a free variable to a K' closure, we instead | |
| push the frame with the other free variables to the top of the stack. | |
| If we want to pass an empty K' closure, we mutably clear the K' stack. *) | |
| fun go (exp: L.exp) (k': K') (k: K) : exp = | |
| case exp of | |
| L.Int i => applyK k (Int i) k' | |
| | L.Var v => applyK k (Var v) k' | |
| | L.Lam (v, body) => go body (K'_Lam1 {k = k, v = v} :: k') [] | |
| | L.App (f, x) => go f k' (K_App1 {x = x} :: k) | |
| | L.Bop (bop, x, y) => go x k' (K_Bop1 {y = y, bop = bop} :: k) | |
| | L.If (c, t, f) => go c k' (K_If1 {t = t, f = f} :: k) | |
| (* Similarly, instead of unpacking k' as a free variable, we refer to it as the | |
| rest of the stack after popping the topmost element. *) | |
| and applyK' [] (exp: exp) = exp | |
| | applyK' (K'_Lam1 {k, v} :: k') body = | |
| let val f = fresh "f" | |
| in applyK k (Var f) (K'_Lam2 {f = f, v = v, body = body} :: k') | |
| end | |
| | applyK' (K'_Lam2 {f, v, body} :: k') rest = | |
| applyK' k' (Fun (f, [v], body, rest)) | |
| | applyK' (K'_App1 {r, f, x} :: k') rest = | |
| applyK' k' (App (r, f, [x], rest)) | |
| | applyK' (K'_Bop1 {r, bop, x, y} :: k') rest = | |
| applyK' k' (Bop (r, bop, x, y, rest)) | |
| | applyK' (K'_If1 {t, f, j, p, c} :: k') rest = | |
| go t (K'_If2 {f = f, j = j, p = p, c = c, rest = rest} :: k') | |
| [K_If2 {j = j}] | |
| | applyK' (K'_If2 {f, j, p, c, rest} :: k') t = | |
| go f (K'_If3 {t = t, j = j, p = p, c = c, rest = rest} :: k') | |
| [K_If2 {j = j}] | |
| | applyK' (K'_If3 {t, j, p, c, rest} :: k') f = | |
| applyK' k' (Join (j, SOME p, rest, If (c, t, f))) | |
| and applyK [] value k' = | |
| applyK' k' (Halt value) | |
| | applyK (K_App1 {x} :: k) f k' = | |
| go x k' (K_App2 {f = f} :: k) | |
| | applyK (K_App2 {f} :: k) x k' = | |
| (case f of | |
| Var f => | |
| let val r = fresh "r" | |
| in applyK k (Var r) (K'_App1 {r = r, f = f, x = x} :: k') | |
| end | |
| | _ => raise Fail "must apply named value") | |
| | applyK (K_Bop1 {y, bop} :: k) x k' = | |
| go y k' (K_Bop2 {x = x, bop = bop} :: k) | |
| | applyK (K_Bop2 {x, bop} :: k) y k' = | |
| let val r = fresh "r" | |
| in applyK k (Var r) (K'_Bop1 {r = r, bop = bop, x = x, y = y} :: k') | |
| end | |
| | applyK (K_If1 {t, f} :: k) c k' = | |
| let val (j, p) = (fresh "j", fresh "p") | |
| in applyK k (Var p) (K'_If1 {t = t, f = f, j = j, p = p, c = c} :: k') | |
| end | |
| | applyK (K_If2 {j} :: _) v k' = | |
| applyK' k' (Jump (j, SOME v)) | |
| in val convertDefunc': L.exp -> exp = fn e => go e [] [] | |
| end | |
| end | |
| local | |
| open Lam | |
| val exp = App | |
| ( Lam ("v", Bop (Times, Var "v", If | |
| ( Bop (Plus, Var "v", Int 2) | |
| , App (Lam ("x", Bop (Minus, Var "x", Int 1)), Var "v") | |
| , If (Var "v", Int 2, Int 3) | |
| ))) | |
| , Int 5 | |
| ) | |
| in | |
| val anf = Anf.convert exp before Anf.reset () | |
| val () = print (Anf.showExp anf ^ "\n\n") | |
| val anf = Anf.convertCPS exp before Anf.reset () | |
| val () = print (Anf.showExp anf ^ "\n\n") | |
| val anf = Anf.convertDefunc exp before Anf.reset () | |
| val () = print (Anf.showExp anf ^ "\n\n") | |
| val anf = Anf.convertDefunc' exp | |
| val () = print (Anf.showExp anf ^ "\n\n") | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment