1(*
2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
3 * Copyright (c) 2003, 2007-14 Matteo Frigo
4 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
19 *
20 *)
21
22
23(* This is the part of the generator that actually computes the FFT
24   in symbolic form *)
25
26open Complex
27open Util
28
29(* choose a suitable factor of n *)
30let choose_factor n =
31  (* first choice: i such that gcd(i, n / i) = 1, i as big as possible *)
32  let choose1 n =
33    let rec loop i f =
34      if (i * i > n) then f
35      else if ((n mod i) == 0 && gcd i (n / i) == 1) then loop (i + 1) i
36      else loop (i + 1) f
37    in loop 1 1
38
39  (* second choice: the biggest factor i of n, where i < sqrt(n), if any *)
40  and choose2 n =
41    let rec loop i f =
42      if (i * i > n) then f
43      else if ((n mod i) == 0) then loop (i + 1) i
44      else loop (i + 1) f
45    in loop 1 1
46
47  in let i = choose1 n in
48  if (i > 1) then i
49  else choose2 n
50
51let is_power_of_two n = (n > 0) && ((n - 1) land n == 0)
52
53let rec dft_prime sign n input =
54  let sum filter i =
55    sigma 0 n (fun j ->
56      let coeff = filter (exp n (sign * i * j))
57      in coeff @* (input j)) in
58  let computation_even = array n (sum identity)
59  and computation_odd =
60    let sumr = array n (sum real)
61    and sumi = array n (sum ((times Complex.i) @@ imag)) in
62    array n (fun i ->
63      if (i = 0) then
64	(* expose some common subexpressions *)
65	input 0 @+
66	sigma 1 ((n + 1) / 2) (fun j -> input j @+ input (n - j))
67      else
68	let i' = min i (n - i) in
69	if (i < n - i) then
70	  sumr i' @+ sumi i'
71	else
72	  sumr i' @- sumi i') in
73  if (n >= !Magic.rader_min) then
74    dft_rader sign n input
75  else if (n == 2) then
76    computation_even
77  else
78    computation_odd
79
80
81and dft_rader sign p input =
82  let half =
83    let one_half = inverse_int 2 in
84    times one_half
85
86  and make_product n a b =
87    let scale_factor = inverse_int n in
88    array n (fun i -> a i @* (scale_factor @* b i)) in
89
90  (* generates a convolution using ffts.  (all arguments are the
91     same as to gen_convolution, below) *)
92  let gen_convolution_by_fft n a b addtoall =
93    let fft_a = dft 1 n a
94    and fft_b = dft 1 n b in
95
96    let fft_ab = make_product n fft_a fft_b
97    and dc_term i = if (i == 0) then addtoall else zero in
98
99    let fft_ab1 = array n (fun i -> fft_ab i @+ dc_term i)
100    and sum = fft_a 0 in
101    let conv = dft (-1) n fft_ab1 in
102    (sum, conv)
103
104  (* alternate routine for convolution.  Seems to work better for
105     small sizes.  I have no idea why. *)
106  and gen_convolution_by_fft_alt n a b addtoall =
107    let ap = array n (fun i -> half (a i @+ a ((n - i) mod n)))
108    and am = array n (fun i -> half (a i @- a ((n - i) mod n)))
109    and bp = array n (fun i -> half (b i @+ b ((n - i) mod n)))
110    and bm = array n (fun i -> half (b i @- b ((n - i) mod n)))
111    in
112
113    let fft_ap = dft 1 n ap
114    and fft_am = dft 1 n am
115    and fft_bp = dft 1 n bp
116    and fft_bm = dft 1 n bm in
117
118    let fft_abpp = make_product n fft_ap fft_bp
119    and fft_abpm = make_product n fft_ap fft_bm
120    and fft_abmp = make_product n fft_am fft_bp
121    and fft_abmm = make_product n fft_am fft_bm
122    and sum = fft_ap 0 @+ fft_am 0
123    and dc_term i = if (i == 0) then addtoall else zero in
124
125    let fft_ab1 = array n (fun i -> (fft_abpp i @+ fft_abmm i) @+ dc_term i)
126    and fft_ab2 = array n (fun i -> fft_abpm i @+ fft_abmp i) in
127    let conv1 = dft (-1) n fft_ab1
128    and conv2 = dft (-1) n fft_ab2 in
129    let conv = array n (fun i ->
130      conv1 i @+ conv2 i) in
131    (sum, conv)
132
133    (* generator of assignment list assigning conv to the convolution of
134       a and b, all of which are of length n.  addtoall is added to
135       all of the elements of the result.  Returns (sum, convolution) pair
136       where sum is the sum of the elements of a. *)
137
138  in let gen_convolution =
139    if (p <= !Magic.alternate_convolution) then
140      gen_convolution_by_fft_alt
141    else
142      gen_convolution_by_fft
143
144  (* fft generator for prime n = p using Rader's algorithm for
145     turning the fft into a convolution, which then can be
146     performed in a variety of ways *)
147  in
148    let g = find_generator p in
149    let ginv = pow_mod g (p - 2) p in
150    let input_perm = array p (fun i -> input (pow_mod g i p))
151    and omega_perm = array p (fun i -> exp p (sign * (pow_mod ginv i p)))
152    and output_perm = array p (fun i -> pow_mod ginv i p)
153    in let (sum, conv) =
154      (gen_convolution (p - 1)  input_perm omega_perm (input 0))
155    in array p (fun i ->
156      if (i = 0) then
157	input 0 @+ sum
158      else
159	let i' = suchthat 0 (fun i' -> i = output_perm i')
160	in conv i')
161
162(* our modified version of the conjugate-pair split-radix algorithm,
163   which reduces the number of multiplications by rescaling the
164   sub-transforms (power-of-two n's only) *)
165and newsplit sign n input =
166  let rec s n k = (* recursive scale factor *)
167    if n <= 4 then
168      one
169    else
170      let k4 = (abs k) mod (n / 4) in
171      let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in
172      (s (n / 4) k4') @* (real (exp n k4'))
173
174  and sinv n k = (* 1 / s(n,k) *)
175    if n <= 4 then
176      one
177    else
178      let k4 = (abs k) mod (n / 4) in
179      let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in
180      (sinv (n / 4) k4') @* (sec n k4')
181
182  in let sdiv2 n k = (s n k) @* (sinv (2*n) k) (* s(n,k) / s(2*n,k) *)
183  and sdiv4 n k = (* s(n,k) / s(4*n,k) *)
184    let k4 = (abs k) mod n in
185    sec (4*n) (if k4 <= (n / 2) then k4 else (n - k4))
186
187  in let t n k = (exp n k) @* (sdiv4 (n/4) k)
188
189  and dft1 input = input
190  and dft2 input = array 2 (fun k -> (input 0) @+ ((input 1) @* exp 2 k))
191
192  in let rec newsplit0 sign n input =
193    if (n == 1) then dft1 input
194    else if (n == 2) then dft2 input
195    else let u = newsplit0 sign (n / 2) (fun i -> input (i*2))
196    and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
197    and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n))
198    and twid = array n (fun k -> s (n/4) k @* exp n (sign * k)) in
199    let w = array n (fun k -> twid k @* z (k mod (n / 4)))
200    and w' = array n (fun k -> conj (twid k) @* z' (k mod (n / 4))) in
201    let ww = array n (fun k -> w k @+ w' k) in
202    array n (fun k -> u (k mod (n / 2)) @+ ww k)
203
204  and newsplitS sign n input =
205    if (n == 1) then dft1 input
206    else if (n == 2) then dft2 input
207    else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2))
208    and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
209    and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in
210    let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4)))
211    and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in
212    let ww = array n (fun k -> w k @+ w' k) in
213    array n (fun k -> u (k mod (n / 2)) @+ ww k)
214
215  and newsplitS2 sign n input =
216    if (n == 1) then dft1 input
217    else if (n == 2) then dft2 input
218    else let u = newsplitS4 sign (n / 2) (fun i -> input (i*2))
219    and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
220    and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in
221    let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4)))
222    and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in
223    let ww = array n (fun k -> (w k @+ w' k) @* (sdiv2 n k)) in
224    array n (fun k -> u (k mod (n / 2)) @+ ww k)
225
226  and newsplitS4 sign n input =
227    if (n == 1) then dft1 input
228    else if (n == 2) then
229      let f = dft2 input
230      in array 2 (fun k -> (f k) @* (sinv 8 k))
231    else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2))
232    and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
233    and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in
234    let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4)))
235    and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in
236    let ww = array n (fun k -> w k @+ w' k) in
237    array n (fun k -> (u (k mod (n / 2)) @+ ww k) @* (sdiv4 n k))
238
239  in newsplit0 sign n input
240
241and dft sign n input =
242  let rec cooley_tukey sign n1 n2 input =
243    let tmp1 =
244      array n2 (fun i2 ->
245	dft sign n1 (fun i1 -> input (i1 * n2 + i2))) in
246    let tmp2 =
247      array n1 (fun i1 ->
248	array n2 (fun i2 ->
249	  exp n (sign * i1 * i2) @* tmp1 i2 i1)) in
250    let tmp3 = array n1 (fun i1 -> dft sign n2 (tmp2 i1)) in
251    (fun i -> tmp3 (i mod n1) (i / n1))
252
253  (*
254   * This is "exponent -1" split-radix by Dan Bernstein.
255   *)
256  and split_radix_dit sign n input =
257    let f0 = dft sign (n / 2) (fun i -> input (i * 2))
258    and f10 = dft sign (n / 4) (fun i -> input (i * 4 + 1))
259    and f11 = dft sign (n / 4) (fun i -> input ((n + i * 4 - 1) mod n)) in
260    let g10 = array n (fun k ->
261      exp n (sign * k) @* f10 (k mod (n / 4)))
262    and g11 = array n (fun k ->
263      exp n (- sign * k) @* f11 (k mod (n / 4))) in
264    let g1 = array n (fun k -> g10 k @+ g11 k) in
265    array n (fun k -> f0 (k mod (n / 2)) @+ g1 k)
266
267  and split_radix_dif sign n input =
268    let n2 = n / 2 and n4 = n / 4 in
269    let x0 = array n2 (fun i -> input i @+ input (i + n2))
270    and x10 = array n4 (fun i -> input i @- input (i + n2))
271    and x11 = array n4 (fun i ->
272	input (i + n4) @- input (i + n2 + n4)) in
273    let x1 k i =
274      exp n (k * i * sign) @* (x10 i @+ exp 4 (k * sign) @* x11 i) in
275    let f0 = dft sign n2 x0
276    and f1 = array 4 (fun k -> dft sign n4 (x1 k)) in
277    array n (fun k ->
278      if k mod 2 = 0 then f0 (k / 2)
279      else let k' = k mod 4 in f1 k' ((k - k') / 4))
280
281  and prime_factor sign n1 n2 input =
282    let tmp1 = array n2 (fun i2 ->
283      dft sign n1 (fun i1 -> input ((i1 * n2 + i2 * n1) mod n)))
284    in let tmp2 = array n1 (fun i1 ->
285      dft sign n2 (fun k2 -> tmp1 k2 i1))
286    in fun i -> tmp2 (i mod n1) (i mod n2)
287
288  in let algorithm sign n =
289    let r = choose_factor n in
290    if List.mem n !Magic.rader_list then
291      (* special cases *)
292      dft_rader sign n
293    else if (r == 1) then  (* n is prime *)
294      dft_prime sign n
295    else if (gcd r (n / r)) == 1 then
296      prime_factor sign r (n / r)
297    else if (n mod 4 = 0 && n > 4) then
298      if !Magic.newsplit && is_power_of_two n then
299	newsplit sign n
300      else if !Magic.dif_split_radix then
301	split_radix_dif sign n
302      else
303	split_radix_dit sign n
304    else
305      cooley_tukey sign r (n / r)
306  in
307  array n (algorithm sign n input)
308