diff --git a/lib/inference.ml b/lib/inference.ml index e7f78d5..416098a 100644 --- a/lib/inference.ml +++ b/lib/inference.ml @@ -1,19 +1,32 @@ +(** Infer the type of a given term and, if exists, returns the type of the term *) let rec typeof = function - | Term.Var _ -> - (* Une variable n'a pas de type *) - None + | Term.Var _ -> None | Term.IntConst _ -> Some Type.Int | Term.Binop (t1, _, t2) -> - (* Les 2 types de l'opération sont égaux *) (match typeof t1, typeof t2 with - | Some (_ as ty1), Some (_ as ty2) -> if ty1 = ty2 then Some ty1 else None + | ty1, ty2 when ty1 = ty2 -> Some Type.Int | _ -> None) | Term.Pair (t1, t2) -> - (* On forme le produit *) + (* Pair give Products *) (match typeof t1, typeof t2 with | Some ty1, Some ty2 -> Some (Type.Product (ty1, ty2)) | _ -> None) - | Term.Proj (_proj, _t) -> failwith "TODO" - | Term.Fun (_, _) -> failwith "TODO" - | Term.App (_t1, _t2) -> failwith "TODO" + | Term.Proj (proj, t) -> + (* Projections returns type of product based on the projection type *) + (match proj, typeof t with + | Term.First, Some (Type.Product (ty, _)) | Term.Second, Some (Type.Product (_, ty)) + -> Some ty + | _, _ -> None) + | Term.Fun (id, t) -> + (match typeof t with + | Some body_type -> Some (Type.Arrow (Type.Int, body_type)) + | None -> Some (Type.Var id)) + | Term.App (t1, t2) -> + (match typeof t1, typeof t2 with + | Some (Type.Arrow (ty_arg, ty_res)), Some ty_arg' -> + (* Unification for application *) + (match Unification.unify ty_arg ty_arg' with + | Some subst -> Some (TypeSubstitution.apply subst ty_res) + | None -> None) + | _ -> None) ;; diff --git a/lib/typeSubstitution.ml b/lib/typeSubstitution.ml index 9137223..efc38b4 100644 --- a/lib/typeSubstitution.ml +++ b/lib/typeSubstitution.ml @@ -1,4 +1,38 @@ -type t = Type.t Map.Make(Identifier).t +(** Map of an id to a type *) +module IdentifierMap = Map.Make (Identifier) -let apply _t _tt = failwith "TODO" -let compose _s2 _s1 = failwith "TODO" +(** Map type *) +type t = Type.t IdentifierMap.t + +(* Empty substitution *) +let empty = IdentifierMap.empty + +(** Create a substitution with one element *) +let singleton id ty = IdentifierMap.singleton id ty + +(** Apply substitution to a type *) +let rec apply subst = function + | Type.Var id as t -> + (match IdentifierMap.find_opt id subst with + | Some ty' -> apply subst ty' + | None -> t) + | Type.Int -> Type.Int + | Type.Product (ty1, ty2) -> Type.Product (apply subst ty1, apply subst ty2) + | Type.Arrow (ty1, ty2) -> Type.Arrow (apply subst ty1, apply subst ty2) +;; + +(** Compose two substitutions *) +let compose s2 s1 = + IdentifierMap.merge + (fun _ ty1 ty2 -> + match ty1, ty2 with + (* If we have 2, we pick one of them *) + | Some ty1, Some _ -> Some (apply s2 ty1) + (* If we have 1, we pick the one we have *) + | Some ty1, None -> Some (apply s2 ty1) + | None, Some ty2 -> Some (apply s2 ty2) + (* If we have 0, we return nothing *) + | None, None -> None) + s1 + s2 +;; diff --git a/lib/typeSubstitution.mli b/lib/typeSubstitution.mli index 92362c7..5b852aa 100644 --- a/lib/typeSubstitution.mli +++ b/lib/typeSubstitution.mli @@ -4,3 +4,5 @@ val apply : t -> Type.t -> Type.t (* compose s2 s1 : first s1, then s2 *) val compose : t -> t -> t +val empty : t +val singleton : Identifier.t -> Type.t -> Type.t Map.Make(Identifier).t diff --git a/lib/unification.ml b/lib/unification.ml index d89600c..4b70f6b 100644 --- a/lib/unification.ml +++ b/lib/unification.ml @@ -1 +1,16 @@ -let unify _ty1 _ty2 = failwith "TODO" +(** Unify 2 types and, if exists, returns the substitution *) +let rec unify ty1 ty2 = + match ty1, ty2 with + | Type.Product (p1_ty1, p1_ty2), Type.Product (p2_ty1, p2_ty2) + | Type.Arrow (p1_ty1, p1_ty2), Type.Arrow (p2_ty1, p2_ty2) -> + (match unify p1_ty1 p2_ty1 with + | Some s1 -> + (match + unify (TypeSubstitution.apply s1 p1_ty2) (TypeSubstitution.apply s1 p2_ty2) + with + | Some s2 -> Some (TypeSubstitution.compose s2 s1) + | None -> None) + | None -> None) + | ty1, ty2 when ty1 = ty2 -> Some TypeSubstitution.empty + | _ -> None +;; diff --git a/test/test_projet_pfa_23_24.ml b/test/test_projet_pfa_23_24.ml index 591711c..c296e03 100644 --- a/test/test_projet_pfa_23_24.ml +++ b/test/test_projet_pfa_23_24.ml @@ -3,30 +3,38 @@ open TypeInference let tests_typeof = let x = Identifier.fresh () in let y = Identifier.fresh () in - [ (* Int Const *) + let z = Identifier.fresh () in + [ (* IntConst *) "0", Term.IntConst 0, Some Type.Int - ; (* Correct function *) + ; (* int -> int -> int = *) ( "fun x -> fun y -> x + y" , Term.(Fun (x, Fun (y, Binop (Var x, Plus, Var y)))) , Some Type.(Arrow (Int, Arrow (Int, Int))) ) ; (* Not typed variable *) - "x", Var "x", None - ; (* Operation *) - "1 + 2", Binop (IntConst 1, Plus, IntConst 2), Some Int + "x", Term.(Var "x"), None + ; (* Binary operation *) + "1 + 2", Term.(Binop (IntConst 1, Plus, IntConst 2)), Some Type.Int ; (* Pair *) - "(1, 2)", Pair (IntConst 1, IntConst 2), Some (Product (Int, Int)) + "(1, 2)", Term.(Pair (IntConst 1, IntConst 2)), Some Type.(Product (Int, Int)) ; (* Projection with first *) - "fst (1, 2)", Proj (First, Pair (IntConst 1, IntConst 2)), Some Int + "fst (1, 2)", Term.(Proj (First, Pair (IntConst 1, IntConst 2))), Some Type.Int ; (* Projection with second *) - "snd (1, 2)", Proj (Second, Pair (IntConst 1, IntConst 2)), Some Int + "snd (1, 2)", Term.(Proj (Second, Pair (IntConst 1, IntConst 2))), Some Type.Int ; (* Apply (int) into (fun : int -> int) *) ( "(fun x -> x + 1) 5" - , App (Fun (x, Binop (Var x, Plus, IntConst 1)), IntConst 5) + , Term.(App (Fun (x, Binop (Var x, Plus, IntConst 1)), IntConst 5)) , Some Type.Int ) ; (* Apply product (int * int) into a not compatible function (fun : int -> int) *) ( "(fun x -> x + 1) (1, 2)" - , App (Fun (x, Binop (Var x, Plus, IntConst 1)), Pair (IntConst 1, IntConst 2)) + , Term.(App (Fun (x, Binop (Var x, Plus, IntConst 1)), Pair (IntConst 1, IntConst 2))) , None ) + ; (* x -> y -> (x -> y -> z) -> z *) + ( "fun x y -> fun z -> z x y" + , Term.(Fun (x, Fun (y, Fun (z, App (Var z, App (Var x, Var y)))))) + , Some + Type.( + Arrow (Var x, Arrow (Var y, Arrow (Arrow (Var x, Arrow (Var y, Var z)), Var z)))) + ) ] ;;