2024-03-28 19:20:37 +01:00
|
|
|
(** Map of an id to a type *)
|
|
|
|
module IdentifierMap = Map.Make (Identifier)
|
2024-03-14 22:24:38 +01:00
|
|
|
|
2024-03-28 19:20:37 +01:00
|
|
|
(** Map type *)
|
|
|
|
type t = Type.t IdentifierMap.t
|
|
|
|
|
|
|
|
(* Empty substitution *)
|
|
|
|
let empty = IdentifierMap.empty
|
|
|
|
|
|
|
|
(** Create a substitution with one element *)
|
|
|
|
let singleton id ty = IdentifierMap.singleton id ty
|
|
|
|
|
|
|
|
(** Apply substitution to a type *)
|
|
|
|
let rec apply subst = function
|
2024-04-11 11:18:14 +02:00
|
|
|
| Type.Int -> Type.Int
|
2024-03-28 19:20:37 +01:00
|
|
|
| Type.Var id as t ->
|
2024-04-11 11:18:14 +02:00
|
|
|
(* Look for a substitution in the map *)
|
2024-03-28 19:20:37 +01:00
|
|
|
(match IdentifierMap.find_opt id subst with
|
|
|
|
| Some ty' -> apply subst ty'
|
|
|
|
| None -> t)
|
|
|
|
| Type.Product (ty1, ty2) -> Type.Product (apply subst ty1, apply subst ty2)
|
|
|
|
| Type.Arrow (ty1, ty2) -> Type.Arrow (apply subst ty1, apply subst ty2)
|
|
|
|
;;
|
|
|
|
|
|
|
|
(** Compose two substitutions *)
|
|
|
|
let compose s2 s1 =
|
|
|
|
IdentifierMap.merge
|
|
|
|
(fun _ ty1 ty2 ->
|
|
|
|
match ty1, ty2 with
|
|
|
|
(* If we have 2, we pick one of them *)
|
2024-04-11 11:18:14 +02:00
|
|
|
| Some ty1', Some _ -> Some (apply s2 ty1')
|
2024-03-28 19:20:37 +01:00
|
|
|
(* If we have 1, we pick the one we have *)
|
2024-04-11 11:18:14 +02:00
|
|
|
| Some ty1', None -> Some (apply s2 ty1')
|
|
|
|
| None, Some ty2' -> Some (apply s2 ty2')
|
2024-03-28 19:20:37 +01:00
|
|
|
(* If we have 0, we return nothing *)
|
|
|
|
| None, None -> None)
|
|
|
|
s1
|
|
|
|
s2
|
|
|
|
;;
|
2024-04-13 15:51:57 +02:00
|
|
|
|
2024-04-13 20:15:39 +02:00
|
|
|
let to_string map =
|
|
|
|
let rec ty_str = function
|
|
|
|
| Type.Var s -> "Var('" ^ s ^ "')"
|
|
|
|
| Type.Int -> "Int"
|
|
|
|
| Type.Product (a, b) -> "Product(" ^ ty_str a ^ ", " ^ ty_str b ^ ")"
|
|
|
|
| Type.Arrow (a, b) -> "Arrow(" ^ ty_str a ^ ", " ^ ty_str b ^ ")"
|
|
|
|
in
|
|
|
|
"{"
|
|
|
|
^ (IdentifierMap.bindings map
|
|
|
|
|> List.map (fun (id, ty) -> Printf.sprintf "'%s' typed as %s" id (ty_str ty))
|
|
|
|
|> String.concat "\n")
|
|
|
|
^ "}"
|
|
|
|
;;
|
|
|
|
|
2024-04-13 15:51:57 +02:00
|
|
|
let find = IdentifierMap.find_opt
|