1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 // NOTE: It is possible to simply include "elemental.hpp" instead
10 #include "elemental-lite.hpp"
11 #include ELEM_HEMM_INC
12 #include ELEM_TRMM_INC
13 #include ELEM_TWOSIDEDTRMM_INC
14 #include ELEM_FROBENIUSNORM_INC
15 #include ELEM_INFINITYNORM_INC
16 #include ELEM_ONENORM_INC
17 #include ELEM_HERMITIANUNIFORMSPECTRUM_INC
18 using namespace std;
19 using namespace elem;
20
21 template<typename F>
TestCorrectness(bool print,UpperOrLower uplo,UnitOrNonUnit diag,const DistMatrix<F> & A,const DistMatrix<F> & B,const DistMatrix<F> & AOrig)22 void TestCorrectness
23 ( bool print, UpperOrLower uplo, UnitOrNonUnit diag,
24 const DistMatrix<F>& A, const DistMatrix<F>& B, const DistMatrix<F>& AOrig )
25 {
26 typedef Base<F> Real;
27 const Grid& g = A.Grid();
28 const Int m = AOrig.Height();
29
30 const Int k=100;
31 DistMatrix<F> X(g), Y(g), Z(g);
32 Uniform( X, m, k );
33 Y = X;
34 Zeros( Z, m, k );
35
36 if( uplo == LOWER )
37 {
38 // Test correctness by comparing the application of A against a
39 // random set of k vectors to the application of
40 // tril(B)^H AOrig tril(B)
41 Trmm( LEFT, LOWER, NORMAL, diag, F(1), B, Y );
42 Hemm( LEFT, LOWER, F(1), AOrig, Y, F(0), Z );
43 Trmm( LEFT, LOWER, ADJOINT, diag, F(1), B, Z );
44 Hemm( LEFT, LOWER, F(-1), A, X, F(1), Z );
45 Real infNormOfAOrig = HermitianInfinityNorm( uplo, AOrig );
46 Real frobNormOfAOrig = HermitianFrobeniusNorm( uplo, AOrig );
47 Real infNormOfA = HermitianInfinityNorm( uplo, A );
48 Real frobNormOfA = HermitianFrobeniusNorm( uplo, A );
49 Real oneNormOfError = OneNorm( Z );
50 Real infNormOfError = InfinityNorm( Z );
51 Real frobNormOfError = FrobeniusNorm( Z );
52 if( g.Rank() == 0 )
53 {
54 cout << "||AOrig||_1 = ||AOrig||_oo = "
55 << infNormOfAOrig << "\n"
56 << "||AOrig||_F = "
57 << frobNormOfAOrig << "\n"
58 << "||A||_1 = ||A||_oo = "
59 << infNormOfA << "\n"
60 << "||A||_F = "
61 << frobNormOfA << "\n"
62 << "||A X - L^H AOrig L X||_1 = "
63 << oneNormOfError << "\n"
64 << "||A X - L^H AOrig L X||_oo = "
65 << infNormOfError << "\n"
66 << "||A X - L^H AOrig L X||_F = "
67 << frobNormOfError << endl;
68 }
69 }
70 else
71 {
72 // Test correctness by comparing the application of A against a
73 // random set of k vectors to the application of
74 // triu(B) AOrig triu(B)^H
75 Trmm( LEFT, UPPER, ADJOINT, diag, F(1), B, Y );
76 Hemm( LEFT, UPPER, F(1), AOrig, Y, F(0), Z );
77 Trmm( LEFT, UPPER, NORMAL, diag, F(1), B, Z );
78 Hemm( LEFT, UPPER, F(-1), A, X, F(1), Z );
79 Real infNormOfAOrig = HermitianInfinityNorm( uplo, AOrig );
80 Real frobNormOfAOrig = HermitianFrobeniusNorm( uplo, AOrig );
81 Real infNormOfA = HermitianInfinityNorm( uplo, A );
82 Real frobNormOfA = HermitianFrobeniusNorm( uplo, A );
83 Real oneNormOfError = OneNorm( Z );
84 Real infNormOfError = InfinityNorm( Z );
85 Real frobNormOfError = FrobeniusNorm( Z );
86 if( g.Rank() == 0 )
87 {
88 cout << "||AOrig||_1 = ||AOrig||_oo = "
89 << infNormOfAOrig << "\n"
90 << "||AOrig||_F = "
91 << frobNormOfAOrig << "\n"
92 << "||A||_1 = ||A||_oo = "
93 << infNormOfA << "\n"
94 << "||A||_F = "
95 << frobNormOfA << "\n"
96 << "||A X - U AOrig U^H X||_1 = "
97 << oneNormOfError << "\n"
98 << "||A X - U AOrig U^H X||_oo = "
99 << infNormOfError << "\n"
100 << "||A X - U AOrig U^H X||_F = "
101 << frobNormOfError << endl;
102 }
103 }
104 }
105
106 template<typename F>
TestTwoSidedTrmm(bool testCorrectness,bool print,UpperOrLower uplo,UnitOrNonUnit diag,Int m,const Grid & g)107 void TestTwoSidedTrmm
108 ( bool testCorrectness, bool print, UpperOrLower uplo, UnitOrNonUnit diag,
109 Int m, const Grid& g )
110 {
111 DistMatrix<F> A(g), B(g), AOrig(g);
112
113 Zeros( A, m, m );
114 Zeros( B, m, m );
115 MakeHermitianUniformSpectrum( A, 1, 10 );
116 MakeHermitianUniformSpectrum( B, 1, 10 );
117 MakeTriangular( uplo, B );
118 if( testCorrectness )
119 {
120 if( g.Rank() == 0 )
121 {
122 cout << " Making copy of original matrix...";
123 cout.flush();
124 }
125 AOrig = A;
126 if( g.Rank() == 0 )
127 cout << "DONE" << endl;
128 }
129 if( print )
130 {
131 Print( A, "A" );
132 Print( B, "B" );
133 }
134
135 if( g.Rank() == 0 )
136 {
137 cout << " Starting reduction to Hermitian standard EVP...";
138 cout.flush();
139 }
140 mpi::Barrier( g.Comm() );
141 const double startTime = mpi::Time();
142 TwoSidedTrmm( uplo, diag, A, B );
143 mpi::Barrier( g.Comm() );
144 const double runTime = mpi::Time() - startTime;
145 double gFlops = Pow(double(m),3.)/(runTime*1.e9);
146 if( IsComplex<F>::val )
147 gFlops *= 4.;
148 if( g.Rank() == 0 )
149 {
150 cout << "DONE. " << endl
151 << " Time = " << runTime << " seconds. GFlops = "
152 << gFlops << endl;
153 }
154 if( print )
155 Print( A, "A after reduction" );
156 if( testCorrectness )
157 TestCorrectness( print, uplo, diag, A, B, AOrig );
158 }
159
160 int
main(int argc,char * argv[])161 main( int argc, char* argv[] )
162 {
163 Initialize( argc, argv );
164 mpi::Comm comm = mpi::COMM_WORLD;
165 const Int commRank = mpi::Rank( comm );
166 const Int commSize = mpi::Size( comm );
167
168 try
169 {
170 Int r = Input("--r","height of process grid",0);
171 const bool colMajor = Input("--colMajor","column-major ordering?",true);
172 const char uploChar = Input
173 ("--uplo","lower or upper triangular storage: L/U",'L');
174 const char diagChar = Input("--unit","(non-)unit diagonal: N/U",'N');
175 const Int m = Input("--m","height of matrix",100);
176 const Int nb = Input("--nb","algorithmic blocksize",96);
177 const bool testCorrectness = Input
178 ("--correctness","test correctness?",true);
179 const bool print = Input("--print","print matrices?",false);
180 ProcessInput();
181 PrintInputReport();
182
183 if( r == 0 )
184 r = Grid::FindFactor( commSize );
185 const GridOrder order = ( colMajor ? COLUMN_MAJOR : ROW_MAJOR );
186 const Grid g( comm, r, order );
187 const UpperOrLower uplo = CharToUpperOrLower( uploChar );
188 const UnitOrNonUnit diag = CharToUnitOrNonUnit( diagChar );
189 SetBlocksize( nb );
190
191 ComplainIfDebug();
192 if( commRank == 0 )
193 cout << "Will test TwoSidedTrmm" << uploChar << diagChar << endl;
194
195 if( commRank == 0 )
196 cout << "Testing with doubles:" << endl;
197 TestTwoSidedTrmm<double>( testCorrectness, print, uplo, diag, m, g );
198
199 if( commRank == 0 )
200 cout << "Testing with double-precision complex:" << endl;
201 TestTwoSidedTrmm<Complex<double>>
202 ( testCorrectness, print, uplo, diag, m, g );
203 }
204 catch( exception& e ) { ReportException(e); }
205
206 Finalize();
207 return 0;
208 }
209