1(*
2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
3 * Copyright (c) 2003, 2007-14 Matteo Frigo
4 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
19 *
20 *)
21
22(* generation of trigonometric transforms *)
23
24open Util
25open Genutil
26open C
27
28
29let usage = "Usage: " ^ Sys.argv.(0) ^ " -n <number>"
30
31let uistride = ref Stride_variable
32let uostride = ref Stride_variable
33let uivstride = ref Stride_variable
34let uovstride = ref Stride_variable
35let normalization = ref 1
36
37type mode =
38  | MDCT
39  | MDCT_MP3
40  | MDCT_VORBIS
41  | MDCT_WINDOW
42  | MDCT_WINDOW_SYM
43  | IMDCT
44  | IMDCT_MP3
45  | IMDCT_VORBIS
46  | IMDCT_WINDOW
47  | IMDCT_WINDOW_SYM
48  | NONE
49
50let mode = ref NONE
51
52let speclist = [
53  "-with-istride",
54  Arg.String(fun x -> uistride := arg_to_stride x),
55  " specialize for given input stride";
56
57  "-with-ostride",
58  Arg.String(fun x -> uostride := arg_to_stride x),
59  " specialize for given output stride";
60
61  "-with-ivstride",
62  Arg.String(fun x -> uivstride := arg_to_stride x),
63  " specialize for given input vector stride";
64
65  "-with-ovstride",
66  Arg.String(fun x -> uovstride := arg_to_stride x),
67  " specialize for given output vector stride";
68
69  "-normalization",
70  Arg.String(fun x -> normalization := int_of_string x),
71  " normalization integer to divide by";
72
73  "-mdct",
74  Arg.Unit(fun () -> mode := MDCT),
75  " generate an MDCT codelet";
76
77  "-mdct-mp3",
78  Arg.Unit(fun () -> mode := MDCT_MP3),
79  " generate an MDCT codelet with MP3 windowing";
80
81  "-mdct-window",
82  Arg.Unit(fun () -> mode := MDCT_WINDOW),
83  " generate an MDCT codelet with window array";
84
85  "-mdct-window-sym",
86  Arg.Unit(fun () -> mode := MDCT_WINDOW_SYM),
87  " generate an MDCT codelet with symmetric window array";
88
89  "-imdct",
90  Arg.Unit(fun () -> mode := IMDCT),
91  " generate an IMDCT codelet";
92
93  "-imdct-mp3",
94  Arg.Unit(fun () -> mode := IMDCT_MP3),
95  " generate an IMDCT codelet with MP3 windowing";
96
97  "-imdct-window",
98  Arg.Unit(fun () -> mode := IMDCT_WINDOW),
99  " generate an IMDCT codelet with window array";
100
101  "-imdct-window-sym",
102  Arg.Unit(fun () -> mode := IMDCT_WINDOW_SYM),
103  " generate an IMDCT codelet with symmetric window array";
104]
105
106let unity_window n i = Complex.one
107
108(* MP3 window(k) = sin(pi/(2n) * (k + 1/2)) *)
109let mp3_window n k =
110  Complex.imag (Complex.exp (8 * n) (2*k + 1))
111
112(* Vorbis window(k) = sin(pi/2 * (mp3_window(k))^2)
113    ... this is transcendental, though, so we can't do it with our
114        current Complex.exp function *)
115
116let window_array n w =
117    array n (fun i ->
118      let stride = C.SInteger 1
119      and klass = Unique.make () in
120      let refr = C.array_subscript w stride i in
121      let kr = Variable.make_constant klass refr in
122      load_r (kr, kr))
123
124let load_window w n i = w i
125let load_window_sym w n i = w (if (i < n) then i else (2*n - 1 - i))
126
127(* fixme: use same locations for input and output so that it works in-place? *)
128
129(* Note: only correct for even n! *)
130let load_array_mdct window n rarr iarr locations =
131  let twon = 2 * n in
132  let arr = load_array_c twon
133      (locative_array_c twon rarr iarr locations "BUG") in
134  let arrw = fun i -> Complex.times (window n i) (arr i) in
135  array n
136    ((Complex.times Complex.half) @@
137     (fun i ->
138       if (i < n/2) then
139	 Complex.uminus (Complex.plus [arrw (i + n + n/2);
140				       arrw (n + n/2 - 1 - i)])
141       else
142	 Complex.plus [arrw (i - n/2);
143		       Complex.uminus (arrw (n + n/2 - 1 - i))]))
144
145let store_array_mdct window n rarr iarr locations arr =
146  store_array_r n (locative_array_c n rarr iarr locations "BUG") arr
147
148let load_array_imdct window n rarr iarr locations =
149  load_array_c n (locative_array_c n rarr iarr locations "BUG")
150
151let store_array_imdct window n rarr iarr locations arr =
152  let n2 = n/2 in
153  let threen2 = 3*n2 in
154  let arr2 = fun i ->
155    if (i < n2) then
156      arr (i + n2)
157    else if (i < threen2) then
158      Complex.uminus (arr (threen2 - 1 - i))
159    else
160      Complex.uminus (arr (i - threen2))
161  in
162  let arr2w = fun i -> Complex.times (window n i) (arr2 i) in
163  let twon = 2 * n in
164  store_array_r twon (locative_array_c twon rarr iarr locations "BUG") arr2w
165
166let window_param = function
167    MDCT_WINDOW -> true
168  | MDCT_WINDOW_SYM -> true
169  | IMDCT_WINDOW -> true
170  | IMDCT_WINDOW_SYM -> true
171  | _ -> false
172
173let generate n mode =
174  let iarray = "I"
175  and oarray = "O"
176  and istride = "istride"
177  and ostride = "ostride"
178  and window = "W"
179  and name = !Magic.codelet_name in
180
181  let vistride = either_stride (!uistride) (C.SVar istride)
182  and vostride = either_stride (!uostride) (C.SVar ostride)
183  in
184
185  let sivs = stride_to_string "ovs" !uovstride in
186  let sovs = stride_to_string "ivs" !uivstride in
187
188  let (transform, load_input, store_output) = match mode with
189  | MDCT -> Trig.dctIV, load_array_mdct unity_window,
190      store_array_mdct unity_window
191  | MDCT_MP3 -> Trig.dctIV, load_array_mdct mp3_window,
192      store_array_mdct unity_window
193  | MDCT_WINDOW -> Trig.dctIV, load_array_mdct
194	(load_window (window_array (2 * n) window)),
195      store_array_mdct unity_window
196  | MDCT_WINDOW_SYM -> Trig.dctIV, load_array_mdct
197	(load_window_sym (window_array n window)),
198      store_array_mdct unity_window
199  | IMDCT -> Trig.dctIV, load_array_imdct unity_window,
200      store_array_imdct unity_window
201  | IMDCT_MP3 -> Trig.dctIV, load_array_imdct unity_window,
202      store_array_imdct mp3_window
203  | IMDCT_WINDOW -> Trig.dctIV, load_array_imdct unity_window,
204      store_array_imdct (load_window (window_array (2 * n) window))
205  | IMDCT_WINDOW_SYM -> Trig.dctIV, load_array_imdct unity_window,
206      store_array_imdct (load_window_sym (window_array n window))
207  | _ -> failwith "must specify transform kind"
208  in
209
210  let locations = unique_array_c (2*n) in
211  let input =
212    load_input n
213      (C.array_subscript iarray vistride)
214      (C.array_subscript "BUG" vistride)
215      locations
216  in
217  let output = (Complex.times (Complex.inverse_int !normalization))
218    @@ (transform n input) in
219  let odag =
220    store_output n
221      (C.array_subscript oarray vostride)
222      (C.array_subscript "BUG" vostride)
223      locations
224      output
225  in
226  let annot = standard_optimizer odag in
227
228  let tree =
229    Fcn ("void", name,
230	 ([Decl (C.constrealtypep, iarray);
231	   Decl (C.realtypep, oarray)]
232	  @ (if stride_fixed !uistride then []
233               else [Decl (C.stridetype, istride)])
234	  @ (if stride_fixed !uostride then []
235	       else [Decl (C.stridetype, ostride)])
236	  @ (choose_simd []
237	       (if stride_fixed !uivstride then [] else
238	       [Decl ("int", sivs)]))
239	  @ (choose_simd []
240	       (if stride_fixed !uovstride then [] else
241	       [Decl ("int", sovs)]))
242	  @ (if (not (window_param mode)) then []
243	       else [Decl (C.constrealtypep, window)])
244	 ),
245	 finalize_fcn (Asch annot))
246
247  in
248  (unparse tree) ^ "\n"
249
250
251let main () =
252  begin
253    parse speclist usage;
254    print_string (generate (check_size ()) !mode);
255  end
256
257let _ = main()
258