diff --git a/lib/inference.ml b/lib/inference.ml index 2e8ec50..4a3ad5b 100644 --- a/lib/inference.ml +++ b/lib/inference.ml @@ -1,43 +1,47 @@ (** Infer the type of a given term and, if exists, returns the type of the term *) -let rec typeof = function - | Term.Var id -> - (match Unification.unify (Type.Var id) Type.Int with - | Some _ -> Some Type.Int - | None -> Some (Type.Var id)) - | Term.IntConst _ -> Some Type.Int - | Term.Binop (t1, _, t2) -> - (* Both operands must have type Int *) - (match typeof t1, typeof t2 with - | Some ty1, Some ty2 -> - (match Unification.unify ty1 Type.Int, Unification.unify ty2 Type.Int with - | Some _, Some _ -> Some Type.Int - | _ -> None) - | _, _ -> None) - | Term.Pair (t1, t2) -> - (match typeof t1, typeof t2 with - | Some ty1, Some ty2 -> Some (Type.Product (ty1, ty2)) - | _, _ -> None) - | Term.Proj (proj, t) -> - (* Check if the term is a pair *) - (match typeof t with - | Some (Type.Product (ty1, ty2)) -> - (match proj with - | Term.First -> Some ty1 - | Term.Second -> Some ty2) - | _ -> None) - | Term.Fun (id, body) -> - (match typeof body with - | Some ty_body -> - (match typeof (Term.Var id) with - | Some tt -> Some (Type.Arrow (tt, ty_body)) - | None -> Some (Type.Arrow (Type.Var id, ty_body))) - | _ -> None) - | Term.App (t1, t2) -> - (* Check if the function type matches the arguments *) - (match typeof t1, typeof t2 with - | Some (Type.Arrow (ty_param, ty_fn)), Some ty_args -> - (match Unification.unify ty_param ty_args with - | Some _ -> Some ty_fn - | None -> None) - | _, _ -> None) +let typeof t = + let rec infer env = function + | Term.Var id -> + ( (match TypeSubstitution.find id env with + | Some ty -> Some ty + | None -> Some (Type.Var id)) + , env ) + | Term.IntConst _ -> Some Type.Int, env + | Term.Binop (t1, _, t2) -> + (* Both operands must have type Int *) + (match infer env t1, infer env t2 with + | (Some ty1, _), (Some ty2, _) -> + (match Unification.unify ty1 Type.Int, Unification.unify ty2 Type.Int with + | Some env1, Some env2 -> Some Type.Int, TypeSubstitution.compose env1 env2 + | _ -> None, env) + | _, _ -> None, env) + | Term.Pair (t1, t2) -> + (match infer env t1, infer env t2 with + | (Some ty1, _), (Some ty2, _) -> Some (Type.Product (ty1, ty2)), env + | _, _ -> None, env) + | Term.Proj (proj, t) -> + (* Check if the term is a pair *) + (match infer env t with + | Some (Type.Product (ty1, ty2)), _ -> + (match proj with + | Term.First -> Some ty1, env + | Term.Second -> Some ty2, env) + | _ -> None, env) + | Term.Fun (id, body) -> + (match infer env body with + | Some ty_body, env' -> + (match infer env' (Term.Var id) with + | Some ty, _ -> Some (Type.Arrow (ty, ty_body)), env' + | None, _ -> Some (Type.Arrow (Type.Var id, ty_body)), env') + | _ -> None, env) + | Term.App (t1, t2) -> + (* Check if the function type matches the arguments *) + (match infer env t1, infer env t2 with + | (Some (Type.Arrow (ty_param, ty_fn)), _), (Some ty_args, _) -> + (match Unification.unify ty_param ty_args with + | Some _ -> Some ty_fn, env + | None -> None, env) + | _, _ -> None, env) + in + fst (infer TypeSubstitution.empty t) ;; diff --git a/lib/typeSubstitution.ml b/lib/typeSubstitution.ml index 007655f..6d07155 100644 --- a/lib/typeSubstitution.ml +++ b/lib/typeSubstitution.ml @@ -37,3 +37,5 @@ let compose s2 s1 = s1 s2 ;; + +let find = IdentifierMap.find_opt diff --git a/lib/typeSubstitution.mli b/lib/typeSubstitution.mli index 2053180..7a7154f 100644 --- a/lib/typeSubstitution.mli +++ b/lib/typeSubstitution.mli @@ -6,3 +6,4 @@ val apply : t -> Type.t -> Type.t val compose : t -> t -> t val empty : t val singleton : Identifier.t -> Type.t -> t +val find : Identifier.t -> t -> Type.t option diff --git a/test/test_projet_pfa_23_24.ml b/test/test_projet_pfa_23_24.ml index c296e03..60af70a 100644 --- a/test/test_projet_pfa_23_24.ml +++ b/test/test_projet_pfa_23_24.ml @@ -11,7 +11,7 @@ let tests_typeof = , Term.(Fun (x, Fun (y, Binop (Var x, Plus, Var y)))) , Some Type.(Arrow (Int, Arrow (Int, Int))) ) ; (* Not typed variable *) - "x", Term.(Var "x"), None + "x", Term.(Var "x"), Some (Type.Var "x") ; (* Binary operation *) "1 + 2", Term.(Binop (IntConst 1, Plus, IntConst 2)), Some Type.Int ; (* Pair *)