1(**************************************************************************)
2(*                                                                        *)
3(*                                 OCaml                                  *)
4(*                                                                        *)
5(*                       Pierre Chambart, OCamlPro                        *)
6(*           Mark Shinwell and Leo White, Jane Street Europe              *)
7(*                                                                        *)
8(*   Copyright 2013--2016 OCamlPro SAS                                    *)
9(*   Copyright 2014--2016 Jane Street Group LLC                           *)
10(*                                                                        *)
11(*   All rights reserved.  This file is distributed under the terms of    *)
12(*   the GNU Lesser General Public License version 2.1, with the          *)
13(*   special exception on linking described in the file LICENSE.          *)
14(*                                                                        *)
15(**************************************************************************)
16
17module Stdlib_map = Map
18module Stdlib_set = Set
19
20module type Thing = sig
21  type t
22
23  include Hashtbl.HashedType with type t := t
24  include Map.OrderedType with type t := t
25
26  val output : out_channel -> t -> unit
27  val print : Format.formatter -> t -> unit
28end
29
30module Pair (A : Thing) (B : Thing) : Thing with type t = A.t * B.t = struct
31  type t = A.t * B.t
32
33  let compare (a1, b1) (a2, b2) =
34    let c = A.compare a1 a2 in
35    if c <> 0 then c
36    else B.compare b1 b2
37
38  let output oc (a, b) = Printf.fprintf oc " (%a, %a)" A.output a B.output b
39  let hash (a, b) = Hashtbl.hash (A.hash a, B.hash b)
40  let equal (a1, b1) (a2, b2) = A.equal a1 a2 && B.equal b1 b2
41  let print ppf (a, b) = Format.fprintf ppf " (%a, @ %a)" A.print a B.print b
42end
43
44module Make_map (T : Thing) = struct
45  include Map.Make (T)
46
47  let filter_map t ~f =
48    fold (fun id v map ->
49        match f id v with
50        | None -> map
51        | Some r -> add id r map) t empty
52
53  let of_list l =
54    List.fold_left (fun map (id, v) -> add id v map) empty l
55
56  let disjoint_union ?eq ?print m1 m2 =
57    union (fun id v1 v2 ->
58        let ok = match eq with
59          | None -> false
60          | Some eq -> eq v1 v2
61        in
62        if not ok then
63          let err =
64            match print with
65            | None ->
66              Format.asprintf "Map.disjoint_union %a" T.print id
67            | Some print ->
68              Format.asprintf "Map.disjoint_union %a => %a <> %a"
69                T.print id print v1 print v2
70          in
71          Misc.fatal_error err
72        else Some v1)
73      m1 m2
74
75  let union_right m1 m2 =
76    merge (fun _id x y -> match x, y with
77        | None, None -> None
78        | None, Some v
79        | Some v, None
80        | Some _, Some v -> Some v)
81      m1 m2
82
83  let union_left m1 m2 = union_right m2 m1
84
85  let union_merge f m1 m2 =
86    let aux _ m1 m2 =
87      match m1, m2 with
88      | None, m | m, None -> m
89      | Some m1, Some m2 -> Some (f m1 m2)
90    in
91    merge aux m1 m2
92
93  let rename m v =
94    try find v m
95    with Not_found -> v
96
97  let map_keys f m =
98    of_list (List.map (fun (k, v) -> f k, v) (bindings m))
99
100  let print f ppf s =
101    let elts ppf s = iter (fun id v ->
102        Format.fprintf ppf "@ (@[%a@ %a@])" T.print id f v) s in
103    Format.fprintf ppf "@[<1>{@[%a@ @]}@]" elts s
104
105  module T_set = Set.Make (T)
106
107  let keys map = fold (fun k _ set -> T_set.add k set) map T_set.empty
108
109  let data t = List.map snd (bindings t)
110
111  let of_set f set = T_set.fold (fun e map -> add e (f e) map) set empty
112
113  let transpose_keys_and_data map = fold (fun k v m -> add v k m) map empty
114  let transpose_keys_and_data_set map =
115    fold (fun k v m ->
116        let set =
117          match find v m with
118          | exception Not_found ->
119            T_set.singleton k
120          | set ->
121            T_set.add k set
122        in
123        add v set m)
124      map empty
125end
126
127module Make_set (T : Thing) = struct
128  include Set.Make (T)
129
130  let output oc s =
131    Printf.fprintf oc " ( ";
132    iter (fun v -> Printf.fprintf oc "%a " T.output v) s;
133    Printf.fprintf oc ")"
134
135  let print ppf s =
136    let elts ppf s = iter (fun e -> Format.fprintf ppf "@ %a" T.print e) s in
137    Format.fprintf ppf "@[<1>{@[%a@ @]}@]" elts s
138
139  let to_string s = Format.asprintf "%a" print s
140
141  let of_list l = match l with
142    | [] -> empty
143    | [t] -> singleton t
144    | t :: q -> List.fold_left (fun acc e -> add e acc) (singleton t) q
145
146  let map f s = of_list (List.map f (elements s))
147end
148
149module Make_tbl (T : Thing) = struct
150  include Hashtbl.Make (T)
151
152  module T_map = Make_map (T)
153
154  let to_list t =
155    fold (fun key datum elts -> (key, datum)::elts) t []
156
157  let of_list elts =
158    let t = create 42 in
159    List.iter (fun (key, datum) -> add t key datum) elts;
160    t
161
162  let to_map v = fold T_map.add v T_map.empty
163
164  let of_map m =
165    let t = create (T_map.cardinal m) in
166    T_map.iter (fun k v -> add t k v) m;
167    t
168
169  let memoize t f = fun key ->
170    try find t key with
171    | Not_found ->
172      let r = f key in
173      add t key r;
174      r
175
176  let map t f =
177    of_map (T_map.map f (to_map t))
178end
179
180module type S = sig
181  type t
182
183  module T : Thing with type t = t
184  include Thing with type t := T.t
185
186  module Set : sig
187    include Stdlib_set.S
188      with type elt = T.t
189      and type t = Make_set (T).t
190
191    val output : out_channel -> t -> unit
192    val print : Format.formatter -> t -> unit
193    val to_string : t -> string
194    val of_list : elt list -> t
195    val map : (elt -> elt) -> t -> t
196  end
197
198  module Map : sig
199    include Stdlib_map.S
200      with type key = T.t
201      and type 'a t = 'a Make_map (T).t
202
203    val filter_map : 'a t -> f:(key -> 'a -> 'b option) -> 'b t
204    val of_list : (key * 'a) list -> 'a t
205    val disjoint_union : ?eq:('a -> 'a -> bool) -> ?print:(Format.formatter -> 'a -> unit) -> 'a t -> 'a t -> 'a t
206    val union_right : 'a t -> 'a t -> 'a t
207    val union_left : 'a t -> 'a t -> 'a t
208    val union_merge : ('a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
209    val rename : key t -> key -> key
210    val map_keys : (key -> key) -> 'a t -> 'a t
211    val keys : 'a t -> Make_set (T).t
212    val data : 'a t -> 'a list
213    val of_set : (key -> 'a) -> Make_set (T).t -> 'a t
214    val transpose_keys_and_data : key t -> key t
215    val transpose_keys_and_data_set : key t -> Set.t t
216    val print :
217      (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit
218  end
219
220  module Tbl : sig
221    include Hashtbl.S
222      with type key = T.t
223      and type 'a t = 'a Hashtbl.Make (T).t
224
225    val to_list : 'a t -> (T.t * 'a) list
226    val of_list : (T.t * 'a) list -> 'a t
227
228    val to_map : 'a t -> 'a Make_map (T).t
229    val of_map : 'a Make_map (T).t -> 'a t
230    val memoize : 'a t -> (key -> 'a) -> key -> 'a
231    val map : 'a t -> ('a -> 'b) -> 'b t
232  end
233end
234
235module Make (T : Thing) = struct
236  module T = T
237  include T
238
239  module Set = Make_set (T)
240  module Map = Make_map (T)
241  module Tbl = Make_tbl (T)
242end
243