1 /*
2    Copyright (c) The University of Texas at Austin, 2013.
3    Copyright (c) Jack Poulson, 2013.
4 
5    Authors: Martin Schatz (primary) and Jack Poulson (maintenance)
6 
7    This file is part of Elemental and is under the BSD 2-Clause License,
8    which can be found in the LICENSE file in the root directory, or at
9    http://opensource.org/licenses/BSD-2-Clause
10 */
11 #include <cstdio>
12 #include "elemental.hpp"
13 using namespace elem;
14 
15 // Initialize auxiliary communicators for depth dimension
InitDepthComms(int meshSize,mpi::Comm & depthComm,mpi::Comm & meshComm)16 void InitDepthComms( int meshSize, mpi::Comm& depthComm, mpi::Comm& meshComm )
17 {
18     const int rank = mpi::Rank( mpi::COMM_WORLD );
19 
20     // Build this process's meshComm (2d grid)
21     const int depthRank = rank / meshSize;
22     const int depthColor = rank % meshSize;
23     mpi::Split( mpi::COMM_WORLD, depthColor, depthRank, depthComm );
24 
25     // Build this process's depthComm (depth communicator)
26     const int meshRank = rank % meshSize;
27     const int meshColor = rank / meshSize;
28     mpi::Split( mpi::COMM_WORLD, meshColor, meshRank, meshComm );
29 }
30 
31 // Have the top layer initialize the distributed matrix, A
InitA(DistMatrix<double> & A,bool print)32 void InitA( DistMatrix<double>& A, bool print )
33 {
34     const int rank = mpi::Rank(mpi::COMM_WORLD);
35     const Grid& g = A.Grid();
36     const int meshSize = g.Size();
37     const int depthRank = rank / meshSize;
38 
39     if( depthRank == 0 )
40     {
41         MakeIdentity( A );
42         Scale( 10.0, A );
43         if( print )
44             Print( A, "A" );
45     }
46 }
47 
48 // Have the top layer initialize the distributed matrix, B
InitB(DistMatrix<double> & B,bool print)49 void InitB( DistMatrix<double>& B, bool print )
50 {
51     const int rank = mpi::Rank(mpi::COMM_WORLD);
52     const Grid& g = B.Grid();
53     const int meshSize = g.Size();
54     const int depthRank = rank / meshSize;
55 
56     if( depthRank == 0 )
57     {
58         if( B.LocalHeight() != B.LDim() )
59             throw std::logic_error("Ldim of B was too large");
60 
61         double* localBuffer = B.Buffer();
62         const int localSize = B.LocalHeight()*B.LocalWidth();
63         for( int iLocal=0; iLocal<localSize; ++iLocal )
64             localBuffer[iLocal] = iLocal*meshSize + rank;
65 
66         if( print )
67             Print( B, "B" );
68     }
69 }
70 
71 // Have the top layer initialize the distributed matrix, C
InitC(DistMatrix<double> & C,bool print)72 void InitC( DistMatrix<double>& C, bool print )
73 {
74     const int rank = mpi::Rank(mpi::COMM_WORLD);
75     const Grid& g = C.Grid();
76     const int meshSize = g.Size();
77     const int depthRank = rank / meshSize;
78 
79     if( depthRank == 0 )
80         MakeZeros( C );
81 }
82 
83 // Create a new set of distributed matrices, so that,
84 //    if depthRank == 0, B = A,
85 //    otherwise,         B = 0.
CopyOrReset(const DistMatrix<double> & A,DistMatrix<double> & B)86 void CopyOrReset
87 ( const DistMatrix<double>& A, DistMatrix<double>& B )
88 {
89     const int rank = mpi::Rank( mpi::COMM_WORLD );
90     const Grid& meshGrid = A.Grid();
91     const int meshSize = meshGrid.Size();
92     const int depthRank = rank / meshSize;
93 
94     //Layer 0
95     if( depthRank == 0 )
96         B = A;
97     else
98     {
99         B.AlignWith( A );
100         Zeros( B, A.Height(), A.Width() );
101     }
102 }
103 
104 // Broadcast a matrix from the root grid to the others
DepthBroadcast(const mpi::Comm & depthComm,const DistMatrix<double> & A,DistMatrix<double> & B)105 void DepthBroadcast
106 ( const mpi::Comm& depthComm,
107   const DistMatrix<double>& A, DistMatrix<double>& B )
108 {
109     const int rank = mpi::Rank(mpi::COMM_WORLD);
110     const Grid& meshGrid = A.Grid();
111     const int meshSize = meshGrid.Size();
112     const int depthRank = rank / meshSize;
113 
114     const int localSize = A.LocalHeight()*A.LocalWidth();
115     if( A.LocalHeight() != A.LDim() )
116         throw std::logic_error("Leading dimension did not match local height");
117 
118     B.Empty();
119     B.AlignWith( A );
120     B.Resize( A.Height(), A.Width() );
121 
122     // Have the root pack the broadcast data
123     if( depthRank == 0 )
124         MemCopy( B.Buffer(), A.LockedBuffer(), localSize );
125 
126     // Broadcast from the root
127     mpi::Broadcast( B.Buffer(), localSize, 0, depthComm );
128 }
129 
130 /*
131  * Distributes A in such a way that
132  *   Layer 0 <- A(:, 0:(n/h - 1))
133  *   Layer 1 <- A(:, (n/h):(2n/h - 1))
134  *     .
135  *     .
136  *     .
137  *   Layer h-1 <- A(:, ((h-1)n/h):n)
138  */
DistributeCols(const mpi::Comm & depthComm,const DistMatrix<double> & A,DistMatrix<double> & B)139 void DistributeCols
140 ( const mpi::Comm& depthComm,
141   const DistMatrix<double>& A, DistMatrix<double>& B )
142 {
143     const int depthSize = mpi::Size( depthComm );
144     const int depthRank = mpi::Rank( depthComm );
145 
146     const int sendCount = A.LocalHeight()*A.LocalWidth();
147     const int recvCount = sendCount / depthSize;
148 
149     // For now, we will make B as large as A...
150     // TODO: NOT DO THIS
151     if( A.LocalHeight() != A.LDim() )
152         throw std::logic_error("Local height did not match ldim");
153     B.Empty();
154     B.AlignWith( A );
155     Zeros( B, A.Height(), A.Width() );
156 
157     // Scatter
158     const int localColOffset = (A.LocalWidth()/depthSize)*depthRank;
159     mpi::Scatter
160     ( A.LockedBuffer(), recvCount,
161       B.Buffer(0,localColOffset), recvCount, 0, depthComm );
162 }
163 
164 /*
165  * Distributes A in such a way that
166  *   Layer 0 <- A(0:(m/h - 1), :)
167  *   Layer 1 <- A((m/h):(2m/h - 1), :)
168  *     .
169  *     .
170  *     .
171  *   Layer h-1 <- A(((h-1)m/h):m, :)
172  */
DistributeRows(const mpi::Comm & depthComm,const DistMatrix<double> & A,DistMatrix<double> & B)173 void DistributeRows
174 ( const mpi::Comm& depthComm,
175   const DistMatrix<double>& A, DistMatrix<double>& B )
176 {
177     const int depthRank = mpi::Rank( depthComm );
178     const int depthSize = mpi::Size( depthComm );
179     const Grid& meshGrid = A.Grid();
180 
181     const int sendCount = A.LocalHeight()*A.LocalWidth();
182     const int recvCount = sendCount / depthSize;
183 
184     // Have the root mesh pack the data for scattering
185     std::vector<double> sendBuf;
186     const int blockSize = A.Height() / depthSize;
187     if( depthRank == 0 )
188     {
189         sendBuf.resize( sendCount );
190         MemZero( &sendBuf[0], sendCount ); // TODO: Is this necessary?!?
191 
192         DistMatrix<double>
193             AT(meshGrid), A0(meshGrid),
194             AB(meshGrid), A1(meshGrid),
195                           A2(meshGrid);
196 
197         // Pack rows block by block for each layer
198         LockedPartitionDown
199         ( A, AT,
200              AB, 0 );
201         for( int i=0; i<depthSize; ++i )
202         {
203             LockedRepartitionDown
204             ( AT,  A0,
205              /**/ /**/
206                    A1,
207               AB,  A2, blockSize );
208 
209             const int dataSize = A1.LocalWidth()*A1.LocalHeight();
210             const int offset = i*dataSize;
211 
212             // TODO: Avoid the extra copy...
213             DistMatrix<double> A1Contig( A1 );
214             MemCopy( &sendBuf[offset], A1Contig.LockedBuffer(), dataSize );
215 
216             SlideLockedPartitionDown
217             ( AT,  A0,
218                    A1,
219              /**/ /**/
220               AB,  A2 );
221         }
222     }
223 
224     // Scatter the packed data
225     std::vector<double> recvBuf( recvCount );
226     mpi::Scatter
227     ( &sendBuf[0], recvCount, &recvBuf[0], recvCount, 0, depthComm );
228 
229     // Pad received data by zero
230     DistMatrix<double> dataBlock( meshGrid );
231     dataBlock.Attach
232     ( blockSize, A.Width(), meshGrid, 0, 0,
233       &recvBuf[0], blockSize/meshGrid.Height() );
234 
235     // TODO: We can probably heavily simplify this...
236     //
237     // dataBlock_T <- transpose(dataBlock)
238     // tmp_T <- padWithZeros(dataBlockT)
239     // tmp <- transpose(tmp_T)
240     // Layer x <- M((x*Mm/h):((x+1)*Mm/h - 1), :)
241     DistMatrix<double> dataBlockTrans( meshGrid );
242     Transpose( dataBlock, dataBlockTrans );
243 
244     std::vector<double> newData( sendCount );
245     MemZero( &newData[0], sendCount );
246     const int offset = depthRank*recvCount;
247 
248     MemCopy( &newData[offset], dataBlockTrans.LockedBuffer(), recvCount );
249 
250     DistMatrix<double> tmpTrans( meshGrid );
251     tmpTrans.Attach
252     ( A.Width(), A.Height(), meshGrid, 0, 0,
253       &newData[0], A.Width()/meshGrid.Width() );
254     DistMatrix<double> tmp( meshGrid );
255     Transpose( tmpTrans, tmp );
256 
257     Transpose( tmpTrans, B );
258 }
259 
260 // Initialize all matrices in order to set up for the G3D GEMM
InitializeMatrices(int type,mpi::Comm & depthComm,int m,int n,int k,DistMatrix<double> & AOut,DistMatrix<double> & BOut,DistMatrix<double> & COut,bool print)261 void InitializeMatrices
262 ( int type, mpi::Comm& depthComm,
263   int m, int n, int k,
264   DistMatrix<double>& AOut,
265   DistMatrix<double>& BOut,
266   DistMatrix<double>& COut,
267   bool print )
268 {
269     const Grid& meshGrid = AOut.Grid();
270 
271     DistMatrix<double> A( m, k, meshGrid ),
272                        B( k, n, meshGrid ),
273                        C( m, n, meshGrid );
274 
275     //Initialize top layer with desired matrices
276     InitA( A, print );
277     InitB( B, print );
278     InitC( C, print );
279 
280     //Distribute matrices according to which matrix is stationary
281     switch (type)
282     {
283     case 'A':
284         DepthBroadcast( depthComm, A, AOut );
285         DistributeCols( depthComm, B, BOut);
286         DistributeCols( depthComm, C, COut);
287         break;
288     case 'B':
289         DistributeRows( depthComm, A, AOut);
290         DepthBroadcast( depthComm, B, BOut );
291         DistributeRows( depthComm, C, COut);
292         break;
293     case 'C':
294         DistributeCols( depthComm, A, AOut);
295         DistributeRows( depthComm, B, BOut);
296         CopyOrReset( C, COut );
297         break;
298     default:
299         throw std::logic_error("Unknown stationary type");
300     }
301 }
302 
303 // Reduce across depth to get end result C
SumContributions(mpi::Comm & depthComm,const DistMatrix<double> & APartial,DistMatrix<double> & A)304 void SumContributions
305 ( mpi::Comm& depthComm,
306   const DistMatrix<double>& APartial, DistMatrix<double>& A )
307 {
308     A.Empty();
309     A.AlignWith( APartial );
310     A.Resize( APartial.Height(), APartial.Width() );
311 
312     if( APartial.LocalHeight() != APartial.LDim() )
313         throw std::logic_error
314         ("APartial did not have matching local height/ldim");
315     if( A.LocalHeight() != A.LDim() )
316         throw std::logic_error("A did not have matching local height/ldim");
317 
318     const int dataSize = APartial.LocalHeight()*APartial.LocalWidth();
319     mpi::AllReduce
320     ( APartial.LockedBuffer(), A.Buffer(), dataSize, mpi::SUM, depthComm );
321 }
322 
main(int argc,char * argv[])323 int main( int argc, char* argv[] )
324 {
325     Initialize( argc, argv );
326     mpi::Comm comm = mpi::COMM_WORLD;
327     const int commRank = mpi::Rank( comm );
328 
329     try
330     {
331         const char type = Input("--type","'A', 'B', or 'C' algorithm",'C');
332         const int r = Input<int>("--gridHeight","height of process grid");
333         const int c = Input<int>("--gridWidth","width of process grid");
334         const int depth = Input<int>("--depth","amount of redundancy");
335         const int m = Input("--m","height of result",500);
336         const int n = Input("--n","width of result",500);
337         const int k = Input("--k","inner dimension",500);
338         const bool print = Input("--print","print matrices?",false);
339         ProcessInput();
340         PrintInputReport();
341 
342         // Sanity check on inputs
343         if( m % r != 0 || m % c != 0 || m % depth != 0 ||
344             k % r != 0 || k % c != 0 || k % depth != 0 ||
345             n % r != 0 || n % c != 0 || n % depth != 0 )
346         {
347             if( commRank == 0 )
348                 std::cout << "Dimensions of matrices must be multiples of "
349                              "grid dimensions (for now)" << std::endl;
350             Finalize();
351             return 0;
352         }
353         if( type < 'A' || type > 'C' )
354         {
355             if( commRank == 0 )
356                 std::cout << "Algorithm must be 'A', 'B', or 'C'" << std::endl;
357             Finalize();
358             return 0;
359         }
360 
361         DEBUG_ONLY(
362             if( commRank == 0 )
363             {
364                 std::cout
365                      << "==========================================\n"
366                      << " In debug mode! Performance will be poor! \n"
367                      << "=========================================="
368                      << std::endl;
369             }
370         )
371 
372         mpi::Comm depthComm, meshComm;
373         InitDepthComms( r*c, depthComm, meshComm );
374         const int depthRank = mpi::Rank( depthComm );
375         const Grid meshGrid( meshComm, r, c );
376 
377         DistMatrix<double> A( m, k, meshGrid ),
378                            B( k, n, meshGrid ),
379                            CPartial( m, n, meshGrid ),
380                            C( m, n, meshGrid );
381 
382         InitializeMatrices( type, depthComm, m, n, k, A, B, CPartial, print );
383 
384         // Compute within our mesh
385         mpi::Barrier( comm );
386         const double startTime = mpi::Time();
387         Gemm( NORMAL, NORMAL, 1.0, A, B, 1.0, CPartial );
388         SumContributions( depthComm, CPartial, C );
389         mpi::Barrier( comm );
390         const double stopTime = mpi::Time();
391         if( commRank == 0 )
392             std::cout << "Runtime: " << stopTime-startTime << " seconds"
393                       << std::endl;
394 
395         if( depthRank == 0 && print )
396             Print( C, "C" );
397     }
398     catch( std::exception& e ) { ReportException(e); }
399 
400     Finalize();
401     return 0;
402 }
403