1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 // NOTE: It is possible to simply include "elemental.hpp" instead
10 #include "elemental-lite.hpp"
11 #include ELEM_HEMM_INC
12 #include ELEM_TRMM_INC
13 #include ELEM_TWOSIDEDTRMM_INC
14 #include ELEM_FROBENIUSNORM_INC
15 #include ELEM_INFINITYNORM_INC
16 #include ELEM_ONENORM_INC
17 #include ELEM_HERMITIANUNIFORMSPECTRUM_INC
18 using namespace std;
19 using namespace elem;
20 
21 template<typename F>
TestCorrectness(bool print,UpperOrLower uplo,UnitOrNonUnit diag,const DistMatrix<F> & A,const DistMatrix<F> & B,const DistMatrix<F> & AOrig)22 void TestCorrectness
23 ( bool print, UpperOrLower uplo, UnitOrNonUnit diag,
24   const DistMatrix<F>& A, const DistMatrix<F>& B, const DistMatrix<F>& AOrig )
25 {
26     typedef Base<F> Real;
27     const Grid& g = A.Grid();
28     const Int m = AOrig.Height();
29 
30     const Int k=100;
31     DistMatrix<F> X(g), Y(g), Z(g);
32     Uniform( X, m, k );
33     Y = X;
34     Zeros( Z, m, k );
35 
36     if( uplo == LOWER )
37     {
38         // Test correctness by comparing the application of A against a
39         // random set of k vectors to the application of
40         // tril(B)^H AOrig tril(B)
41         Trmm( LEFT, LOWER, NORMAL, diag, F(1), B, Y );
42         Hemm( LEFT, LOWER, F(1), AOrig, Y, F(0), Z );
43         Trmm( LEFT, LOWER, ADJOINT, diag, F(1), B, Z );
44         Hemm( LEFT, LOWER, F(-1), A, X, F(1), Z );
45         Real infNormOfAOrig = HermitianInfinityNorm( uplo, AOrig );
46         Real frobNormOfAOrig = HermitianFrobeniusNorm( uplo, AOrig );
47         Real infNormOfA = HermitianInfinityNorm( uplo, A );
48         Real frobNormOfA = HermitianFrobeniusNorm( uplo, A );
49         Real oneNormOfError = OneNorm( Z );
50         Real infNormOfError = InfinityNorm( Z );
51         Real frobNormOfError = FrobeniusNorm( Z );
52         if( g.Rank() == 0 )
53         {
54             cout << "||AOrig||_1 = ||AOrig||_oo     = "
55                  << infNormOfAOrig << "\n"
56                  << "||AOrig||_F                    = "
57                  << frobNormOfAOrig << "\n"
58                  << "||A||_1 = ||A||_oo             = "
59                  << infNormOfA << "\n"
60                  << "||A||_F                        = "
61                  << frobNormOfA << "\n"
62                  << "||A X - L^H AOrig L X||_1  = "
63                  << oneNormOfError << "\n"
64                  << "||A X - L^H AOrig L X||_oo = "
65                  << infNormOfError << "\n"
66                  << "||A X - L^H AOrig L X||_F  = "
67                  << frobNormOfError << endl;
68         }
69     }
70     else
71     {
72         // Test correctness by comparing the application of A against a
73         // random set of k vectors to the application of
74         // triu(B) AOrig triu(B)^H
75         Trmm( LEFT, UPPER, ADJOINT, diag, F(1), B, Y );
76         Hemm( LEFT, UPPER, F(1), AOrig, Y, F(0), Z );
77         Trmm( LEFT, UPPER, NORMAL, diag, F(1), B, Z );
78         Hemm( LEFT, UPPER, F(-1), A, X, F(1), Z );
79         Real infNormOfAOrig = HermitianInfinityNorm( uplo, AOrig );
80         Real frobNormOfAOrig = HermitianFrobeniusNorm( uplo, AOrig );
81         Real infNormOfA = HermitianInfinityNorm( uplo, A );
82         Real frobNormOfA = HermitianFrobeniusNorm( uplo, A );
83         Real oneNormOfError = OneNorm( Z );
84         Real infNormOfError = InfinityNorm( Z );
85         Real frobNormOfError = FrobeniusNorm( Z );
86         if( g.Rank() == 0 )
87         {
88             cout << "||AOrig||_1 = ||AOrig||_oo     = "
89                  << infNormOfAOrig << "\n"
90                  << "||AOrig||_F                    = "
91                  << frobNormOfAOrig << "\n"
92                  << "||A||_1 = ||A||_oo             = "
93                  << infNormOfA << "\n"
94                  << "||A||_F                        = "
95                  << frobNormOfA << "\n"
96                  << "||A X - U AOrig U^H X||_1  = "
97                  << oneNormOfError << "\n"
98                  << "||A X - U AOrig U^H X||_oo = "
99                  << infNormOfError << "\n"
100                  << "||A X - U AOrig U^H X||_F  = "
101                  << frobNormOfError << endl;
102         }
103     }
104 }
105 
106 template<typename F>
TestTwoSidedTrmm(bool testCorrectness,bool print,UpperOrLower uplo,UnitOrNonUnit diag,Int m,const Grid & g)107 void TestTwoSidedTrmm
108 ( bool testCorrectness, bool print, UpperOrLower uplo, UnitOrNonUnit diag,
109   Int m, const Grid& g )
110 {
111     DistMatrix<F> A(g), B(g), AOrig(g);
112 
113     Zeros( A, m, m );
114     Zeros( B, m, m );
115     MakeHermitianUniformSpectrum( A, 1, 10 );
116     MakeHermitianUniformSpectrum( B, 1, 10 );
117     MakeTriangular( uplo, B );
118     if( testCorrectness )
119     {
120         if( g.Rank() == 0 )
121         {
122             cout << "  Making copy of original matrix...";
123             cout.flush();
124         }
125         AOrig = A;
126         if( g.Rank() == 0 )
127             cout << "DONE" << endl;
128     }
129     if( print )
130     {
131         Print( A, "A" );
132         Print( B, "B" );
133     }
134 
135     if( g.Rank() == 0 )
136     {
137         cout << "  Starting reduction to Hermitian standard EVP...";
138         cout.flush();
139     }
140     mpi::Barrier( g.Comm() );
141     const double startTime = mpi::Time();
142     TwoSidedTrmm( uplo, diag, A, B );
143     mpi::Barrier( g.Comm() );
144     const double runTime = mpi::Time() - startTime;
145     double gFlops = Pow(double(m),3.)/(runTime*1.e9);
146     if( IsComplex<F>::val )
147         gFlops *= 4.;
148     if( g.Rank() == 0 )
149     {
150         cout << "DONE. " << endl
151              << "  Time = " << runTime << " seconds. GFlops = "
152              << gFlops << endl;
153     }
154     if( print )
155         Print( A, "A after reduction" );
156     if( testCorrectness )
157         TestCorrectness( print, uplo, diag, A, B, AOrig );
158 }
159 
160 int
main(int argc,char * argv[])161 main( int argc, char* argv[] )
162 {
163     Initialize( argc, argv );
164     mpi::Comm comm = mpi::COMM_WORLD;
165     const Int commRank = mpi::Rank( comm );
166     const Int commSize = mpi::Size( comm );
167 
168     try
169     {
170         Int r = Input("--r","height of process grid",0);
171         const bool colMajor = Input("--colMajor","column-major ordering?",true);
172         const char uploChar = Input
173             ("--uplo","lower or upper triangular storage: L/U",'L');
174         const char diagChar = Input("--unit","(non-)unit diagonal: N/U",'N');
175         const Int m = Input("--m","height of matrix",100);
176         const Int nb = Input("--nb","algorithmic blocksize",96);
177         const bool testCorrectness = Input
178             ("--correctness","test correctness?",true);
179         const bool print = Input("--print","print matrices?",false);
180         ProcessInput();
181         PrintInputReport();
182 
183         if( r == 0 )
184             r = Grid::FindFactor( commSize );
185         const GridOrder order = ( colMajor ? COLUMN_MAJOR : ROW_MAJOR );
186         const Grid g( comm, r, order );
187         const UpperOrLower uplo = CharToUpperOrLower( uploChar );
188         const UnitOrNonUnit diag = CharToUnitOrNonUnit( diagChar );
189         SetBlocksize( nb );
190 
191         ComplainIfDebug();
192         if( commRank == 0 )
193             cout << "Will test TwoSidedTrmm" << uploChar << diagChar << endl;
194 
195         if( commRank == 0 )
196             cout << "Testing with doubles:" << endl;
197         TestTwoSidedTrmm<double>( testCorrectness, print, uplo, diag, m, g );
198 
199         if( commRank == 0 )
200             cout << "Testing with double-precision complex:" << endl;
201         TestTwoSidedTrmm<Complex<double>>
202         ( testCorrectness, print, uplo, diag, m, g );
203     }
204     catch( exception& e ) { ReportException(e); }
205 
206     Finalize();
207     return 0;
208 }
209