diff --git a/lib/inference.ml b/lib/inference.ml index 416098a..2e8ec50 100644 --- a/lib/inference.ml +++ b/lib/inference.ml @@ -1,32 +1,43 @@ (** Infer the type of a given term and, if exists, returns the type of the term *) let rec typeof = function - | Term.Var _ -> None + | Term.Var id -> + (match Unification.unify (Type.Var id) Type.Int with + | Some _ -> Some Type.Int + | None -> Some (Type.Var id)) | Term.IntConst _ -> Some Type.Int | Term.Binop (t1, _, t2) -> + (* Both operands must have type Int *) (match typeof t1, typeof t2 with - | ty1, ty2 when ty1 = ty2 -> Some Type.Int - | _ -> None) + | Some ty1, Some ty2 -> + (match Unification.unify ty1 Type.Int, Unification.unify ty2 Type.Int with + | Some _, Some _ -> Some Type.Int + | _ -> None) + | _, _ -> None) | Term.Pair (t1, t2) -> - (* Pair give Products *) (match typeof t1, typeof t2 with | Some ty1, Some ty2 -> Some (Type.Product (ty1, ty2)) - | _ -> None) - | Term.Proj (proj, t) -> - (* Projections returns type of product based on the projection type *) - (match proj, typeof t with - | Term.First, Some (Type.Product (ty, _)) | Term.Second, Some (Type.Product (_, ty)) - -> Some ty | _, _ -> None) - | Term.Fun (id, t) -> + | Term.Proj (proj, t) -> + (* Check if the term is a pair *) (match typeof t with - | Some body_type -> Some (Type.Arrow (Type.Int, body_type)) - | None -> Some (Type.Var id)) - | Term.App (t1, t2) -> - (match typeof t1, typeof t2 with - | Some (Type.Arrow (ty_arg, ty_res)), Some ty_arg' -> - (* Unification for application *) - (match Unification.unify ty_arg ty_arg' with - | Some subst -> Some (TypeSubstitution.apply subst ty_res) - | None -> None) + | Some (Type.Product (ty1, ty2)) -> + (match proj with + | Term.First -> Some ty1 + | Term.Second -> Some ty2) | _ -> None) + | Term.Fun (id, body) -> + (match typeof body with + | Some ty_body -> + (match typeof (Term.Var id) with + | Some tt -> Some (Type.Arrow (tt, ty_body)) + | None -> Some (Type.Arrow (Type.Var id, ty_body))) + | _ -> None) + | Term.App (t1, t2) -> + (* Check if the function type matches the arguments *) + (match typeof t1, typeof t2 with + | Some (Type.Arrow (ty_param, ty_fn)), Some ty_args -> + (match Unification.unify ty_param ty_args with + | Some _ -> Some ty_fn + | None -> None) + | _, _ -> None) ;;