1(*
2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
3 * Copyright (c) 2003, 2007-14 Matteo Frigo
4 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
19 *
20 *)
21
22(* generation of trigonometric transforms *)
23
24open Util
25open Genutil
26open C
27
28
29let usage = "Usage: " ^ Sys.argv.(0) ^ " -n <number>"
30
31let uistride = ref Stride_variable
32let uostride = ref Stride_variable
33let uivstride = ref Stride_variable
34let uovstride = ref Stride_variable
35
36type mode =
37  | RDFT
38  | HDFT
39  | DHT
40  | REDFT00
41  | REDFT10
42  | REDFT01
43  | REDFT11
44  | RODFT00
45  | RODFT10
46  | RODFT01
47  | RODFT11
48  | NONE
49
50let mode = ref NONE
51let normsqr = ref 1
52let unitary = ref false
53let noloop = ref false
54
55let speclist = [
56  "-with-istride",
57  Arg.String(fun x -> uistride := arg_to_stride x),
58  " specialize for given input stride";
59
60  "-with-ostride",
61  Arg.String(fun x -> uostride := arg_to_stride x),
62  " specialize for given output stride";
63
64  "-with-ivstride",
65  Arg.String(fun x -> uivstride := arg_to_stride x),
66  " specialize for given input vector stride";
67
68  "-with-ovstride",
69  Arg.String(fun x -> uovstride := arg_to_stride x),
70  " specialize for given output vector stride";
71
72  "-rdft",
73  Arg.Unit(fun () -> mode := RDFT),
74  " generate a real DFT codelet";
75
76  "-hdft",
77  Arg.Unit(fun () -> mode := HDFT),
78  " generate a Hermitian DFT codelet";
79
80  "-dht",
81  Arg.Unit(fun () -> mode := DHT),
82  " generate a DHT codelet";
83
84  "-redft00",
85  Arg.Unit(fun () -> mode := REDFT00),
86  " generate a DCT-I codelet";
87
88  "-redft10",
89  Arg.Unit(fun () -> mode := REDFT10),
90  " generate a DCT-II codelet";
91
92  "-redft01",
93  Arg.Unit(fun () -> mode := REDFT01),
94  " generate a DCT-III codelet";
95
96  "-redft11",
97  Arg.Unit(fun () -> mode := REDFT11),
98  " generate a DCT-IV codelet";
99
100  "-rodft00",
101  Arg.Unit(fun () -> mode := RODFT00),
102  " generate a DST-I codelet";
103
104  "-rodft10",
105  Arg.Unit(fun () -> mode := RODFT10),
106  " generate a DST-II codelet";
107
108  "-rodft01",
109  Arg.Unit(fun () -> mode := RODFT01),
110  " generate a DST-III codelet";
111
112  "-rodft11",
113  Arg.Unit(fun () -> mode := RODFT11),
114  " generate a DST-IV codelet";
115
116  "-normalization",
117  Arg.String(fun x -> let ix = int_of_string x in normsqr := ix * ix),
118  " normalization integer to divide by";
119
120  "-normsqr",
121  Arg.String(fun x -> normsqr := int_of_string x),
122  " integer square of normalization to divide by";
123
124  "-unitary",
125  Arg.Unit(fun () -> unitary := true),
126  " unitary normalization (up overall scale factor)";
127
128  "-noloop",
129  Arg.Unit(fun () -> noloop := true),
130  " no vector loop";
131]
132
133let sqrt_half = Complex.inverse_int_sqrt 2
134let sqrt_two = Complex.int_sqrt 2
135
136let rescale sc s1 s2 input i =
137  if ((i == s1 || i == s2) && !unitary) then
138    Complex.times (input i) sc
139  else
140    input i
141
142let generate n mode =
143  let iarray = "I"
144  and oarray = "O"
145  and istride = "is"
146  and ostride = "os"
147  and i = "i"
148  and v = "v"
149  in
150
151  let sign = !Genutil.sign
152  and name = !Magic.codelet_name in
153
154  let vistride = either_stride (!uistride) (C.SVar istride)
155  and vostride = either_stride (!uostride) (C.SVar ostride)
156  in
157
158  let sovs = stride_to_string "ovs" !uovstride in
159  let sivs = stride_to_string "ivs" !uivstride in
160
161  let (transform, load_input, store_output, si1,si2,so1,so2) = match mode with
162  | RDFT -> Trig.rdft sign, load_array_r, store_array_hc, -1,-1,-1,-1
163  | HDFT -> Trig.hdft sign, load_array_c, store_array_r, -1,-1,-1,-1 (* TODO *)
164  | DHT -> Trig.dht 1, load_array_r, store_array_r, -1,-1,-1,-1
165  | REDFT00 -> Trig.dctI, load_array_r, store_array_r, 0,n-1,0,n-1
166  | REDFT10 -> Trig.dctII, load_array_r, store_array_r, -1,-1,0,-1
167  | REDFT01 -> Trig.dctIII, load_array_r, store_array_r, 0,-1,-1,-1
168  | REDFT11 -> Trig.dctIV, load_array_r, store_array_r, -1,-1,-1,-1
169  | RODFT00 -> Trig.dstI, load_array_r, store_array_r, -1,-1,-1,-1
170  | RODFT10 -> Trig.dstII, load_array_r, store_array_r, -1,-1,n-1,-1
171  | RODFT01 -> Trig.dstIII, load_array_r, store_array_r, n-1,-1,-1,-1
172  | RODFT11 -> Trig.dstIV, load_array_r, store_array_r, -1,-1,-1,-1
173  | _ -> failwith "must specify transform kind"
174  in
175
176  let locations = unique_array_c n in
177  let input = locative_array_c n
178      (C.array_subscript iarray vistride)
179      (C.array_subscript "BUG" vistride)
180      locations sivs in
181  let output = rescale sqrt_half so1 so2
182      ((Complex.times (Complex.inverse_int_sqrt !normsqr))
183       @@ (transform n (rescale sqrt_two si1 si2 (load_array_c n input)))) in
184  let oloc =
185    locative_array_c n
186      (C.array_subscript oarray vostride)
187      (C.array_subscript "BUG" vostride)
188      locations sovs in
189  let odag = store_output n oloc output in
190  let annot = standard_optimizer odag in
191
192  let body = if !noloop then Block([], [Asch annot]) else Block (
193    [Decl ("INT", i)],
194    [For (Expr_assign (CVar i, CVar v),
195	  Binop (" > ", CVar i, Integer 0),
196	  list_to_comma
197	    [Expr_assign (CVar i, CPlus [CVar i; CUminus (Integer 1)]);
198	     Expr_assign (CVar iarray, CPlus [CVar iarray; CVar sivs]);
199	     Expr_assign (CVar oarray, CPlus [CVar oarray; CVar sovs]);
200	     make_volatile_stride (2*n) (CVar istride);
201	     make_volatile_stride (2*n) (CVar ostride)
202	   ],
203	  Asch annot)
204   ])
205  in
206
207  let tree =
208    Fcn ((if !Magic.standalone then "void" else "static void"), name,
209	 ([Decl (C.constrealtypep, iarray);
210	   Decl (C.realtypep, oarray)]
211	  @ (if stride_fixed !uistride then []
212               else [Decl (C.stridetype, istride)])
213	  @ (if stride_fixed !uostride then []
214	       else [Decl (C.stridetype, ostride)])
215	  @ (if !noloop then [] else
216               [Decl ("INT", v)]
217	       @ (if stride_fixed !uivstride then []
218                    else [Decl ("INT", "ivs")])
219	       @ (if stride_fixed !uovstride then []
220                    else [Decl ("INT", "ovs")]))),
221	 finalize_fcn body)
222
223  in let desc =
224    Printf.sprintf
225      "static const kr2r_desc desc = { %d, \"%s\", %s, &GENUS, %s };\n\n"
226      n name (flops_of tree)
227      (match mode with
228      | RDFT -> "RDFT00"
229      | HDFT -> "HDFT00"
230      | DHT  -> "DHT"
231      | REDFT00 -> "REDFT00"
232      | REDFT10 -> "REDFT10"
233      | REDFT01 -> "REDFT01"
234      | REDFT11 -> "REDFT11"
235      | RODFT00 -> "RODFT00"
236      | RODFT10 -> "RODFT10"
237      | RODFT01 -> "RODFT01"
238      | RODFT11 -> "RODFT11"
239      | _ -> failwith "must specify a transform kind")
240
241  and init =
242    (declare_register_fcn name) ^
243    "{" ^
244    "  X(kr2r_register)(p, " ^ name ^ ", &desc);\n" ^
245    "}\n"
246
247  in
248  (unparse tree) ^ "\n" ^ (if !Magic.standalone then "" else desc ^ init)
249
250
251let main () =
252  begin
253    parse speclist usage;
254    print_string (generate (check_size ()) !mode);
255  end
256
257let _ = main()
258