1 /*
2 Copyright (c) The University of Texas at Austin, 2013.
3 Copyright (c) Jack Poulson, 2013.
4
5 Authors: Martin Schatz (primary) and Jack Poulson (maintenance)
6
7 This file is part of Elemental and is under the BSD 2-Clause License,
8 which can be found in the LICENSE file in the root directory, or at
9 http://opensource.org/licenses/BSD-2-Clause
10 */
11 #include <cstdio>
12 #include "elemental.hpp"
13 using namespace elem;
14
15 // Initialize auxiliary communicators for depth dimension
InitDepthComms(int meshSize,mpi::Comm & depthComm,mpi::Comm & meshComm)16 void InitDepthComms( int meshSize, mpi::Comm& depthComm, mpi::Comm& meshComm )
17 {
18 const int rank = mpi::Rank( mpi::COMM_WORLD );
19
20 // Build this process's meshComm (2d grid)
21 const int depthRank = rank / meshSize;
22 const int depthColor = rank % meshSize;
23 mpi::Split( mpi::COMM_WORLD, depthColor, depthRank, depthComm );
24
25 // Build this process's depthComm (depth communicator)
26 const int meshRank = rank % meshSize;
27 const int meshColor = rank / meshSize;
28 mpi::Split( mpi::COMM_WORLD, meshColor, meshRank, meshComm );
29 }
30
31 // Have the top layer initialize the distributed matrix, A
InitA(DistMatrix<double> & A,bool print)32 void InitA( DistMatrix<double>& A, bool print )
33 {
34 const int rank = mpi::Rank(mpi::COMM_WORLD);
35 const Grid& g = A.Grid();
36 const int meshSize = g.Size();
37 const int depthRank = rank / meshSize;
38
39 if( depthRank == 0 )
40 {
41 MakeIdentity( A );
42 Scale( 10.0, A );
43 if( print )
44 Print( A, "A" );
45 }
46 }
47
48 // Have the top layer initialize the distributed matrix, B
InitB(DistMatrix<double> & B,bool print)49 void InitB( DistMatrix<double>& B, bool print )
50 {
51 const int rank = mpi::Rank(mpi::COMM_WORLD);
52 const Grid& g = B.Grid();
53 const int meshSize = g.Size();
54 const int depthRank = rank / meshSize;
55
56 if( depthRank == 0 )
57 {
58 if( B.LocalHeight() != B.LDim() )
59 throw std::logic_error("Ldim of B was too large");
60
61 double* localBuffer = B.Buffer();
62 const int localSize = B.LocalHeight()*B.LocalWidth();
63 for( int iLocal=0; iLocal<localSize; ++iLocal )
64 localBuffer[iLocal] = iLocal*meshSize + rank;
65
66 if( print )
67 Print( B, "B" );
68 }
69 }
70
71 // Have the top layer initialize the distributed matrix, C
InitC(DistMatrix<double> & C,bool print)72 void InitC( DistMatrix<double>& C, bool print )
73 {
74 const int rank = mpi::Rank(mpi::COMM_WORLD);
75 const Grid& g = C.Grid();
76 const int meshSize = g.Size();
77 const int depthRank = rank / meshSize;
78
79 if( depthRank == 0 )
80 MakeZeros( C );
81 }
82
83 // Create a new set of distributed matrices, so that,
84 // if depthRank == 0, B = A,
85 // otherwise, B = 0.
CopyOrReset(const DistMatrix<double> & A,DistMatrix<double> & B)86 void CopyOrReset
87 ( const DistMatrix<double>& A, DistMatrix<double>& B )
88 {
89 const int rank = mpi::Rank( mpi::COMM_WORLD );
90 const Grid& meshGrid = A.Grid();
91 const int meshSize = meshGrid.Size();
92 const int depthRank = rank / meshSize;
93
94 //Layer 0
95 if( depthRank == 0 )
96 B = A;
97 else
98 {
99 B.AlignWith( A );
100 Zeros( B, A.Height(), A.Width() );
101 }
102 }
103
104 // Broadcast a matrix from the root grid to the others
DepthBroadcast(const mpi::Comm & depthComm,const DistMatrix<double> & A,DistMatrix<double> & B)105 void DepthBroadcast
106 ( const mpi::Comm& depthComm,
107 const DistMatrix<double>& A, DistMatrix<double>& B )
108 {
109 const int rank = mpi::Rank(mpi::COMM_WORLD);
110 const Grid& meshGrid = A.Grid();
111 const int meshSize = meshGrid.Size();
112 const int depthRank = rank / meshSize;
113
114 const int localSize = A.LocalHeight()*A.LocalWidth();
115 if( A.LocalHeight() != A.LDim() )
116 throw std::logic_error("Leading dimension did not match local height");
117
118 B.Empty();
119 B.AlignWith( A );
120 B.Resize( A.Height(), A.Width() );
121
122 // Have the root pack the broadcast data
123 if( depthRank == 0 )
124 MemCopy( B.Buffer(), A.LockedBuffer(), localSize );
125
126 // Broadcast from the root
127 mpi::Broadcast( B.Buffer(), localSize, 0, depthComm );
128 }
129
130 /*
131 * Distributes A in such a way that
132 * Layer 0 <- A(:, 0:(n/h - 1))
133 * Layer 1 <- A(:, (n/h):(2n/h - 1))
134 * .
135 * .
136 * .
137 * Layer h-1 <- A(:, ((h-1)n/h):n)
138 */
DistributeCols(const mpi::Comm & depthComm,const DistMatrix<double> & A,DistMatrix<double> & B)139 void DistributeCols
140 ( const mpi::Comm& depthComm,
141 const DistMatrix<double>& A, DistMatrix<double>& B )
142 {
143 const int depthSize = mpi::Size( depthComm );
144 const int depthRank = mpi::Rank( depthComm );
145
146 const int sendCount = A.LocalHeight()*A.LocalWidth();
147 const int recvCount = sendCount / depthSize;
148
149 // For now, we will make B as large as A...
150 // TODO: NOT DO THIS
151 if( A.LocalHeight() != A.LDim() )
152 throw std::logic_error("Local height did not match ldim");
153 B.Empty();
154 B.AlignWith( A );
155 Zeros( B, A.Height(), A.Width() );
156
157 // Scatter
158 const int localColOffset = (A.LocalWidth()/depthSize)*depthRank;
159 mpi::Scatter
160 ( A.LockedBuffer(), recvCount,
161 B.Buffer(0,localColOffset), recvCount, 0, depthComm );
162 }
163
164 /*
165 * Distributes A in such a way that
166 * Layer 0 <- A(0:(m/h - 1), :)
167 * Layer 1 <- A((m/h):(2m/h - 1), :)
168 * .
169 * .
170 * .
171 * Layer h-1 <- A(((h-1)m/h):m, :)
172 */
DistributeRows(const mpi::Comm & depthComm,const DistMatrix<double> & A,DistMatrix<double> & B)173 void DistributeRows
174 ( const mpi::Comm& depthComm,
175 const DistMatrix<double>& A, DistMatrix<double>& B )
176 {
177 const int depthRank = mpi::Rank( depthComm );
178 const int depthSize = mpi::Size( depthComm );
179 const Grid& meshGrid = A.Grid();
180
181 const int sendCount = A.LocalHeight()*A.LocalWidth();
182 const int recvCount = sendCount / depthSize;
183
184 // Have the root mesh pack the data for scattering
185 std::vector<double> sendBuf;
186 const int blockSize = A.Height() / depthSize;
187 if( depthRank == 0 )
188 {
189 sendBuf.resize( sendCount );
190 MemZero( &sendBuf[0], sendCount ); // TODO: Is this necessary?!?
191
192 DistMatrix<double>
193 AT(meshGrid), A0(meshGrid),
194 AB(meshGrid), A1(meshGrid),
195 A2(meshGrid);
196
197 // Pack rows block by block for each layer
198 LockedPartitionDown
199 ( A, AT,
200 AB, 0 );
201 for( int i=0; i<depthSize; ++i )
202 {
203 LockedRepartitionDown
204 ( AT, A0,
205 /**/ /**/
206 A1,
207 AB, A2, blockSize );
208
209 const int dataSize = A1.LocalWidth()*A1.LocalHeight();
210 const int offset = i*dataSize;
211
212 // TODO: Avoid the extra copy...
213 DistMatrix<double> A1Contig( A1 );
214 MemCopy( &sendBuf[offset], A1Contig.LockedBuffer(), dataSize );
215
216 SlideLockedPartitionDown
217 ( AT, A0,
218 A1,
219 /**/ /**/
220 AB, A2 );
221 }
222 }
223
224 // Scatter the packed data
225 std::vector<double> recvBuf( recvCount );
226 mpi::Scatter
227 ( &sendBuf[0], recvCount, &recvBuf[0], recvCount, 0, depthComm );
228
229 // Pad received data by zero
230 DistMatrix<double> dataBlock( meshGrid );
231 dataBlock.Attach
232 ( blockSize, A.Width(), meshGrid, 0, 0,
233 &recvBuf[0], blockSize/meshGrid.Height() );
234
235 // TODO: We can probably heavily simplify this...
236 //
237 // dataBlock_T <- transpose(dataBlock)
238 // tmp_T <- padWithZeros(dataBlockT)
239 // tmp <- transpose(tmp_T)
240 // Layer x <- M((x*Mm/h):((x+1)*Mm/h - 1), :)
241 DistMatrix<double> dataBlockTrans( meshGrid );
242 Transpose( dataBlock, dataBlockTrans );
243
244 std::vector<double> newData( sendCount );
245 MemZero( &newData[0], sendCount );
246 const int offset = depthRank*recvCount;
247
248 MemCopy( &newData[offset], dataBlockTrans.LockedBuffer(), recvCount );
249
250 DistMatrix<double> tmpTrans( meshGrid );
251 tmpTrans.Attach
252 ( A.Width(), A.Height(), meshGrid, 0, 0,
253 &newData[0], A.Width()/meshGrid.Width() );
254 DistMatrix<double> tmp( meshGrid );
255 Transpose( tmpTrans, tmp );
256
257 Transpose( tmpTrans, B );
258 }
259
260 // Initialize all matrices in order to set up for the G3D GEMM
InitializeMatrices(int type,mpi::Comm & depthComm,int m,int n,int k,DistMatrix<double> & AOut,DistMatrix<double> & BOut,DistMatrix<double> & COut,bool print)261 void InitializeMatrices
262 ( int type, mpi::Comm& depthComm,
263 int m, int n, int k,
264 DistMatrix<double>& AOut,
265 DistMatrix<double>& BOut,
266 DistMatrix<double>& COut,
267 bool print )
268 {
269 const Grid& meshGrid = AOut.Grid();
270
271 DistMatrix<double> A( m, k, meshGrid ),
272 B( k, n, meshGrid ),
273 C( m, n, meshGrid );
274
275 //Initialize top layer with desired matrices
276 InitA( A, print );
277 InitB( B, print );
278 InitC( C, print );
279
280 //Distribute matrices according to which matrix is stationary
281 switch (type)
282 {
283 case 'A':
284 DepthBroadcast( depthComm, A, AOut );
285 DistributeCols( depthComm, B, BOut);
286 DistributeCols( depthComm, C, COut);
287 break;
288 case 'B':
289 DistributeRows( depthComm, A, AOut);
290 DepthBroadcast( depthComm, B, BOut );
291 DistributeRows( depthComm, C, COut);
292 break;
293 case 'C':
294 DistributeCols( depthComm, A, AOut);
295 DistributeRows( depthComm, B, BOut);
296 CopyOrReset( C, COut );
297 break;
298 default:
299 throw std::logic_error("Unknown stationary type");
300 }
301 }
302
303 // Reduce across depth to get end result C
SumContributions(mpi::Comm & depthComm,const DistMatrix<double> & APartial,DistMatrix<double> & A)304 void SumContributions
305 ( mpi::Comm& depthComm,
306 const DistMatrix<double>& APartial, DistMatrix<double>& A )
307 {
308 A.Empty();
309 A.AlignWith( APartial );
310 A.Resize( APartial.Height(), APartial.Width() );
311
312 if( APartial.LocalHeight() != APartial.LDim() )
313 throw std::logic_error
314 ("APartial did not have matching local height/ldim");
315 if( A.LocalHeight() != A.LDim() )
316 throw std::logic_error("A did not have matching local height/ldim");
317
318 const int dataSize = APartial.LocalHeight()*APartial.LocalWidth();
319 mpi::AllReduce
320 ( APartial.LockedBuffer(), A.Buffer(), dataSize, mpi::SUM, depthComm );
321 }
322
main(int argc,char * argv[])323 int main( int argc, char* argv[] )
324 {
325 Initialize( argc, argv );
326 mpi::Comm comm = mpi::COMM_WORLD;
327 const int commRank = mpi::Rank( comm );
328
329 try
330 {
331 const char type = Input("--type","'A', 'B', or 'C' algorithm",'C');
332 const int r = Input<int>("--gridHeight","height of process grid");
333 const int c = Input<int>("--gridWidth","width of process grid");
334 const int depth = Input<int>("--depth","amount of redundancy");
335 const int m = Input("--m","height of result",500);
336 const int n = Input("--n","width of result",500);
337 const int k = Input("--k","inner dimension",500);
338 const bool print = Input("--print","print matrices?",false);
339 ProcessInput();
340 PrintInputReport();
341
342 // Sanity check on inputs
343 if( m % r != 0 || m % c != 0 || m % depth != 0 ||
344 k % r != 0 || k % c != 0 || k % depth != 0 ||
345 n % r != 0 || n % c != 0 || n % depth != 0 )
346 {
347 if( commRank == 0 )
348 std::cout << "Dimensions of matrices must be multiples of "
349 "grid dimensions (for now)" << std::endl;
350 Finalize();
351 return 0;
352 }
353 if( type < 'A' || type > 'C' )
354 {
355 if( commRank == 0 )
356 std::cout << "Algorithm must be 'A', 'B', or 'C'" << std::endl;
357 Finalize();
358 return 0;
359 }
360
361 DEBUG_ONLY(
362 if( commRank == 0 )
363 {
364 std::cout
365 << "==========================================\n"
366 << " In debug mode! Performance will be poor! \n"
367 << "=========================================="
368 << std::endl;
369 }
370 )
371
372 mpi::Comm depthComm, meshComm;
373 InitDepthComms( r*c, depthComm, meshComm );
374 const int depthRank = mpi::Rank( depthComm );
375 const Grid meshGrid( meshComm, r, c );
376
377 DistMatrix<double> A( m, k, meshGrid ),
378 B( k, n, meshGrid ),
379 CPartial( m, n, meshGrid ),
380 C( m, n, meshGrid );
381
382 InitializeMatrices( type, depthComm, m, n, k, A, B, CPartial, print );
383
384 // Compute within our mesh
385 mpi::Barrier( comm );
386 const double startTime = mpi::Time();
387 Gemm( NORMAL, NORMAL, 1.0, A, B, 1.0, CPartial );
388 SumContributions( depthComm, CPartial, C );
389 mpi::Barrier( comm );
390 const double stopTime = mpi::Time();
391 if( commRank == 0 )
392 std::cout << "Runtime: " << stopTime-startTime << " seconds"
393 << std::endl;
394
395 if( depthRank == 0 && print )
396 Print( C, "C" );
397 }
398 catch( std::exception& e ) { ReportException(e); }
399
400 Finalize();
401 return 0;
402 }
403