Skip to content

Instantly share code, notes, and snippets.

@Guest0x0
Created January 14, 2025 11:37
Show Gist options
  • Select an option

  • Save Guest0x0/e0c8c8fd0974bec98aa081a2016d8570 to your computer and use it in GitHub Desktop.

Select an option

Save Guest0x0/e0c8c8fd0974bec98aa081a2016d8570 to your computer and use it in GitHub Desktop.
STLC type inference with type graph optimization via an explicit substitution
type type_name = string
type var = string
type term =
| Atom of { type_name : type_name }
| Var of var
| Lam of var * term
| App of term * term
type tvar_id = int
type typ =
| TVar of tvar_id
| TAtom of string
| TFunc of typ * typ
type tvar_status =
| Unsolved
| Solved of typ
module Subst = Hashtbl
type subst = (tvar_id, tvar_status) Subst.t
(* since we are maintaining a lazy substitution,
before matching on a type, we must "force" it,
and actually perform substitution for solved type variables.
We are performing path compression here
(if [A] is substituted to [B] and [B] is substituted to [C],
when fetching [A], we let [A] point to [C] directly),
so a new substitution is also returned.
*)
let rec get_type subst typ =
match typ with
| TVar tvar_id ->
(match Subst.find subst tvar_id with
| Unsolved -> typ
| Solved typ' ->
let typ' = get_type subst typ' in
(* path compression *)
Subst.replace subst tvar_id (Solved typ');
typ')
| _ -> typ
let rec check_occurence subst tvar typ =
match get_type subst typ with
| TVar tvar' ->
if tvar = tvar'
then failwith "occurence check"
else ()
| TAtom _ -> ()
| TFunc (t1, t2) ->
check_occurence subst tvar t1;
check_occurence subst tvar t2
let rec unify subst ty1 ty2 =
let ty1 = get_type subst ty1 in
let ty2 = get_type subst ty2 in
match ty1, ty2 with
| TVar tv1, TVar tv2 when tv1 = tv2 ->
()
| TVar tvar, typ | typ, TVar tvar ->
check_occurence subst tvar typ;
Subst.replace subst tvar (Solved typ)
| TAtom a1, TAtom a2 when a1 = a2 ->
()
| TFunc (t11, t12), TFunc (t21, t22) ->
unify subst t11 t21;
unify subst t12 t22
| _ -> failwith "type mismatch"
let tvar_id = ref 0
let fresh_tvar subst =
let id = !tvar_id in
incr tvar_id;
Subst.add subst id Unsolved;
id
module Env = Map.Make(struct
type t = var
let compare = String.compare
end)
let rec infer subst env expr =
match expr with
| Atom { type_name } -> TAtom type_name
| Var var ->
(match Env.find_opt var env with
| Some typ -> typ
| None -> failwith "undefined variable")
| Lam (x, body) ->
let t1 = TVar (fresh_tvar subst) in
let t2 = infer subst (Env.add x t1 env) body in
TFunc (t1, t2)
| App (f, a) ->
let f_typ = infer subst env f in
let a_typ = infer subst env a in
let ty_result = TVar (fresh_tvar subst) in
unify subst f_typ (TFunc (a_typ, ty_result));
ty_result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment