Skip to content

Instantly share code, notes, and snippets.

@hsqStephenZhang
Last active November 19, 2025 13:08
Show Gist options
  • Select an option

  • Save hsqStephenZhang/d69d2bbc10b176476f64fd8bc9bb3835 to your computer and use it in GitHub Desktop.

Select an option

Save hsqStephenZhang/d69d2bbc10b176476f64fd8bc9bb3835 to your computer and use it in GitHub Desktop.
hm type inference by algo_w
#![allow(unused)]
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug)]
pub enum Expr {
Bool(bool),
Int(i32),
// x
Var(String),
// λ x. e
Lam {
x: String,
e: Box<Expr>,
},
// e1 e2
App {
e1: Box<Expr>,
e2: Box<Expr>,
},
// e1 + e2
Binop {
e1: Box<Expr>,
e2: Box<Expr>,
},
// let x = e1 in e2
Let {
x: String,
e1: Box<Expr>,
e2: Box<Expr>,
},
// if cond then e1 else e2
IfThenElse {
cond: Box<Expr>,
then_branch: Box<Expr>,
else_branch: Box<Expr>,
},
}
impl std::fmt::Display for Expr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expr::Int(n) => write!(f, "{}", n),
Expr::Var(x) => write!(f, "{}", x),
Expr::Lam { x, e } => write!(f, "(λ {}. {})", x, e),
Expr::App { e1, e2 } => write!(f, "({} {})", e1, e2),
Expr::Binop { e1, e2 } => write!(f, "({} + {})", e1, e2),
Expr::Let { x, e1, e2 } => write!(f, "(let {} = {} in {})", x, e1, e2),
Expr::Bool(b) => write!(f, "{}", b),
Expr::IfThenElse {
cond,
then_branch,
else_branch,
} => {
write!(f, "(if {} then {} else {})", cond, then_branch, else_branch)
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct TypeVar(usize);
#[derive(Clone, Debug)]
pub enum Type {
Bool,
Int,
TypeVar(TypeVar),
Func(Box<Type>, Box<Type>),
}
impl Type {
// parallel substitution
pub fn apply(&self, subst: &Subst) -> Type {
match self {
Type::Bool => Type::Bool,
Type::Int => Type::Int,
Type::TypeVar(type_var) => {
if let Some(ty) = subst.as_ref().get(type_var) {
ty.clone()
} else {
Type::TypeVar(type_var.clone())
}
}
Type::Func(t1, t2) => Type::Func(Box::new(t1.apply(subst)), Box::new(t2.apply(subst))),
}
}
pub fn free_type_vars(&self) -> HashSet<TypeVar> {
match self {
Type::Bool | Type::Int => HashSet::new(),
Type::TypeVar(type_var) => HashSet::from([*type_var]),
Type::Func(t1, t2) => {
let mut res = t1.free_type_vars();
res.extend(t2.free_type_vars());
res
}
}
}
}
impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Bool => write!(f, "Bool"),
Type::Int => write!(f, "Int"),
Type::TypeVar(TypeVar(n)) => write!(f, "α{}", n),
Type::Func(t1, t2) => write!(f, "({} -> {})", t1, t2),
}
}
}
// forall α1 ... αn . τ
#[derive(Clone, Debug)]
pub struct TypeSchema {
ty_vars: Vec<TypeVar>,
ty: Type,
}
impl TypeSchema {
pub fn new(ty_vars: Vec<TypeVar>, ty: Type) -> Self {
Self { ty_vars, ty }
}
// parallel substitution
pub fn apply(&self, subst: &Subst) -> TypeSchema {
let mut restricted_subst = subst.inner.clone();
for bound_var in &self.ty_vars {
restricted_subst.remove(bound_var);
}
let s_prime = Subst {
inner: restricted_subst,
};
TypeSchema {
ty_vars: self.ty_vars.clone(),
ty: self.ty.apply(&s_prime),
}
}
pub fn free_type_vars(&self) -> HashSet<TypeVar> {
let mut free_vars = self.ty.free_type_vars();
for bounded_var in &self.ty_vars {
free_vars.remove(bounded_var);
}
free_vars
}
}
// expr with type annotation
// you could also see this pattern in rustc's source code
#[derive(Clone, Debug)]
pub struct TypedExpr {
pub ty: Type,
pub kind: TypedExprKind,
}
impl TypedExpr {
pub fn new(ty: Type, kind: TypedExprKind) -> Self {
Self { ty, kind }
}
pub fn apply(self, subst: &Subst) -> Self {
let ty = self.ty.apply(subst);
let kind = self.kind;
let kind = match kind {
TypedExprKind::Bool(i) => TypedExprKind::Bool(i),
TypedExprKind::Int(i) => TypedExprKind::Int(i),
TypedExprKind::Var(x) => TypedExprKind::Var(x),
TypedExprKind::Lam { x, e } => {
let e = e.apply(subst);
TypedExprKind::Lam { x, e: Box::new(e) }
}
TypedExprKind::App { e1, e2 } => {
let e1 = e1.apply(subst);
let e2 = e2.apply(subst);
TypedExprKind::App {
e1: Box::new(e1),
e2: Box::new(e2),
}
}
TypedExprKind::Binop { e1, e2 } => {
let e1 = e1.apply(subst);
let e2 = e2.apply(subst);
TypedExprKind::Binop {
e1: Box::new(e1),
e2: Box::new(e2),
}
}
TypedExprKind::Let { x, e1, e2 } => {
let e1 = e1.apply(subst);
let e2 = e2.apply(subst);
TypedExprKind::Let {
x,
e1: Box::new(e1),
e2: Box::new(e2),
}
}
TypedExprKind::IfThenElse {
cond,
then_branch,
else_branch,
} => {
let cond = cond.apply(subst);
let then_branch = then_branch.apply(subst);
let else_branch = else_branch.apply(subst);
TypedExprKind::IfThenElse {
cond: Box::new(cond),
then_branch: Box::new(then_branch),
else_branch: Box::new(else_branch),
}
}
};
TypedExpr { ty, kind }
}
}
impl std::fmt::Display for TypedExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {
TypedExprKind::Bool(n) => write!(f, "{}", n),
TypedExprKind::Int(n) => write!(f, "{}", n),
TypedExprKind::Var(x) => write!(f, "{}", x),
TypedExprKind::Lam { x, e } => write!(f, "(λ {}. {})", x, e),
TypedExprKind::App { e1, e2 } => write!(f, "({} {})", e1, e2),
TypedExprKind::Binop { e1, e2 } => write!(f, "({} + {})", e1, e2),
TypedExprKind::Let { x, e1, e2 } => write!(f, "(let {} = {} in {})", x, e1, e2),
TypedExprKind::IfThenElse {
cond,
then_branch,
else_branch,
} => write!(f, "(if {} then {} else {})", cond, then_branch, else_branch),
}?;
write!(f, ": {}", self.ty)
}
}
#[derive(Clone, Debug)]
pub enum TypedExprKind {
Bool(bool),
Int(i32),
Var(String),
Lam {
x: String,
e: Box<TypedExpr>,
},
App {
e1: Box<TypedExpr>,
e2: Box<TypedExpr>,
},
Binop {
e1: Box<TypedExpr>,
e2: Box<TypedExpr>,
},
Let {
x: String,
e1: Box<TypedExpr>,
e2: Box<TypedExpr>,
},
IfThenElse {
cond: Box<TypedExpr>,
then_branch: Box<TypedExpr>,
else_branch: Box<TypedExpr>,
},
}
#[derive(Clone, Debug)]
pub struct TypeContext {
inner: HashMap<String, TypeSchema>,
}
impl TypeContext {
pub fn new() -> Self {
Self {
inner: Default::default(),
}
}
// extend a type context in
pub fn extend(&mut self, x: String, ty: TypeSchema) {
self.inner.insert(x, ty);
}
// apply subst for each var's type
pub fn apply_subst(&self, subst: &Subst) -> Self {
let mut new_ctx = self.clone();
for (k, v) in new_ctx.inner.iter_mut() {
let new_schema = v.apply(subst);
*v = new_schema;
}
new_ctx
}
pub fn free_type_vars(&self) -> HashSet<TypeVar> {
self.inner
.values()
.map(|s| s.free_type_vars())
.flatten()
.collect::<HashSet<_>>()
}
}
impl AsRef<HashMap<String, TypeSchema>> for TypeContext {
fn as_ref(&self) -> &HashMap<String, TypeSchema> {
&self.inner
}
}
#[derive(Clone, Debug)]
pub struct Subst {
inner: HashMap<TypeVar, Type>,
}
impl AsRef<HashMap<TypeVar, Type>> for Subst {
fn as_ref(&self) -> &HashMap<TypeVar, Type> {
&self.inner
}
}
impl Subst {
pub fn new() -> Self {
Self {
inner: Default::default(),
}
}
// S2 ◦ S1
// S2 is the new solved one, replace all the value in S1 with the new
pub fn compose(&mut self, other: &Self) -> Self {
let mut res = HashMap::new();
let mut s1 = self.inner.clone();
for (k, v) in s1.into_iter() {
let v = match other.as_ref().get(&k) {
Some(v) => v.apply(other),
None => v,
};
res.insert(k, v);
}
for (k, v) in &other.inner {
if !res.contains_key(k) {
res.insert(k.clone(), v.clone());
}
}
Subst { inner: res }
}
}
#[derive(Clone, Debug, Default)]
struct NumberAllocator {
cnt: usize,
}
impl NumberAllocator {
pub fn next(&mut self) -> TypeVar {
let res = self.cnt;
self.cnt += 1;
TypeVar(res)
}
}
#[derive(Clone, Debug)]
pub struct AlgoW {
fresh_type_var_gen: NumberAllocator,
}
impl AlgoW {
pub fn new() -> Self {
Self {
fresh_type_var_gen: Default::default(),
}
}
pub fn run(&mut self, mut ctx: TypeContext, e: &Expr) -> (Subst, TypedExpr) {
match e {
// primitive types
// will just return empty set as subst, and the original expr with the type
Expr::Bool(b) => (
Subst::new(),
TypedExpr::new(Type::Bool, TypedExprKind::Bool(*b)),
),
// same as bool
Expr::Int(i) => (
Subst::new(),
TypedExpr::new(Type::Int, TypedExprKind::Int(*i)),
),
// variable
// we require the env to be closed so x must be in ctx
// for get the schema of x and instantiate it with fresh type vars
// instantiate could be understood as calling a function with type vars as arguments
Expr::Var(x) => {
let schema = ctx.as_ref()[x].clone();
let ty = self.specilization(schema);
(
Subst::new(),
TypedExpr::new(ty, TypedExprKind::Var(x.clone())),
)
}
// lambda abstraction
// generate fresh type var `fresh_var` for the argument
// extend the context with the argument and its type schema
// run algoW on the body, suppose its type is `body_ty`
// the type of the lambda is `fresh_var -> body_ty`
Expr::Lam { x, e } => {
let fresh_var = self.fresh_var();
let ty = Type::TypeVar(fresh_var);
let schema = TypeSchema::new(Default::default(), ty.clone());
let new_ctx = {
let mut ctx = ctx.clone();
ctx.extend(x.clone(), schema);
ctx
};
let (subst1, expr1) = self.run(new_ctx, &*e);
let func_ty = Type::Func(Box::new(ty.apply(&subst1)), Box::new(expr1.ty.clone()));
(
subst1,
TypedExpr::new(
func_ty,
TypedExprKind::Lam {
x: x.clone(),
e: Box::new(expr1),
},
),
)
}
// lambda application
// run algoW on e1, get its type `expr1.ty`
// run algoW on e2 under the updated context, get its type `expr2.ty`
// generate fresh type var `fresh_var` for the result type
// we have the relationship that `expr1.ty` = `expr2.ty -> fresh_var`
// unify(solve the equations) to get the substitution `subst3`, which contains the solution for `fresh_var`
// and that's the type of the application
Expr::App { e1, e2 } => {
let (mut subst1, expr1) = self.run(ctx.clone(), &*e1);
let ctx2 = ctx.apply_subst(&subst1);
let (subst2, expr2) = self.run(ctx2, &*e2);
// solve the equations of
let fresh_var = self.fresh_var();
let fn_ty = expr1.ty.apply(&subst2);
let subst3 = unify(
fn_ty,
Type::Func(
Box::new(expr2.ty.clone()),
Box::new(Type::TypeVar(fresh_var)),
),
);
let subst = subst1.compose(&subst2).compose(&subst3);
let ty = Type::TypeVar(fresh_var).apply(&subst);
(
subst,
TypedExpr::new(
ty,
TypedExprKind::App {
e1: Box::new(expr1),
e2: Box::new(expr2),
},
),
)
}
// let polymorphism
// the idea is that when we use e1 as a polymophisic type, we first generalized it and
// add the binding to typecontext
// then later in e2, when e2 use e1 as binding to x, we will instantiate it with fresh type var
// so for polymophism, the introduction is `let`, the elimination is `var`
Expr::Let { x, e1, e2 } => {
let (mut subst1, expr1) = self.run(ctx.clone(), &*e1);
let mut ctx2 = ctx.apply_subst(&subst1);
let generalized = Self::generalize(expr1.ty.clone(), &ctx2);
ctx2.extend(x.clone(), generalized);
let (subst2, expr2) = self.run(ctx2, &*e2);
let subst = subst1.compose(&subst2);
let ty = expr2.ty.clone();
(
subst,
TypedExpr::new(
ty,
TypedExprKind::Let {
x: x.clone(),
e1: Box::new(expr1),
e2: Box::new(expr2),
},
),
)
}
// binary operation
// it's pretty simply compared with application
Expr::Binop { e1, e2 } => {
let (mut subst1, expr1) = self.run(ctx.clone(), &*e1);
let ctx2 = ctx.apply_subst(&subst1);
let (subst2, expr2) = self.run(ctx2, &*e2);
let subst3 = unify(expr1.ty.apply(&subst2), Type::Int);
let subst4 = unify(expr2.ty.apply(&subst3), Type::Int);
let subst = subst1.compose(&subst2).compose(&subst3).compose(&subst4);
(
subst,
TypedExpr::new(
Type::Int,
TypedExprKind::Binop {
e1: Box::new(expr1),
e2: Box::new(expr2),
},
),
)
}
// if then else
// the unification process is
Expr::IfThenElse {
cond,
then_branch,
else_branch,
} => {
let (mut subst1, expr1) = self.run(ctx.clone(), cond);
let subst_cond = unify(expr1.ty.apply(&subst1), Type::Bool);
let mut subst1 = subst1.compose(&subst_cond);
let ctx1 = ctx.apply_subst(&subst1);
let (subst2, expr2) = self.run(ctx1.clone(), &*then_branch);
let ctx2 = ctx1.apply_subst(&subst2);
let (subst3, expr3) = self.run(ctx2, &*else_branch);
let subst_then = unify(expr2.ty.apply(&subst3), expr3.ty.apply(&subst3));
let subst = subst1
.compose(&subst2)
.compose(&subst3)
.compose(&subst_then);
let ty = expr2.ty.apply(&subst);
(
subst,
TypedExpr::new(
ty,
TypedExprKind::IfThenElse {
cond: Box::new(expr1),
then_branch: Box::new(expr2),
else_branch: Box::new(expr3),
},
),
)
}
}
}
fn fresh_var(&mut self) -> TypeVar {
self.fresh_type_var_gen.next()
}
// or bigapp
// for schema like forall α1 ... αn . τ
// generate fresh typevar beta-i for each alpha-i
// then apply substitution on `τ`
fn specilization(&mut self, schema: TypeSchema) -> Type {
let subst = (schema.ty_vars)
.iter()
.map(|ty_var| (ty_var.clone(), Type::TypeVar(self.fresh_var())))
.collect::<HashMap<_, _>>();
let subst = Subst { inner: subst };
schema.ty.apply(&subst)
}
// or biglam
fn generalize(ty: Type, ctx: &TypeContext) -> TypeSchema {
let s1 = ty.free_type_vars();
let s2 = ctx.free_type_vars();
let ty_vars = s1.difference(&s2).cloned().collect::<Vec<_>>();
TypeSchema { ty, ty_vars }
}
}
// syntactically comparison and generate substitution/equation
fn unify(t1: Type, t2: Type) -> Subst {
let mut subst = Subst::new();
match (t1, t2) {
(Type::Bool, Type::Bool) | (Type::Int, Type::Int) => subst,
(Type::TypeVar(tv), ty) | (ty, Type::TypeVar(tv)) => {
subst.inner.insert(tv, ty);
subst
}
(Type::Func(t1a, t1b), Type::Func(t2a, t2b)) => {
let mut s1 = unify(*t1a, *t2a);
let t1b_applied = t1b.apply(&s1);
let t2b_applied = t2b.apply(&s1);
let s2 = unify(t1b_applied, t2b_applied);
s1.compose(&s2)
}
_ => panic!("unification failed"),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run(e: &Expr) {
let mut algo_w = AlgoW::new();
let ctx = TypeContext::new();
let (subst, expr) = algo_w.run(ctx, e);
let expr = expr.apply(&subst);
// println!("subst: {:?}", subst);
println!("expr: {}", expr);
}
#[test]
fn t0() {
// (lam x. x) 1
let exp = Expr::App {
e1: Box::new(Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Var("x".to_string())),
}),
e2: Box::new(Expr::Int(1)),
};
run(&exp);
}
#[test]
fn t1() {
// lam x. x
let id = Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Var("x".to_string())),
};
run(&id);
// lam x. lam y. x
let true_exp = Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Lam {
x: "y".to_string(),
e: Box::new(Expr::Var("x".to_string())),
}),
};
run(&true_exp);
// lam x. lam y. y
let false_expr = Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Lam {
x: "y".to_string(),
e: Box::new(Expr::Var("y".to_string())),
}),
};
run(&false_expr);
// true_exp applied to 1 0
let choose_1 = Expr::App {
e1: Box::new(Expr::App {
e1: Box::new(true_exp),
e2: Box::new(Expr::Int(1)),
}),
e2: Box::new(Expr::Int(0)),
};
run(&choose_1);
}
#[test]
fn t2() {
// (λx. x) (λy. y)
let exp1 = Expr::App {
e1: Box::new(Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Var("x".to_string())),
}),
e2: Box::new(Expr::Lam {
x: "y".to_string(),
e: Box::new(Expr::Var("y".to_string())),
}),
};
run(&exp1);
// λf. λx. f x
let exp2 = Expr::Lam {
x: "f".to_string(),
e: Box::new(Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::App {
e1: Box::new(Expr::Var("f".to_string())),
e2: Box::new(Expr::Var("x".to_string())),
}),
}),
};
run(&exp2);
}
#[test]
fn t3() {
// lam x. x + 1
let exp = Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Binop {
e1: Box::new(Expr::Var("x".to_string())),
e2: Box::new(Expr::Int(1)),
}),
};
run(&exp);
}
#[test]
fn test_let1() {
// let id = λx. x in id
let exp = Expr::Let {
x: "id".to_string(),
e1: Box::new(Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Var("x".to_string())),
}),
e2: Box::new(Expr::Var("id".to_string())),
};
run(&exp);
// let id = λx. x in (id (lam x. x)) (id 3)
let exp2 = Expr::Let {
x: "id".to_string(),
e1: Box::new(Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Var("x".to_string())),
}),
e2: Box::new(Expr::App {
e1: Box::new(Expr::Var("id".to_string())),
e2: Box::new(Expr::Int(3)),
}),
};
run(&exp2);
}
#[test]
fn test_let2() {
// let id = λx. x in ( if (id true) (id 1) (id 0) )
let exp = Expr::Let {
x: "id".to_string(),
e1: Box::new(Expr::Lam {
x: "x".to_string(),
e: Box::new(Expr::Var("x".to_string())),
}),
e2: Box::new(Expr::IfThenElse {
cond: Box::new(Expr::App {
e1: Box::new(Expr::Var("id".to_string())),
e2: Box::new(Expr::Bool(true)),
}),
then_branch: Box::new(Expr::App {
e1: Box::new(Expr::Var("id".to_string())),
e2: Box::new(Expr::Int(1)),
}),
else_branch: Box::new(Expr::App {
e1: Box::new(Expr::Var("id".to_string())),
e2: Box::new(Expr::Int(0)),
}),
}),
};
run(&exp);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment