1 #ifndef HALIDE_FFT_H
2 #define HALIDE_FFT_H
3 
4 #include <algorithm>
5 #include <limits>
6 #include <string>
7 #include <vector>
8 
9 #include "Halide.h"
10 #include "complex.h"
11 
12 // This is an optional extra description for the details of computing an FFT.
13 struct Fft2dDesc {
14     // Gain to apply to the FFT. This is folded into gains already being applied
15     // to the FFT when possible.
16     Halide::Expr gain = 1.0f;
17 
18     // The following option specifies that a particular vector width should be
19     // used when the vector width can change the results of the FFT.
20     // Some parts of the FFT algorithm use the vector width to change the way
21     // floating point operations are ordered and grouped, which causes the results
22     // to vary with respect to the target architecture. Setting this option forces
23     // such stages to use the specified vector width (independent of the actual
24     // architecture's vector width), which eliminates the architecture specific
25     // behavior.
26     int vector_width = 0;
27 
28     // The following option indicates that the FFT should parallelize within a
29     // single FFT. This only makes sense to use on large FFTs, and generally only
30     // if there is no outer loop around FFTs that can be parallelized.
31     bool parallel = false;
32 
33     // This option will schedule the input to the FFT at the innermost location
34     // that makes sense.
35     bool schedule_input = false;
36 
37     // A name to prepend to the name of the Funcs the FFT defines.
38     std::string name = "";
39 };
40 
41 // Compute the N0 x N1 2D complex DFT of the first 2 dimensions of a complex
42 // valued function x. The first 2 dimensions of x should be defined on at least
43 // [0, N0) and [0, N1) for dimensions 0, 1, respectively. sign = -1 indicates a
44 // forward FFT, sign = 1 indicates an inverse FFT. There is no normalization of
45 // the FFT in either direction, i.e.:
46 //
47 //   X = fft2d_c2c(x, N0, N1, -1);
48 //   x = fft2d_c2c(X, N0, N1, 1) / (N0 * N1);
49 ComplexFunc fft2d_c2c(ComplexFunc x, int N0, int N1, int sign,
50                       const Halide::Target &target,
51                       const Fft2dDesc &desc = Fft2dDesc());
52 
53 // Compute the N0 x N1 2D complex DFT of the first 2 dimensions of a real valued
54 // function r. The first 2 dimensions of r should be defined on at least [0, N0)
55 // and [0, N1) for dimensions 0, 1, respectively. Note that the transform domain
56 // has dimensions N0 x N1 / 2 + 1 due to the conjugate symmetry of real DFTs.
57 // There is no normalization.
58 ComplexFunc fft2d_r2c(Halide::Func r, int N0, int N1,
59                       const Halide::Target &target,
60                       const Fft2dDesc &desc = Fft2dDesc());
61 
62 // Compute the real valued N0 x N1 2D inverse DFT of dimensions 0, 1 of c. Note
63 // that the transform domain has dimensions N0 x N1 / 2 + 1 due to the conjugate
64 // symmetry of real DFTs. There is no normalization.
65 Halide::Func fft2d_c2r(ComplexFunc c, int N0, int N1,
66                        const Halide::Target &target,
67                        const Fft2dDesc &desc = Fft2dDesc());
68 
69 #endif
70