1(* 2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology 3 * Copyright (c) 2003, 2007-14 Matteo Frigo 4 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology 5 * 6 * This program is free software; you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License as published by 8 * the Free Software Foundation; either version 2 of the License, or 9 * (at your option) any later version. 10 * 11 * This program is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU General Public License for more details. 15 * 16 * You should have received a copy of the GNU General Public License 17 * along with this program; if not, write to the Free Software 18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 19 * 20 *) 21 22 23(* This is the part of the generator that actually computes the FFT 24 in symbolic form *) 25 26open Complex 27open Util 28 29(* choose a suitable factor of n *) 30let choose_factor n = 31 (* first choice: i such that gcd(i, n / i) = 1, i as big as possible *) 32 let choose1 n = 33 let rec loop i f = 34 if (i * i > n) then f 35 else if ((n mod i) == 0 && gcd i (n / i) == 1) then loop (i + 1) i 36 else loop (i + 1) f 37 in loop 1 1 38 39 (* second choice: the biggest factor i of n, where i < sqrt(n), if any *) 40 and choose2 n = 41 let rec loop i f = 42 if (i * i > n) then f 43 else if ((n mod i) == 0) then loop (i + 1) i 44 else loop (i + 1) f 45 in loop 1 1 46 47 in let i = choose1 n in 48 if (i > 1) then i 49 else choose2 n 50 51let is_power_of_two n = (n > 0) && ((n - 1) land n == 0) 52 53let rec dft_prime sign n input = 54 let sum filter i = 55 sigma 0 n (fun j -> 56 let coeff = filter (exp n (sign * i * j)) 57 in coeff @* (input j)) in 58 let computation_even = array n (sum identity) 59 and computation_odd = 60 let sumr = array n (sum real) 61 and sumi = array n (sum ((times Complex.i) @@ imag)) in 62 array n (fun i -> 63 if (i = 0) then 64 (* expose some common subexpressions *) 65 input 0 @+ 66 sigma 1 ((n + 1) / 2) (fun j -> input j @+ input (n - j)) 67 else 68 let i' = min i (n - i) in 69 if (i < n - i) then 70 sumr i' @+ sumi i' 71 else 72 sumr i' @- sumi i') in 73 if (n >= !Magic.rader_min) then 74 dft_rader sign n input 75 else if (n == 2) then 76 computation_even 77 else 78 computation_odd 79 80 81and dft_rader sign p input = 82 let half = 83 let one_half = inverse_int 2 in 84 times one_half 85 86 and make_product n a b = 87 let scale_factor = inverse_int n in 88 array n (fun i -> a i @* (scale_factor @* b i)) in 89 90 (* generates a convolution using ffts. (all arguments are the 91 same as to gen_convolution, below) *) 92 let gen_convolution_by_fft n a b addtoall = 93 let fft_a = dft 1 n a 94 and fft_b = dft 1 n b in 95 96 let fft_ab = make_product n fft_a fft_b 97 and dc_term i = if (i == 0) then addtoall else zero in 98 99 let fft_ab1 = array n (fun i -> fft_ab i @+ dc_term i) 100 and sum = fft_a 0 in 101 let conv = dft (-1) n fft_ab1 in 102 (sum, conv) 103 104 (* alternate routine for convolution. Seems to work better for 105 small sizes. I have no idea why. *) 106 and gen_convolution_by_fft_alt n a b addtoall = 107 let ap = array n (fun i -> half (a i @+ a ((n - i) mod n))) 108 and am = array n (fun i -> half (a i @- a ((n - i) mod n))) 109 and bp = array n (fun i -> half (b i @+ b ((n - i) mod n))) 110 and bm = array n (fun i -> half (b i @- b ((n - i) mod n))) 111 in 112 113 let fft_ap = dft 1 n ap 114 and fft_am = dft 1 n am 115 and fft_bp = dft 1 n bp 116 and fft_bm = dft 1 n bm in 117 118 let fft_abpp = make_product n fft_ap fft_bp 119 and fft_abpm = make_product n fft_ap fft_bm 120 and fft_abmp = make_product n fft_am fft_bp 121 and fft_abmm = make_product n fft_am fft_bm 122 and sum = fft_ap 0 @+ fft_am 0 123 and dc_term i = if (i == 0) then addtoall else zero in 124 125 let fft_ab1 = array n (fun i -> (fft_abpp i @+ fft_abmm i) @+ dc_term i) 126 and fft_ab2 = array n (fun i -> fft_abpm i @+ fft_abmp i) in 127 let conv1 = dft (-1) n fft_ab1 128 and conv2 = dft (-1) n fft_ab2 in 129 let conv = array n (fun i -> 130 conv1 i @+ conv2 i) in 131 (sum, conv) 132 133 (* generator of assignment list assigning conv to the convolution of 134 a and b, all of which are of length n. addtoall is added to 135 all of the elements of the result. Returns (sum, convolution) pair 136 where sum is the sum of the elements of a. *) 137 138 in let gen_convolution = 139 if (p <= !Magic.alternate_convolution) then 140 gen_convolution_by_fft_alt 141 else 142 gen_convolution_by_fft 143 144 (* fft generator for prime n = p using Rader's algorithm for 145 turning the fft into a convolution, which then can be 146 performed in a variety of ways *) 147 in 148 let g = find_generator p in 149 let ginv = pow_mod g (p - 2) p in 150 let input_perm = array p (fun i -> input (pow_mod g i p)) 151 and omega_perm = array p (fun i -> exp p (sign * (pow_mod ginv i p))) 152 and output_perm = array p (fun i -> pow_mod ginv i p) 153 in let (sum, conv) = 154 (gen_convolution (p - 1) input_perm omega_perm (input 0)) 155 in array p (fun i -> 156 if (i = 0) then 157 input 0 @+ sum 158 else 159 let i' = suchthat 0 (fun i' -> i = output_perm i') 160 in conv i') 161 162(* our modified version of the conjugate-pair split-radix algorithm, 163 which reduces the number of multiplications by rescaling the 164 sub-transforms (power-of-two n's only) *) 165and newsplit sign n input = 166 let rec s n k = (* recursive scale factor *) 167 if n <= 4 then 168 one 169 else 170 let k4 = (abs k) mod (n / 4) in 171 let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in 172 (s (n / 4) k4') @* (real (exp n k4')) 173 174 and sinv n k = (* 1 / s(n,k) *) 175 if n <= 4 then 176 one 177 else 178 let k4 = (abs k) mod (n / 4) in 179 let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in 180 (sinv (n / 4) k4') @* (sec n k4') 181 182 in let sdiv2 n k = (s n k) @* (sinv (2*n) k) (* s(n,k) / s(2*n,k) *) 183 and sdiv4 n k = (* s(n,k) / s(4*n,k) *) 184 let k4 = (abs k) mod n in 185 sec (4*n) (if k4 <= (n / 2) then k4 else (n - k4)) 186 187 in let t n k = (exp n k) @* (sdiv4 (n/4) k) 188 189 and dft1 input = input 190 and dft2 input = array 2 (fun k -> (input 0) @+ ((input 1) @* exp 2 k)) 191 192 in let rec newsplit0 sign n input = 193 if (n == 1) then dft1 input 194 else if (n == 2) then dft2 input 195 else let u = newsplit0 sign (n / 2) (fun i -> input (i*2)) 196 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) 197 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) 198 and twid = array n (fun k -> s (n/4) k @* exp n (sign * k)) in 199 let w = array n (fun k -> twid k @* z (k mod (n / 4))) 200 and w' = array n (fun k -> conj (twid k) @* z' (k mod (n / 4))) in 201 let ww = array n (fun k -> w k @+ w' k) in 202 array n (fun k -> u (k mod (n / 2)) @+ ww k) 203 204 and newsplitS sign n input = 205 if (n == 1) then dft1 input 206 else if (n == 2) then dft2 input 207 else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2)) 208 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) 209 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in 210 let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4))) 211 and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in 212 let ww = array n (fun k -> w k @+ w' k) in 213 array n (fun k -> u (k mod (n / 2)) @+ ww k) 214 215 and newsplitS2 sign n input = 216 if (n == 1) then dft1 input 217 else if (n == 2) then dft2 input 218 else let u = newsplitS4 sign (n / 2) (fun i -> input (i*2)) 219 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) 220 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in 221 let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4))) 222 and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in 223 let ww = array n (fun k -> (w k @+ w' k) @* (sdiv2 n k)) in 224 array n (fun k -> u (k mod (n / 2)) @+ ww k) 225 226 and newsplitS4 sign n input = 227 if (n == 1) then dft1 input 228 else if (n == 2) then 229 let f = dft2 input 230 in array 2 (fun k -> (f k) @* (sinv 8 k)) 231 else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2)) 232 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) 233 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in 234 let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4))) 235 and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in 236 let ww = array n (fun k -> w k @+ w' k) in 237 array n (fun k -> (u (k mod (n / 2)) @+ ww k) @* (sdiv4 n k)) 238 239 in newsplit0 sign n input 240 241and dft sign n input = 242 let rec cooley_tukey sign n1 n2 input = 243 let tmp1 = 244 array n2 (fun i2 -> 245 dft sign n1 (fun i1 -> input (i1 * n2 + i2))) in 246 let tmp2 = 247 array n1 (fun i1 -> 248 array n2 (fun i2 -> 249 exp n (sign * i1 * i2) @* tmp1 i2 i1)) in 250 let tmp3 = array n1 (fun i1 -> dft sign n2 (tmp2 i1)) in 251 (fun i -> tmp3 (i mod n1) (i / n1)) 252 253 (* 254 * This is "exponent -1" split-radix by Dan Bernstein. 255 *) 256 and split_radix_dit sign n input = 257 let f0 = dft sign (n / 2) (fun i -> input (i * 2)) 258 and f10 = dft sign (n / 4) (fun i -> input (i * 4 + 1)) 259 and f11 = dft sign (n / 4) (fun i -> input ((n + i * 4 - 1) mod n)) in 260 let g10 = array n (fun k -> 261 exp n (sign * k) @* f10 (k mod (n / 4))) 262 and g11 = array n (fun k -> 263 exp n (- sign * k) @* f11 (k mod (n / 4))) in 264 let g1 = array n (fun k -> g10 k @+ g11 k) in 265 array n (fun k -> f0 (k mod (n / 2)) @+ g1 k) 266 267 and split_radix_dif sign n input = 268 let n2 = n / 2 and n4 = n / 4 in 269 let x0 = array n2 (fun i -> input i @+ input (i + n2)) 270 and x10 = array n4 (fun i -> input i @- input (i + n2)) 271 and x11 = array n4 (fun i -> 272 input (i + n4) @- input (i + n2 + n4)) in 273 let x1 k i = 274 exp n (k * i * sign) @* (x10 i @+ exp 4 (k * sign) @* x11 i) in 275 let f0 = dft sign n2 x0 276 and f1 = array 4 (fun k -> dft sign n4 (x1 k)) in 277 array n (fun k -> 278 if k mod 2 = 0 then f0 (k / 2) 279 else let k' = k mod 4 in f1 k' ((k - k') / 4)) 280 281 and prime_factor sign n1 n2 input = 282 let tmp1 = array n2 (fun i2 -> 283 dft sign n1 (fun i1 -> input ((i1 * n2 + i2 * n1) mod n))) 284 in let tmp2 = array n1 (fun i1 -> 285 dft sign n2 (fun k2 -> tmp1 k2 i1)) 286 in fun i -> tmp2 (i mod n1) (i mod n2) 287 288 in let algorithm sign n = 289 let r = choose_factor n in 290 if List.mem n !Magic.rader_list then 291 (* special cases *) 292 dft_rader sign n 293 else if (r == 1) then (* n is prime *) 294 dft_prime sign n 295 else if (gcd r (n / r)) == 1 then 296 prime_factor sign r (n / r) 297 else if (n mod 4 = 0 && n > 4) then 298 if !Magic.newsplit && is_power_of_two n then 299 newsplit sign n 300 else if !Magic.dif_split_radix then 301 split_radix_dif sign n 302 else 303 split_radix_dit sign n 304 else 305 cooley_tukey sign r (n / r) 306 in 307 array n (algorithm sign n input) 308