(****************************************************************************** * SNU 4190.310 Programming Languages (Fall 2006) * * HW 6: M, static type inference ******************************************************************************) (****************************************************************************** * Exercise 1. * M, static type inference ******************************************************************************) // stringifier for printing better type error message structure M_String = struct open M val obuf = Buffer.create 80 val ps = Buffer.add_string obuf val pi = fn i => ps (string_of_int i) val nl = fn () => ps "\n" fun indent i = let fun it 0 = () | it n = ps " "; it (n-1) in nl (); it i end fun pp n = fn CONST (S s) => ps s | CONST (N m) => pi m | CONST (B true) => ps "true" | CONST (B false) => ps "false" | VAR s => ps s | FN (x, e) => ps ("fn "^x^" => "); ( case e of FN _ => pp (n+1) e | _ => indent (n+1); pp (n+1) e ) | APP (e, e') => pp n e; ps " "; pp n e' | IF (e1, e2, e3)=> ps "if "; pp n e1; ps " then "; indent (n+1); pp (n+1) e2; indent (n); ps "else"; indent (n+1); pp (n+1) e3 | READ => ps "read " | WRITE (e) => ps "write("; pp n e; ps ")" | LET (d, e) => let fun sugaring LET (d, LET (d', e)) acc = sugaring LET (d', e) (d::acc) | sugaring LET (d, e) acc = (List.rev (d::acc), e) | sugaring _ _ = raise Invalid_argument "impossible" val (decls, body) = sugaring LET (d, e) [] in ps "let "; List.iter (fn x => (indent(n+1); printDecl (n+1) x)) decls; indent n; ps "in"; indent (n+1); pp (n+1) body; indent n; ps "end" end | MALLOC e => ps "malloc "; pp (n+1) e | ASSIGN (e, e') => pp n e; ps " := "; pp n e' | BANG e => ps "!"; pp n e | SEQ (e, e') => pp n e; ps ";"; indent n; pp n e' | PAIR (e1, e2) => ps "("; pp n e1; ps ", "; pp n e2; ps ")" | SEL1 e => pp n e; ps ".1" | SEL2 e => pp n e; ps ".2" | BOP (op, e1, e2) => ps "("; pp n e1; ps (" "^(case op of ADD => "+" | SUB => "-" | EQ => "=" | AND => "and" | OR => "or")^" "); pp n e2; ps ")" and printDecl n = fn NREC (x, e) => ps "val "; ps (x^" = "); pp (n+1) e | REC (x, e) => ps ("rec "^x^" = "); pp (n+1) e fun pp_type ty = case ty of TyInt => ps "int" | TyBool => ps "bool" | TyString => ps "string" | TyPair (tau1, tau2) => ps "("; pp_type tau1; ps " , "; pp_type tau2; ps ")" | TyLoc tau1 => ps "loc("; pp_type tau1; ps ")" | TyArrow (tau1, tau2) => ps "("; pp_type tau1; ps ")" ; ps "->"; ps "("; pp_type tau2; ps ")" val print = pp 0 fun printTypes ty = pp_type ty; nl() fun string_of e = Buffer.reset obuf; print e; Buffer.contents obuf fun string_of_types t = Buffer.reset obuf; printTypes t; Buffer.contents obuf end structure M_Checker : M_TypeChecker = struct open M /// definition of types part type ty = TVar of int | TInt | TString | TBool | TPair of ty * ty | TLoc of ty | TArrow of ty * ty val next_tvar = ref 0 fun newt() = next_tvar++; TVar !next_tvar val rec string_of_types = fn TInt => "int" | TString => "string" | TBool => "bool" | TVar x => "'" ^ string_of_int x | TPair (t1, t2) => "(" ^ string_of_types t1 ^ " * " ^ string_of_types t2 ^ ")" | TLoc t => "loc " ^ string_of_types t | TArrow (t1, t2) => "(" ^ string_of_types t1 ^ "->" ^ string_of_types t2 ^ ")" val rec programType = fn TInt => TyInt | TString => TyString | TBool => TyBool | TPair (t1, t2) => TyPair (programType t1, programType t2) | TLoc t => TyLoc (programType t) | t => raise TypeError ("Invalid program type: " ^ string_of_types t) // primitive operators, values fun ( @@ ) g f = (fn t => g (f t)) fun ( @+ ) g (x, t) = fn y => if y = x then t else g y val emptyG = fn x => raise TypeError ("Unknown id: " ^ x) val id = (fn x => x) /// substitution, unification part // subst: int -> ty -> subst fun subst x tau = let fun s t = case t of TVar y => if y = x then tau else t | TPair (t1, t2) => TPair (s t1, s t2) | TArrow (t1, t2) => TArrow (s t1, s t2) | TLoc t' => TLoc (s t') | TInt | TString | TBool => t in s end // occurs: int -> ty -> bool fun occurs x tau = case tau of TVar y => if y = x then true else false | TPair (t1, t2) | TArrow (t1, t2) => occurs x t1 orelse occurs x t2 | TLoc t => occurs x t | _ => false // unify: ty -> ty -> subst fun unify TVar x tau = if TVar x = tau then id else if not occurs x tau then subst x tau else raise TypeError ("infinite type: " ^ string_of_types TVar x ^ " = " ^ string_of_types tau) | unify tau TVar x = unify TVar x tau | unify TPair p1 TPair p2 = unifypair p1 p2 | unify TArrow p1 TArrow p2 = unifypair p1 p2 | unify TLoc t TLoc t' = unify t t' | unify tau tau' = if tau = tau' then id else raise TypeError ("mismatch between " ^ string_of_types tau ^ " and " ^ string_of_types tau') and unifypair (t1, t2) (t1', t2') = let val s = unify t1 t1' val s' = unify (s t2) (s t2') in s' @@ s end /// constraint derivation and solving part type constraint = U of exp * ty * ty | Or of exp * ty * ty list exception TypeError' of string * exp exception AbortCheck of string * exp // v: (id -> ty) -> exp -> ty -> constraint list -> constraint list fun v g exp tau = let fun u tau tau' = (fn c => U (exp, tau, tau') :: c) fun or tau taul = (fn c => Or (exp, tau, taul) :: c) in case exp of CONST (S _) => u tau TString | CONST (N _) => u tau TInt | CONST (B _) => u tau TBool | VAR x => u tau (g x) | FN (x, e) => let val t1 = newt() val t2 = newt() in u tau TArrow (t1, t2) @@ v (g @+ (x, t1)) e t2 end | APP (e1, e2) => let val t1 = newt() in v g e1 TArrow (t1, tau) @@ v g e2 t1 end | LET (NREC (x, e1), e2) => let val t1 = newt() val g' = g @+ (x, t1) in v g' e2 tau @@ v g e1 t1 end | LET (REC (x, e1), e2) => let val t1 = newt() val g' = g @+ (x, t1) in v g' e2 tau @@ v g' e1 t1 end | IF (e1, e2, e3) => v g e3 tau @@ v g e2 tau @@ v g e1 TBool | BOP (op, e1, e2) => let val (t, t') = case op of ADD | SUB => (TInt, TInt) | AND | OR => (TBool, TBool) | EQ => (newt(), TBool) in u tau t' @@ or t [TInt, TString, TBool, TLoc (newt())] @@ v g e2 t @@ v g e1 t end | READ => u tau TInt | WRITE e => or tau [TInt, TString, TBool] @@ v g e tau | MALLOC e => let val t = newt() in u tau TLoc t @@ v g e t end | ASSIGN (e1, e2) => v g e2 tau @@ v g e1 TLoc tau | BANG e => v g e TLoc tau | SEQ (e1, e2) => v g e2 tau @@ v g e1 (newt()) | PAIR (e1, e2) => let val t1 = newt() val t2 = newt() in u tau TPair (t1, t2) @@ v g e2 t2 @@ v g e1 t1 end | SEL1 e => v g e TPair (tau, newt()) | SEL2 e => v g e TPair (newt(), tau) end // c2s: subst -> constraint list -> subst fun c2s s [] = s | c2s s (U (e, t1, t2)::c) = c2s ((unify (s t1) (s t2) @@ s) handle TypeError msg => raise TypeError' (msg, e)) c | c2s s (Or (e, t1, tl)::c) = let fun try [] = raise AbortCheck ("none matches " ^ string_of_types (s t1) ^ " among " ^ String.concat ", " (List.map (string_of_types @@ s) tl), e) | try (h::t) = c2s s (U (e, t1, h)::c) handle TypeError' _ => try t in try tl end // reorder: constraint list -> constraint list fun reorder c = let val (u, or) = List.partition (fn Or _ => false | _ => true) c in u @ or end // check: exp -> types fun check exp = let val tau = newt() in programType ((c2s id (reorder (v emptyG exp tau []))) tau) end handle TypeError' (msg, e) | AbortCheck (msg, e) => raise TypeError ("For `" ^ M_String.string_of e ^ "', " ^ msg ^ ".") end structure M_LowFat : M_Runner = struct open M // domains type loc = int type value = Num of int | String of string | Bool of bool | Loc of loc | Pair of value * value | Closure of closure and closure = fexpr * env and fexpr = Fun of id * exp | RecFun of id * id * exp and env = id -> value type memory = int * (loc -> value) fun (@+) f (x, v) = (fn y => if y = x then v else f y) fun store (l, m) p = (l, m @+ p) fun fetch (_, m) l = m l fun malloc (l, m) = (l, (l+1, m)) // auxiliary functions fun error msg = raise RuntimeError msg val op2fn = fn ADD => (fn (Num n1,Num n2) => Num (n1 + n2)) | SUB => (fn (Num n1,Num n2) => Num (n1 - n2)) | AND => (fn (Bool b1,Bool b2) => Bool (b1 andalso b2)) | OR => (fn (Bool b1,Bool b2) => Bool (b1 orelse b2)) | EQ => (fn (v1,v2) => Bool (v1 = v2)) val rec printValue = fn Num n => print_int n; print_newline() | Bool b => print_endline (if b then "true" else "false") | String s => print_endline s fun eval env mem exp = case exp of CONST c => (case c of S s => String s | N n => Num n | B b => Bool b, mem) | VAR x => (env x, mem) | FN (x, e) => (Closure (Fun (x, e), env), mem) | APP (e1, e2) => let val (v1, m') = eval env mem e1 val Closure (c, env') = v1 val (v2, m'') = eval env m' e2 in case c of Fun (x, e) => eval (env' @+ (x,v2)) m'' e | RecFun (f, x, e) => eval ((env' @+ (x,v2)) @+ (f,v1)) m'' e end | LET (NREC (x, e1), e2) => let val (v1, m') = eval env mem e1 in eval (env @+ (x,v1)) m' e2 end | LET (REC (f, e1), e2) => let val (v1, m') = eval env mem e1 val Closure (c, env') = v1 in case c of Fun (x, e) => eval (env @+ (f, Closure (RecFun (f, x, e), env'))) m' e2 | _ => raise Invalid_argument "redundant let rec" end | IF (e1, e2, e3) => let val (Bool b, m') = eval env mem e1 in eval env m' (if b then e2 else e3) end | BOP (op, e1, e2) => let val (v1, m') = eval env mem e1 val (v2, m'') = eval env m' e2 in ((op2fn op) (v1,v2), m'') end | READ => let val n = read_int () handle _ => error "read error" in (Num n, mem) end | WRITE e => let val (v, m') = eval env mem e in printValue v; (v, m') end | MALLOC e => let val (v, m') = eval env mem e val (l, m'') = malloc m' in (Loc l, store m'' (l,v)) end | ASSIGN (e1, e2) => let val (v1, m') = eval env mem e1 val Loc l = v1 val (v, m'') = eval env m' e2 in (v, store m'' (l,v)) end | BANG e => let val (Loc l, m') = eval env mem e in (fetch m' l, m') end | SEQ (e1, e2) => let val (v1, m1) = eval env mem e1 in eval env m1 e2 end | PAIR (e1, e2) => let val (v1, m1) = eval env mem e1 val (v2, m2) = eval env m1 e2 in (Pair (v1, v2), m2) end | SEL1 e | SEL2 e => let val (Pair p, m') = eval env mem e in ((if exp = SEL1 e then fst else snd) p, m') end val emptyEnv = (fn x => raise Invalid_argument ("unknown id: " ^ x)) val emptyMem = (0, fn l => raise Invalid_argument ("unknown loc: " ^ string_of_int l)) fun run exp = ignore (eval emptyEnv emptyMem exp) end structure M_Vanilla : M_Runner = struct open M // domains type loc = int type value = Num of int | String of string | Bool of bool | Loc of loc | Pair of value * value | Closure of closure and closure = fexpr * env and fexpr = Fun of id * exp | RecFun of id * id * exp and env = id -> value type memory = int * (loc -> value) fun (@+) f (x, v) = (fn y => if y = x then v else f y) fun store (l, m) p = (l, m @+ p) fun fetch (_, m) l = m l fun malloc (l, m) = (l, (l+1, m)) // auxiliary functions fun error msg = raise RuntimeError msg val getClosure = fn Closure c => c | _ => error "not a function" val getNum = fn Num n => n | _ => error "not a number value" val getBool = fn Bool b => b | _ => error "not a boolean value" val getLoc = fn Loc l => l | _ => error "not a location value" val getPair = fn Pair p => p | _ => error "not a pair" val op2fn = fn ADD => (fn (v1,v2) => Num (getNum v1 + getNum v2)) | SUB => (fn (v1,v2) => Num (getNum v1 - getNum v2)) | AND => (fn (v1,v2) => Bool (getBool v1 andalso getBool v2)) | OR => (fn (v1,v2) => Bool (getBool v1 orelse getBool v2)) | EQ => (fn (Num n1,Num n2) => Bool (n1 = n2) | (String s1,String s2) => Bool (s1 = s2) | (Bool b1,Bool b2) => Bool (b1 = b2) | (Loc l1,Loc l2) => Bool (l1 = l2) | (_,_) => error "uncomparable") val rec printValue = fn Num n => print_int n; print_newline() | Bool b => print_endline (if b then "true" else "false") | String s => print_endline s | _ => error "unprintable" fun eval env mem exp = case exp of CONST c => (case c of S s => String s | N n => Num n | B b => Bool b, mem) | VAR x => (env x, mem) | FN (x, e) => (Closure (Fun (x, e), env), mem) | APP (e1, e2) => let val (v1, m') = eval env mem e1 val (c, env') = getClosure v1 val (v2, m'') = eval env m' e2 in case c of Fun (x, e) => eval (env' @+ (x,v2)) m'' e | RecFun (f, x, e) => eval ((env' @+ (x,v2)) @+ (f,v1)) m'' e end | LET (NREC (x, e1), e2) => let val (v1, m') = eval env mem e1 in eval (env @+ (x,v1)) m' e2 end | LET (REC (f, e1), e2) => let val (v1, m') = eval env mem e1 val (c, env') = getClosure v1 in case c of Fun (x, e) => eval (env @+ (f, Closure (RecFun (f, x, e), env'))) m' e2 | _ => raise Invalid_argument "redundant let rec" end | IF (e1, e2, e3) => let val (v1, m') = eval env mem e1 in eval env m' (if getBool v1 then e2 else e3) end | BOP (op, e1, e2) => let val (v1, m') = eval env mem e1 val (v2, m'') = eval env m' e2 in ((op2fn op) (v1,v2), m'') end | READ => let val n = read_int () handle _ => error "read error" in (Num n, mem) end | WRITE e => let val (v, m') = eval env mem e in printValue v; (v, m') end | MALLOC e => let val (v, m') = eval env mem e val (l, m'') = malloc m' in (Loc l, store m'' (l,v)) end | ASSIGN (e1, e2) => let val (v1, m') = eval env mem e1 val l = getLoc v1 val (v, m'') = eval env m' e2 in (v, store m'' (l,v)) end | BANG e => let val (v, m') = eval env mem e in (fetch m' (getLoc v), m') end | SEQ (e1, e2) => let val (v1, m1) = eval env mem e1 in eval env m1 e2 end | PAIR (e1, e2) => let val (v1, m1) = eval env mem e1 val (v2, m2) = eval env m1 e2 in (Pair (v1, v2), m2) end | SEL1 e | SEL2 e => let val (v, m') = eval env mem e in ((if exp = SEL1 e then fst else snd) (getPair v), m') end val emptyEnv = (fn x => error ("unknown id: " ^ x)) val emptyMem = (0, fn l => error ("unknown loc: " ^ string_of_int l)) fun run exp = ignore (eval emptyEnv emptyMem exp) end