Ex1.hs

{- Type inference -}

import Maps
import Result
import Misc

import Data.List

-- Abstract syntax
data Expr = Integer Integer
          | Boolean Bool
          | Op String Expr Expr
          | And Expr Expr
          | Leq Expr Expr
          | If Expr Expr Expr
          | Var Variable
          | Let Variable Expr Expr
          | Lam Variable Expr --- remove type annotation
          | App Expr Expr
          | Fix Expr 
          deriving (Show, Eq)

type TyVariable = String

-- Types
data Type = TyInteger
          | TyBoolean
          | TyArrow Type Type
          | TyVar TyVariable --- add a type variable
          deriving (Show, Eq)

type TyEnv = Map String Type

-- Constraints:
-- E.g.,
--   X = TyInteger
--   X = Integer -> Integer
--   X -> X = Integer -> Integer
--   X -> Integer =? Integer -> Boolean
type Constraint = (Type, Type)

-- The type checker now computes a preliminary type together with a set of constraints
typeOf :: TyEnv -> Expr -> Result (Type, [Constraint])
typeOf _ (Integer i) = return (TyInteger, [])
typeOf _ (Boolean b) = return (TyBoolean, [])
typeOf tenv (Op _ e1 e2) = do
    (t1, c1) <- typeOf tenv e1
    (t2, c2) <- typeOf tenv e2
    return (TyInteger, [(t1, TyInteger), (t2, TyInteger)] ++ c1 ++ c2)
typeOf tenv (And e1 e2) = do
    (t1, c1)  <- typeOf tenv e1
    (t2, c2) <- typeOf tenv e2
    return (TyBoolean, [(t1, TyBoolean), (t2, TyBoolean)] ++ c1 ++ c2)
typeOf tenv (Leq e1 e2) = do
    (t1, c1) <- typeOf tenv e1
    (t2, c2) <- typeOf tenv e2
    return (TyBoolean, [(t1, TyInteger), (t2, TyInteger)] ++ c1 ++ c2)
typeOf tenv (If e1 e2 e3) = do
    (t1, c1) <- typeOf tenv e1
    (t2, c2) <- typeOf tenv e2
    (t3, c3) <- typeOf tenv e3
    return (t2, [(t1, TyBoolean), (t2, t3)] ++ c1 ++ c2 ++ c3)
typeOf tenv (Var x) = do
    t <- fromMaybe' ("Variable " ++ x ++ " is not bound") $ get tenv x
    return (t, [])
typeOf tenv (Let x e1 e2) = do
    (t1, c1) <- typeOf tenv e1
    (t2, c2) <- typeOf (set tenv x t1) e2
    return (t2, c1 ++ c2)
typeOf tenv (Lam x e) = do
    let tX = TyVar $ gensym "X"
    (t', c') <- typeOf (set tenv x tX) e -- if x then x + 1 else 3 -> X = TyBoolean, X = TyInteger
    return (TyArrow tX t', c')
typeOf tenv (App e1 e2) = do
    (t1, c1) <- typeOf tenv e1
    (t2, c2) <- typeOf tenv e2
    let tX = TyVar $ gensym "Y"
    return (tX, (t1, TyArrow t2 tX) : c1 ++ c2) 
typeOf tenv (Fix e) = do
    (t, c) <- typeOf tenv e
    let tX = TyVar $ gensym "Z"
    return (tX, (t, TyArrow tX tX) : c)

-- type substitution
tySubst :: TyVariable -> Type -> Type -> Type
tySubst x t (TyArrow t1 t2) = TyArrow (tySubst x t t1) (tySubst x t t2)
tySubst x t (TyVar y) | x == y = t
                      | otherwise = TyVar y
tySubst _ _ t' = t'

-- free type variables in a type
freeTyVars :: Type -> [TyVariable]
freeTyVars (TyArrow t1 t2) = freeTyVars t1 `union` freeTyVars t2
freeTyVars (TyVar x) = [x]
freeTyVars _ = []


-- a substitution is an associative list of type variables and types
type Substitution = [(TyVariable, Type)]

-- applying a Substitution
applySubst :: Substitution -> Type -> Type
applySubst ((x, t) : s) t' = applySubst s (tySubst x t t')
applySubst [] t' = t'

-- solve a set of constraints by unification
unify :: [Constraint] -> Result Substitution
unify ((t1, t2) : rest) | t1 == t2 = unify rest
unify ((TyVar x, t) : rest) | x `notElem` freeTyVars t = do
    s <- unify $ map (\(t1, t2) -> (tySubst x t t1, tySubst x t t2)) rest
    return $ (x, t) : s
unify ((t, TyVar x) : rest) | x `notElem` freeTyVars t = do
    s <- unify $ map (\(t1, t2) -> (tySubst x t t1, tySubst x t t2)) rest
    return $ (x, t) : s
unify ((TyArrow t1 t2, TyArrow t1' t2') : rest) = do 
    unify $ [(t1, t1'), (t2, t2')] ++ rest
unify [] = return []
unify ((t1, t2) : _) = fail $ "Could not unify " ++ showType t1 ++ " and " ++ showType t2

-- step-wise unification, returning an intermediate substitution and constraint
unifyStep :: [Constraint] -> Result (Substitution, [Constraint])
unifyStep ((t1, t2) : rest) | t1 == t2 = 
    return ([], rest)
unifyStep ((TyVar x, t) : rest) | x `notElem` freeTyVars t = do
    let rest' = map (\(t1, t2) -> (tySubst x t t1, tySubst x t t2)) rest
    return ([(x, t)], rest')
unifyStep ((t, TyVar x) : rest) | x `notElem` freeTyVars t = do
    let rest' = map (\(t1, t2) -> (tySubst x t t1, tySubst x t t2)) rest
    return ([(x, t)], rest')
unifyStep ((TyArrow t1 t2, TyArrow t1' t2') : rest) = do 
    return ([], [(t1, t1'), (t2, t2')] ++ rest)
unifyStep [] = return ([], [])
unifyStep ((t1, t2) : _) = fail $ "Could not unify " ++ showType t1 ++ " and " ++ showType t2


-------------- Examples ---------------------

-- abbreviations
tyArrow ts@(_ : _) = foldr1 TyArrow ts
app es = foldl1 App es
lam args e = foldr (\x e -> Lam x e) e args


ex1 = Op "+" (Integer 1) (Integer 3)
ex2 = And (Integer 1) (Boolean True)
ex3 = And (Boolean True) (And (Boolean False) (Boolean True))
ex4 = If (Boolean False) (Integer 1) (Boolean True)
ex5 = If (Boolean False) (Integer 1) (Integer 3)
ex6 = Op "+" (Integer 1) (If (Leq (Integer 4) (Integer 5)) (Integer 3) (Boolean False))
ex7 = Let "x" (Op "+" (Integer 1) (Integer 3)) 
          (Op "+" (Var "x") (Var "x"))
ex8 = Lam "x" (Var "x")
ex9 = App ex8 (Op "+" (Integer 1) (Integer 2))
ex10 = (Lam "x" (Lam "y" (Leq (Var "x") (Var "y"))))
ex11 = (App ex10 (Integer 10))
ex12 = (App ex11 (Integer 1))
ex13 = (App ex11 (Boolean True))
ex14 = (Lam "sum"
         (Lam "n" 
           (If (Leq (Var "n") (Integer 0))
               (Integer 0) 
               (Op "+" (Var "n")
                    (App (Var "sum")
                         (Op "-" (Var "n") (Integer 1)))))))
ex15 = Fix ex14
-- flip
ex16 = lam [ "f", "x", "y"] $
         app [Var "f", Var "y", Var "x"]
-- while the identity function (ex8) is initially assigned a type only 
-- involving variables, there is only one way of instantiating them, so the 
-- following fails to type-check
ex17 = Let "id" ex8 $
           If (App (Var "id") $ Boolean True)
              (App (Var "id") $ Integer 3)
              (Integer 2)

showType :: Type -> String 
showType TyInteger = "Integer"
showType TyBoolean = "Boolean"
showType (TyArrow t1 t2) = "(" ++ showType t1 ++ " -> " ++ showType t2 ++ ")"
showType (TyVar x) = x

showExpr :: Expr -> String
showExpr (Integer i) = show i
showExpr (Boolean b) = show b
showExpr (Op op e1 e2) = "(" ++ showExpr e1 ++ " " ++ op ++ " " ++ showExpr e2 ++ ")"
showExpr (And e1 e2) = "(" ++ showExpr e1 ++ " && " ++ showExpr e2 ++ ")"
showExpr (Leq e1 e2) = "(" ++ showExpr e1 ++ " <= " ++ showExpr e2 ++ ")"
showExpr (If e1 e2 e3) = "if " ++ showExpr e1 ++ " then " ++ showExpr e2 ++ " else " ++ showExpr e3
showExpr (Var x) = x 
showExpr (Let x e1 e2) = "let " ++ x ++ " = " ++ showExpr e1 ++ " in " ++ showExpr e2
showExpr (Lam x e) = "(λ" ++ x ++ ". " ++ showExpr e ++ ")"
showExpr (App e1 e2) = "(" ++ showExpr e1 ++ " " ++ showExpr e2 ++ ")"
showExpr (Fix e) = "(fix " ++ showExpr e ++ ")"

typeCheck e = do
    putStrLn $ "Checking: " ++ showExpr e
    (t, c) <- toIO $ typeOf empty e
    putStrLn $ "Initial type: " ++ showType t
    putStrLn $ "Initial set of constraints:"
    mapM_ (\(t1, t2) -> putStrLn $ "   " ++ showType t1 ++ " = " ++ showType t2) c
    s <- toIO $ unify c
    putStrLn $ "Final type: " ++ (showType $ applySubst s t)

-- Interactive step-wise unification for inferring the final type
-- E.g., in GHCi:
--
--   > typeCheckStep ex14
--
typeCheckStep e = do
    putStrLn $ "Checking: " ++ showExpr e
    (t, c) <- toIO $ typeOf empty e
    t' <- loop t [] c
    putStrLn $ "Done. Final type: " ++ showType t'
  where 
    loop t s c = do
        let t' = applySubst s t
        printInfo t' s c
        putStrLn $ "Press enter to continue..."
        getLine
        (s', c') <- toIO $ unifyStep c
        if c' == c
            then return t'
            else loop t' (s' ++ s) c'

    printInfo t s c = do
        putStrLn $ "Current type: " ++ showType t
        putStrLn $ "Current constraints:"
        mapM_ (\(t1, t2) -> putStrLn $ "   " ++ showType t1 ++ " = " ++ showType t2) c
        putStrLn $ "Current substitution: "
        putStrLn $ "  " ++ concatMap (\(x, t) -> x ++ " := " ++ showType t ++ ", ") s