1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_FOURIER_HPP
11 #define ELEM_FOURIER_HPP
12 
13 namespace elem {
14 
15 template<typename Real>
16 inline void
MakeFourier(Matrix<Complex<Real>> & A)17 MakeFourier( Matrix<Complex<Real>>& A )
18 {
19     DEBUG_ONLY(CallStackEntry cse("MakeFourier"))
20     const Int m = A.Height();
21     const Int n = A.Width();
22     if( m != n )
23         LogicError("Cannot make a non-square DFT matrix");
24 
25     const Real pi = 4*Atan( Real(1) );
26     const Real nSqrt = Sqrt( Real(n) );
27     for( Int j=0; j<n; ++j )
28     {
29         for( Int i=0; i<m; ++i )
30         {
31             const Real theta = -2*pi*i*j/n;
32             const Real realPart = Cos(theta)/nSqrt;
33             const Real imagPart = Sin(theta)/nSqrt;
34             A.Set( i, j, Complex<Real>(realPart,imagPart) );
35         }
36     }
37 }
38 
39 template<typename Real,Dist U,Dist V>
40 inline void
MakeFourier(DistMatrix<Complex<Real>,U,V> & A)41 MakeFourier( DistMatrix<Complex<Real>,U,V>& A )
42 {
43     DEBUG_ONLY(CallStackEntry cse("MakeFourier"))
44     const Int m = A.Height();
45     const Int n = A.Width();
46     if( m != n )
47         LogicError("Cannot make a non-square DFT matrix");
48 
49     const Real pi = 4*Atan( Real(1) );
50     const Real nSqrt = Sqrt( Real(n) );
51     const Int localHeight = A.LocalHeight();
52     const Int localWidth = A.LocalWidth();
53     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
54     {
55         const Int j = A.GlobalCol(jLoc);
56         for( Int iLoc=0; iLoc<localHeight; ++iLoc )
57         {
58             const Int i = A.GlobalRow(iLoc);
59             const Real theta = -2*pi*i*j/n;
60             const Real realPart = Cos(theta)/nSqrt;
61             const Real imagPart = Sin(theta)/nSqrt;
62             A.SetLocal( iLoc, jLoc, Complex<Real>(realPart,imagPart) );
63         }
64     }
65 }
66 
67 template<typename Real,Dist U,Dist V>
68 inline void
MakeFourier(BlockDistMatrix<Complex<Real>,U,V> & A)69 MakeFourier( BlockDistMatrix<Complex<Real>,U,V>& A )
70 {
71     DEBUG_ONLY(CallStackEntry cse("MakeFourier"))
72     const Int m = A.Height();
73     const Int n = A.Width();
74     if( m != n )
75         LogicError("Cannot make a non-square DFT matrix");
76 
77     const Real pi = 4*Atan( Real(1) );
78     const Real nSqrt = Sqrt( Real(n) );
79     const Int localHeight = A.LocalHeight();
80     const Int localWidth = A.LocalWidth();
81     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
82     {
83         const Int j = A.GlobalCol(jLoc);
84         for( Int iLoc=0; iLoc<localHeight; ++iLoc )
85         {
86             const Int i = A.GlobalRow(iLoc);
87             const Real theta = -2*pi*i*j/n;
88             const Real realPart = Cos(theta)/nSqrt;
89             const Real imagPart = Sin(theta)/nSqrt;
90             A.SetLocal( iLoc, jLoc, Complex<Real>(realPart,imagPart) );
91         }
92     }
93 }
94 
95 template<typename Real>
96 inline void
Fourier(Matrix<Complex<Real>> & A,Int n)97 Fourier( Matrix<Complex<Real>>& A, Int n )
98 {
99     DEBUG_ONLY(CallStackEntry cse("Fourier"))
100     A.Resize( n, n );
101     MakeFourier( A );
102 }
103 
104 template<typename Real,Dist U,Dist V>
105 inline void
Fourier(DistMatrix<Complex<Real>,U,V> & A,Int n)106 Fourier( DistMatrix<Complex<Real>,U,V>& A, Int n )
107 {
108     DEBUG_ONLY(CallStackEntry cse("Fourier"))
109     A.Resize( n, n );
110     MakeFourier( A );
111 }
112 
113 template<typename Real,Dist U,Dist V>
114 inline void
Fourier(BlockDistMatrix<Complex<Real>,U,V> & A,Int n)115 Fourier( BlockDistMatrix<Complex<Real>,U,V>& A, Int n )
116 {
117     DEBUG_ONLY(CallStackEntry cse("Fourier"))
118     A.Resize( n, n );
119     MakeFourier( A );
120 }
121 
122 } // namespace elem
123 
124 #endif // ifndef ELEM_FOURIER_HPP
125