(** Map of an id to a type *) module IdentifierMap = Map.Make (Identifier) (** Map type *) type t = Type.t IdentifierMap.t (* Empty substitution *) let empty = IdentifierMap.empty (** Create a substitution with one element *) let singleton id ty = IdentifierMap.singleton id ty (** Apply substitution to a type *) let rec apply subst = function | Type.Int -> Type.Int | Type.Var id as t -> (* Look for a substitution in the map *) (match IdentifierMap.find_opt id subst with | Some ty' -> apply subst ty' | None -> t) | Type.Product (ty1, ty2) -> Type.Product (apply subst ty1, apply subst ty2) | Type.Arrow (ty1, ty2) -> Type.Arrow (apply subst ty1, apply subst ty2) ;; (** Compose two substitutions, last with priority *) let compose s2 s1 = IdentifierMap.merge (* ID type_s1 type_s2 *) (fun _ ty1 ty2 -> match ty1, ty2 with (* Give priority to s1 *) | Some ty1', Some _ -> Some ty1' (* Use of the substitution we already have *) | Some ty1', None -> Some ty1' | None, Some ty2' -> Some ty2' (* Variable untyped *) | None, None -> None) s1 s2 ;; (** string of IdentifierMap *) let to_string map = let rec ty_str = function | Type.Var s -> "Var('" ^ s ^ "')" | Type.Int -> "Int" | Type.Product (a, b) -> "Product(" ^ ty_str a ^ ", " ^ ty_str b ^ ")" | Type.Arrow (a, b) -> "Arrow(" ^ ty_str a ^ ", " ^ ty_str b ^ ")" in "{" ^ (IdentifierMap.bindings map |> List.map (fun (id, ty) -> Printf.sprintf "'%s' typed as %s" id (ty_str ty)) |> String.concat "\n") ^ "}" ;; (** Find an element in an IdentifierMap *) let find = IdentifierMap.find_opt (** Check if two IdentifierMap are equal *) let equal map1 map2 = let l_map1 = List.length (IdentifierMap.bindings map1) in let l_map2 = List.length (IdentifierMap.bindings map2) in if l_map1 = l_map2 && l_map1 = 0 then (* Two empty maps *) true else ( (* Iterate over the largest map *) let map_forall, map_find = if l_map1 > l_map2 then map1, map2 else map2, map1 in IdentifierMap.for_all (fun key value -> match IdentifierMap.find_opt key map_find with | Some value' -> (* Equality between the two keys *) value = value' | _ -> (* Key in map_find doesn't exists in map_forall *) false) map_forall) ;;