1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_FOURIER_HPP
11 #define ELEM_FOURIER_HPP
12
13 namespace elem {
14
15 template<typename Real>
16 inline void
MakeFourier(Matrix<Complex<Real>> & A)17 MakeFourier( Matrix<Complex<Real>>& A )
18 {
19 DEBUG_ONLY(CallStackEntry cse("MakeFourier"))
20 const Int m = A.Height();
21 const Int n = A.Width();
22 if( m != n )
23 LogicError("Cannot make a non-square DFT matrix");
24
25 const Real pi = 4*Atan( Real(1) );
26 const Real nSqrt = Sqrt( Real(n) );
27 for( Int j=0; j<n; ++j )
28 {
29 for( Int i=0; i<m; ++i )
30 {
31 const Real theta = -2*pi*i*j/n;
32 const Real realPart = Cos(theta)/nSqrt;
33 const Real imagPart = Sin(theta)/nSqrt;
34 A.Set( i, j, Complex<Real>(realPart,imagPart) );
35 }
36 }
37 }
38
39 template<typename Real,Dist U,Dist V>
40 inline void
MakeFourier(DistMatrix<Complex<Real>,U,V> & A)41 MakeFourier( DistMatrix<Complex<Real>,U,V>& A )
42 {
43 DEBUG_ONLY(CallStackEntry cse("MakeFourier"))
44 const Int m = A.Height();
45 const Int n = A.Width();
46 if( m != n )
47 LogicError("Cannot make a non-square DFT matrix");
48
49 const Real pi = 4*Atan( Real(1) );
50 const Real nSqrt = Sqrt( Real(n) );
51 const Int localHeight = A.LocalHeight();
52 const Int localWidth = A.LocalWidth();
53 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
54 {
55 const Int j = A.GlobalCol(jLoc);
56 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
57 {
58 const Int i = A.GlobalRow(iLoc);
59 const Real theta = -2*pi*i*j/n;
60 const Real realPart = Cos(theta)/nSqrt;
61 const Real imagPart = Sin(theta)/nSqrt;
62 A.SetLocal( iLoc, jLoc, Complex<Real>(realPart,imagPart) );
63 }
64 }
65 }
66
67 template<typename Real,Dist U,Dist V>
68 inline void
MakeFourier(BlockDistMatrix<Complex<Real>,U,V> & A)69 MakeFourier( BlockDistMatrix<Complex<Real>,U,V>& A )
70 {
71 DEBUG_ONLY(CallStackEntry cse("MakeFourier"))
72 const Int m = A.Height();
73 const Int n = A.Width();
74 if( m != n )
75 LogicError("Cannot make a non-square DFT matrix");
76
77 const Real pi = 4*Atan( Real(1) );
78 const Real nSqrt = Sqrt( Real(n) );
79 const Int localHeight = A.LocalHeight();
80 const Int localWidth = A.LocalWidth();
81 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
82 {
83 const Int j = A.GlobalCol(jLoc);
84 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
85 {
86 const Int i = A.GlobalRow(iLoc);
87 const Real theta = -2*pi*i*j/n;
88 const Real realPart = Cos(theta)/nSqrt;
89 const Real imagPart = Sin(theta)/nSqrt;
90 A.SetLocal( iLoc, jLoc, Complex<Real>(realPart,imagPart) );
91 }
92 }
93 }
94
95 template<typename Real>
96 inline void
Fourier(Matrix<Complex<Real>> & A,Int n)97 Fourier( Matrix<Complex<Real>>& A, Int n )
98 {
99 DEBUG_ONLY(CallStackEntry cse("Fourier"))
100 A.Resize( n, n );
101 MakeFourier( A );
102 }
103
104 template<typename Real,Dist U,Dist V>
105 inline void
Fourier(DistMatrix<Complex<Real>,U,V> & A,Int n)106 Fourier( DistMatrix<Complex<Real>,U,V>& A, Int n )
107 {
108 DEBUG_ONLY(CallStackEntry cse("Fourier"))
109 A.Resize( n, n );
110 MakeFourier( A );
111 }
112
113 template<typename Real,Dist U,Dist V>
114 inline void
Fourier(BlockDistMatrix<Complex<Real>,U,V> & A,Int n)115 Fourier( BlockDistMatrix<Complex<Real>,U,V>& A, Int n )
116 {
117 DEBUG_ONLY(CallStackEntry cse("Fourier"))
118 A.Resize( n, n );
119 MakeFourier( A );
120 }
121
122 } // namespace elem
123
124 #endif // ifndef ELEM_FOURIER_HPP
125