1(**************************************************************************)
2(*                                                                        *)
3(*                                 OCaml                                  *)
4(*                                                                        *)
5(*             Xavier Leroy, projet Cristal, INRIA Rocquencourt           *)
6(*                                                                        *)
7(*   Copyright 1996 Institut National de Recherche en Informatique et     *)
8(*     en Automatique.                                                    *)
9(*                                                                        *)
10(*   All rights reserved.  This file is distributed under the terms of    *)
11(*   the GNU Lesser General Public License version 2.1, with the          *)
12(*   special exception on linking described in the file LICENSE.          *)
13(*                                                                        *)
14(**************************************************************************)
15
16(* Translation of string matching from closed lambda to C-- *)
17
18open Lambda
19open Cmm
20
21module type I = sig
22  val string_block_length : Cmm.expression -> Cmm.expression
23  val transl_switch :
24      Cmm.expression -> int -> int ->
25        (int * Cmm.expression) list -> Cmm.expression ->
26          Cmm.expression
27end
28
29module Make(I:I) = struct
30
31(* Debug *)
32
33  let dbg = false
34
35  let mask =
36    let open Nativeint in
37    sub (shift_left one 8) one
38
39  let pat_as_string p =
40    let rec digits k n p =
41      if n <= 0 then k
42      else
43        let d = Nativeint.to_int (Nativeint.logand mask p) in
44        let d = Char.escaped (Char.chr d) in
45        digits (d::k) (n-1) (Nativeint.shift_right_logical p  8) in
46    let ds = digits [] Arch.size_addr p in
47    let ds =
48      if Arch.big_endian then ds else List.rev ds in
49    String.concat "" ds
50
51  let do_pp_cases chan cases =
52    List.iter
53      (fun (ps,_) ->
54        Printf.fprintf chan "  [%s]\n"
55          (String.concat "; " (List.map pat_as_string ps)))
56      cases
57
58  let pp_cases chan tag cases =
59    Printf.eprintf "%s:\n" tag ;
60    do_pp_cases chan cases
61
62  let pp_match chan tag idxs cases =
63    Printf.eprintf
64      "%s: idx=[%s]\n" tag
65      (String.concat "; " (List.map string_of_int idxs)) ;
66    do_pp_cases chan cases
67
68(* Utilities *)
69
70  let gen_cell_id () = Ident.create "cell"
71  let gen_size_id () = Ident.create "size"
72
73  let mk_let_cell id str ind body =
74    let dbg = Debuginfo.none in
75    let cell =
76      Cop(Cload (Word_int, Asttypes.Mutable),
77        [Cop(Cadda,[str;Cconst_int(Arch.size_int*ind)], dbg)],
78        dbg) in
79    Clet(id, cell, body)
80
81  let mk_let_size id str body =
82    let size = I.string_block_length str in
83    Clet(id, size, body)
84
85  let mk_cmp_gen cmp_op id nat ifso ifnot =
86    let dbg = Debuginfo.none in
87    let test =
88      Cop (Ccmpi cmp_op, [ Cvar id; Cconst_natpointer nat ], dbg)
89    in
90    Cifthenelse (test, ifso, ifnot)
91
92  let mk_lt = mk_cmp_gen Clt
93  let mk_eq = mk_cmp_gen Ceq
94
95  module IntArg =
96    struct
97      type t = int
98      let compare (x:int) (y:int) =
99        if x < y then -1
100        else if x > y then 1
101        else 0
102    end
103
104  let interval m0 n =
105    let rec do_rec m =
106      if m >= n then []
107      else m::do_rec (m+1) in
108    do_rec m0
109
110
111(*****************************************************)
112(* Compile strings to a lists of words [native ints] *)
113(*****************************************************)
114
115  let pat_of_string str =
116    let len = String.length str in
117    let n = len / Arch.size_addr + 1 in
118    let get_byte i =
119      if i < len then int_of_char str.[i]
120      else if i < n * Arch.size_addr - 1 then 0
121      else n * Arch.size_addr - 1 - len in
122    let mk_word ind =
123      let w = ref 0n in
124      let imin = ind * Arch.size_addr
125      and imax = (ind + 1) * Arch.size_addr - 1 in
126      if Arch.big_endian then
127        for i = imin to imax do
128          w := Nativeint.logor (Nativeint.shift_left !w 8)
129              (Nativeint.of_int (get_byte i));
130        done
131      else
132        for i = imax downto imin do
133          w := Nativeint.logor (Nativeint.shift_left !w 8)
134              (Nativeint.of_int (get_byte i));
135        done;
136      !w in
137    let rec mk_words ind  =
138      if ind >= n then []
139      else mk_word ind::mk_words (ind+1) in
140    mk_words 0
141
142(*****************************)
143(* Discriminating heuristics *)
144(*****************************)
145
146  module IntSet = Set.Make(IntArg)
147  module NativeSet = Set.Make(Nativeint)
148
149  let rec add_one sets ps = match sets,ps with
150  | [],[] -> []
151  | set::sets,p::ps ->
152      let sets = add_one sets ps in
153      NativeSet.add p set::sets
154  | _,_ -> assert false
155
156  let count_arities cases = match cases with
157  | [] -> assert false
158  | (ps,_)::_ ->
159      let sets =
160        List.fold_left
161          (fun sets (ps,_) -> add_one sets ps)
162          (List.map (fun _ -> NativeSet.empty) ps) cases in
163      List.map NativeSet.cardinal sets
164
165  let count_arities_first cases =
166    let set =
167      List.fold_left
168        (fun set case -> match case with
169        | (p::_,_) -> NativeSet.add p set
170        | _ -> assert false)
171        NativeSet.empty cases in
172    NativeSet.cardinal set
173
174  let count_arities_length cases =
175    let set =
176      List.fold_left
177        (fun set (ps,_) -> IntSet.add (List.length ps) set)
178        IntSet.empty cases in
179    IntSet.cardinal set
180
181  let best_col =
182    let rec do_rec kbest best k = function
183      | [] -> kbest
184      | x::xs ->
185          if x < best then
186            do_rec k x (k+1) xs
187          else
188            do_rec kbest best (k+1) xs in
189    let smallest = do_rec (-1) max_int 0 in
190    fun cases ->
191      let ars = count_arities cases in
192      smallest ars
193
194  let swap_list =
195    let rec do_rec k xs = match xs with
196    | [] -> assert false
197    | x::xs ->
198        if k <= 0 then [],x,xs
199        else
200          let xs,mid,ys = do_rec (k-1) xs in
201          x::xs,mid,ys in
202    fun k xs ->
203      let xs,x,ys = do_rec  k xs in
204      x::xs @ ys
205
206  let swap k idxs cases =
207    if k = 0 then idxs,cases
208    else
209      let idxs = swap_list k idxs
210      and cases =
211        List.map
212          (fun (ps,act) -> swap_list k ps,act)
213          cases in
214      if dbg then begin
215        pp_match stderr "SWAP" idxs cases
216      end ;
217      idxs,cases
218
219  let best_first idxs cases = match idxs with
220  | []|[_] -> idxs,cases (* optimisation: one column only *)
221  | _ ->
222      let k = best_col cases in
223      swap k idxs cases
224
225(************************************)
226(* Divide according to first column *)
227(************************************)
228
229  module Divide(O:Set.OrderedType) = struct
230
231    module OMap = Map.Make(O)
232
233    let divide cases =
234      let env =
235        List.fold_left
236          (fun env (p,psact) ->
237            let old =
238              try OMap.find p env
239              with Not_found -> [] in
240            OMap.add p ((psact)::old) env)
241          OMap.empty cases in
242      let r =  OMap.fold (fun key v k -> (key,v)::k) env [] in
243      List.rev r (* Now sorted *)
244  end
245
246(***************)
247(* Compilation *)
248(***************)
249
250(* Group by cell *)
251
252    module DivideNative = Divide(Nativeint)
253
254    let by_cell cases =
255      DivideNative.divide
256        (List.map
257           (fun case -> match case with
258           | (p::ps),act -> p,(ps,act)
259           | [],_ -> assert false)
260           cases)
261
262(* Split into two halves *)
263
264    let rec do_split idx env = match env with
265    | [] -> assert false
266    | (midkey,_ as x)::rem ->
267        if idx <= 0 then [],midkey,env
268        else
269          let lt,midkey,ge = do_split (idx-1) rem in
270          x::lt,midkey,ge
271
272    let split_env len env = do_split (len/2) env
273
274(* Switch according to one cell *)
275
276(*
277  Emit the switch, here as a comparison tree.
278  Argument compile_rec is to be called to compile the rest of patterns,
279  as match_on_cell can be called in two different contexts :
280  from do_compile_pats and top_compile below.
281 *)
282    let match_oncell compile_rec str default idx env =
283      let id = gen_cell_id () in
284      let rec comp_rec env =
285        let len = List.length env in
286        if len <= 3 then
287          List.fold_right
288            (fun (key,cases) ifnot ->
289              mk_eq id key
290                (compile_rec str default cases)
291              ifnot)
292            env default
293        else
294          let lt,midkey,ge = split_env len env in
295          mk_lt id midkey (comp_rec lt) (comp_rec ge) in
296      mk_let_cell id str idx (comp_rec env)
297
298(*
299  Recursive 'list of cells' compile function:
300  - choose the matched cell and switch on it
301  - notice: patterns (and idx) all have the same length
302 *)
303
304    let rec do_compile_pats idxs str default cases =
305      if dbg then begin
306        pp_match stderr "COMPILE" idxs cases
307      end ;
308      match idxs with
309      | [] ->
310          begin match cases with
311          | [] -> default
312          | (_,e)::_ -> e
313          end
314      | _::_ ->
315          let idxs,cases = best_first idxs cases in
316          begin match idxs with
317          | [] -> assert false
318          | idx::idxs ->
319              match_oncell
320                (do_compile_pats idxs) str default idx (by_cell cases)
321          end
322
323
324(* Group by size *)
325
326    module DivideInt = Divide(IntArg)
327
328
329    let by_size cases =
330      DivideInt.divide
331        (List.map
332           (fun (ps,_ as case) -> List.length ps,case)
333           cases)
334(*
335  Switch according to pattern size
336  Argument from_ind is the starting index, it can be zero
337  or one (when the swicth on the cell 0 has already been performed.
338  In that latter case pattern len is string length-1 and is corrected.
339 *)
340
341    let compile_by_size dbg from_ind str default cases =
342      let size_cases =
343        List.map
344          (fun (len,cases) ->
345            let len = len+from_ind in
346            let act =
347              do_compile_pats
348                (interval from_ind len)
349                str default  cases in
350            (len,act))
351          (by_size cases) in
352      let id = gen_size_id () in
353      ignore dbg;
354      let switch = I.transl_switch (Cvar id) 1 max_int size_cases default in
355      mk_let_size id str switch
356
357(*
358  Compilation entry point: we choose to switch
359  either on size or on first cell, using the
360  'least discriminant' heuristics.
361 *)
362    let top_compile debuginfo str default cases =
363      let a_len = count_arities_length cases
364      and a_fst = count_arities_first cases in
365      if a_len <= a_fst then begin
366        if dbg then pp_cases stderr "SIZE" cases ;
367        compile_by_size debuginfo 0 str default cases
368      end else begin
369        if dbg then pp_cases stderr "FIRST COL" cases ;
370        let compile_size_rest str default cases =
371          compile_by_size debuginfo 1 str default cases in
372        match_oncell compile_size_rest str default 0 (by_cell cases)
373      end
374
375(* Module entry point *)
376
377    let catch arg k = match arg with
378    | Cexit (_e,[]) ->  k arg
379    | _ ->
380        let e =  next_raise_count () in
381        ccatch (e,[],k (Cexit (e,[])),arg)
382
383    let compile dbg str default cases =
384(* We do not attempt to really optimise default=None *)
385      let cases,default = match cases,default with
386      | (_,e)::cases,None
387      | cases,Some e -> cases,e
388      | [],None -> assert false in
389      let cases =
390        List.rev_map
391          (fun (s,act) -> pat_of_string s,act)
392          cases in
393      catch default (fun default -> top_compile dbg str default cases)
394
395  end
396