1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #include "elemental-lite.hpp"
10 #include ELEM_TRANSPOSE_INC
11 #include ELEM_ZEROS_INC
12 
13 namespace elem {
14 
15 // Public section
16 // ##############
17 
18 // Constructors and destructors
19 // ============================
20 
21 template<typename T,Dist U,Dist V>
GeneralDistMatrix(const elem::Grid & grid,Int root)22 GeneralDistMatrix<T,U,V>::GeneralDistMatrix( const elem::Grid& grid, Int root )
23 : AbstractDistMatrix<T>(grid,root)
24 { }
25 
26 template<typename T,Dist U,Dist V>
GeneralDistMatrix(GeneralDistMatrix<T,U,V> && A)27 GeneralDistMatrix<T,U,V>::GeneralDistMatrix( GeneralDistMatrix<T,U,V>&& A )
28 ELEM_NOEXCEPT
29 : AbstractDistMatrix<T>(std::move(A))
30 { }
31 
32 // Assignment and reconfiguration
33 // ==============================
34 
35 template<typename T,Dist U,Dist V>
36 GeneralDistMatrix<T,U,V>&
operator =(GeneralDistMatrix<T,U,V> && A)37 GeneralDistMatrix<T,U,V>::operator=( GeneralDistMatrix<T,U,V>&& A )
38 {
39     AbstractDistMatrix<T>::operator=( std::move(A) );
40     return *this;
41 }
42 
43 template<typename T,Dist U,Dist V>
44 void
AlignColsWith(const elem::DistData & data,bool constrain)45 GeneralDistMatrix<T,U,V>::AlignColsWith
46 ( const elem::DistData& data, bool constrain )
47 {
48     DEBUG_ONLY(CallStackEntry cse("GDM::AlignColsWith"))
49     this->SetGrid( *data.grid );
50     this->SetRoot( data.root );
51     if( data.colDist == U || data.colDist == UPart )
52         this->AlignCols( data.colAlign, constrain );
53     else if( data.rowDist == U || data.rowDist == UPart )
54         this->AlignCols( data.rowAlign, constrain );
55     else if( data.colDist == UScat )
56         this->AlignCols( data.colAlign % this->ColStride(), constrain );
57     else if( data.rowDist == UScat )
58         this->AlignCols( data.rowAlign % this->ColStride(), constrain );
59     DEBUG_ONLY(
60         else if( U != UGath && data.colDist != UGath && data.rowDist != UGath )
61             LogicError("Nonsensical alignment");
62     )
63 }
64 
65 template<typename T,Dist U,Dist V>
66 void
AlignRowsWith(const elem::DistData & data,bool constrain)67 GeneralDistMatrix<T,U,V>::AlignRowsWith
68 ( const elem::DistData& data, bool constrain )
69 {
70     DEBUG_ONLY(CallStackEntry cse("GDM::AlignRowsWith"))
71     this->SetGrid( *data.grid );
72     this->SetRoot( data.root );
73     if( data.colDist == V || data.colDist == VPart )
74         this->AlignRows( data.colAlign, constrain );
75     else if( data.rowDist == V || data.rowDist == VPart )
76         this->AlignRows( data.rowAlign, constrain );
77     else if( data.colDist == VScat )
78         this->AlignRows( data.colAlign % this->RowStride(), constrain );
79     else if( data.rowDist == VScat )
80         this->AlignRows( data.rowAlign % this->RowStride(), constrain );
81     DEBUG_ONLY(
82         else if( V != VGath && data.colDist != VGath && data.rowDist != VGath )
83             LogicError("Nonsensical alignment");
84     )
85 }
86 
87 template<typename T,Dist U,Dist V>
88 void
Translate(DistMatrix<T,U,V> & A) const89 GeneralDistMatrix<T,U,V>::Translate( DistMatrix<T,U,V>& A ) const
90 {
91     DEBUG_ONLY(CallStackEntry cse("GDM::Translate"))
92     const Grid& g = this->Grid();
93     const Int height = this->Height();
94     const Int width = this->Width();
95     const Int colAlign = this->ColAlign();
96     const Int rowAlign = this->RowAlign();
97     const Int root = this->Root();
98     A.SetGrid( g );
99     if( !A.RootConstrained() )
100         A.SetRoot( root );
101     if( !A.ColConstrained() )
102         A.AlignCols( colAlign, false );
103     if( !A.RowConstrained() )
104         A.AlignRows( rowAlign, false );
105     A.Resize( height, width );
106     if( !g.InGrid() )
107         return;
108 
109     const bool aligned = colAlign == A.ColAlign() && rowAlign == A.RowAlign();
110     if( aligned && root == A.Root() )
111     {
112         A.matrix_ = this->matrix_;
113     }
114     else
115     {
116 #ifdef ELEM_UNALIGNED_WARNINGS
117         if( g.Rank() == 0 )
118             std::cerr << "Unaligned [U,V] <- [U,V]" << std::endl;
119 #endif
120         const Int colRank = this->ColRank();
121         const Int rowRank = this->RowRank();
122         const Int crossRank = this->CrossRank();
123         const Int colStride = this->ColStride();
124         const Int rowStride = this->RowStride();
125         const Int maxHeight = MaxLength( height, colStride );
126         const Int maxWidth  = MaxLength( width,  rowStride );
127         const Int pkgSize = mpi::Pad( maxHeight*maxWidth );
128         T* buffer=0;
129         if( crossRank == root || crossRank == A.Root() )
130             buffer = A.auxMemory_.Require( pkgSize );
131 
132         const Int colAlignA = A.ColAlign();
133         const Int rowAlignA = A.RowAlign();
134         const Int localHeightA =
135             Length( height, colRank, colAlignA, colStride );
136         const Int localWidthA = Length( width, rowRank, rowAlignA, rowStride );
137         const Int recvSize = mpi::Pad( localHeightA*localWidthA );
138 
139         if( crossRank == root )
140         {
141             // Pack the local data
142             const Int localHeight = this->LocalHeight();
143             const Int localWidth = this->LocalWidth();
144             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
145                 MemCopy
146                 ( &buffer[jLoc*localHeight], this->LockedBuffer(0,jLoc),
147                   localHeight );
148 
149             if( !aligned )
150             {
151                 // If we were not aligned, then SendRecv over the DistComm
152                 const Int toRow = Mod(colRank+colAlignA-colAlign,colStride);
153                 const Int toCol = Mod(rowRank+rowAlignA-rowAlign,rowStride);
154                 const Int fromRow = Mod(colRank+colAlign-colAlignA,colStride);
155                 const Int fromCol = Mod(rowRank+rowAlign-rowAlignA,rowStride);
156                 const Int toRank = toRow + toCol*colStride;
157                 const Int fromRank = fromRow + fromCol*colStride;
158                 mpi::SendRecv
159                 ( buffer, pkgSize, toRank, fromRank, this->DistComm() );
160             }
161         }
162         if( root != A.Root() )
163         {
164             // Send to the correct new root over the cross communicator
165             if( crossRank == root )
166                 mpi::Send( buffer, recvSize, A.Root(), A.CrossComm() );
167             else if( crossRank == A.Root() )
168                 mpi::Recv( buffer, recvSize, root, A.CrossComm() );
169         }
170         // Unpack
171         if( crossRank == A.Root() )
172             for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
173                 MemCopy
174                 ( A.Buffer(0,jLoc), &buffer[jLoc*localHeightA], localHeightA );
175         if( crossRank == root || crossRank == A.Root() )
176             A.auxMemory_.Release();
177     }
178 }
179 
180 template<typename T,Dist U,Dist V>
181 void
AllGather(DistMatrix<T,UGath,VGath> & A) const182 GeneralDistMatrix<T,U,V>::AllGather( DistMatrix<T,UGath,VGath>& A ) const
183 {
184     DEBUG_ONLY(CallStackEntry cse("GDM::AllGather"))
185     const Int height = this->Height();
186     const Int width = this->Width();
187     A.SetGrid( this->Grid() );
188     A.Resize( height, width );
189 
190     if( this->Participating() )
191     {
192         const Int colStride = this->ColStride();
193         const Int rowStride = this->RowStride();
194         const Int distStride = colStride*rowStride;
195 
196         const Int thisLocalHeight = this->LocalHeight();
197         const Int thisLocalWidth = this->LocalWidth();
198         const Int maxLocalHeight = MaxLength(height,colStride);
199         const Int maxLocalWidth = MaxLength(width,rowStride);
200 
201         const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
202         T* buffer = A.auxMemory_.Require( (distStride+1)*portionSize );
203         T* sendBuf = &buffer[0];
204         T* recvBuf = &buffer[portionSize];
205 
206         // Pack
207         const Int ldim = this->LDim();
208         const T* thisBuf = this->LockedBuffer();
209         ELEM_PARALLEL_FOR
210         for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
211             MemCopy
212             ( &sendBuf[jLoc*thisLocalHeight], &thisBuf[jLoc*ldim],
213               thisLocalHeight );
214 
215         // Communicate
216         mpi::AllGather
217         ( sendBuf, portionSize, recvBuf, portionSize, this->DistComm() );
218 
219         // Unpack
220         T* ABuf = A.Buffer();
221         const Int ALDim = A.LDim();
222         const Int colAlign = this->ColAlign();
223         const Int rowAlign = this->RowAlign();
224         ELEM_OUTER_PARALLEL_FOR
225         for( Int l=0; l<rowStride; ++l )
226         {
227             const Int rowShift = Shift_( l, rowAlign, rowStride );
228             const Int localWidth = Length_( width, rowShift, rowStride );
229             for( Int k=0; k<colStride; ++k )
230             {
231                 const T* data = &recvBuf[(k+l*colStride)*portionSize];
232                 const Int colShift = Shift_( k, colAlign, colStride );
233                 const Int localHeight = Length_( height, colShift, colStride );
234                 ELEM_INNER_PARALLEL_FOR
235                 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
236                 {
237                     T* destCol =
238                         &ABuf[colShift+(rowShift+jLoc*rowStride)*ALDim];
239                     const T* sourceCol = &data[jLoc*localHeight];
240                     for( Int iLoc=0; iLoc<localHeight; ++iLoc )
241                         destCol[iLoc*colStride] = sourceCol[iLoc];
242                 }
243             }
244         }
245         A.auxMemory_.Release();
246     }
247     if( this->Grid().InGrid() && this->CrossComm() != mpi::COMM_SELF )
248     {
249         // Pack from the root
250         const Int localHeight = A.LocalHeight();
251         const Int localWidth = A.LocalWidth();
252         T* buf = A.auxMemory_.Require( localHeight*localWidth );
253         if( this->CrossRank() == this->Root() )
254         {
255             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
256                 MemCopy
257                 ( &buf[jLoc*localHeight], A.LockedBuffer(0,jLoc), localHeight );
258         }
259 
260         // Broadcast from the root
261         mpi::Broadcast
262         ( buf, localHeight*localWidth, this->Root(), this->CrossComm() );
263 
264         // Unpack if not the root
265         if( this->CrossRank() != this->Root() )
266         {
267             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
268                 MemCopy
269                 ( A.Buffer(0,jLoc), &buf[jLoc*localHeight], localHeight );
270         }
271         A.auxMemory_.Release();
272     }
273 }
274 
275 template<typename T,Dist U,Dist V>
276 void
ColAllGather(DistMatrix<T,UGath,V> & A) const277 GeneralDistMatrix<T,U,V>::ColAllGather( DistMatrix<T,UGath,V>& A ) const
278 {
279     DEBUG_ONLY(
280         CallStackEntry cse("GDM::ColAllGather");
281         this->AssertSameGrid( A.Grid() );
282     )
283     const Int height = this->Height();
284     const Int width = this->Width();
285 #ifdef ELEM_CACHE_WARNINGS
286     if( height != 1 && this->Grid().Rank() == 0 )
287     {
288         std::cerr <<
289           "The matrix redistribution [* ,V] <- [U,V] potentially causes a "
290           "large amount of cache-thrashing. If possible, avoid it by "
291           "performing the redistribution with a (conjugate)-transpose"
292           << std::endl;
293     }
294 #endif
295     A.AlignRowsAndResize( this->RowAlign(), height, width, false, false );
296 
297     if( this->Participating() )
298     {
299         if( this->RowAlign() == A.RowAlign() )
300         {
301             if( height == 1 )
302             {
303                 const Int localWidthA = A.LocalWidth();
304                 T* bcastBuf = A.auxMemory_.Require( localWidthA );
305 
306                 if( this->ColRank() == this->ColAlign() )
307                 {
308                     A.matrix_ = this->LockedMatrix();
309                     // Pack
310                     const T* ABuf = A.LockedBuffer();
311                     const Int ALDim = A.LDim();
312                     ELEM_PARALLEL_FOR
313                     for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
314                         bcastBuf[jLoc] = ABuf[jLoc*ALDim];
315                 }
316 
317                 // Broadcast within the column comm
318                 mpi::Broadcast
319                 ( bcastBuf, localWidthA, this->ColAlign(), this->ColComm() );
320 
321                 // Unpack
322                 T* ABuf = A.Buffer();
323                 const Int ALDim = A.LDim();
324                 ELEM_PARALLEL_FOR
325                 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
326                     ABuf[jLoc*ALDim] = bcastBuf[jLoc];
327                 A.auxMemory_.Release();
328             }
329             else
330             {
331                 const Int colStride = this->ColStride();
332                 const Int localWidth = this->LocalWidth();
333                 const Int thisLocalHeight = this->LocalHeight();
334                 const Int maxLocalHeight = MaxLength(height,colStride);
335                 const Int portionSize = mpi::Pad( maxLocalHeight*localWidth );
336 
337                 T* buffer = A.auxMemory_.Require( (colStride+1)*portionSize );
338                 T* sendBuf = &buffer[0];
339                 T* recvBuf = &buffer[portionSize];
340 
341                 // Pack
342                 const Int ldim = this->LDim();
343                 const T* thisBuf = this->LockedBuffer();
344                 ELEM_PARALLEL_FOR
345                 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
346                 {
347                     const T* thisCol = &thisBuf[jLoc*ldim];
348                     T* sendBufCol = &sendBuf[jLoc*thisLocalHeight];
349                     MemCopy( sendBufCol, thisCol, thisLocalHeight );
350                 }
351 
352                 // Communicate
353                 mpi::AllGather
354                 ( sendBuf, portionSize, recvBuf, portionSize, this->ColComm() );
355 
356                 // Unpack
357                 T* ABuf = A.Buffer();
358                 const Int ALDim = A.LDim();
359                 const Int colAlign = this->ColAlign();
360                 ELEM_OUTER_PARALLEL_FOR
361                 for( Int k=0; k<colStride; ++k )
362                 {
363                     const T* data = &recvBuf[k*portionSize];
364                     const Int colShift = Shift_( k, colAlign, colStride );
365                     const Int localHeight =
366                         Length_( height, colShift, colStride );
367                     ELEM_INNER_PARALLEL_FOR
368                     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
369                     {
370                         T* destCol = &ABuf[colShift+jLoc*ALDim];
371                         const T* sourceCol = &data[jLoc*localHeight];
372                         for( Int iLoc=0; iLoc<localHeight; ++iLoc )
373                             destCol[iLoc*colStride] = sourceCol[iLoc];
374                     }
375                 }
376                 A.auxMemory_.Release();
377             }
378         }
379         else
380         {
381 #ifdef ELEM_UNALIGNED_WARNINGS
382             if( this->Grid().Rank() == 0 )
383                 std::cerr << "Unaligned [U,V] -> [* ,V]." << std::endl;
384 #endif
385             const Int colStride = this->ColStride();
386             const Int rowStride = this->RowStride();
387             const Int rowRank = this->RowRank();
388 
389             const Int rowAlign = this->RowAlign();
390             const Int rowAlignA = A.RowAlign();
391             const Int sendRowRank =
392                 (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
393             const Int recvRowRank =
394                 (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
395 
396             if( height == 1 )
397             {
398                 const Int localWidthA = A.LocalWidth();
399                 T* bcastBuf;
400 
401                 if( this->ColRank() == this->ColAlign() )
402                 {
403                     const Int localWidth = this->LocalWidth();
404 
405                     T* buffer = A.auxMemory_.Require( localWidth+localWidthA );
406                     T* sendBuf = &buffer[0];
407                     bcastBuf   = &buffer[localWidth];
408 
409                     // Pack
410                     const T* thisBuf = this->LockedBuffer();
411                     const Int ldim = this->LDim();
412                     ELEM_PARALLEL_FOR
413                     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
414                         sendBuf[jLoc] = thisBuf[jLoc*ldim];
415 
416                     // Communicate
417                     mpi::SendRecv
418                     ( sendBuf,  localWidth,  sendRowRank,
419                       bcastBuf, localWidthA, recvRowRank, this->RowComm() );
420                 }
421                 else
422                 {
423                     bcastBuf = A.auxMemory_.Require( localWidthA );
424                 }
425 
426                 // Communicate
427                 mpi::Broadcast
428                 ( bcastBuf, localWidthA, this->ColAlign(), this->ColComm() );
429 
430                 // Unpack
431                 T* ABuf = A.Buffer();
432                 const Int ALDim = A.LDim();
433                 ELEM_PARALLEL_FOR
434                 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
435                     ABuf[jLoc*ALDim] = bcastBuf[jLoc];
436                 A.auxMemory_.Release();
437             }
438             else
439             {
440                 const Int thisLocalWidth = this->LocalWidth();
441                 const Int localWidthA = A.LocalWidth();
442                 const Int thisLocalHeight = this->LocalHeight();
443                 const Int maxLocalHeight = MaxLength(height,colStride);
444                 const Int maxLocalWidth = MaxLength(width,rowStride);
445                 const Int portionSize =
446                     mpi::Pad( maxLocalHeight*maxLocalWidth );
447 
448                 T* buffer = A.auxMemory_.Require( (colStride+1)*portionSize );
449                 T* firstBuf  = &buffer[0];
450                 T* secondBuf = &buffer[portionSize];
451 
452                 // Pack
453                 const Int ldim = this->LDim();
454                 const T* thisBuf = this->LockedBuffer();
455                 ELEM_PARALLEL_FOR
456                 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
457                 {
458                     const T* thisCol = &thisBuf[jLoc*ldim];
459                     T* secondBufCol = &secondBuf[jLoc*thisLocalHeight];
460                     MemCopy( secondBufCol, thisCol, thisLocalHeight );
461                 }
462 
463                 // Realign
464                 mpi::SendRecv
465                 ( secondBuf, portionSize, sendRowRank,
466                   firstBuf,  portionSize, recvRowRank, this->RowComm() );
467 
468                 // AllGather the aligned data
469                 mpi::AllGather
470                 ( firstBuf, portionSize,
471                   secondBuf, portionSize, this->ColComm() );
472 
473                 // Unpack the contents of each member of the column team
474                 T* ABuf = A.Buffer();
475                 const Int ALDim = A.LDim();
476                 const Int colAlign = this->ColAlign();
477                 ELEM_OUTER_PARALLEL_FOR
478                 for( Int k=0; k<colStride; ++k )
479                 {
480                     const T* data = &secondBuf[k*portionSize];
481                     const Int colShift = Shift_( k, colAlign, colStride );
482                     const Int localHeight =
483                         Length_( height, colShift, colStride );
484                     ELEM_INNER_PARALLEL_FOR
485                     for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
486                     {
487                         T* destCol = &ABuf[colShift+jLoc*ALDim];
488                         const T* sourceCol = &data[jLoc*localHeight];
489                         for( Int iLoc=0; iLoc<localHeight; ++iLoc )
490                             destCol[iLoc*colStride] = sourceCol[iLoc];
491                     }
492                 }
493                 A.auxMemory_.Release();
494             }
495         }
496     }
497     if( this->Grid().InGrid() && this->CrossComm() != mpi::COMM_SELF )
498     {
499         // Pack from the root
500         const Int localHeight = A.LocalHeight();
501         const Int localWidth = A.LocalWidth();
502         T* buf = A.auxMemory_.Require( localHeight*localWidth );
503         if( this->CrossRank() == this->Root() )
504         {
505             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
506                 MemCopy
507                 ( &buf[jLoc*localHeight], A.LockedBuffer(0,jLoc), localHeight );
508         }
509 
510         // Broadcast from the root
511         mpi::Broadcast
512         ( buf, localHeight*localWidth, this->Root(), this->CrossComm() );
513 
514         // Unpack if not the root
515         if( this->CrossRank() != this->Root() )
516         {
517             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
518                 MemCopy
519                 ( A.Buffer(0,jLoc), &buf[jLoc*localHeight], localHeight );
520         }
521         A.auxMemory_.Release();
522     }
523 }
524 
525 template<typename T,Dist U,Dist V>
526 void
RowAllGather(DistMatrix<T,U,VGath> & A) const527 GeneralDistMatrix<T,U,V>::RowAllGather( DistMatrix<T,U,VGath>& A ) const
528 {
529     DEBUG_ONLY(
530         CallStackEntry cse("GDM::RowAllGather");
531         this->AssertSameGrid( A.Grid() );
532     )
533     const Int height = this->Height();
534     const Int width = this->Width();
535     A.AlignColsAndResize( this->ColAlign(), height, width, false, false );
536 
537     if( this->Participating() )
538     {
539         if( this->ColAlign() == A.ColAlign() )
540         {
541             if( width == 1 )
542             {
543                 if( this->RowRank() == this->RowAlign() )
544                     A.matrix_ = this->LockedMatrix();
545                 mpi::Broadcast
546                 ( A.matrix_.Buffer(), A.LocalHeight(), this->RowAlign(),
547                   this->RowComm() );
548             }
549             else
550             {
551                 const Int rowStride = this->RowStride();
552                 const Int thisLocalWidth = this->LocalWidth();
553                 const Int localHeight = this->LocalHeight();
554                 const Int maxLocalWidth = MaxLength(width,rowStride);
555 
556                 const Int portionSize = mpi::Pad( localHeight*maxLocalWidth );
557                 T* buffer = A.auxMemory_.Require( (rowStride+1)*portionSize );
558                 T* sendBuf = &buffer[0];
559                 T* recvBuf = &buffer[portionSize];
560 
561                 // Pack
562                 const Int ldim = this->LDim();
563                 const T* thisBuf = this->LockedBuffer();
564                 ELEM_PARALLEL_FOR
565                 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
566                 {
567                     const T* thisCol = &thisBuf[jLoc*ldim];
568                     T* sendBufCol = &sendBuf[jLoc*localHeight];
569                     MemCopy( sendBufCol, thisCol, localHeight );
570                 }
571 
572                 // Communicate
573                 mpi::AllGather
574                 ( sendBuf, portionSize, recvBuf, portionSize, this->RowComm() );
575 
576                 // Unpack
577                 T* ABuf = A.Buffer();
578                 const Int ALDim = A.LDim();
579                 const Int rowAlign = this->RowAlign();
580                 ELEM_OUTER_PARALLEL_FOR
581                 for( Int k=0; k<rowStride; ++k )
582                 {
583                     const T* data = &recvBuf[k*portionSize];
584                     const Int rowShift = Shift_( k, rowAlign, rowStride );
585                     const Int localWidth =
586                         Length_( width, rowShift, rowStride );
587                     ELEM_INNER_PARALLEL_FOR
588                     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
589                     {
590                         const T* dataCol = &data[jLoc*localHeight];
591                         T* ACol = &ABuf[(rowShift+jLoc*rowStride)*ALDim];
592                         MemCopy( ACol, dataCol, localHeight );
593                     }
594                 }
595                 A.auxMemory_.Release();
596             }
597         }
598         else
599         {
600 #ifdef ELEM_UNALIGNED_WARNINGS
601             if( this->Grid().Rank() == 0 )
602                 std::cerr << "Unaligned RowAllGather." << std::endl;
603 #endif
604             const Int colStride = this->ColStride();
605             const Int rowStride = this->RowStride();
606             const Int colRank = this->ColRank();
607 
608             const Int colAlign = this->ColAlign();
609             const Int colAlignA = A.ColAlign();
610             const Int sendColRank =
611                 (colRank+colStride+colAlignA-colAlign) % colStride;
612             const Int recvColRank =
613                 (colRank+colStride+colAlign-colAlignA) % colStride;
614 
615             if( width == 1 )
616             {
617                 const Int localHeightA = A.LocalHeight();
618                 if( this->RowRank() == this->RowAlign() )
619                 {
620                     const Int localHeight = this->LocalHeight();
621                     T* buffer = A.auxMemory_.Require( localHeight );
622 
623                     // Pack
624                     const T* thisCol = this->LockedBuffer();
625                     MemCopy( buffer, thisCol, localHeight );
626 
627                     // Realign
628                     mpi::SendRecv
629                     ( buffer, localHeight, sendColRank,
630                       A.matrix_.Buffer(), localHeightA, recvColRank,
631                       this->ColComm() );
632                     A.auxMemory_.Release();
633                 }
634 
635                 // Perform the row broadcast
636                 mpi::Broadcast
637                 ( A.matrix_.Buffer(), localHeightA, this->RowAlign(),
638                   this->RowComm() );
639             }
640             else
641             {
642                 const Int localHeight = this->LocalHeight();
643                 const Int thisLocalWidth = this->LocalWidth();
644                 const Int localHeightA = A.LocalHeight();
645                 const Int maxLocalHeight = MaxLength(height,colStride);
646                 const Int maxLocalWidth = MaxLength(width,rowStride);
647 
648                 const Int portionSize =
649                     mpi::Pad( maxLocalHeight*maxLocalWidth );
650                 T* buffer = A.auxMemory_.Require( (rowStride+1)*portionSize );
651                 T* firstBuf = &buffer[0];
652                 T* secondBuf = &buffer[portionSize];
653 
654                 // Pack
655                 const Int ldim = this->LDim();
656                 const T* thisBuf = this->LockedBuffer();
657                 ELEM_PARALLEL_FOR
658                 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
659                 {
660                     const T* thisCol = &thisBuf[jLoc*ldim];
661                     T* secondBufCol = &secondBuf[jLoc*localHeight];
662                     MemCopy( secondBufCol, thisCol, localHeight );
663                 }
664 
665                 // Realign
666                 mpi::SendRecv
667                 ( secondBuf, portionSize, sendColRank,
668                   firstBuf,  portionSize, recvColRank, this->ColComm() );
669 
670                 // Perform the row AllGather
671                 mpi::AllGather
672                 ( firstBuf,  portionSize,
673                   secondBuf, portionSize, this->RowComm() );
674 
675                 // Unpack
676                 T* ABuf = A.Buffer();
677                 const Int ALDim = A.LDim();
678                 const Int rowAlign = this->RowAlign();
679                 ELEM_OUTER_PARALLEL_FOR
680                 for( Int k=0; k<rowStride; ++k )
681                 {
682                     const T* data = &secondBuf[k*portionSize];
683                     const Int rowShift = Shift_( k, rowAlign, rowStride );
684                     const Int localWidth =
685                         Length_( width, rowShift, rowStride );
686                     ELEM_INNER_PARALLEL_FOR
687                     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
688                     {
689                         const T* dataCol = &data[jLoc*localHeightA];
690                         T* ACol = &ABuf[(rowShift+jLoc*rowStride)*ALDim];
691                         MemCopy( ACol, dataCol, localHeightA );
692                     }
693                 }
694                 A.auxMemory_.Release();
695             }
696         }
697     }
698     if( this->Grid().InGrid() && this->CrossComm() != mpi::COMM_SELF )
699     {
700         // Pack from the root
701         const Int localHeight = A.LocalHeight();
702         const Int localWidth = A.LocalWidth();
703         T* buf = A.auxMemory_.Require( localHeight*localWidth );
704         if( this->CrossRank() == this->Root() )
705         {
706             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
707                 MemCopy
708                 ( &buf[jLoc*localHeight], A.LockedBuffer(0,jLoc), localHeight );
709         }
710 
711         // Broadcast from the root
712         mpi::Broadcast
713         ( buf, localHeight*localWidth, this->Root(), this->CrossComm() );
714 
715         // Unpack if not the root
716         if( this->CrossRank() != this->Root() )
717         {
718             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
719                 MemCopy
720                 ( A.Buffer(0,jLoc), &buf[jLoc*localHeight], localHeight );
721         }
722         A.auxMemory_.Release();
723     }
724 }
725 
726 template<typename T,Dist U,Dist V>
727 void
PartialColAllGather(DistMatrix<T,UPart,V> & A) const728 GeneralDistMatrix<T,U,V>::PartialColAllGather( DistMatrix<T,UPart,V>& A ) const
729 {
730     DEBUG_ONLY(
731         CallStackEntry cse("GDM::PartialColAllGather");
732         this->AssertSameGrid( A.Grid() );
733     )
734     const Int height = this->Height();
735     const Int width = this->Width();
736 #ifdef ELEM_VECTOR_WARNINGS
737     if( width == 1 && this->Grid().Rank() == 0 )
738     {
739         std::cerr <<
740           "The vector version of PartialColAllGather is not yet written but "
741           "would only require modifying the vector version of "
742           "PartialRowAllGather" << std::endl;
743     }
744 #endif
745 #ifdef ELEM_CACHE_WARNINGS
746     if( width && this->Grid().Rank() == 0 )
747     {
748         std::cerr <<
749           "PartialColAllGather potentially causes a large amount of cache-"
750           "thrashing. If possible, avoid it by performing the redistribution"
751           "on the (conjugate-)transpose" << std::endl;
752     }
753 #endif
754     A.AlignColsAndResize
755     ( this->ColAlign()%A.ColStride(), height, width, false, false );
756     if( !this->Participating() )
757         return;
758 
759     DEBUG_ONLY(
760         if( this->LocalWidth() != this->Width() )
761             LogicError("This routine assumes rows are not distributed");
762     )
763     const T* thisBuf = this->LockedBuffer();
764     const Int ldim = this->LDim();
765     T* ABuf = A.Buffer();
766     const Int ALDim = A.LDim();
767 
768     const Int colAlign = this->ColAlign();
769     const Int colAlignA = A.ColAlign();
770     const Int colStride = this->ColStride();
771     const Int colStrideUnion = this->PartialUnionColStride();
772     const Int colStridePart = this->PartialColStride();
773     const Int colRankPart = this->PartialColRank();
774     const Int colShiftA = A.ColShift();
775 
776     const Int thisLocalHeight = this->LocalHeight();
777     const Int maxLocalHeight = MaxLength(height,colStride);
778     const Int portionSize = mpi::Pad( maxLocalHeight*width );
779     T* buffer = A.auxMemory_.Require( (colStrideUnion+1)*portionSize );
780     T* firstBuf = &buffer[0];
781     T* secondBuf = &buffer[portionSize];
782 
783     if( colAlignA == colAlign % colStridePart )
784     {
785         // Pack
786         ELEM_PARALLEL_FOR
787         for( Int j=0; j<width; ++j )
788         {
789             const T* thisCol = &thisBuf[j*ldim];
790             T* firstBufCol = &firstBuf[j*thisLocalHeight];
791             MemCopy( firstBufCol, thisCol, thisLocalHeight );
792         }
793 
794         // Communicate
795         mpi::AllGather
796         ( firstBuf, portionSize, secondBuf, portionSize,
797           this->PartialUnionColComm() );
798 
799         // Unpack
800         ELEM_OUTER_PARALLEL_FOR
801         for( Int k=0; k<colStrideUnion; ++k )
802         {
803             const T* data = &secondBuf[k*portionSize];
804             const Int colShift =
805                 Shift_( colRankPart+k*colStridePart, colAlign, colStride );
806             const Int colOffset = (colShift-colShiftA) / colStridePart;
807             const Int localHeight = Length_( height, colShift, colStride );
808             ELEM_INNER_PARALLEL_FOR
809             for( Int j=0; j<width; ++j )
810             {
811                 const T* dataCol = &data[j*localHeight];
812                 T* ACol = &ABuf[colOffset+j*ALDim];
813                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
814                     ACol[iLoc*colStrideUnion] = dataCol[iLoc];
815             }
816         }
817     }
818     else
819     {
820 #ifdef ELEM_UNALIGNED_WARNINGS
821         if( this->Grid().Rank() == 0 )
822             std::cerr << "Unaligned PartialColAllGather" << std::endl;
823 #endif
824         // Perform a SendRecv to match the row alignments
825         const Int colRank = this->ColRank();
826         const Int sendColRank =
827             (colRank+colStride+colAlignA-colAlign) % colStride;
828         const Int recvColRank =
829             (colRank+colStride+colAlign-colAlignA) % colStride;
830         ELEM_PARALLEL_FOR
831         for( Int j=0; j<width; ++j )
832         {
833             const T* thisCol = &thisBuf[j*ldim];
834             T* secondBufCol = &secondBuf[j*thisLocalHeight];
835             MemCopy( secondBufCol, thisCol, thisLocalHeight );
836         }
837         mpi::SendRecv
838         ( secondBuf, portionSize, sendColRank,
839           firstBuf,  portionSize, recvColRank, this->ColComm() );
840 
841         // Use the SendRecv as an input to the partial union AllGather
842         mpi::AllGather
843         ( firstBuf,  portionSize,
844           secondBuf, portionSize, this->PartialUnionColComm() );
845 
846         // Unpack
847         ELEM_OUTER_PARALLEL_FOR
848         for( Int k=0; k<colStrideUnion; ++k )
849         {
850             const T* data = &secondBuf[k*portionSize];
851             const Int colShift =
852                 Shift_( colRankPart+colStridePart*k, colAlignA, colStride );
853             const Int colOffset = (colShift-colShiftA) / colStridePart;
854             const Int localHeight = Length_( height, colShift, colStride );
855             ELEM_INNER_PARALLEL_FOR
856             for( Int j=0; j<width; ++j )
857             {
858                 const T* dataCol = &data[j*localHeight];
859                 T* ACol = &ABuf[colOffset+j*ALDim];
860                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
861                     ACol[iLoc*colStrideUnion] = dataCol[iLoc];
862             }
863         }
864     }
865     A.auxMemory_.Release();
866 }
867 
868 template<typename T,Dist U,Dist V>
869 void
PartialRowAllGather(DistMatrix<T,U,VPart> & A) const870 GeneralDistMatrix<T,U,V>::PartialRowAllGather( DistMatrix<T,U,VPart>& A ) const
871 {
872     DEBUG_ONLY(
873         CallStackEntry cse("GDM::PartialRowAllGather");
874         this->AssertSameGrid( A.Grid() );
875     )
876     const Int height = this->Height();
877     const Int width = this->Width();
878     A.AlignRowsAndResize
879     ( this->RowAlign()%A.RowStride(), height, width, false, false );
880     if( !this->Participating() )
881         return;
882 
883     DEBUG_ONLY(
884         if( this->LocalHeight() != this->Height() )
885             LogicError("This routine assumes columns are not distributed");
886     )
887     const T* thisBuf = this->LockedBuffer();
888     const Int ldim = this->LDim();
889     T* ABuf = A.Buffer();
890     const Int ALDim = A.LDim();
891 
892     const Int rowAlign = this->RowAlign();
893     const Int rowAlignA = A.RowAlign();
894     const Int rowStride = this->RowStride();
895     const Int rowStrideUnion = this->PartialUnionRowStride();
896     const Int rowStridePart = this->PartialRowStride();
897     const Int rowRankPart = this->PartialRowRank();
898     const Int rowShiftA = A.RowShift();
899 
900     const Int thisLocalWidth = this->LocalWidth();
901     const Int maxLocalWidth = MaxLength(width,rowStride);
902     const Int portionSize = mpi::Pad( height*maxLocalWidth );
903     T* buffer = A.auxMemory_.Require( (rowStrideUnion+1)*portionSize );
904     T* firstBuf = &buffer[0];
905     T* secondBuf = &buffer[portionSize];
906 
907     if( rowAlignA == rowAlign % rowStridePart )
908     {
909         // Pack
910         ELEM_PARALLEL_FOR
911         for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
912         {
913             const T* thisCol = &thisBuf[jLoc*ldim];
914             T* firstBufCol = &firstBuf[jLoc*height];
915             MemCopy( firstBufCol, thisCol, height );
916         }
917 
918         // Communicate
919         mpi::AllGather
920         ( firstBuf, portionSize, secondBuf, portionSize,
921           this->PartialUnionRowComm() );
922 
923         // Unpack
924         ELEM_OUTER_PARALLEL_FOR
925         for( Int k=0; k<rowStrideUnion; ++k )
926         {
927             const T* data = &secondBuf[k*portionSize];
928             const Int rowShift =
929                 Shift_( rowRankPart+k*rowStridePart, rowAlign, rowStride );
930             const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
931             const Int localWidth = Length_( width, rowShift, rowStride );
932             ELEM_INNER_PARALLEL_FOR
933             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
934             {
935                 const T* dataCol = &data[jLoc*height];
936                 T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
937                 MemCopy( ACol, dataCol, height );
938             }
939         }
940     }
941     else
942     {
943 #ifdef ELEM_UNALIGNED_WARNINGS
944         if( this->Grid().Rank() == 0 )
945             std::cerr << "Unaligned PartialRowAllGather" << std::endl;
946 #endif
947         // Perform a SendRecv to match the row alignments
948         const Int rowRank = this->RowRank();
949         const Int sendRowRank =
950             (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
951         const Int recvRowRank =
952             (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
953         ELEM_PARALLEL_FOR
954         for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
955         {
956             const T* thisCol = &thisBuf[jLoc*ldim];
957             T* secondBufCol = &secondBuf[jLoc*height];
958             MemCopy( secondBufCol, thisCol, height );
959         }
960         mpi::SendRecv
961         ( secondBuf, portionSize, sendRowRank,
962           firstBuf,  portionSize, recvRowRank, this->RowComm() );
963 
964         // Use the SendRecv as an input to the partial union AllGather
965         mpi::AllGather
966         ( firstBuf,  portionSize,
967           secondBuf, portionSize, this->PartialUnionRowComm() );
968 
969         // Unpack
970         ELEM_OUTER_PARALLEL_FOR
971         for( Int k=0; k<rowStrideUnion; ++k )
972         {
973             const T* data = &secondBuf[k*portionSize];
974             const Int rowShift =
975                 Shift_( rowRankPart+rowStridePart*k, rowAlignA, rowStride );
976             const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
977             const Int localWidth = Length_( width, rowShift, rowStride );
978             ELEM_INNER_PARALLEL_FOR
979             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
980             {
981                 const T* dataCol = &data[jLoc*height];
982                 T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
983                 MemCopy( ACol, dataCol, height );
984             }
985         }
986     }
987     A.auxMemory_.Release();
988 }
989 
990 template<typename T,Dist U,Dist V>
991 void
FilterFrom(const DistMatrix<T,UGath,VGath> & A)992 GeneralDistMatrix<T,U,V>::FilterFrom( const DistMatrix<T,UGath,VGath>& A )
993 {
994     DEBUG_ONLY(
995         CallStackEntry cse("GDM::FilterFrom");
996         this->AssertSameGrid( A.Grid() );
997     )
998     const Int height = A.Height();
999     const Int width = A.Width();
1000     this->Resize( height, width );
1001     if( !this->Participating() )
1002         return;
1003 
1004     const Int colStride = this->ColStride();
1005     const Int rowStride = this->RowStride();
1006     const Int colShift = this->ColShift();
1007     const Int rowShift = this->RowShift();
1008 
1009     const Int localHeight = this->LocalHeight();
1010     const Int localWidth = this->LocalWidth();
1011 
1012     T* thisBuf = this->Buffer();
1013     const Int ldim = this->LDim();
1014     const T* ABuf = A.LockedBuffer();
1015     const Int ALDim = A.LDim();
1016     ELEM_PARALLEL_FOR
1017     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1018     {
1019         T* thisCol = &thisBuf[jLoc*ldim];
1020         const T* ACol = &ABuf[colShift+(rowShift+jLoc*rowStride)*ALDim];
1021         for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1022             thisCol[iLoc] = ACol[iLoc*colStride];
1023     }
1024 }
1025 
1026 template<typename T,Dist U,Dist V>
1027 void
ColFilterFrom(const DistMatrix<T,UGath,V> & A)1028 GeneralDistMatrix<T,U,V>::ColFilterFrom( const DistMatrix<T,UGath,V>& A )
1029 {
1030     DEBUG_ONLY(
1031         CallStackEntry cse("GDM::ColFilterFrom");
1032         this->AssertSameGrid( A.Grid() );
1033     )
1034     const Int height = A.Height();
1035     const Int width = A.Width();
1036     this->AlignRowsAndResize( A.RowAlign(), height, width, false, false );
1037     if( !this->Participating() )
1038         return;
1039 
1040     const Int colStride = this->ColStride();
1041     const Int colShift = this->ColShift();
1042     const Int rowAlign = this->RowAlign();
1043     const Int rowAlignA = A.RowAlign();
1044 
1045     const Int localHeight = this->LocalHeight();
1046     const Int localWidth = this->LocalWidth();
1047 
1048     T* thisBuf = this->Buffer();
1049     const Int ldim = this->LDim();
1050     const T* ABuf = A.LockedBuffer();
1051     const Int ALDim = A.LDim();
1052 
1053     if( rowAlign == rowAlignA )
1054     {
1055         ELEM_PARALLEL_FOR
1056         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1057         {
1058             T* thisCol = &thisBuf[jLoc*ldim];
1059             const T* ACol = &ABuf[colShift+jLoc*ALDim];
1060             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1061                 thisCol[iLoc] = ACol[iLoc*colStride];
1062         }
1063     }
1064     else
1065     {
1066 #ifdef ELEM_UNALIGNED_WARNINGS
1067         if( this->Grid().Rank() == 0 )
1068             std::cerr << "Unaligned ColFilterFrom" << std::endl;
1069 #endif
1070         const Int rowStride = this->RowStride();
1071         const Int rowRank = this->RowRank();
1072         const Int sendRowRank =
1073             (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
1074         const Int recvRowRank =
1075             (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
1076         const Int localWidthA = A.LocalWidth();
1077         const Int sendSize = localHeight*localWidthA;
1078         const Int recvSize = localHeight*localWidth;
1079         T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1080         T* sendBuf = &buffer[0];
1081         T* recvBuf = &buffer[sendSize];
1082 
1083         // Pack
1084         ELEM_PARALLEL_FOR
1085         for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1086         {
1087             T* sendCol = &sendBuf[jLoc*localHeight];
1088             const T* ACol = &ABuf[colShift+jLoc*ALDim];
1089             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1090                 sendCol[iLoc] = ACol[iLoc*colStride];
1091         }
1092 
1093         // Realign
1094         mpi::SendRecv
1095         ( sendBuf, sendSize, sendRowRank,
1096           recvBuf, recvSize, recvRowRank, this->RowComm() );
1097 
1098         // Unpack
1099         ELEM_PARALLEL_FOR
1100         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1101             MemCopy
1102             ( &thisBuf[jLoc*ldim], &recvBuf[jLoc*localHeight], localHeight );
1103         this->auxMemory_.Release();
1104     }
1105 }
1106 
1107 template<typename T,Dist U,Dist V>
1108 void
RowFilterFrom(const DistMatrix<T,U,VGath> & A)1109 GeneralDistMatrix<T,U,V>::RowFilterFrom( const DistMatrix<T,U,VGath>& A )
1110 {
1111     DEBUG_ONLY(
1112         CallStackEntry cse("GDM::RowFilterFrom");
1113         this->AssertSameGrid( A.Grid() );
1114     )
1115     const Int height = A.Height();
1116     const Int width = A.Width();
1117     this->AlignColsAndResize( A.ColAlign(), height, width, false, false );
1118     if( !this->Participating() )
1119         return;
1120 
1121     const Int colAlign = this->ColAlign();
1122     const Int colAlignA = A.ColAlign();
1123     const Int rowStride = this->RowStride();
1124     const Int rowShift = this->RowShift();
1125 
1126     const Int localHeight = this->LocalHeight();
1127     const Int localWidth = this->LocalWidth();
1128 
1129     T* thisBuf = this->Buffer();
1130     const Int ldim = this->LDim();
1131     const T* ABuf = A.LockedBuffer();
1132     const Int ALDim = A.LDim();
1133 
1134     if( colAlign == colAlignA )
1135     {
1136         ELEM_PARALLEL_FOR
1137         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1138         {
1139             T* thisCol = &thisBuf[jLoc*ldim];
1140             const T* ACol = &ABuf[(rowShift+jLoc*rowStride)*ALDim];
1141             MemCopy( thisCol, ACol, localHeight );
1142         }
1143     }
1144     else
1145     {
1146 #ifdef ELEM_UNALIGNED_WARNINGS
1147         if( this->Grid().Rank() == 0 )
1148             std::cerr << "Unaligned RowFilterFrom" << std::endl;
1149 #endif
1150         const Int colRank = this->ColRank();
1151         const Int colStride = this->ColStride();
1152         const Int sendColRank =
1153             (colRank+colStride+colAlign-colAlignA) % colStride;
1154         const Int recvColRank =
1155             (colRank+colStride+colAlignA-colAlign) % colStride;
1156         const Int localHeightA = A.LocalHeight();
1157         const Int sendSize = localHeightA*localWidth;
1158         const Int recvSize = localHeight *localWidth;
1159 
1160         T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1161         T* sendBuf = &buffer[0];
1162         T* recvBuf = &buffer[sendSize];
1163 
1164         // Pack
1165         ELEM_PARALLEL_FOR
1166         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1167             MemCopy
1168             ( &sendBuf[jLoc*localHeightA],
1169               &ABuf[(rowShift+jLoc*rowStride)*ALDim], localHeightA );
1170 
1171         // Realign
1172         mpi::SendRecv
1173         ( sendBuf, sendSize, sendColRank,
1174           recvBuf, recvSize, recvColRank, this->ColComm() );
1175 
1176         // Unpack
1177         ELEM_PARALLEL_FOR
1178         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1179             MemCopy
1180             ( &thisBuf[jLoc*ldim], &recvBuf[jLoc*localHeight], localHeight );
1181         this->auxMemory_.Release();
1182     }
1183 }
1184 
1185 template<typename T,Dist U,Dist V>
1186 void
PartialColFilterFrom(const DistMatrix<T,UPart,V> & A)1187 GeneralDistMatrix<T,U,V>::PartialColFilterFrom( const DistMatrix<T,UPart,V>& A )
1188 {
1189     DEBUG_ONLY(
1190         CallStackEntry cse("GDM::PartialColFilterFrom");
1191         this->AssertSameGrid( A.Grid() );
1192     )
1193     const Int height = A.Height();
1194     const Int width = A.Width();
1195     this->AlignColsAndResize( A.ColAlign(), height, width, false, false );
1196     if( !this->Participating() )
1197         return;
1198 
1199     const Int colAlign = this->ColAlign();
1200     const Int colAlignA = A.ColAlign();
1201     const Int colStride = this->ColStride();
1202     const Int colStridePart = this->PartialColStride();
1203     const Int colStrideUnion = this->PartialUnionColStride();
1204     const Int colShiftA = A.ColShift();
1205 
1206     const Int localHeight = this->LocalHeight();
1207 
1208     T* thisBuf = this->Buffer();
1209     const Int ldim = this->LDim();
1210     const T* ABuf = A.LockedBuffer();
1211     const Int ALDim = A.LDim();
1212     if( colAlign % colStridePart == colAlignA )
1213     {
1214         const Int colShift = this->ColShift();
1215         const Int colOffset = (colShift-colShiftA) / colStridePart;
1216         ELEM_PARALLEL_FOR
1217         for( Int j=0; j<width; ++j )
1218         {
1219             T* thisCol = &thisBuf[j*ldim];
1220             const T* ACol = &ABuf[colOffset+j*ALDim];
1221             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1222                 thisCol[iLoc] = ACol[iLoc*colStrideUnion];
1223         }
1224     }
1225     else
1226     {
1227 #ifdef ELEM_UNALIGNED_WARNINGS
1228         if( this->Grid().Rank() == 0 )
1229             std::cerr << "Unaligned PartialColFilterFrom" << std::endl;
1230 #endif
1231         const Int colRankPart = this->PartialColRank();
1232         const Int colRankUnion = this->PartialUnionColRank();
1233         const Int colShiftA = A.ColShift();
1234 
1235         // Realign
1236         // -------
1237         const Int sendColRankPart =
1238             (colRankPart+colStridePart+(colAlign%colStridePart)-colAlignA) %
1239             colStridePart;
1240         const Int recvColRankPart =
1241             (colRankPart+colStridePart+colAlignA-(colAlign%colStridePart)) %
1242             colStridePart;
1243         const Int sendColRank = sendColRankPart + colStridePart*colRankUnion;
1244         const Int sendColShift = Shift( sendColRank, colAlign, colStride );
1245         const Int sendColOffset = (sendColShift-colShiftA) / colStridePart;
1246         const Int localHeightSend = Length( height, sendColShift, colStride );
1247         const Int sendSize = localHeightSend*width;
1248         const Int recvSize = localHeight    *width;
1249         T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1250         T* sendBuf = &buffer[0];
1251         T* recvBuf = &buffer[sendSize];
1252         // Pack
1253         ELEM_PARALLEL_FOR
1254         for( Int j=0; j<width; ++j )
1255         {
1256             T* sendCol = &sendBuf[j*localHeightSend];
1257             const T* ACol = &ABuf[sendColOffset+j*ALDim];
1258             for( Int iLoc=0; iLoc<localHeightSend; ++iLoc )
1259                 sendCol[iLoc] = ACol[iLoc*colStrideUnion];
1260         }
1261         // Change the column alignment
1262         mpi::SendRecv
1263         ( sendBuf, sendSize, sendColRankPart,
1264           recvBuf, recvSize, recvColRankPart, this->PartialColComm() );
1265 
1266         // Unpack
1267         // ------
1268         ELEM_PARALLEL_FOR
1269         for( Int j=0; j<width; ++j )
1270         {
1271             const T* recvCol = &recvBuf[j*localHeight];
1272             T* thisCol = &thisBuf[j*ldim];
1273             MemCopy( thisCol, recvCol, localHeight );
1274         }
1275         this->auxMemory_.Release();
1276     }
1277 }
1278 
1279 template<typename T,Dist U,Dist V>
1280 void
PartialRowFilterFrom(const DistMatrix<T,U,VPart> & A)1281 GeneralDistMatrix<T,U,V>::PartialRowFilterFrom( const DistMatrix<T,U,VPart>& A )
1282 {
1283     DEBUG_ONLY(
1284         CallStackEntry cse("GDM::PartialRowFilterFrom");
1285         this->AssertSameGrid( A.Grid() );
1286     )
1287     const Int height = A.Height();
1288     const Int width = A.Width();
1289     this->AlignRowsAndResize( A.RowAlign(), height, width, false, false );
1290     if( !this->Participating() )
1291         return;
1292 
1293     const Int rowAlign = this->RowAlign();
1294     const Int rowAlignA = A.RowAlign();
1295     const Int rowStride = this->RowStride();
1296     const Int rowStridePart = this->PartialRowStride();
1297     const Int rowStrideUnion = this->PartialUnionRowStride();
1298     const Int rowShiftA = A.RowShift();
1299 
1300     const Int localWidth = this->LocalWidth();
1301 
1302     T* thisBuf = this->Buffer();
1303     const Int ldim = this->LDim();
1304     const T* ABuf = A.LockedBuffer();
1305     const Int ALDim = A.LDim();
1306     if( rowAlign % rowStridePart == rowAlignA )
1307     {
1308         const Int rowShift = this->RowShift();
1309         const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1310         ELEM_PARALLEL_FOR
1311         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1312         {
1313             T* thisCol = &thisBuf[jLoc*ldim];
1314             const T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
1315             MemCopy( thisCol, ACol, height );
1316         }
1317     }
1318     else
1319     {
1320 #ifdef ELEM_UNALIGNED_WARNINGS
1321         if( this->Grid().Rank() == 0 )
1322             std::cerr << "Unaligned PartialRowFilterFrom" << std::endl;
1323 #endif
1324         const Int rowRankPart = this->PartialRowRank();
1325         const Int rowRankUnion = this->PartialUnionRowRank();
1326         const Int rowShiftA = A.RowShift();
1327 
1328         // Realign
1329         // -------
1330         const Int sendRowRankPart =
1331             (rowRankPart+rowStridePart+(rowAlign%rowStridePart)-rowAlignA) %
1332             rowStridePart;
1333         const Int recvRowRankPart =
1334             (rowRankPart+rowStridePart+rowAlignA-(rowAlign%rowStridePart)) %
1335             rowStridePart;
1336         const Int sendRowRank = sendRowRankPart + rowStridePart*rowRankUnion;
1337         const Int sendRowShift = Shift( sendRowRank, rowAlign, rowStride );
1338         const Int sendRowOffset = (sendRowShift-rowShiftA) / rowStridePart;
1339         const Int localWidthSend = Length( width, sendRowShift, rowStride );
1340         const Int sendSize = height*localWidthSend;
1341         const Int recvSize = height*localWidth;
1342         T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1343         T* sendBuf = &buffer[0];
1344         T* recvBuf = &buffer[sendSize];
1345         // Pack
1346         ELEM_PARALLEL_FOR
1347         for( Int jLoc=0; jLoc<localWidthSend; ++jLoc )
1348         {
1349             T* sendCol = &sendBuf[jLoc*height];
1350             const T* ACol = &ABuf[(sendRowOffset+jLoc*rowStrideUnion)*ALDim];
1351             MemCopy( sendCol, ACol, height );
1352         }
1353         // Change the column alignment
1354         mpi::SendRecv
1355         ( sendBuf, sendSize, sendRowRankPart,
1356           recvBuf, recvSize, recvRowRankPart, this->PartialRowComm() );
1357 
1358         // Unpack
1359         // ------
1360         ELEM_PARALLEL_FOR
1361         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1362         {
1363             const T* recvCol = &recvBuf[jLoc*height];
1364             T* thisCol = &thisBuf[jLoc*ldim];
1365             MemCopy( thisCol, recvCol, height );
1366         }
1367         this->auxMemory_.Release();
1368     }
1369 }
1370 
1371 template<typename T,Dist U,Dist V>
1372 void
PartialColAllToAllFrom(const DistMatrix<T,UPart,VScat> & A)1373 GeneralDistMatrix<T,U,V>::PartialColAllToAllFrom
1374 ( const DistMatrix<T,UPart,VScat>& A )
1375 {
1376     DEBUG_ONLY(
1377         CallStackEntry cse("GDM::PartialColAllToAllFrom");
1378         this->AssertSameGrid( A.Grid() );
1379     )
1380     const Int height = A.Height();
1381     const Int width = A.Width();
1382     this->AlignColsAndResize( A.ColAlign(), height, width, false, false );
1383     if( !this->Participating() )
1384         return;
1385 
1386     const Int colAlign = this->ColAlign();
1387     const Int colAlignA = A.ColAlign();
1388     const Int rowAlignA = A.RowAlign();
1389 
1390     const Int colStride = this->ColStride();
1391     const Int colStridePart = this->PartialColStride();
1392     const Int colStrideUnion = this->PartialUnionColStride();
1393     const Int colRankPart = this->PartialColRank();
1394 
1395     const Int colShiftA = A.ColShift();
1396 
1397     const Int thisLocalHeight = this->LocalHeight();
1398     const Int localWidthA = A.LocalWidth();
1399     const Int maxLocalHeight = MaxLength(height,colStride);
1400     const Int maxLocalWidth = MaxLength(width,colStrideUnion);
1401     const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1402 
1403     T* thisBuf = this->Buffer();
1404     const Int ldim = this->LDim();
1405     const T* ABuf = A.LockedBuffer();
1406     const Int ALDim = A.LDim();
1407 
1408     T* buffer = this->auxMemory_.Require( 2*colStrideUnion*portionSize );
1409     T* firstBuf  = &buffer[0];
1410     T* secondBuf = &buffer[colStrideUnion*portionSize];
1411 
1412     if( colAlign % colStridePart == colAlignA )
1413     {
1414         // Pack
1415         ELEM_OUTER_PARALLEL_FOR
1416         for( Int k=0; k<colStrideUnion; ++k )
1417         {
1418             T* data = &firstBuf[k*portionSize];
1419             const Int colRank = colRankPart + k*colStridePart;
1420             const Int colShift = Shift_( colRank, colAlign, colStride );
1421             const Int colOffset = (colShift-colShiftA) / colStridePart;
1422             const Int localHeight = Length_( height, colShift, colStride );
1423             ELEM_INNER_PARALLEL_FOR
1424             for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1425             {
1426                 T* dataCol = &data[jLoc*localHeight];
1427                 const T* ACol = &ABuf[colOffset+jLoc*ALDim];
1428                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1429                     dataCol[iLoc] = ACol[iLoc*colStrideUnion];
1430             }
1431         }
1432 
1433         // Simultaneously Scatter in columns and Gather in rows
1434         mpi::AllToAll
1435         ( firstBuf,  portionSize,
1436           secondBuf, portionSize, this->PartialUnionColComm() );
1437 
1438         // Unpack
1439         ELEM_OUTER_PARALLEL_FOR
1440         for( Int k=0; k<colStrideUnion; ++k )
1441         {
1442             const T* data = &secondBuf[k*portionSize];
1443             const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1444             const Int localWidth = Length_( width, rowShift, colStrideUnion );
1445             ELEM_INNER_PARALLEL_FOR
1446             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1447             {
1448                 const T* dataCol = &data[jLoc*thisLocalHeight];
1449                 T* thisCol = &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim];
1450                 MemCopy( thisCol, dataCol, thisLocalHeight );
1451             }
1452         }
1453     }
1454     else
1455     {
1456 #ifdef ELEM_UNALIGNED_WARNINGS
1457         if( this->Grid().Rank() == 0 )
1458             std::cerr << "Unaligned PartialColAllToAllFrom" << std::endl;
1459 #endif
1460         const Int sendColRankPart =
1461             (colRankPart+colStridePart+(colAlign%colStridePart)-colAlignA) %
1462             colStridePart;
1463         const Int recvColRankPart =
1464             (colRankPart+colStridePart+colAlignA-(colAlign%colStridePart)) %
1465             colStridePart;
1466 
1467         // Pack
1468         ELEM_OUTER_PARALLEL_FOR
1469         for( Int k=0; k<colStrideUnion; ++k )
1470         {
1471             T* data = &secondBuf[k*portionSize];
1472             const Int colRank = sendColRankPart + k*colStridePart;
1473             const Int colShift = Shift_( colRank, colAlign, colStride );
1474             const Int colOffset = (colShift-colShiftA) / colStridePart;
1475             const Int localHeight = Length_( height, colShift, colStride );
1476             ELEM_INNER_PARALLEL_FOR
1477             for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1478             {
1479                 T* dataCol = &data[jLoc*localHeight];
1480                 const T* ACol = &ABuf[colOffset+jLoc*ALDim];
1481                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1482                     dataCol[iLoc] = ACol[iLoc*colStrideUnion];
1483             }
1484         }
1485 
1486         // Simultaneously Scatter in columns and Gather in rows
1487         mpi::AllToAll
1488         ( secondBuf, portionSize,
1489           firstBuf,  portionSize, this->PartialUnionColComm() );
1490 
1491         // Realign the result
1492         mpi::SendRecv
1493         ( firstBuf,  colStrideUnion*portionSize, sendColRankPart,
1494           secondBuf, colStrideUnion*portionSize, recvColRankPart,
1495           this->PartialColComm() );
1496 
1497         // Unpack
1498         ELEM_OUTER_PARALLEL_FOR
1499         for( Int k=0; k<colStrideUnion; ++k )
1500         {
1501             const T* data = &secondBuf[k*portionSize];
1502             const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1503             const Int localWidth = Length_( width, rowShift, colStrideUnion );
1504             ELEM_INNER_PARALLEL_FOR
1505             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1506             {
1507                 const T* dataCol = &data[jLoc*thisLocalHeight];
1508                 T* thisCol = &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim];
1509                 MemCopy( thisCol, dataCol, thisLocalHeight );
1510             }
1511         }
1512     }
1513     this->auxMemory_.Release();
1514 }
1515 
1516 template<typename T,Dist U,Dist V>
1517 void
PartialRowAllToAllFrom(const DistMatrix<T,UScat,VPart> & A)1518 GeneralDistMatrix<T,U,V>::PartialRowAllToAllFrom
1519 ( const DistMatrix<T,UScat,VPart>& A )
1520 {
1521     DEBUG_ONLY(
1522         CallStackEntry cse("GDM::PartialRowAllToAllFrom");
1523         this->AssertSameGrid( A.Grid() );
1524     )
1525     const Int height = A.Height();
1526     const Int width = A.Width();
1527     this->AlignRowsAndResize( A.RowAlign(), height, width, false, false );
1528     if( !this->Participating() )
1529         return;
1530 
1531     const Int rowAlign = this->RowAlign();
1532     const Int rowAlignA = A.RowAlign();
1533     const Int colAlignA = A.ColAlign();
1534 
1535     const Int rowStride = this->RowStride();
1536     const Int rowStridePart = this->PartialRowStride();
1537     const Int rowStrideUnion = this->PartialUnionRowStride();
1538     const Int rowRankPart = this->PartialRowRank();
1539 
1540     const Int rowShiftA = A.RowShift();
1541 
1542     const Int thisLocalWidth = this->LocalWidth();
1543     const Int localHeightA = A.LocalHeight();
1544     const Int maxLocalHeight = MaxLength(height,rowStrideUnion);
1545     const Int maxLocalWidth = MaxLength(width,rowStride);
1546     const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1547 
1548     T* thisBuf = this->Buffer();
1549     const Int ldim = this->LDim();
1550     const T* ABuf = A.LockedBuffer();
1551     const Int ALDim = A.LDim();
1552 
1553     T* buffer = this->auxMemory_.Require( 2*rowStrideUnion*portionSize );
1554     T* firstBuf  = &buffer[0];
1555     T* secondBuf = &buffer[rowStrideUnion*portionSize];
1556 
1557     if( rowAlign % rowStridePart == rowAlignA )
1558     {
1559         // Pack
1560         ELEM_OUTER_PARALLEL_FOR
1561         for( Int k=0; k<rowStrideUnion; ++k )
1562         {
1563             T* data = &firstBuf[k*portionSize];
1564             const Int rowRank = rowRankPart + k*rowStridePart;
1565             const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1566             const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1567             const Int localWidth = Length_( width, rowShift, rowStride );
1568             ELEM_INNER_PARALLEL_FOR
1569             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1570             {
1571                 T* dataCol = &data[jLoc*localHeightA];
1572                 const T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
1573                 MemCopy( dataCol, ACol, localHeightA );
1574             }
1575         }
1576 
1577         // Simultaneously Scatter in rows and Gather in columns
1578         mpi::AllToAll
1579         ( firstBuf,  portionSize,
1580           secondBuf, portionSize, this->PartialUnionRowComm() );
1581 
1582         // Unpack
1583         ELEM_OUTER_PARALLEL_FOR
1584         for( Int k=0; k<rowStrideUnion; ++k )
1585         {
1586             const T* data = &secondBuf[k*portionSize];
1587             const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1588             const Int localHeight = Length_( height, colShift, rowStrideUnion );
1589             ELEM_INNER_PARALLEL_FOR
1590             for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1591             {
1592                 const T* dataCol = &data[jLoc*localHeight];
1593                 T* thisCol = &thisBuf[colShift+jLoc*ldim];
1594                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1595                     thisCol[iLoc*rowStrideUnion] = dataCol[iLoc];
1596             }
1597         }
1598     }
1599     else
1600     {
1601 #ifdef ELEM_UNALIGNED_WARNINGS
1602         if( this->Grid().Rank() == 0 )
1603             std::cerr << "Unaligned PartialRowAllToAllFrom" << std::endl;
1604 #endif
1605         const Int sendRowRankPart =
1606             (rowRankPart+rowStridePart+(rowAlign%rowStridePart)-rowAlignA) %
1607             rowStridePart;
1608         const Int recvRowRankPart =
1609             (rowRankPart+rowStridePart+rowAlignA-(rowAlign%rowStridePart)) %
1610             rowStridePart;
1611 
1612         // Pack
1613         ELEM_OUTER_PARALLEL_FOR
1614         for( Int k=0; k<rowStrideUnion; ++k )
1615         {
1616             T* data = &secondBuf[k*portionSize];
1617             const Int rowRank = sendRowRankPart + k*rowStridePart;
1618             const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1619             const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1620             const Int localWidth = Length_( width, rowShift, rowStride );
1621             ELEM_INNER_PARALLEL_FOR
1622             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1623             {
1624                 T* dataCol = &data[jLoc*localHeightA];
1625                 const T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
1626                 MemCopy( dataCol, ACol, localHeightA );
1627             }
1628         }
1629 
1630         // Simultaneously Scatter in rows and Gather in columns
1631         mpi::AllToAll
1632         ( secondBuf, portionSize,
1633           firstBuf,  portionSize, this->PartialUnionRowComm() );
1634 
1635         // Realign the result
1636         mpi::SendRecv
1637         ( firstBuf,  rowStrideUnion*portionSize, sendRowRankPart,
1638           secondBuf, rowStrideUnion*portionSize, recvRowRankPart,
1639           this->PartialRowComm() );
1640 
1641         // Unpack
1642         ELEM_OUTER_PARALLEL_FOR
1643         for( Int k=0; k<rowStrideUnion; ++k )
1644         {
1645             const T* data = &secondBuf[k*portionSize];
1646             const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1647             const Int localHeight = Length_( height, colShift, rowStrideUnion );
1648             ELEM_INNER_PARALLEL_FOR
1649             for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1650             {
1651                 const T* dataCol = &data[jLoc*localHeight];
1652                 T* thisCol = &thisBuf[colShift+jLoc*ldim];
1653                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1654                     thisCol[iLoc*rowStrideUnion] = dataCol[iLoc];
1655             }
1656         }
1657     }
1658     this->auxMemory_.Release();
1659 }
1660 
1661 template<typename T,Dist U,Dist V>
1662 void
PartialColAllToAll(DistMatrix<T,UPart,VScat> & A) const1663 GeneralDistMatrix<T,U,V>::PartialColAllToAll
1664 ( DistMatrix<T,UPart,VScat>& A ) const
1665 {
1666     DEBUG_ONLY(
1667         CallStackEntry cse("GDM::PartialColAllToAll");
1668         this->AssertSameGrid( A.Grid() );
1669     )
1670     const Int height = this->Height();
1671     const Int width = this->Width();
1672     A.AlignColsAndResize
1673     ( this->ColAlign()%A.ColStride(), height, width, false, false );
1674     if( !A.Participating() )
1675         return;
1676 
1677     const Int colAlign = this->ColAlign();
1678     const Int colAlignA = A.ColAlign();
1679     const Int rowAlignA = A.RowAlign();
1680 
1681     const Int colStride = this->ColStride();
1682     const Int colStridePart = this->PartialColStride();
1683     const Int colStrideUnion = this->PartialUnionColStride();
1684     const Int colRankPart = this->PartialColRank();
1685 
1686     const Int colShiftA = A.ColShift();
1687 
1688     const Int thisLocalHeight = this->LocalHeight();
1689     const Int localWidthA = A.LocalWidth();
1690     const Int maxLocalHeight = MaxLength(height,colStride);
1691     const Int maxLocalWidth = MaxLength(width,colStrideUnion);
1692     const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1693 
1694     const T* thisBuf = this->LockedBuffer();
1695     const Int ldim = this->LDim();
1696     T* ABuf = A.Buffer();
1697     const Int ALDim = A.LDim();
1698 
1699     T* buffer = A.auxMemory_.Require( 2*colStrideUnion*portionSize );
1700     T* firstBuf  = &buffer[0];
1701     T* secondBuf = &buffer[colStrideUnion*portionSize];
1702 
1703     if( colAlignA == colAlign % colStridePart )
1704     {
1705         // Pack
1706         ELEM_OUTER_PARALLEL_FOR
1707         for( Int k=0; k<colStrideUnion; ++k )
1708         {
1709             T* data = &firstBuf[k*portionSize];
1710             const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1711             const Int localWidth = Length_( width, rowShift, colStrideUnion );
1712             ELEM_INNER_PARALLEL_FOR
1713             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1714                 MemCopy
1715                 ( &data[jLoc*thisLocalHeight],
1716                   &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim],
1717                   thisLocalHeight );
1718         }
1719 
1720         // Simultaneously Gather in columns and Scatter in rows
1721         mpi::AllToAll
1722         ( firstBuf,  portionSize,
1723           secondBuf, portionSize, this->PartialUnionColComm() );
1724 
1725         // Unpack
1726         ELEM_OUTER_PARALLEL_FOR
1727         for( Int k=0; k<colStrideUnion; ++k )
1728         {
1729             const T* data = &secondBuf[k*portionSize];
1730             const Int colRank = colRankPart + k*colStridePart;
1731             const Int colShift = Shift_( colRank, colAlign, colStride );
1732             const Int colOffset = (colShift-colShiftA) / colStridePart;
1733             const Int localHeight = Length_( height, colShift, colStride );
1734             ELEM_INNER_PARALLEL_FOR
1735             for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1736             {
1737                 T* ACol = &ABuf[colOffset+jLoc*ALDim];
1738                 const T* dataCol = &data[jLoc*localHeight];
1739                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1740                     ACol[iLoc*colStrideUnion] = dataCol[iLoc];
1741             }
1742         }
1743     }
1744     else
1745     {
1746 #ifdef ELEM_UNALIGNED_WARNINGS
1747         if( this->Grid().Rank() == 0 )
1748             std::cerr << "Unaligned PartialColAllToAll" << std::endl;
1749 #endif
1750         const Int colAlignDiff = colAlignA - (colAlign%colStridePart);
1751         const Int sendColRankPart =
1752             (colRankPart+colStridePart+colAlignDiff) % colStridePart;
1753         const Int recvColRankPart =
1754             (colRankPart+colStridePart-colAlignDiff) % colStridePart;
1755 
1756         // Pack
1757         ELEM_OUTER_PARALLEL_FOR
1758         for( Int k=0; k<colStrideUnion; ++k )
1759         {
1760             T* data = &secondBuf[k*portionSize];
1761             const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1762             const Int localWidth = Length_( width, rowShift, colStrideUnion );
1763             ELEM_INNER_PARALLEL_FOR
1764             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1765                 MemCopy
1766                 ( &data[jLoc*thisLocalHeight],
1767                   &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim],
1768                   thisLocalHeight );
1769         }
1770 
1771         // Realign the input
1772         mpi::SendRecv
1773         ( secondBuf, colStrideUnion*portionSize, sendColRankPart,
1774           firstBuf,  colStrideUnion*portionSize, recvColRankPart,
1775           this->PartialColComm() );
1776 
1777         // Simultaneously Scatter in columns and Gather in rows
1778         mpi::AllToAll
1779         ( firstBuf,  portionSize,
1780           secondBuf, portionSize, this->PartialUnionColComm() );
1781 
1782         // Unpack
1783         ELEM_OUTER_PARALLEL_FOR
1784         for( Int k=0; k<colStrideUnion; ++k )
1785         {
1786             const T* data = &secondBuf[k*portionSize];
1787             const Int colRank = recvColRankPart + k*colStridePart;
1788             const Int colShift = Shift_( colRank, colAlign, colStride );
1789             const Int colOffset = (colShift-colShiftA) / colStridePart;
1790             const Int localHeight = Length_( height, colShift, colStride );
1791             ELEM_INNER_PARALLEL_FOR
1792             for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1793             {
1794                 T* ACol = &ABuf[colOffset+jLoc*ALDim];
1795                 const T* dataCol = &data[jLoc*localHeight];
1796                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1797                     ACol[iLoc*colStrideUnion] = dataCol[iLoc];
1798             }
1799         }
1800     }
1801     A.auxMemory_.Release();
1802 }
1803 
1804 template<typename T,Dist U,Dist V>
1805 void
PartialRowAllToAll(DistMatrix<T,UScat,VPart> & A) const1806 GeneralDistMatrix<T,U,V>::PartialRowAllToAll
1807 ( DistMatrix<T,UScat,VPart>& A ) const
1808 {
1809     DEBUG_ONLY(
1810         CallStackEntry cse("GDM::PartialRowAllToAll");
1811         this->AssertSameGrid( A.Grid() );
1812     )
1813     const Int height = this->Height();
1814     const Int width = this->Width();
1815     A.AlignRowsAndResize
1816     ( this->RowAlign()%A.RowStride(), height, width, false, false );
1817     if( !A.Participating() )
1818         return;
1819 
1820     const Int colAlignA = A.ColAlign();
1821     const Int rowAlign = this->RowAlign();
1822     const Int rowAlignA = A.RowAlign();
1823 
1824     const Int rowStride = this->RowStride();
1825     const Int rowStridePart = this->PartialRowStride();
1826     const Int rowStrideUnion = this->PartialUnionRowStride();
1827     const Int rowRankPart = this->PartialRowRank();
1828 
1829     const Int rowShiftA = A.RowShift();
1830 
1831     const Int thisLocalWidth = this->LocalWidth();
1832     const Int localHeightA = A.LocalHeight();
1833     const Int maxLocalWidth = MaxLength(width,rowStride);
1834     const Int maxLocalHeight = MaxLength(height,rowStrideUnion);
1835     const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1836 
1837     const T* thisBuf = this->LockedBuffer();
1838     const Int ldim = this->LDim();
1839     T* ABuf = A.Buffer();
1840     const Int ALDim = A.LDim();
1841 
1842     T* buffer = A.auxMemory_.Require( 2*rowStrideUnion*portionSize );
1843     T* firstBuf  = &buffer[0];
1844     T* secondBuf = &buffer[rowStrideUnion*portionSize];
1845 
1846     if( rowAlignA == rowAlign % rowStridePart )
1847     {
1848         // Pack
1849         ELEM_OUTER_PARALLEL_FOR
1850         for( Int k=0; k<rowStrideUnion; ++k )
1851         {
1852             T* data = &firstBuf[k*portionSize];
1853             const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1854             const Int localHeight = Length_( height, colShift, rowStrideUnion );
1855             ELEM_INNER_PARALLEL_FOR
1856             for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1857             {
1858                 T* dataCol = &data[jLoc*localHeight];
1859                 const T* thisCol = &thisBuf[colShift+jLoc*ldim];
1860                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1861                     dataCol[iLoc] = thisCol[iLoc*rowStrideUnion];
1862             }
1863         }
1864 
1865         // Simultaneously Gather in rows and Scatter in columns
1866         mpi::AllToAll
1867         ( firstBuf,  portionSize,
1868           secondBuf, portionSize, this->PartialUnionRowComm() );
1869 
1870         // Unpack
1871         ELEM_OUTER_PARALLEL_FOR
1872         for( Int k=0; k<rowStrideUnion; ++k )
1873         {
1874             const T* data = &secondBuf[k*portionSize];
1875             const Int rowRank = rowRankPart + k*rowStridePart;
1876             const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1877             const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1878             const Int localWidth = Length_( width, rowShift, rowStride );
1879             ELEM_INNER_PARALLEL_FOR
1880             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1881                 MemCopy
1882                 ( &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim],
1883                   &data[jLoc*localHeightA], localHeightA );
1884         }
1885     }
1886     else
1887     {
1888 #ifdef ELEM_UNALIGNED_WARNINGS
1889         if( this->Grid().Rank() == 0 )
1890             std::cerr << "Unaligned PartialRowAllToAll" << std::endl;
1891 #endif
1892         const Int rowAlignDiff = rowAlignA - (rowAlign%rowStridePart);
1893         const Int sendRowRankPart =
1894             (rowRankPart+rowStridePart+rowAlignDiff) % rowStridePart;
1895         const Int recvRowRankPart =
1896             (rowRankPart+rowStridePart-rowAlignDiff) % rowStridePart;
1897 
1898         // Pack
1899         ELEM_OUTER_PARALLEL_FOR
1900         for( Int k=0; k<rowStrideUnion; ++k )
1901         {
1902             T* data = &secondBuf[k*portionSize];
1903             const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1904             const Int localHeight = Length_( height, colShift, rowStrideUnion );
1905             ELEM_INNER_PARALLEL_FOR
1906             for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1907             {
1908                 T* dataCol = &data[jLoc*localHeight];
1909                 const T* sourceCol = &thisBuf[colShift+jLoc*ldim];
1910                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1911                     dataCol[iLoc] = sourceCol[iLoc*rowStrideUnion];
1912             }
1913         }
1914 
1915         // Realign the input
1916         mpi::SendRecv
1917         ( secondBuf, rowStrideUnion*portionSize, sendRowRankPart,
1918           firstBuf,  rowStrideUnion*portionSize, recvRowRankPart,
1919           this->PartialRowComm() );
1920 
1921         // Simultaneously Scatter in rows and Gather in columns
1922         mpi::AllToAll
1923         ( firstBuf,  portionSize,
1924           secondBuf, portionSize, this->PartialUnionRowComm() );
1925 
1926         // Unpack
1927         ELEM_OUTER_PARALLEL_FOR
1928         for( Int k=0; k<rowStrideUnion; ++k )
1929         {
1930             const T* data = &secondBuf[k*portionSize];
1931             const Int rowRank = recvRowRankPart + k*rowStridePart;
1932             const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1933             const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1934             const Int localWidth = Length_( width, rowShift, rowStride );
1935             ELEM_INNER_PARALLEL_FOR
1936             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1937                 MemCopy
1938                 ( &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim],
1939                   &data[jLoc*localHeightA], localHeightA );
1940         }
1941     }
1942     A.auxMemory_.Release();
1943 }
1944 
1945 template<typename T,Dist U,Dist V>
1946 void
RowSumScatterFrom(const DistMatrix<T,U,VGath> & A)1947 GeneralDistMatrix<T,U,V>::RowSumScatterFrom( const DistMatrix<T,U,VGath>& A )
1948 {
1949     DEBUG_ONLY(
1950         CallStackEntry cse("GDM::RowSumScatterFrom");
1951         this->AssertSameGrid( A.Grid() );
1952     )
1953     this->AlignColsAndResize
1954     ( A.ColAlign(), A.Height(), A.Width(), false, false );
1955     // NOTE: This will be *slightly* slower than necessary due to the result
1956     //       of the MPI operations being added rather than just copied
1957     Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
1958     this->RowSumScatterUpdate( T(1), A );
1959 }
1960 
1961 template<typename T,Dist U,Dist V>
1962 void
ColSumScatterFrom(const DistMatrix<T,UGath,V> & A)1963 GeneralDistMatrix<T,U,V>::ColSumScatterFrom( const DistMatrix<T,UGath,V>& A )
1964 {
1965     DEBUG_ONLY(
1966         CallStackEntry cse("GDM::ColSumScatterFrom");
1967         this->AssertSameGrid( A.Grid() );
1968     )
1969     this->AlignRowsAndResize
1970     ( A.RowAlign(), A.Height(), A.Width(), false, false );
1971     // NOTE: This will be *slightly* slower than necessary due to the result
1972     //       of the MPI operations being added rather than just copied
1973     Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
1974     this->ColSumScatterUpdate( T(1), A );
1975 }
1976 
1977 template<typename T,Dist U,Dist V>
1978 void
SumScatterFrom(const DistMatrix<T,UGath,VGath> & A)1979 GeneralDistMatrix<T,U,V>::SumScatterFrom( const DistMatrix<T,UGath,VGath>& A )
1980 {
1981     DEBUG_ONLY(
1982         CallStackEntry cse("GDM::SumScatterFrom");
1983         this->AssertSameGrid( A.Grid() );
1984     )
1985     this->Resize( A.Height(), A.Width() );
1986     // NOTE: This will be *slightly* slower than necessary due to the result
1987     //       of the MPI operations being added rather than just copied
1988     Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
1989     this->SumScatterUpdate( T(1), A );
1990 }
1991 
1992 template<typename T,Dist U,Dist V>
1993 void
PartialRowSumScatterFrom(const DistMatrix<T,U,VPart> & A)1994 GeneralDistMatrix<T,U,V>::PartialRowSumScatterFrom
1995 ( const DistMatrix<T,U,VPart>& A )
1996 {
1997     DEBUG_ONLY(
1998         CallStackEntry cse("GDM::PartialRowSumScatterFrom");
1999         this->AssertSameGrid( A.Grid() );
2000     )
2001     this->AlignAndResize
2002     ( A.ColAlign(), A.RowAlign(), A.Height(), A.Width(), false, false );
2003     // NOTE: This will be *slightly* slower than necessary due to the result
2004     //       of the MPI operations being added rather than just copied
2005     Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
2006     this->PartialRowSumScatterUpdate( T(1), A );
2007 }
2008 
2009 template<typename T,Dist U,Dist V>
2010 void
PartialColSumScatterFrom(const DistMatrix<T,UPart,V> & A)2011 GeneralDistMatrix<T,U,V>::PartialColSumScatterFrom
2012 ( const DistMatrix<T,UPart,V>& A )
2013 {
2014     DEBUG_ONLY(
2015         CallStackEntry cse("GDM::PartialColSumScatterFrom");
2016         this->AssertSameGrid( A.Grid() );
2017     )
2018     this->AlignAndResize
2019     ( A.ColAlign(), A.RowAlign(), A.Height(), A.Width(), false, false );
2020     // NOTE: This will be *slightly* slower than necessary due to the result
2021     //       of the MPI operations being added rather than just copied
2022     Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
2023     this->PartialColSumScatterUpdate( T(1), A );
2024 }
2025 
2026 template<typename T,Dist U,Dist V>
2027 void
RowSumScatterUpdate(T alpha,const DistMatrix<T,U,VGath> & A)2028 GeneralDistMatrix<T,U,V>::RowSumScatterUpdate
2029 ( T alpha, const DistMatrix<T,U,VGath>& A )
2030 {
2031     DEBUG_ONLY(
2032         CallStackEntry cse("GDM::RowSumScatterUpdate");
2033         this->AssertNotLocked();
2034         this->AssertSameGrid( A.Grid() );
2035         this->AssertSameSize( A.Height(), A.Width() );
2036     )
2037     if( !this->Participating() )
2038         return;
2039 
2040     if( this->ColAlign() == A.ColAlign() )
2041     {
2042         if( this->Width() == 1 )
2043         {
2044             const Int rowAlign = this->RowAlign();
2045             const Int rowRank = this->RowRank();
2046 
2047             const Int localHeight = this->LocalHeight();
2048             const Int portionSize = mpi::Pad( localHeight );
2049             T* buffer = this->auxMemory_.Require( 2*portionSize );
2050             T* sendBuf = &buffer[0];
2051             T* recvBuf = &buffer[portionSize];
2052 
2053             // Pack
2054             const T* ACol = A.LockedBuffer();
2055             MemCopy( sendBuf, ACol, localHeight );
2056 
2057             // Reduce to rowAlign
2058             mpi::Reduce
2059             ( sendBuf, recvBuf, portionSize, rowAlign, this->RowComm() );
2060 
2061             if( rowRank == rowAlign )
2062             {
2063                 T* thisCol = this->Buffer();
2064                 ELEM_FMA_PARALLEL_FOR
2065                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2066                     thisCol[iLoc] += alpha*recvBuf[iLoc];
2067             }
2068 
2069             this->auxMemory_.Release();
2070         }
2071         else
2072         {
2073             const Int rowStride = this->RowStride();
2074             const Int rowAlign = this->RowAlign();
2075 
2076             const Int width = this->Width();
2077             const Int localHeight = this->LocalHeight();
2078             const Int localWidth = this->LocalWidth();
2079             const Int maxLocalWidth = MaxLength(width,rowStride);
2080 
2081             const Int portionSize = mpi::Pad( localHeight*maxLocalWidth );
2082             const Int sendSize = rowStride*portionSize;
2083 
2084             // Pack
2085             const Int ALDim = A.LDim();
2086             const T* ABuffer = A.LockedBuffer();
2087             T* buffer = this->auxMemory_.Require( sendSize );
2088             ELEM_OUTER_PARALLEL_FOR
2089             for( Int k=0; k<rowStride; ++k )
2090             {
2091                 T* data = &buffer[k*portionSize];
2092                 const Int thisRowShift = Shift_( k, rowAlign, rowStride );
2093                 const Int thisLocalWidth =
2094                     Length_(width,thisRowShift,rowStride);
2095                 ELEM_INNER_PARALLEL_FOR
2096                 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2097                 {
2098                     const T* ACol =
2099                         &ABuffer[(thisRowShift+jLoc*rowStride)*ALDim];
2100                     T* dataCol = &data[jLoc*localHeight];
2101                     MemCopy( dataCol, ACol, localHeight );
2102                 }
2103             }
2104             // Communicate
2105             mpi::ReduceScatter( buffer, portionSize, this->RowComm() );
2106 
2107             // Update with our received data
2108             T* thisBuffer = this->Buffer();
2109             const Int thisLDim = this->LDim();
2110             ELEM_PARALLEL_FOR
2111             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2112             {
2113                 const T* bufferCol = &buffer[jLoc*localHeight];
2114                 T* thisCol = &thisBuffer[jLoc*thisLDim];
2115                 blas::Axpy( localHeight, alpha, bufferCol, 1, thisCol, 1 );
2116             }
2117             this->auxMemory_.Release();
2118         }
2119     }
2120     else
2121     {
2122 #ifdef ELEM_UNALIGNED_WARNINGS
2123         if( this->Grid().Rank() == 0 )
2124             std::cerr << "Unaligned RowSumScatterUpdate" << std::endl;
2125 #endif
2126         if( this->Width() == 1 )
2127         {
2128             const Int colStride = this->ColStride();
2129             const Int rowAlign = this->RowAlign();
2130             const Int colRank = this->ColRank();
2131             const Int rowRank = this->RowRank();
2132 
2133             const Int height = this->Height();
2134             const Int localHeight = this->LocalHeight();
2135             const Int localHeightA = A.LocalHeight();
2136             const Int maxLocalHeight = MaxLength(height,colStride);
2137             const Int portionSize = mpi::Pad( maxLocalHeight );
2138 
2139             const Int colAlign = this->ColAlign();
2140             const Int colAlignA = A.ColAlign();
2141             const Int sendRow =
2142                 (colRank+colStride+colAlign-colAlignA) % colStride;
2143             const Int recvRow =
2144                 (colRank+colStride+colAlignA-colAlign) % colStride;
2145 
2146             T* buffer = this->auxMemory_.Require( 2*portionSize );
2147             T* sendBuf = &buffer[0];
2148             T* recvBuf = &buffer[portionSize];
2149 
2150             // Pack
2151             const T* ACol = A.LockedBuffer();
2152             MemCopy( sendBuf, ACol, localHeightA );
2153 
2154             // Reduce to rowAlign
2155             mpi::Reduce
2156             ( sendBuf, recvBuf, portionSize, rowAlign, this->RowComm() );
2157 
2158             if( rowRank == rowAlign )
2159             {
2160                 // Perform the realignment
2161                 mpi::SendRecv
2162                 ( recvBuf, portionSize, sendRow,
2163                   sendBuf, portionSize, recvRow, this->ColComm() );
2164 
2165                 T* thisCol = this->Buffer();
2166                 ELEM_FMA_PARALLEL_FOR
2167                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2168                     thisCol[iLoc] += alpha*sendBuf[iLoc];
2169             }
2170             this->auxMemory_.Release();
2171         }
2172         else
2173         {
2174             const Int colStride = this->ColStride();
2175             const Int rowStride = this->RowStride();
2176             const Int colRank = this->ColRank();
2177 
2178             const Int colAlign = this->ColAlign();
2179             const Int rowAlign = this->RowAlign();
2180             const Int colAlignA = A.ColAlign();
2181             const Int sendRow =
2182                 (colRank+colStride+colAlign-colAlignA) % colStride;
2183             const Int recvRow =
2184                 (colRank+colStride+colAlignA-colAlign) % colStride;
2185 
2186             const Int width = this->Width();
2187             const Int localHeight = this->LocalHeight();
2188             const Int localWidth = this->LocalWidth();
2189             const Int localHeightA = A.LocalHeight();
2190             const Int maxLocalWidth = MaxLength(width,rowStride);
2191 
2192             const Int recvSize_RS = mpi::Pad( localHeightA*maxLocalWidth );
2193             const Int sendSize_RS = rowStride * recvSize_RS;
2194             const Int recvSize_SR = localHeight * localWidth;
2195 
2196             T* buffer = this->auxMemory_.Require
2197                 ( recvSize_RS + std::max(sendSize_RS,recvSize_SR) );
2198             T* firstBuf = &buffer[0];
2199             T* secondBuf = &buffer[recvSize_RS];
2200 
2201             // Pack
2202             const T* ABuffer = A.LockedBuffer();
2203             const Int ALDim = A.LDim();
2204             ELEM_OUTER_PARALLEL_FOR
2205             for( Int k=0; k<rowStride; ++k )
2206             {
2207                 T* data = &secondBuf[k*recvSize_RS];
2208                 const Int thisRowShift = Shift_( k, rowAlign, rowStride );
2209                 const Int thisLocalWidth =
2210                     Length_(width,thisRowShift,rowStride);
2211                 ELEM_INNER_PARALLEL_FOR
2212                 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2213                 {
2214                     const T* ACol =
2215                         &ABuffer[(thisRowShift+jLoc*rowStride)*ALDim];
2216                     T* dataCol = &data[jLoc*localHeightA];
2217                     MemCopy( dataCol, ACol, localHeightA );
2218                 }
2219             }
2220 
2221             // Reduce-scatter over each process row
2222             mpi::ReduceScatter
2223             ( secondBuf, firstBuf, recvSize_RS, this->RowComm() );
2224 
2225             // Trade reduced data with the appropriate process row
2226             mpi::SendRecv
2227             ( firstBuf,  localHeightA*localWidth, sendRow,
2228               secondBuf, localHeight*localWidth,  recvRow, this->ColComm() );
2229 
2230             // Update with our received data
2231             T* thisBuffer = this->Buffer();
2232             const Int thisLDim = this->LDim();
2233             ELEM_FMA_PARALLEL_FOR
2234             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2235             {
2236                 const T* secondBufCol = &secondBuf[jLoc*localHeight];
2237                 T* thisCol = &thisBuffer[jLoc*thisLDim];
2238                 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2239                     thisCol[iLoc] += alpha*secondBufCol[iLoc];
2240             }
2241             this->auxMemory_.Release();
2242         }
2243     }
2244 }
2245 
2246 template<typename T,Dist U,Dist V>
2247 void
ColSumScatterUpdate(T alpha,const DistMatrix<T,UGath,V> & A)2248 GeneralDistMatrix<T,U,V>::ColSumScatterUpdate
2249 ( T alpha, const DistMatrix<T,UGath,V>& A )
2250 {
2251     DEBUG_ONLY(
2252         CallStackEntry cse("GDM::ColSumScatterUpdate");
2253         this->AssertNotLocked();
2254         this->AssertSameGrid( A.Grid() );
2255         this->AssertSameSize( A.Height(), A.Width() );
2256     )
2257 #ifdef ELEM_VECTOR_WARNINGS
2258     if( A.Width() == 1 && this->Grid().Rank() == 0 )
2259     {
2260         std::cerr <<
2261           "The vector version of ColSumScatterUpdate does not"
2262           " yet have a vector version implemented, but it would only "
2263           "require a modification of the vector version of RowSumScatterUpdate"
2264           << std::endl;
2265     }
2266 #endif
2267 #ifdef ELEM_CACHE_WARNINGS
2268     if( A.Width() != 1 && this->Grid().Rank() == 0 )
2269     {
2270         std::cerr <<
2271           "ColSumScatterUpdate potentially causes a large "
2272           "amount of cache-thrashing. If possible, avoid it by forming the "
2273           "(conjugate-)transpose of the [* ,V] matrix instead." << std::endl;
2274     }
2275 #endif
2276     if( !this->Participating() )
2277         return;
2278 
2279     if( this->RowAlign() == A.RowAlign() )
2280     {
2281         const Int colStride = this->ColStride();
2282         const Int colAlign = this->ColAlign();
2283         const Int height = this->Height();
2284         const Int localHeight = this->LocalHeight();
2285         const Int localWidth = this->LocalWidth();
2286         const Int maxLocalHeight = MaxLength(height,colStride);
2287 
2288         const Int recvSize = mpi::Pad( maxLocalHeight*localWidth );
2289         const Int sendSize = colStride*recvSize;
2290 
2291         // Pack
2292         const T* ABuffer = A.LockedBuffer();
2293         const Int ALDim = A.LDim();
2294         T* buffer = this->auxMemory_.Require( sendSize );
2295         ELEM_OUTER_PARALLEL_FOR
2296         for( Int k=0; k<colStride; ++k )
2297         {
2298             T* data = &buffer[k*recvSize];
2299             const Int thisColShift = Shift_( k, colAlign, colStride );
2300             const Int thisLocalHeight = Length_(height,thisColShift,colStride);
2301             ELEM_INNER_PARALLEL_FOR
2302             for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2303             {
2304                 T* destCol = &data[jLoc*thisLocalHeight];
2305                 const T* sourceCol = &ABuffer[thisColShift+jLoc*ALDim];
2306                 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2307                     destCol[iLoc] = sourceCol[iLoc*colStride];
2308             }
2309         }
2310 
2311         // Communicate
2312         mpi::ReduceScatter( buffer, recvSize, this->ColComm() );
2313 
2314         // Update with our received data
2315         T* thisBuffer = this->Buffer();
2316         const Int thisLDim = this->LDim();
2317         ELEM_FMA_PARALLEL_FOR
2318         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2319         {
2320             const T* bufferCol = &buffer[jLoc*localHeight];
2321             T* thisCol = &thisBuffer[jLoc*thisLDim];
2322             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2323                 thisCol[iLoc] += alpha*bufferCol[iLoc];
2324         }
2325         this->auxMemory_.Release();
2326     }
2327     else
2328     {
2329 #ifdef ELEM_UNALIGNED_WARNINGS
2330         if( this->Grid().Rank() == 0 )
2331             std::cerr << "Unaligned ColSumScatterUpdate" << std::endl;
2332 #endif
2333         const Int colStride = this->ColStride();
2334         const Int rowStride = this->RowStride();
2335         const Int rowRank = this->RowRank();
2336 
2337         const Int colAlign = this->ColAlign();
2338         const Int rowAlign = this->RowAlign();
2339         const Int rowAlignA = A.RowAlign();
2340         const Int sendCol = (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
2341         const Int recvCol = (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
2342 
2343         const Int height = this->Height();
2344         const Int localHeight = this->LocalHeight();
2345         const Int localWidth = this->LocalWidth();
2346         const Int localWidthA = A.LocalWidth();
2347         const Int maxLocalHeight = MaxLength(height,colStride);
2348 
2349         const Int recvSize_RS = mpi::Pad( maxLocalHeight*localWidthA );
2350         const Int sendSize_RS = colStride * recvSize_RS;
2351         const Int recvSize_SR = localHeight * localWidth;
2352 
2353         T* buffer = this->auxMemory_.Require
2354             ( recvSize_RS + std::max(sendSize_RS,recvSize_SR) );
2355         T* firstBuf = &buffer[0];
2356         T* secondBuf = &buffer[recvSize_RS];
2357 
2358         // Pack
2359         const T* ABuffer = A.LockedBuffer();
2360         const Int ALDim = A.LDim();
2361         ELEM_OUTER_PARALLEL_FOR
2362         for( Int k=0; k<colStride; ++k )
2363         {
2364             T* data = &secondBuf[k*recvSize_RS];
2365             const Int thisColShift = Shift_( k, colAlign, colStride );
2366             const Int thisLocalHeight = Length_(height,thisColShift,colStride);
2367             ELEM_INNER_PARALLEL_FOR
2368             for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
2369             {
2370                 T* destCol = &data[jLoc*thisLocalHeight];
2371                 const T* sourceCol = &ABuffer[thisColShift+jLoc*ALDim];
2372                 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2373                     destCol[iLoc] = sourceCol[iLoc*colStride];
2374             }
2375         }
2376 
2377         // Reduce-scatter over each col
2378         mpi::ReduceScatter( secondBuf, firstBuf, recvSize_RS, this->ColComm() );
2379 
2380         // Trade reduced data with the appropriate col
2381         mpi::SendRecv
2382         ( firstBuf,  localHeight*localWidthA, sendCol,
2383           secondBuf, localHeight*localWidth,  recvCol, this->RowComm() );
2384 
2385         // Update with our received data
2386         T* thisBuffer = this->Buffer();
2387         const Int thisLDim = this->LDim();
2388         ELEM_FMA_PARALLEL_FOR
2389         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2390         {
2391             const T* secondBufCol = &secondBuf[jLoc*localHeight];
2392             T* thisCol = &thisBuffer[jLoc*thisLDim];
2393             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2394                 thisCol[iLoc] += alpha*secondBufCol[iLoc];
2395         }
2396         this->auxMemory_.Release();
2397     }
2398 }
2399 
2400 template<typename T,Dist U,Dist V>
2401 void
SumScatterUpdate(T alpha,const DistMatrix<T,UGath,VGath> & A)2402 GeneralDistMatrix<T,U,V>::SumScatterUpdate
2403 ( T alpha, const DistMatrix<T,UGath,VGath>& A )
2404 {
2405     DEBUG_ONLY(
2406         CallStackEntry cse("GDM::SumScatterUpdate");
2407         this->AssertNotLocked();
2408         this->AssertSameGrid( A.Grid() );
2409         this->AssertSameSize( A.Height(), A.Width() );
2410     )
2411     if( !this->Participating() )
2412         return;
2413 
2414     const Int colStride = this->ColStride();
2415     const Int rowStride = this->RowStride();
2416     const Int colAlign = this->ColAlign();
2417     const Int rowAlign = this->RowAlign();
2418 
2419     const Int height = this->Height();
2420     const Int width = this->Width();
2421     const Int localHeight = this->LocalHeight();
2422     const Int localWidth = this->LocalWidth();
2423     const Int maxLocalHeight = MaxLength(height,colStride);
2424     const Int maxLocalWidth = MaxLength(width,rowStride);
2425 
2426     const Int recvSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
2427     const Int sendSize = colStride*rowStride*recvSize;
2428 
2429     // Pack
2430     const T* ABuffer = A.LockedBuffer();
2431     const Int ALDim = A.LDim();
2432     T* buffer = this->auxMemory_.Require( sendSize );
2433     ELEM_OUTER_PARALLEL_FOR
2434     for( Int l=0; l<rowStride; ++l )
2435     {
2436         const Int thisRowShift = Shift_( l, rowAlign, rowStride );
2437         const Int thisLocalWidth = Length_( width, thisRowShift, rowStride );
2438         for( Int k=0; k<colStride; ++k )
2439         {
2440             T* data = &buffer[(k+l*colStride)*recvSize];
2441             const Int thisColShift = Shift_( k, colAlign, colStride );
2442             const Int thisLocalHeight = Length_(height,thisColShift,colStride);
2443             ELEM_INNER_PARALLEL_FOR
2444             for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2445             {
2446                 T* destCol = &data[jLoc*thisLocalHeight];
2447                 const T* sourceCol =
2448                     &ABuffer[thisColShift+(thisRowShift+jLoc*rowStride)*ALDim];
2449                 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2450                     destCol[iLoc] = sourceCol[iLoc*colStride];
2451             }
2452         }
2453     }
2454 
2455     // Communicate
2456     mpi::ReduceScatter( buffer, recvSize, this->DistComm() );
2457 
2458     // Unpack our received data
2459     T* thisBuffer = this->Buffer();
2460     const Int thisLDim = this->LDim();
2461     ELEM_FMA_PARALLEL_FOR
2462     for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2463     {
2464         const T* bufferCol = &buffer[jLoc*localHeight];
2465         T* thisCol = &thisBuffer[jLoc*thisLDim];
2466         for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2467             thisCol[iLoc] += alpha*bufferCol[iLoc];
2468     }
2469     this->auxMemory_.Release();
2470 }
2471 
2472 template<typename T,Dist U,Dist V>
2473 void
PartialRowSumScatterUpdate(T alpha,const DistMatrix<T,U,VPart> & A)2474 GeneralDistMatrix<T,U,V>::PartialRowSumScatterUpdate
2475 ( T alpha, const DistMatrix<T,U,VPart>& A )
2476 {
2477     DEBUG_ONLY(
2478         CallStackEntry cse("GDM::PartialRowSumScatterUpdate");
2479         this->AssertNotLocked();
2480         this->AssertSameGrid( A.Grid() );
2481         this->AssertSameSize( A.Height(), A.Width() );
2482     )
2483     if( !this->Participating() )
2484         return;
2485 
2486     if( this->RowAlign() % A.RowStride() == A.RowAlign() )
2487     {
2488         const Int rowStride = this->RowStride();
2489         const Int rowStridePart = this->PartialRowStride();
2490         const Int rowStrideUnion = this->PartialUnionRowStride();
2491         const Int rowRankPart = this->PartialRowRank();
2492         const Int rowAlign = this->RowAlign();
2493         const Int rowShiftOfA = A.RowShift();
2494 
2495         const Int height = this->Height();
2496         const Int width = this->Width();
2497         const Int localWidth = this->LocalWidth();
2498         const Int maxLocalWidth = MaxLength( width, rowStride );
2499         const Int recvSize = mpi::Pad( height*maxLocalWidth );
2500         const Int sendSize = rowStrideUnion*recvSize;
2501 
2502         // Pack
2503         const T* ABuf = A.LockedBuffer();
2504         const Int ALDim = A.LDim();
2505         T* buffer = this->auxMemory_.Require( sendSize );
2506         ELEM_OUTER_PARALLEL_FOR
2507         for( Int k=0; k<rowStrideUnion; ++k )
2508         {
2509             T* data = &buffer[k*recvSize];
2510             const Int thisRank = rowRankPart+k*rowStridePart;
2511             const Int thisRowShift = Shift_( thisRank, rowAlign, rowStride );
2512             const Int thisRowOffset =
2513                 (thisRowShift-rowShiftOfA) / rowStridePart;
2514             const Int thisLocalWidth =
2515                 Length_( width, thisRowShift, rowStride );
2516             ELEM_INNER_PARALLEL_FOR
2517             for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2518             {
2519                 const T* ACol =
2520                     &ABuf[(thisRowOffset+jLoc*rowStrideUnion)*ALDim];
2521                 T* dataCol = &data[jLoc*height];
2522                 MemCopy( dataCol, ACol, height );
2523             }
2524         }
2525 
2526         // Communicate
2527         mpi::ReduceScatter( buffer, recvSize, this->PartialUnionRowComm() );
2528 
2529         // Unpack our received data
2530         T* thisBuf = this->Buffer();
2531         const Int thisLDim = this->LDim();
2532         ELEM_PARALLEL_FOR
2533         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2534         {
2535             const T* bufferCol = &buffer[jLoc*height];
2536             T* thisCol = &thisBuf[jLoc*thisLDim];
2537             for( Int i=0; i<height; ++i )
2538                 thisCol[i] += alpha*bufferCol[i];
2539         }
2540         this->auxMemory_.Release();
2541     }
2542     else
2543     {
2544         LogicError("Unaligned PartialRowSumScatterUpdate not implemented");
2545     }
2546 }
2547 
2548 template<typename T,Dist U,Dist V>
2549 void
PartialColSumScatterUpdate(T alpha,const DistMatrix<T,UPart,V> & A)2550 GeneralDistMatrix<T,U,V>::PartialColSumScatterUpdate
2551 ( T alpha, const DistMatrix<T,UPart,V>& A )
2552 {
2553     DEBUG_ONLY(
2554         CallStackEntry cse("GDM::PartialColSumScatterUpdate");
2555         this->AssertNotLocked();
2556         this->AssertSameGrid( A.Grid() );
2557         this->AssertSameSize( A.Height(), A.Width() );
2558     )
2559     if( !this->Participating() )
2560         return;
2561 
2562 #ifdef ELEM_CACHE_WARNINGS
2563     if( A.Width() != 1 && A.Grid().Rank() == 0 )
2564     {
2565         std::cerr <<
2566           "PartialColSumScatterUpdate potentially causes a large amount"
2567           " of cache-thrashing. If possible, avoid it by forming the "
2568           "(conjugate-)transpose of the [UGath,* ] matrix instead."
2569           << std::endl;
2570     }
2571 #endif
2572     if( this->ColAlign() % A.ColStride() == A.ColAlign() )
2573     {
2574         const Int colStride = this->ColStride();
2575         const Int colStridePart = this->PartialColStride();
2576         const Int colStrideUnion = this->PartialUnionColStride();
2577         const Int colRankPart = this->PartialColRank();
2578         const Int colAlign = this->ColAlign();
2579         const Int colShiftOfA = A.ColShift();
2580 
2581         const Int height = this->Height();
2582         const Int width = this->Width();
2583         const Int localHeight = this->LocalHeight();
2584         const Int maxLocalHeight = MaxLength( height, colStride );
2585         const Int recvSize = mpi::Pad( maxLocalHeight*width );
2586         const Int sendSize = colStrideUnion*recvSize;
2587 
2588         T* buffer = this->auxMemory_.Require( sendSize );
2589 
2590         // Pack
2591         const Int ALDim = A.LDim();
2592         const T* ABuf = A.LockedBuffer();
2593         ELEM_OUTER_PARALLEL_FOR
2594         for( Int k=0; k<colStrideUnion; ++k )
2595         {
2596             T* data = &buffer[k*recvSize];
2597             const Int thisRank = colRankPart+k*colStridePart;
2598             const Int thisColShift = Shift_( thisRank, colAlign, colStride );
2599             const Int thisColOffset =
2600                 (thisColShift-colShiftOfA) / colStridePart;
2601             const Int thisLocalHeight =
2602                 Length_( height, thisColShift, colStride );
2603             ELEM_INNER_PARALLEL_FOR
2604             for( Int j=0; j<width; ++j )
2605             {
2606                 T* destCol = &data[j*thisLocalHeight];
2607                 const T* sourceCol = &ABuf[thisColOffset+j*ALDim];
2608                 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2609                     destCol[iLoc] = sourceCol[iLoc*colStrideUnion];
2610             }
2611         }
2612 
2613         // Communicate
2614         mpi::ReduceScatter( buffer, recvSize, this->PartialUnionColComm() );
2615 
2616         // Unpack our received data
2617         T* thisBuf = this->Buffer();
2618         const Int thisLDim = this->LDim();
2619         ELEM_PARALLEL_FOR
2620         for( Int j=0; j<width; ++j )
2621         {
2622             const T* bufferCol = &buffer[j*localHeight];
2623             T* thisCol = &thisBuf[j*thisLDim];
2624             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2625                 thisCol[iLoc] += alpha*bufferCol[iLoc];
2626         }
2627         this->auxMemory_.Release();
2628     }
2629     else
2630     {
2631         LogicError("Unaligned PartialColSumScatterUpdate not implemented");
2632     }
2633 }
2634 
2635 template<typename T,Dist U,Dist V>
2636 void
TransposeColAllGather(DistMatrix<T,V,UGath> & A,bool conjugate) const2637 GeneralDistMatrix<T,U,V>::TransposeColAllGather
2638 ( DistMatrix<T,V,UGath>& A, bool conjugate ) const
2639 {
2640     DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColAllGather"))
2641     DistMatrix<T,V,U> ATrans( this->Grid() );
2642     ATrans.AlignWith( *this );
2643     ATrans.Resize( this->Width(), this->Height() );
2644     Transpose( this->LockedMatrix(), ATrans.Matrix(), conjugate );
2645     ATrans.RowAllGather( A );
2646 }
2647 
2648 template<typename T,Dist U,Dist V>
2649 void
TransposePartialColAllGather(DistMatrix<T,V,UPart> & A,bool conjugate) const2650 GeneralDistMatrix<T,U,V>::TransposePartialColAllGather
2651 ( DistMatrix<T,V,UPart>& A, bool conjugate ) const
2652 {
2653     DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColAllGather"))
2654     DistMatrix<T,V,U> ATrans( this->Grid() );
2655     ATrans.AlignWith( *this );
2656     ATrans.Resize( this->Width(), this->Height() );
2657     Transpose( this->LockedMatrix(), ATrans.Matrix(), conjugate );
2658     ATrans.PartialRowAllGather( A );
2659 }
2660 
2661 template<typename T,Dist U,Dist V>
2662 void
AdjointColAllGather(DistMatrix<T,V,UGath> & A) const2663 GeneralDistMatrix<T,U,V>::AdjointColAllGather( DistMatrix<T,V,UGath>& A ) const
2664 {
2665     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointRowAllGather"))
2666     this->TransposeColAllGather( A, true );
2667 }
2668 
2669 template<typename T,Dist U,Dist V>
2670 void
AdjointPartialColAllGather(DistMatrix<T,V,UPart> & A) const2671 GeneralDistMatrix<T,U,V>::AdjointPartialColAllGather
2672 ( DistMatrix<T,V,UPart>& A ) const
2673 {
2674     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColAllGather"))
2675     this->TransposePartialColAllGather( A, true );
2676 }
2677 
2678 template<typename T,Dist U,Dist V>
2679 void
TransposeColFilterFrom(const DistMatrix<T,V,UGath> & A,bool conjugate)2680 GeneralDistMatrix<T,U,V>::TransposeColFilterFrom
2681 ( const DistMatrix<T,V,UGath>& A, bool conjugate )
2682 {
2683     DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColFilterFrom"))
2684     DistMatrix<T,V,U> AFilt( A.Grid() );
2685     if( this->ColConstrained() )
2686         AFilt.AlignRowsWith( *this, false );
2687     if( this->RowConstrained() )
2688         AFilt.AlignColsWith( *this, false );
2689     AFilt.RowFilterFrom( A );
2690     if( !this->ColConstrained() )
2691         this->AlignColsWith( AFilt, false );
2692     if( !this->RowConstrained() )
2693         this->AlignRowsWith( AFilt, false );
2694     this->Resize( A.Width(), A.Height() );
2695     Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2696 }
2697 
2698 template<typename T,Dist U,Dist V>
2699 void
TransposeRowFilterFrom(const DistMatrix<T,VGath,U> & A,bool conjugate)2700 GeneralDistMatrix<T,U,V>::TransposeRowFilterFrom
2701 ( const DistMatrix<T,VGath,U>& A, bool conjugate )
2702 {
2703     DEBUG_ONLY(CallStackEntry cse("GDM::TransposeRowFilterFrom"))
2704     DistMatrix<T,V,U> AFilt( A.Grid() );
2705     if( this->ColConstrained() )
2706         AFilt.AlignRowsWith( *this, false );
2707     if( this->RowConstrained() )
2708         AFilt.AlignColsWith( *this, false );
2709     AFilt.ColFilterFrom( A );
2710     if( !this->ColConstrained() )
2711         this->AlignColsWith( AFilt, false );
2712     if( !this->RowConstrained() )
2713         this->AlignRowsWith( AFilt, false );
2714     this->Resize( A.Width(), A.Height() );
2715     Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2716 }
2717 
2718 template<typename T,Dist U,Dist V>
2719 void
TransposePartialColFilterFrom(const DistMatrix<T,V,UPart> & A,bool conjugate)2720 GeneralDistMatrix<T,U,V>::TransposePartialColFilterFrom
2721 ( const DistMatrix<T,V,UPart>& A, bool conjugate )
2722 {
2723     DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColFilterFrom"))
2724     DistMatrix<T,V,U> AFilt( A.Grid() );
2725     if( this->ColConstrained() )
2726         AFilt.AlignRowsWith( *this, false );
2727     if( this->RowConstrained() )
2728         AFilt.AlignColsWith( *this, false );
2729     AFilt.PartialRowFilterFrom( A );
2730     if( !this->ColConstrained() )
2731         this->AlignColsWith( AFilt, false );
2732     if( !this->RowConstrained() )
2733         this->AlignRowsWith( AFilt, false );
2734     this->Resize( A.Width(), A.Height() );
2735     Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2736 }
2737 
2738 template<typename T,Dist U,Dist V>
2739 void
TransposePartialRowFilterFrom(const DistMatrix<T,VPart,U> & A,bool conjugate)2740 GeneralDistMatrix<T,U,V>::TransposePartialRowFilterFrom
2741 ( const DistMatrix<T,VPart,U>& A, bool conjugate )
2742 {
2743     DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialRowFilterFrom"))
2744     DistMatrix<T,V,U> AFilt( A.Grid() );
2745     if( this->ColConstrained() )
2746         AFilt.AlignRowsWith( *this, false );
2747     if( this->RowConstrained() )
2748         AFilt.AlignColsWith( *this, false );
2749     AFilt.PartialColFilterFrom( A );
2750     if( !this->ColConstrained() )
2751         this->AlignColsWith( AFilt, false );
2752     if( !this->RowConstrained() )
2753         this->AlignRowsWith( AFilt, false );
2754     this->Resize( A.Width(), A.Height() );
2755     Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2756 }
2757 
2758 template<typename T,Dist U,Dist V>
2759 void
AdjointColFilterFrom(const DistMatrix<T,V,UGath> & A)2760 GeneralDistMatrix<T,U,V>::AdjointColFilterFrom( const DistMatrix<T,V,UGath>& A )
2761 {
2762     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointColFilterFrom"))
2763     this->TransposeColFilterFrom( A, true );
2764 }
2765 
2766 template<typename T,Dist U,Dist V>
2767 void
AdjointRowFilterFrom(const DistMatrix<T,VGath,U> & A)2768 GeneralDistMatrix<T,U,V>::AdjointRowFilterFrom( const DistMatrix<T,VGath,U>& A )
2769 {
2770     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointRowFilterFrom"))
2771     this->TransposeRowFilterFrom( A, true );
2772 }
2773 
2774 template<typename T,Dist U,Dist V>
2775 void
AdjointPartialColFilterFrom(const DistMatrix<T,V,UPart> & A)2776 GeneralDistMatrix<T,U,V>::AdjointPartialColFilterFrom
2777 ( const DistMatrix<T,V,UPart>& A )
2778 {
2779     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColFilterFrom"))
2780     this->TransposePartialColFilterFrom( A, true );
2781 }
2782 
2783 template<typename T,Dist U,Dist V>
2784 void
AdjointPartialRowFilterFrom(const DistMatrix<T,VPart,U> & A)2785 GeneralDistMatrix<T,U,V>::AdjointPartialRowFilterFrom
2786 ( const DistMatrix<T,VPart,U>& A )
2787 {
2788     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialRowFilterFrom"))
2789     this->TransposePartialRowFilterFrom( A, true );
2790 }
2791 
2792 template<typename T,Dist U,Dist V>
2793 void
TransposeColSumScatterFrom(const DistMatrix<T,V,UGath> & A,bool conjugate)2794 GeneralDistMatrix<T,U,V>::TransposeColSumScatterFrom
2795 ( const DistMatrix<T,V,UGath>& A, bool conjugate )
2796 {
2797     DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColSumScatterFrom"))
2798     DistMatrix<T,V,U> ASumFilt( A.Grid() );
2799     if( this->ColConstrained() )
2800         ASumFilt.AlignRowsWith( *this, false );
2801     if( this->RowConstrained() )
2802         ASumFilt.AlignColsWith( *this, false );
2803     ASumFilt.RowSumScatterFrom( A );
2804     if( !this->ColConstrained() )
2805         this->AlignColsWith( ASumFilt, false );
2806     if( !this->RowConstrained() )
2807         this->AlignRowsWith( ASumFilt, false );
2808     this->Resize( A.Width(), A.Height() );
2809     Transpose( ASumFilt.LockedMatrix(), this->Matrix(), conjugate );
2810 }
2811 
2812 template<typename T,Dist U,Dist V>
2813 void
TransposePartialColSumScatterFrom(const DistMatrix<T,V,UPart> & A,bool conjugate)2814 GeneralDistMatrix<T,U,V>::TransposePartialColSumScatterFrom
2815 ( const DistMatrix<T,V,UPart>& A, bool conjugate )
2816 {
2817     DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColSumScatterFrom"))
2818     DistMatrix<T,V,U> ASumFilt( A.Grid() );
2819     if( this->ColConstrained() )
2820         ASumFilt.AlignRowsWith( *this, false );
2821     if( this->RowConstrained() )
2822         ASumFilt.AlignColsWith( *this, false );
2823     ASumFilt.PartialRowSumScatterFrom( A );
2824     if( !this->ColConstrained() )
2825         this->AlignColsWith( ASumFilt, false );
2826     if( !this->RowConstrained() )
2827         this->AlignRowsWith( ASumFilt, false );
2828     this->Resize( A.Width(), A.Height() );
2829     Transpose( ASumFilt.LockedMatrix(), this->Matrix(), conjugate );
2830 }
2831 
2832 template<typename T,Dist U,Dist V>
2833 void
AdjointColSumScatterFrom(const DistMatrix<T,V,UGath> & A)2834 GeneralDistMatrix<T,U,V>::AdjointColSumScatterFrom
2835 ( const DistMatrix<T,V,UGath>& A )
2836 {
2837     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointColSumScatterFrom"))
2838     this->TransposeColSumScatterFrom( A, true );
2839 }
2840 
2841 template<typename T,Dist U,Dist V>
2842 void
AdjointPartialColSumScatterFrom(const DistMatrix<T,V,UPart> & A)2843 GeneralDistMatrix<T,U,V>::AdjointPartialColSumScatterFrom
2844 ( const DistMatrix<T,V,UPart>& A )
2845 {
2846     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColSumScatterFrom"))
2847     this->TransposePartialColSumScatterFrom( A, true );
2848 }
2849 
2850 template<typename T,Dist U,Dist V>
2851 void
TransposeColSumScatterUpdate(T alpha,const DistMatrix<T,V,UGath> & A,bool conjugate)2852 GeneralDistMatrix<T,U,V>::TransposeColSumScatterUpdate
2853 ( T alpha, const DistMatrix<T,V,UGath>& A, bool conjugate )
2854 {
2855     DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColSumScatterUpdate"))
2856     DistMatrix<T,V,U> ASumFilt( A.Grid() );
2857     if( this->ColConstrained() )
2858         ASumFilt.AlignRowsWith( *this, false );
2859     if( this->RowConstrained() )
2860         ASumFilt.AlignColsWith( *this, false );
2861     ASumFilt.RowSumScatterFrom( A );
2862     if( !this->ColConstrained() )
2863         this->AlignColsWith( ASumFilt, false );
2864     if( !this->RowConstrained() )
2865         this->AlignRowsWith( ASumFilt, false );
2866     // ALoc += alpha ASumFiltLoc'
2867     elem::Matrix<T>& ALoc = this->Matrix();
2868     const elem::Matrix<T>& BLoc = ASumFilt.LockedMatrix();
2869     const Int localHeight = ALoc.Height();
2870     const Int localWidth = ALoc.Width();
2871     if( conjugate )
2872     {
2873         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2874             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2875                 ALoc.Update( iLoc, jLoc, alpha*Conj(BLoc.Get(jLoc,iLoc)) );
2876     }
2877     else
2878     {
2879         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2880             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2881                 ALoc.Update( iLoc, jLoc, alpha*BLoc.Get(jLoc,iLoc) );
2882     }
2883 }
2884 
2885 template<typename T,Dist U,Dist V>
2886 void
TransposePartialColSumScatterUpdate(T alpha,const DistMatrix<T,V,UPart> & A,bool conjugate)2887 GeneralDistMatrix<T,U,V>::TransposePartialColSumScatterUpdate
2888 ( T alpha, const DistMatrix<T,V,UPart>& A, bool conjugate )
2889 {
2890     DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColSumScatterUpdate"))
2891     DistMatrix<T,V,U> ASumFilt( A.Grid() );
2892     if( this->ColConstrained() )
2893         ASumFilt.AlignRowsWith( *this, false );
2894     if( this->RowConstrained() )
2895         ASumFilt.AlignColsWith( *this, false );
2896     ASumFilt.PartialRowSumScatterFrom( A );
2897     if( !this->ColConstrained() )
2898         this->AlignColsWith( ASumFilt, false );
2899     if( !this->RowConstrained() )
2900         this->AlignRowsWith( ASumFilt, false );
2901     // ALoc += alpha ASumFiltLoc'
2902     elem::Matrix<T>& ALoc = this->Matrix();
2903     const elem::Matrix<T>& BLoc = ASumFilt.LockedMatrix();
2904     const Int localHeight = ALoc.Height();
2905     const Int localWidth = ALoc.Width();
2906     if( conjugate )
2907     {
2908         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2909             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2910                 ALoc.Update( iLoc, jLoc, alpha*Conj(BLoc.Get(jLoc,iLoc)) );
2911     }
2912     else
2913     {
2914         for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2915             for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2916                 ALoc.Update( iLoc, jLoc, alpha*BLoc.Get(jLoc,iLoc) );
2917     }
2918 }
2919 
2920 template<typename T,Dist U,Dist V>
2921 void
AdjointColSumScatterUpdate(T alpha,const DistMatrix<T,V,UGath> & A)2922 GeneralDistMatrix<T,U,V>::AdjointColSumScatterUpdate
2923 ( T alpha, const DistMatrix<T,V,UGath>& A )
2924 {
2925     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointColSumScatterUpdate"))
2926     this->TransposeColSumScatterUpdate( alpha, A, true );
2927 }
2928 
2929 template<typename T,Dist U,Dist V>
2930 void
AdjointPartialColSumScatterUpdate(T alpha,const DistMatrix<T,V,UPart> & A)2931 GeneralDistMatrix<T,U,V>::AdjointPartialColSumScatterUpdate
2932 ( T alpha, const DistMatrix<T,V,UPart>& A )
2933 {
2934     DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColSumScatterUpdate"))
2935     this->TransposePartialColSumScatterUpdate( alpha, A, true );
2936 }
2937 
2938 // Diagonal manipulation
2939 // =====================
2940 template<typename T,Dist U,Dist V>
2941 bool
DiagonalAlignedWith(const elem::DistData & d,Int offset) const2942 GeneralDistMatrix<T,U,V>::DiagonalAlignedWith
2943 ( const elem::DistData& d, Int offset ) const
2944 {
2945     DEBUG_ONLY(CallStackEntry cse("GDM::DiagonalAlignedWith"))
2946     if( this->Grid() != *d.grid )
2947         return false;
2948 
2949     const Int diagRoot = this->DiagonalRoot(offset);
2950     if( diagRoot != d.root )
2951         return false;
2952 
2953     const Int diagAlign = this->DiagonalAlign(offset);
2954     if( d.colDist == UDiag && d.rowDist == VDiag )
2955         return d.colAlign == diagAlign;
2956     else if( d.colDist == VDiag && d.rowDist == UDiag )
2957         return d.rowAlign == diagAlign;
2958     else
2959         return false;
2960 }
2961 
2962 template<typename T,Dist U,Dist V>
2963 Int
DiagonalRoot(Int offset) const2964 GeneralDistMatrix<T,U,V>::DiagonalRoot( Int offset ) const
2965 {
2966     DEBUG_ONLY(CallStackEntry cse("GDM::DiagonalRoot"))
2967     const elem::Grid& grid = this->Grid();
2968 
2969     if( U == MC && V == MR )
2970     {
2971         // Result is an [MD,* ] or [* ,MD]
2972         Int owner;
2973         if( offset >= 0 )
2974         {
2975             const Int procRow = this->ColAlign();
2976             const Int procCol = (this->RowAlign()+offset) % this->RowStride();
2977             owner = procRow + this->ColStride()*procCol;
2978         }
2979         else
2980         {
2981             const Int procRow = (this->ColAlign()-offset) % this->ColStride();
2982             const Int procCol = this->RowAlign();
2983             owner = procRow + this->ColStride()*procCol;
2984         }
2985         return grid.DiagPath(owner);
2986     }
2987     else if( U == MR && V == MC )
2988     {
2989         // Result is an [MD,* ] or [* ,MD]
2990         Int owner;
2991         if( offset >= 0 )
2992         {
2993             const Int procCol = this->ColAlign();
2994             const Int procRow = (this->RowAlign()+offset) % this->RowStride();
2995             owner = procRow + this->ColStride()*procCol;
2996         }
2997         else
2998         {
2999             const Int procCol = (this->ColAlign()-offset) % this->ColStride();
3000             const Int procRow = this->RowAlign();
3001             owner = procRow + this->ColStride()*procCol;
3002         }
3003         return grid.DiagPath(owner);
3004     }
3005     else
3006         return this->Root();
3007 }
3008 
3009 template<typename T,Dist U,Dist V>
3010 Int
DiagonalAlign(Int offset) const3011 GeneralDistMatrix<T,U,V>::DiagonalAlign( Int offset ) const
3012 {
3013     DEBUG_ONLY(CallStackEntry cse("GDM::DiagonalAlign"))
3014     const elem::Grid& grid = this->Grid();
3015 
3016     if( U == MC && V == MR )
3017     {
3018         // Result is an [MD,* ] or [* ,MD]
3019         Int owner;
3020         if( offset >= 0 )
3021         {
3022             const Int procRow = this->ColAlign();
3023             const Int procCol = (this->RowAlign()+offset) % this->RowStride();
3024             owner = procRow + this->ColStride()*procCol;
3025         }
3026         else
3027         {
3028             const Int procRow = (this->ColAlign()-offset) % this->ColStride();
3029             const Int procCol = this->RowAlign();
3030             owner = procRow + this->ColStride()*procCol;
3031         }
3032         return grid.DiagPathRank(owner);
3033     }
3034     else if( U == MR && V == MC )
3035     {
3036         // Result is an [MD,* ] or [* ,MD]
3037         Int owner;
3038         if( offset >= 0 )
3039         {
3040             const Int procCol = this->ColAlign();
3041             const Int procRow = (this->RowAlign()+offset) % this->RowStride();
3042             owner = procRow + this->ColStride()*procCol;
3043         }
3044         else
3045         {
3046             const Int procCol = (this->ColAlign()-offset) % this->ColStride();
3047             const Int procRow = this->RowAlign();
3048             owner = procRow + this->ColStride()*procCol;
3049         }
3050         return grid.DiagPathRank(owner);
3051     }
3052     else if( U == STAR )
3053     {
3054         // Result is a [V,* ] or [* ,V]
3055         if( offset >= 0 )
3056             return (this->RowAlign()+offset) % this->RowStride();
3057         else
3058             return this->RowAlign();
3059     }
3060     else
3061     {
3062         // Result is [U,V] or [V,U], where V is either STAR or CIRC
3063         if( offset >= 0 )
3064             return this->ColAlign();
3065         else
3066             return (this->ColAlign()-offset) % this->ColStride();
3067     }
3068 }
3069 
3070 template<typename T,Dist U,Dist V>
3071 void
GetDiagonal(DistMatrix<T,UDiag,VDiag> & d,Int offset) const3072 GeneralDistMatrix<T,U,V>::GetDiagonal
3073 ( DistMatrix<T,UDiag,VDiag>& d, Int offset ) const
3074 {
3075     DEBUG_ONLY(CallStackEntry cse("GDM::GetDiagonal"))
3076     this->GetDiagonalHelper
3077     ( d, offset, []( T& alpha, T beta ) { alpha = beta; } );
3078 }
3079 
3080 template<typename T,Dist U,Dist V>
3081 void
GetRealPartOfDiagonal(DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset) const3082 GeneralDistMatrix<T,U,V>::GetRealPartOfDiagonal
3083 ( DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset ) const
3084 {
3085     DEBUG_ONLY(CallStackEntry cse("GDM::GetRealPartOfDiagonal"))
3086     this->GetDiagonalHelper
3087     ( d, offset, []( Base<T>& alpha, T beta ) { alpha = RealPart(beta); } );
3088 }
3089 
3090 template<typename T,Dist U,Dist V>
3091 void
GetImagPartOfDiagonal(DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset) const3092 GeneralDistMatrix<T,U,V>::GetImagPartOfDiagonal
3093 ( DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset ) const
3094 {
3095     DEBUG_ONLY(CallStackEntry cse("GDM::GetImagPartOfDiagonal"))
3096     this->GetDiagonalHelper
3097     ( d, offset, []( Base<T>& alpha, T beta ) { alpha = ImagPart(beta); } );
3098 }
3099 
3100 template<typename T,Dist U,Dist V>
3101 auto
GetDiagonal(Int offset) const3102 GeneralDistMatrix<T,U,V>::GetDiagonal( Int offset ) const
3103 -> DistMatrix<T,UDiag,VDiag>
3104 {
3105     DistMatrix<T,UDiag,VDiag> d( this->Grid() );
3106     GetDiagonal( d, offset );
3107     return d;
3108 }
3109 
3110 template<typename T,Dist U,Dist V>
3111 auto
GetRealPartOfDiagonal(Int offset) const3112 GeneralDistMatrix<T,U,V>::GetRealPartOfDiagonal( Int offset ) const
3113 -> DistMatrix<Base<T>,UDiag,VDiag>
3114 {
3115     DistMatrix<Base<T>,UDiag,VDiag> d( this->Grid() );
3116     GetRealPartOfDiagonal( d, offset );
3117     return d;
3118 }
3119 
3120 template<typename T,Dist U,Dist V>
3121 auto
GetImagPartOfDiagonal(Int offset) const3122 GeneralDistMatrix<T,U,V>::GetImagPartOfDiagonal( Int offset ) const
3123 -> DistMatrix<Base<T>,UDiag,VDiag>
3124 {
3125     DistMatrix<Base<T>,UDiag,VDiag> d( this->Grid() );
3126     GetImagPartOfDiagonal( d, offset );
3127     return d;
3128 }
3129 
3130 template<typename T,Dist U,Dist V>
3131 void
SetDiagonal(const DistMatrix<T,UDiag,VDiag> & d,Int offset)3132 GeneralDistMatrix<T,U,V>::SetDiagonal
3133 ( const DistMatrix<T,UDiag,VDiag>& d, Int offset )
3134 {
3135     DEBUG_ONLY(CallStackEntry cse("GDM::SetDiagonal"))
3136     this->SetDiagonalHelper
3137     ( d, offset, []( T& alpha, T beta ) { alpha = beta; } );
3138 }
3139 
3140 template<typename T,Dist U,Dist V>
3141 void
SetRealPartOfDiagonal(const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3142 GeneralDistMatrix<T,U,V>::SetRealPartOfDiagonal
3143 ( const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3144 {
3145     DEBUG_ONLY(CallStackEntry cse("GDM::SetRealPartOfDiagonal"))
3146     this->SetDiagonalHelper
3147     ( d, offset,
3148       []( T& alpha, Base<T> beta ) { elem::SetRealPart(alpha,beta); } );
3149 }
3150 
3151 template<typename T,Dist U,Dist V>
3152 void
SetImagPartOfDiagonal(const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3153 GeneralDistMatrix<T,U,V>::SetImagPartOfDiagonal
3154 ( const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3155 {
3156     DEBUG_ONLY(CallStackEntry cse("GDM::SetImagPartOfDiagonal"))
3157     this->SetDiagonalHelper
3158     ( d, offset,
3159       []( T& alpha, Base<T> beta ) { elem::SetImagPart(alpha,beta); } );
3160 }
3161 
3162 template<typename T,Dist U,Dist V>
3163 void
UpdateDiagonal(T gamma,const DistMatrix<T,UDiag,VDiag> & d,Int offset)3164 GeneralDistMatrix<T,U,V>::UpdateDiagonal
3165 ( T gamma, const DistMatrix<T,UDiag,VDiag>& d, Int offset )
3166 {
3167     DEBUG_ONLY(CallStackEntry cse("GDM::UpdateDiagonal"))
3168     this->SetDiagonalHelper
3169     ( d, offset, [gamma]( T& alpha, T beta ) { alpha += gamma*beta; } );
3170 }
3171 
3172 template<typename T,Dist U,Dist V>
3173 void
UpdateRealPartOfDiagonal(Base<T> gamma,const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3174 GeneralDistMatrix<T,U,V>::UpdateRealPartOfDiagonal
3175 ( Base<T> gamma, const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3176 {
3177     DEBUG_ONLY(CallStackEntry cse("GDM::UpdateRealPartOfDiagonal"))
3178     this->SetDiagonalHelper
3179     ( d, offset,
3180       [gamma]( T& alpha, Base<T> beta )
3181       { elem::UpdateRealPart(alpha,gamma*beta); } );
3182 }
3183 
3184 template<typename T,Dist U,Dist V>
3185 void
UpdateImagPartOfDiagonal(Base<T> gamma,const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3186 GeneralDistMatrix<T,U,V>::UpdateImagPartOfDiagonal
3187 ( Base<T> gamma, const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3188 {
3189     DEBUG_ONLY(CallStackEntry cse("GDM::UpdateImagPartOfDiagonal"))
3190     this->SetDiagonalHelper
3191     ( d, offset,
3192       [gamma]( T& alpha, Base<T> beta )
3193       { elem::UpdateImagPart(alpha,gamma*beta); } );
3194 }
3195 
3196 // Private section
3197 // ###############
3198 
3199 // Diagonal helper functions
3200 // =========================
3201 template<typename T,Dist U,Dist V>
3202 template<typename S,class Function>
3203 void
GetDiagonalHelper(DistMatrix<S,UDiag,VDiag> & d,Int offset,Function func) const3204 GeneralDistMatrix<T,U,V>::GetDiagonalHelper
3205 ( DistMatrix<S,UDiag,VDiag>& d, Int offset, Function func ) const
3206 {
3207     DEBUG_ONLY(CallStackEntry cse("GDM::GetDiagonalHelper"))
3208     d.SetGrid( this->Grid() );
3209     d.SetRoot( this->DiagonalRoot(offset) );
3210     d.AlignCols( this->DiagonalAlign(offset), false );
3211     d.Resize( this->DiagonalLength(offset), 1 );
3212     if( !d.Participating() )
3213         return;
3214 
3215     const Int diagShift = d.ColShift();
3216     const Int diagStride = d.ColStride();
3217     const Int iStart = ( offset>=0 ? diagShift        : diagShift-offset );
3218     const Int jStart = ( offset>=0 ? diagShift+offset : diagShift        );
3219 
3220     const Int colStride = this->ColStride();
3221     const Int rowStride = this->RowStride();
3222     const Int iLocStart = (iStart-this->ColShift()) / colStride;
3223     const Int jLocStart = (jStart-this->RowShift()) / rowStride;
3224 
3225     const Int localDiagLength = d.LocalHeight();
3226     S* dBuf = d.Buffer();
3227     const T* buffer = this->LockedBuffer();
3228     const Int ldim = this->LDim();
3229 
3230     ELEM_PARALLEL_FOR
3231     for( Int k=0; k<localDiagLength; ++k )
3232     {
3233         const Int iLoc = iLocStart + k*(diagStride/colStride);
3234         const Int jLoc = jLocStart + k*(diagStride/rowStride);
3235         func( dBuf[k], buffer[iLoc+jLoc*ldim] );
3236     }
3237 }
3238 
3239 template<typename T,Dist U,Dist V>
3240 template<typename S,class Function>
3241 void
SetDiagonalHelper(const DistMatrix<S,UDiag,VDiag> & d,Int offset,Function func)3242 GeneralDistMatrix<T,U,V>::SetDiagonalHelper
3243 ( const DistMatrix<S,UDiag,VDiag>& d, Int offset, Function func )
3244 {
3245     DEBUG_ONLY(
3246         CallStackEntry cse("GDM::SetDiagonalHelper");
3247         if( !this->DiagonalAlignedWith( d, offset ) )
3248             LogicError("Invalid diagonal alignment");
3249     )
3250     if( !d.Participating() )
3251         return;
3252 
3253     const Int diagShift = d.ColShift();
3254     const Int diagStride = d.ColStride();
3255     const Int iStart = ( offset>=0 ? diagShift        : diagShift-offset );
3256     const Int jStart = ( offset>=0 ? diagShift+offset : diagShift        );
3257 
3258     const Int colStride = this->ColStride();
3259     const Int rowStride = this->RowStride();
3260     const Int iLocStart = (iStart-this->ColShift()) / colStride;
3261     const Int jLocStart = (jStart-this->RowShift()) / rowStride;
3262 
3263     const Int localDiagLength = d.LocalHeight();
3264     const S* dBuf = d.LockedBuffer();
3265     T* buffer = this->Buffer();
3266     const Int ldim = this->LDim();
3267 
3268     ELEM_PARALLEL_FOR
3269     for( Int k=0; k<localDiagLength; ++k )
3270     {
3271         const Int iLoc = iLocStart + k*(diagStride/colStride);
3272         const Int jLoc = jLocStart + k*(diagStride/rowStride);
3273         func( buffer[iLoc+jLoc*ldim], dBuf[k] );
3274     }
3275 }
3276 
3277 // Instantiations for {Int,Real,Complex<Real>} for each Real in {float,double}
3278 // ###########################################################################
3279 
3280 #define DISTPROTO(T,U,V) template class GeneralDistMatrix<T,U,V>
3281 
3282 #define PROTO(T)\
3283   DISTPROTO(T,CIRC,CIRC);\
3284   DISTPROTO(T,MC,  MR  );\
3285   DISTPROTO(T,MC,  STAR);\
3286   DISTPROTO(T,MD,  STAR);\
3287   DISTPROTO(T,MR,  MC  );\
3288   DISTPROTO(T,MR,  STAR);\
3289   DISTPROTO(T,STAR,MC  );\
3290   DISTPROTO(T,STAR,MD  );\
3291   DISTPROTO(T,STAR,MR  );\
3292   DISTPROTO(T,STAR,STAR);\
3293   DISTPROTO(T,STAR,VC  );\
3294   DISTPROTO(T,STAR,VR  );\
3295   DISTPROTO(T,VC,  STAR);\
3296   DISTPROTO(T,VR,  STAR);
3297 
3298 #ifndef ELEM_DISABLE_COMPLEX
3299  #ifndef ELEM_DISABLE_FLOAT
3300   PROTO(Int);
3301   PROTO(float);
3302   PROTO(double);
3303   PROTO(Complex<float>);
3304   PROTO(Complex<double>);
3305  #else // ifndef ELEM_DISABLE_FLOAT
3306   PROTO(Int);
3307   PROTO(double);
3308   PROTO(Complex<double>);
3309  #endif // ifndef ELEM_DISABLE_FLOAT
3310 #else // ifndef ELEM_DISABLE_COMPLEX
3311  #ifndef ELEM_DISABLE_FLOAT
3312   PROTO(Int);
3313   PROTO(float);
3314   PROTO(double);
3315  #else // ifndef ELEM_DISABLE_FLOAT
3316   PROTO(Int);
3317   PROTO(double);
3318  #endif // ifndef ELEM_DISABLE_FLOAT
3319 #endif // ifndef ELEM_DISABLE_COMPLEX
3320 
3321 } // namespace elem
3322