1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #include "elemental-lite.hpp"
10
11 #define ColDist MC
12 #define RowDist MR
13
14 #include "./setup.hpp"
15
16 namespace elem {
17
18 // Public section
19 // ##############
20
21 // Assignment and reconfiguration
22 // ==============================
23
24 template<typename T>
25 DM&
operator =(const DM & A)26 DM::operator=( const DM& A )
27 {
28 DEBUG_ONLY(CallStackEntry cse("DM[U,V] = DM[U,V]"))
29 if( this->Grid() == A.Grid() )
30 A.Translate( *this );
31 else
32 this->CopyFromDifferentGrid( A );
33 return *this;
34 }
35
36 template<typename T>
37 DM&
operator =(const DistMatrix<T,MC,STAR> & A)38 DM::operator=( const DistMatrix<T,MC,STAR>& A )
39 {
40 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [MC,STAR]"))
41 this->RowFilterFrom( A );
42 return *this;
43 }
44
45 template<typename T>
46 DM&
operator =(const DistMatrix<T,STAR,MR> & A)47 DM::operator=( const DistMatrix<T,STAR,MR>& A )
48 {
49 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [STAR,MR]"))
50 this->ColFilterFrom( A );
51 return *this;
52 }
53
54 template<typename T>
55 DM&
operator =(const DistMatrix<T,MD,STAR> & A)56 DM::operator=( const DistMatrix<T,MD,STAR>& A )
57 {
58 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [MD,STAR]"))
59 // TODO: More efficient implementation?
60 DistMatrix<T,STAR,STAR> A_STAR_STAR( A );
61 *this = A_STAR_STAR;
62 return *this;
63 }
64
65 template<typename T>
66 DM&
operator =(const DistMatrix<T,STAR,MD> & A)67 DM::operator=( const DistMatrix<T,STAR,MD>& A )
68 {
69 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [STAR,MD]"))
70 // TODO: More efficient implementation?
71 DistMatrix<T,STAR,STAR> A_STAR_STAR( A );
72 *this = A_STAR_STAR;
73 return *this;
74 }
75
76 template<typename T>
77 DM&
operator =(const DistMatrix<T,MR,MC> & A)78 DM::operator=( const DistMatrix<T,MR,MC>& A )
79 {
80 DEBUG_ONLY(
81 CallStackEntry cse("[MC,MR] = [MR,MC]");
82 this->AssertNotLocked();
83 this->AssertSameGrid( A.Grid() );
84 )
85 const elem::Grid& g = A.Grid();
86 this->Resize( A.Height(), A.Width() );
87 if( !this->Participating() )
88 return *this;
89
90 if( A.Width() == 1 )
91 {
92 const Int r = g.Height();
93 const Int c = g.Width();
94 const Int p = g.Size();
95 const Int myRow = g.Row();
96 const Int myCol = g.Col();
97 const Int rankCM = g.VCRank();
98 const Int rankRM = g.VRRank();
99 const Int ownerCol = this->RowAlign();
100 const Int ownerRow = A.RowAlign();
101 const Int colAlign = this->ColAlign();
102 const Int colAlignA = A.ColAlign();
103 const Int colShift = this->ColShift();
104 const Int colShiftA = A.ColShift();
105
106 const Int height = A.Height();
107 const Int maxLocalHeight = MaxLength(height,p);
108 const Int portionSize = mpi::Pad( maxLocalHeight );
109
110 const Int colShiftVC = Shift(rankCM,colAlign,p);
111 const Int colShiftVRA = Shift(rankRM,colAlignA,p);
112 const Int sendRankCM = (rankCM+(p+colShiftVRA-colShiftVC)) % p;
113 const Int recvRankRM = (rankRM+(p+colShiftVC-colShiftVRA)) % p;
114 const Int recvRankCM = (recvRankRM/c)+r*(recvRankRM%c);
115
116 T* buffer = this->auxMemory_.Require( (r+c)*portionSize );
117 T* sendBuf = &buffer[0];
118 T* recvBuf = &buffer[c*portionSize];
119
120 if( myRow == ownerRow )
121 {
122 // Pack
123 const T* ABuffer = A.LockedBuffer();
124 ELEM_PARALLEL_FOR
125 for( Int k=0; k<r; ++k )
126 {
127 T* data = &recvBuf[k*portionSize];
128
129 const Int shift = Shift_(myCol+c*k,colAlignA,p);
130 const Int offset = (shift-colShiftA) / c;
131 const Int thisLocalHeight = Length_(height,shift,p);
132
133 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
134 data[iLoc] = ABuffer[offset+iLoc*r];
135 }
136 }
137
138 // A[VR,STAR] <- A[MR,MC]
139 mpi::Scatter
140 ( recvBuf, portionSize, sendBuf, portionSize, ownerRow, g.ColComm() );
141
142 // A[VC,STAR] <- A[VR,STAR]
143 mpi::SendRecv
144 ( sendBuf, portionSize, sendRankCM,
145 recvBuf, portionSize, recvRankCM, g.VCComm() );
146
147 // A[MC,MR] <- A[VC,STAR]
148 mpi::Gather
149 ( recvBuf, portionSize, sendBuf, portionSize, ownerCol, g.RowComm() );
150
151 if( myCol == ownerCol )
152 {
153 // Unpack
154 T* thisBuffer = this->Buffer();
155 ELEM_PARALLEL_FOR
156 for( Int k=0; k<c; ++k )
157 {
158 const T* data = &sendBuf[k*portionSize];
159
160 const Int shift = Shift_(myRow+r*k,colAlign,p);
161 const Int offset = (shift-colShift) / r;
162 const Int thisLocalHeight = Length_(height,shift,p);
163
164 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
165 thisBuffer[offset+iLoc*c] = data[iLoc];
166 }
167 }
168 this->auxMemory_.Release();
169 }
170 else if( A.Height() == 1 )
171 {
172 const Int r = g.Height();
173 const Int c = g.Width();
174 const Int p = g.Size();
175 const Int myRow = g.Row();
176 const Int myCol = g.Col();
177 const Int rankCM = g.VCRank();
178 const Int rankRM = g.VRRank();
179 const Int ownerRow = this->ColAlign();
180 const Int ownerCol = A.ColAlign();
181 const Int rowAlign = this->RowAlign();
182 const Int rowAlignA = A.RowAlign();
183 const Int rowShift = this->RowShift();
184 const Int rowShiftA = A.RowShift();
185
186 const Int width = A.Width();
187 const Int maxLocalWidth = MaxLength(width,p);
188 const Int portionSize = mpi::Pad( maxLocalWidth );
189
190 const Int rowShiftVR = Shift(rankRM,rowAlign,p);
191 const Int rowShiftVCA = Shift(rankCM,rowAlignA,p);
192 const Int sendRankRM = (rankRM+(p+rowShiftVCA-rowShiftVR)) % p;
193 const Int recvRankCM = (rankCM+(p+rowShiftVR-rowShiftVCA)) % p;
194 const Int recvRankRM = (recvRankCM/r)+c*(recvRankCM%r);
195
196 T* buffer = this->auxMemory_.Require( (r+c)*portionSize );
197 T* sendBuf = &buffer[0];
198 T* recvBuf = &buffer[r*portionSize];
199
200 if( myCol == ownerCol )
201 {
202 // Pack
203 const T* ABuffer = A.LockedBuffer();
204 const Int ALDim = A.LDim();
205 ELEM_PARALLEL_FOR
206 for( Int k=0; k<c; ++k )
207 {
208 T* data = &recvBuf[k*portionSize];
209
210 const Int shift = Shift_(myRow+r*k,rowAlignA,p);
211 const Int offset = (shift-rowShiftA) / r;
212 const Int thisLocalWidth = Length_(width,shift,p);
213
214 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
215 data[jLoc] = ABuffer[(offset+jLoc*c)*ALDim];
216 }
217 }
218
219 // A[STAR,VC] <- A[MR,MC]
220 mpi::Scatter
221 ( recvBuf, portionSize, sendBuf, portionSize, ownerCol, g.RowComm() );
222
223 // A[STAR,VR] <- A[STAR,VC]
224 mpi::SendRecv
225 ( sendBuf, portionSize, sendRankRM,
226 recvBuf, portionSize, recvRankRM, g.VRComm() );
227
228 // A[MC,MR] <- A[STAR,VR]
229 mpi::Gather
230 ( recvBuf, portionSize, sendBuf, portionSize, ownerRow, g.ColComm() );
231
232 if( myRow == ownerRow )
233 {
234 // Unpack
235 T* thisBuffer = this->Buffer();
236 const Int thisLDim = this->LDim();
237 ELEM_PARALLEL_FOR
238 for( Int k=0; k<r; ++k )
239 {
240 const T* data = &sendBuf[k*portionSize];
241
242 const Int shift = Shift_(myCol+c*k,rowAlign,p);
243 const Int offset = (shift-rowShift) / c;
244 const Int thisLocalWidth = Length_(width,shift,p);
245
246 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
247 thisBuffer[(offset+jLoc*r)*thisLDim] = data[jLoc];
248 }
249 }
250
251 this->auxMemory_.Release();
252 }
253 else
254 {
255 if( A.Height() >= A.Width() )
256 {
257 std::unique_ptr<DistMatrix<T,VR,STAR>> A_VR_STAR
258 ( new DistMatrix<T,VR,STAR>(A) );
259
260 std::unique_ptr<DistMatrix<T,VC,STAR>> A_VC_STAR
261 ( new DistMatrix<T,VC,STAR>(g) );
262 A_VC_STAR->AlignColsWith(*this);
263 *A_VC_STAR = *A_VR_STAR;
264 delete A_VR_STAR.release(); // lowers memory highwater
265
266 *this = *A_VC_STAR;
267 }
268 else
269 {
270 std::unique_ptr<DistMatrix<T,STAR,VC>> A_STAR_VC
271 ( new DistMatrix<T,STAR,VC>(A) );
272
273 std::unique_ptr<DistMatrix<T,STAR,VR>> A_STAR_VR
274 ( new DistMatrix<T,STAR,VR>(g) );
275 A_STAR_VR->AlignRowsWith(*this);
276 *A_STAR_VR = *A_STAR_VC;
277 delete A_STAR_VC.release(); // lowers memory highwater
278
279 *this = *A_STAR_VR;
280 this->Resize( A_STAR_VR->Height(), A_STAR_VR->Width() );
281 }
282 }
283 return *this;
284 }
285
286 template<typename T>
287 DM&
operator =(const DistMatrix<T,MR,STAR> & A)288 DM::operator=( const DistMatrix<T,MR,STAR>& A )
289 {
290 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [MR,STAR]"))
291 std::unique_ptr<DistMatrix<T,VR,STAR>> A_VR_STAR
292 ( new DistMatrix<T,VR,STAR>(A) );
293 std::unique_ptr<DistMatrix<T,VC,STAR>> A_VC_STAR
294 ( new DistMatrix<T,VC,STAR>(this->Grid()) );
295 A_VC_STAR->AlignColsWith(*this);
296 *A_VC_STAR = *A_VR_STAR;
297 delete A_VR_STAR.release(); // lowers memory highwater
298 *this = *A_VC_STAR;
299 return *this;
300 }
301
302 template<typename T>
303 DM&
operator =(const DistMatrix<T,STAR,MC> & A)304 DM::operator=( const DistMatrix<T,STAR,MC>& A )
305 {
306 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [STAR,MC]"))
307 std::unique_ptr<DistMatrix<T,STAR,VC>> A_STAR_VC
308 ( new DistMatrix<T,STAR,VC>(A) );
309 std::unique_ptr<DistMatrix<T,STAR,VR>> A_STAR_VR
310 ( new DistMatrix<T,STAR,VR>(this->Grid()) );
311 A_STAR_VR->AlignRowsWith(*this);
312 *A_STAR_VR = *A_STAR_VC;
313 delete A_STAR_VC.release(); // lowers memory highwater
314 *this = *A_STAR_VR;
315 return *this;
316 }
317
318 template<typename T>
319 DM&
operator =(const DistMatrix<T,VC,STAR> & A)320 DM::operator=( const DistMatrix<T,VC,STAR>& A )
321 {
322 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [VC,STAR]"))
323 A.PartialColAllToAll( *this );
324 return *this;
325 }
326
327 template<typename T>
328 DM&
operator =(const DistMatrix<T,STAR,VC> & A)329 DM::operator=( const DistMatrix<T,STAR,VC>& A )
330 {
331 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [STAR,VC]"))
332 DistMatrix<T,STAR,VR> A_STAR_VR(this->Grid());
333 A_STAR_VR.AlignRowsWith(*this);
334 A_STAR_VR = A;
335 *this = A_STAR_VR;
336 return *this;
337 }
338
339 template<typename T>
340 DM&
operator =(const DistMatrix<T,VR,STAR> & A)341 DM::operator=( const DistMatrix<T,VR,STAR>& A )
342 {
343 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [VR,STAR]"))
344 DistMatrix<T,VC,STAR> A_VC_STAR(this->Grid());
345 A_VC_STAR.AlignColsWith(*this);
346 A_VC_STAR = A;
347 *this = A_VC_STAR;
348 return *this;
349 }
350
351 template<typename T>
352 DM&
operator =(const DistMatrix<T,STAR,VR> & A)353 DM::operator=( const DistMatrix<T,STAR,VR>& A )
354 {
355 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [STAR,VR]"))
356 A.PartialRowAllToAll( *this );
357 return *this;
358 }
359
360 template<typename T>
361 DM&
operator =(const DistMatrix<T,STAR,STAR> & A)362 DM::operator=( const DistMatrix<T,STAR,STAR>& A )
363 {
364 DEBUG_ONLY(CallStackEntry cse("[MC,MR] = [STAR,STAR]"))
365 this->FilterFrom( A );
366 return *this;
367 }
368
369 template<typename T>
370 DM&
operator =(const DistMatrix<T,CIRC,CIRC> & A)371 DM::operator=( const DistMatrix<T,CIRC,CIRC>& A )
372 {
373 DEBUG_ONLY(
374 CallStackEntry cse("[MC,MR] = [CIRC,CIRC]");
375 this->AssertNotLocked();
376 this->AssertSameGrid( A.Grid() );
377 )
378 const Grid& g = A.Grid();
379 const Int m = A.Height();
380 const Int n = A.Width();
381 const Int colStride = this->ColStride();
382 const Int rowStride = this->RowStride();
383 const Int p = g.Size();
384 this->Resize( m, n );
385
386 const Int colAlign = this->ColAlign();
387 const Int rowAlign = this->RowAlign();
388 const Int mLocal = this->LocalHeight();
389 const Int nLocal = this->LocalWidth();
390 const Int pkgSize = mpi::Pad(MaxLength(m,colStride)*MaxLength(n,rowStride));
391 const Int recvSize = pkgSize;
392 const Int sendSize = p*pkgSize;
393 T* recvBuf=0; // some compilers (falsely) warn otherwise
394 if( A.Participating() )
395 {
396 T* buffer = this->auxMemory_.Require( sendSize + recvSize );
397 T* sendBuf = &buffer[0];
398 recvBuf = &buffer[sendSize];
399
400 // Pack the send buffer
401 const Int ALDim = A.LDim();
402 const T* ABuffer = A.LockedBuffer();
403 for( Int t=0; t<rowStride; ++t )
404 {
405 const Int tLocalWidth = Length( n, t, rowStride );
406 const Int col = (rowAlign+t) % rowStride;
407 for( Int s=0; s<colStride; ++s )
408 {
409 const Int sLocalHeight = Length( m, s, colStride );
410 const Int row = (colAlign+s) % colStride;
411 const Int q = row + col*colStride;
412 for( Int jLoc=0; jLoc<tLocalWidth; ++jLoc )
413 {
414 const Int j = t + jLoc*rowStride;
415 for( Int iLoc=0; iLoc<sLocalHeight; ++iLoc )
416 {
417 const Int i = s + iLoc*colStride;
418 sendBuf[q*pkgSize+iLoc+jLoc*sLocalHeight] =
419 ABuffer[i+j*ALDim];
420 }
421 }
422 }
423 }
424
425 // Scatter from the root
426 mpi::Scatter
427 ( sendBuf, pkgSize, recvBuf, pkgSize, A.Root(), g.VCComm() );
428 }
429 else if( this->Participating() )
430 {
431 recvBuf = this->auxMemory_.Require( recvSize );
432
433 // Perform the receiving portion of the scatter from the non-root
434 mpi::Scatter
435 ( static_cast<T*>(0), pkgSize,
436 recvBuf, pkgSize, A.Root(), g.VCComm() );
437 }
438
439 if( this->Participating() )
440 {
441 // Unpack
442 const Int ldim = this->LDim();
443 T* buffer = this->Buffer();
444 for( Int jLoc=0; jLoc<nLocal; ++jLoc )
445 for( Int iLoc=0; iLoc<mLocal; ++iLoc )
446 buffer[iLoc+jLoc*ldim] = recvBuf[iLoc+jLoc*mLocal];
447 this->auxMemory_.Release();
448 }
449
450 return *this;
451 }
452
453 // Basic queries
454 // =============
455
456 template<typename T>
DistComm() const457 mpi::Comm DM::DistComm() const { return this->grid_->VCComm(); }
458 template<typename T>
CrossComm() const459 mpi::Comm DM::CrossComm() const { return mpi::COMM_SELF; }
460 template<typename T>
RedundantComm() const461 mpi::Comm DM::RedundantComm() const { return mpi::COMM_SELF; }
462 template<typename T>
ColComm() const463 mpi::Comm DM::ColComm() const { return this->grid_->MCComm(); }
464 template<typename T>
RowComm() const465 mpi::Comm DM::RowComm() const { return this->grid_->MRComm(); }
466
467 template<typename T>
ColStride() const468 Int DM::ColStride() const { return this->grid_->MCSize(); }
469 template<typename T>
RowStride() const470 Int DM::RowStride() const { return this->grid_->MRSize(); }
471 template<typename T>
DistSize() const472 Int DM::DistSize() const { return this->grid_->VCSize(); }
473 template<typename T>
CrossSize() const474 Int DM::CrossSize() const { return 1; }
475 template<typename T>
RedundantSize() const476 Int DM::RedundantSize() const { return 1; }
477
478 // Private section
479 // ###############
480
481 template<typename T>
CopyFromDifferentGrid(const DM & A)482 void DM::CopyFromDifferentGrid( const DM& A )
483 {
484 DEBUG_ONLY(CallStackEntry cse("[MC,MR]::CopyFromDifferentGrid"))
485 this->Resize( A.Height(), A.Width() );
486 // Just need to ensure that each viewing comm contains the other team's
487 // owning comm. Congruence is too strong.
488
489 // Compute the number of process rows and columns that each process
490 // needs to send to.
491 const Int colStride = this->ColStride();
492 const Int rowStride = this->RowStride();
493 const Int colRank = this->ColRank();
494 const Int rowRank = this->RowRank();
495 const Int colStrideA = A.ColStride();
496 const Int rowStrideA = A.RowStride();
497 const Int colRankA = A.ColRank();
498 const Int rowRankA = A.RowRank();
499 const Int colGCD = GCD( colStride, colStrideA );
500 const Int rowGCD = GCD( rowStride, rowStrideA );
501 const Int colLCM = colStride*colStrideA / colGCD;
502 const Int rowLCM = rowStride*rowStrideA / rowGCD;
503 const Int numColSends = colStride / colGCD;
504 const Int numRowSends = rowStride / rowGCD;
505 const Int localColStride = colLCM / colStride;
506 const Int localRowStride = rowLCM / rowStride;
507 const Int localColStrideA = numColSends;
508 const Int localRowStrideA = numRowSends;
509
510 const Int colAlign = this->ColAlign();
511 const Int rowAlign = this->RowAlign();
512 const Int colAlignA = A.ColAlign();
513 const Int rowAlignA = A.RowAlign();
514
515 const bool inThisGrid = this->Participating();
516 const bool inAGrid = A.Participating();
517 if( !inThisGrid && !inAGrid )
518 return;
519
520 const Int maxSendSize =
521 (A.Height()/(colStrideA*localColStrideA)+1) *
522 (A.Width()/(rowStrideA*localRowStrideA)+1);
523
524 // Translate the ranks from A's VC communicator to this's viewing so that
525 // we can match send/recv communicators. Since A's VC communicator is not
526 // necessarily defined on every process, we instead work with A's owning
527 // group and account for row-major ordering if necessary.
528 const int sizeA = A.Grid().Size();
529 std::vector<int> rankMap(sizeA), ranks(sizeA);
530 if( A.Grid().Order() == COLUMN_MAJOR )
531 {
532 for( int j=0; j<sizeA; ++j )
533 ranks[j] = j;
534 }
535 else
536 {
537 // The (i,j) = i + j*colStrideA rank in the column-major ordering is
538 // equal to the j + i*rowStrideA rank in a row-major ordering.
539 // Since we desire rankMap[i+j*colStrideA] to correspond to process
540 // (i,j) in A's grid's rank in this viewing group, ranks[i+j*colStrideA]
541 // should correspond to process (i,j) in A's owning group. Since the
542 // owning group is ordered row-major in this case, its rank is
543 // j+i*rowStrideA. Note that setting
544 // ranks[j+i*rowStrideA] = i+j*colStrideA is *NOT* valid.
545 for( int i=0; i<colStrideA; ++i )
546 for( int j=0; j<rowStrideA; ++j )
547 ranks[i+j*colStrideA] = j+i*rowStrideA;
548 }
549 mpi::Translate
550 ( A.Grid().OwningGroup(), sizeA, &ranks[0],
551 this->Grid().ViewingComm(), &rankMap[0] );
552
553 // Have each member of A's grid individually send to all numRow x numCol
554 // processes in order, while the members of this grid receive from all
555 // necessary processes at each step.
556 Int requiredMemory = 0;
557 if( inAGrid )
558 requiredMemory += maxSendSize;
559 if( inThisGrid )
560 requiredMemory += maxSendSize;
561 T* auxBuf = this->auxMemory_.Require( requiredMemory );
562 Int offset = 0;
563 T* sendBuf = &auxBuf[offset];
564 if( inAGrid )
565 offset += maxSendSize;
566 T* recvBuf = &auxBuf[offset];
567
568 Int recvRow = 0; // avoid compiler warnings...
569 if( inAGrid )
570 recvRow = (((colRankA+colStrideA-colAlignA)%colStrideA)+colAlign) %
571 colStride;
572 for( Int colSend=0; colSend<numColSends; ++colSend )
573 {
574 Int recvCol = 0; // avoid compiler warnings...
575 if( inAGrid )
576 recvCol = (((rowRankA+rowStrideA-rowAlignA)%rowStrideA)+rowAlign) %
577 rowStride;
578 for( Int rowSend=0; rowSend<numRowSends; ++rowSend )
579 {
580 mpi::Request sendRequest;
581 // Fire off this round of non-blocking sends
582 if( inAGrid )
583 {
584 // Pack the data
585 Int sendHeight = Length(A.LocalHeight(),colSend,numColSends);
586 Int sendWidth = Length(A.LocalWidth(),rowSend,numRowSends);
587 const T* ABuffer = A.LockedBuffer();
588 const Int ALDim = A.LDim();
589 ELEM_PARALLEL_FOR
590 for( Int jLoc=0; jLoc<sendWidth; ++jLoc )
591 {
592 const Int j = rowSend+jLoc*localRowStrideA;
593 for( Int iLoc=0; iLoc<sendHeight; ++iLoc )
594 {
595 const Int i = colSend+iLoc*localColStrideA;
596 sendBuf[iLoc+jLoc*sendHeight] = ABuffer[i+j*ALDim];
597 }
598 }
599 // Send data
600 const Int recvVCRank = recvRow + recvCol*colStride;
601 const Int recvViewingRank =
602 this->Grid().VCToViewingMap( recvVCRank );
603 mpi::ISend
604 ( sendBuf, sendHeight*sendWidth, recvViewingRank,
605 this->Grid().ViewingComm(), sendRequest );
606 }
607 // Perform this round of recv's
608 if( inThisGrid )
609 {
610 const Int sendColOffset = (colSend*colStrideA+colAlignA) % colStrideA;
611 const Int recvColOffset = (colSend*colStrideA+colAlign) % colStride;
612 const Int sendRowOffset = (rowSend*rowStrideA+rowAlignA) % rowStrideA;
613 const Int recvRowOffset = (rowSend*rowStrideA+rowAlign) % rowStride;
614
615 const Int firstSendRow = (((colRank+colStride-recvColOffset)%colStride)+sendColOffset)%colStrideA;
616 const Int firstSendCol = (((rowRank+rowStride-recvRowOffset)%rowStride)+sendRowOffset)%rowStrideA;
617
618 const Int colShift = (colRank+colStride-recvColOffset)%colStride;
619 const Int rowShift = (rowRank+rowStride-recvRowOffset)%rowStride;
620 const Int numColRecvs = Length( colStrideA, colShift, colStride );
621 const Int numRowRecvs = Length( rowStrideA, rowShift, rowStride );
622
623 // Recv data
624 // For now, simply receive sequentially. Until we switch to
625 // nonblocking recv's, we won't be using much of the
626 // recvBuf
627 Int sendRow = firstSendRow;
628 for( Int colRecv=0; colRecv<numColRecvs; ++colRecv )
629 {
630 const Int sendColShift = Shift( sendRow, colAlignA, colStrideA ) + colSend*colStrideA;
631 const Int sendHeight = Length( A.Height(), sendColShift, colLCM );
632 const Int localColOffset = (sendColShift-this->ColShift()) / colStride;
633
634 Int sendCol = firstSendCol;
635 for( Int rowRecv=0; rowRecv<numRowRecvs; ++rowRecv )
636 {
637 const Int sendRowShift = Shift( sendCol, rowAlignA, rowStrideA ) + rowSend*rowStrideA;
638 const Int sendWidth = Length( A.Width(), sendRowShift, rowLCM );
639 const Int localRowOffset = (sendRowShift-this->RowShift()) / rowStride;
640
641 const Int sendVCRank = sendRow+sendCol*colStrideA;
642 mpi::Recv
643 ( recvBuf, sendHeight*sendWidth, rankMap[sendVCRank],
644 this->Grid().ViewingComm() );
645
646 // Unpack the data
647 T* buffer = this->Buffer();
648 const Int ldim = this->LDim();
649 ELEM_PARALLEL_FOR
650 for( Int jLoc=0; jLoc<sendWidth; ++jLoc )
651 {
652 const Int j = localRowOffset+jLoc*localRowStride;
653 for( Int iLoc=0; iLoc<sendHeight; ++iLoc )
654 {
655 const Int i = localColOffset+iLoc*localColStride;
656 buffer[i+j*ldim] = recvBuf[iLoc+jLoc*sendHeight];
657 }
658 }
659 // Set up the next send col
660 sendCol = (sendCol + rowStride) % rowStrideA;
661 }
662 // Set up the next send row
663 sendRow = (sendRow + colStride) % colStrideA;
664 }
665 }
666 // Ensure that this round of non-blocking sends completes
667 if( inAGrid )
668 {
669 mpi::Wait( sendRequest );
670 recvCol = (recvCol + rowStrideA) % rowStride;
671 }
672 }
673 if( inAGrid )
674 recvRow = (recvRow + colStrideA) % colStride;
675 }
676 this->auxMemory_.Release();
677 }
678
679 // Instantiate {Int,Real,Complex<Real>} for each Real in {float,double}
680 // ####################################################################
681
682 #define PROTO(T) template class DistMatrix<T,ColDist,RowDist>
683 #define SELF(T,U,V) \
684 template DistMatrix<T,ColDist,RowDist>::DistMatrix \
685 ( const DistMatrix<T,U,V>& A );
686 #define OTHER(T,U,V) \
687 template DistMatrix<T,ColDist,RowDist>::DistMatrix \
688 ( const BlockDistMatrix<T,U,V>& A ); \
689 template DistMatrix<T,ColDist,RowDist>& \
690 DistMatrix<T,ColDist,RowDist>::operator= \
691 ( const BlockDistMatrix<T,U,V>& A )
692 #define BOTH(T,U,V) \
693 SELF(T,U,V); \
694 OTHER(T,U,V)
695 #define FULL(T) \
696 PROTO(T); \
697 BOTH( T,CIRC,CIRC); \
698 OTHER(T,MC, MR ); \
699 BOTH( T,MC, STAR); \
700 BOTH( T,MD, STAR); \
701 BOTH( T,MR, MC ); \
702 BOTH( T,MR, STAR); \
703 BOTH( T,STAR,MC ); \
704 BOTH( T,STAR,MD ); \
705 BOTH( T,STAR,MR ); \
706 BOTH( T,STAR,STAR); \
707 BOTH( T,STAR,VC ); \
708 BOTH( T,STAR,VR ); \
709 BOTH( T,VC, STAR); \
710 BOTH( T,VR, STAR);
711
712 FULL(Int);
713 #ifndef ELEM_DISABLE_FLOAT
714 FULL(float);
715 #endif
716 FULL(double);
717
718 #ifndef ELEM_DISABLE_COMPLEX
719 #ifndef ELEM_DISABLE_FLOAT
720 FULL(Complex<float>);
721 #endif
722 FULL(Complex<double>);
723 #endif
724
725 } // namespace elem
726