1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #include "elemental-lite.hpp"
10 #include ELEM_TRANSPOSE_INC
11 #include ELEM_ZEROS_INC
12
13 namespace elem {
14
15 // Public section
16 // ##############
17
18 // Constructors and destructors
19 // ============================
20
21 template<typename T,Dist U,Dist V>
GeneralDistMatrix(const elem::Grid & grid,Int root)22 GeneralDistMatrix<T,U,V>::GeneralDistMatrix( const elem::Grid& grid, Int root )
23 : AbstractDistMatrix<T>(grid,root)
24 { }
25
26 template<typename T,Dist U,Dist V>
GeneralDistMatrix(GeneralDistMatrix<T,U,V> && A)27 GeneralDistMatrix<T,U,V>::GeneralDistMatrix( GeneralDistMatrix<T,U,V>&& A )
28 ELEM_NOEXCEPT
29 : AbstractDistMatrix<T>(std::move(A))
30 { }
31
32 // Assignment and reconfiguration
33 // ==============================
34
35 template<typename T,Dist U,Dist V>
36 GeneralDistMatrix<T,U,V>&
operator =(GeneralDistMatrix<T,U,V> && A)37 GeneralDistMatrix<T,U,V>::operator=( GeneralDistMatrix<T,U,V>&& A )
38 {
39 AbstractDistMatrix<T>::operator=( std::move(A) );
40 return *this;
41 }
42
43 template<typename T,Dist U,Dist V>
44 void
AlignColsWith(const elem::DistData & data,bool constrain)45 GeneralDistMatrix<T,U,V>::AlignColsWith
46 ( const elem::DistData& data, bool constrain )
47 {
48 DEBUG_ONLY(CallStackEntry cse("GDM::AlignColsWith"))
49 this->SetGrid( *data.grid );
50 this->SetRoot( data.root );
51 if( data.colDist == U || data.colDist == UPart )
52 this->AlignCols( data.colAlign, constrain );
53 else if( data.rowDist == U || data.rowDist == UPart )
54 this->AlignCols( data.rowAlign, constrain );
55 else if( data.colDist == UScat )
56 this->AlignCols( data.colAlign % this->ColStride(), constrain );
57 else if( data.rowDist == UScat )
58 this->AlignCols( data.rowAlign % this->ColStride(), constrain );
59 DEBUG_ONLY(
60 else if( U != UGath && data.colDist != UGath && data.rowDist != UGath )
61 LogicError("Nonsensical alignment");
62 )
63 }
64
65 template<typename T,Dist U,Dist V>
66 void
AlignRowsWith(const elem::DistData & data,bool constrain)67 GeneralDistMatrix<T,U,V>::AlignRowsWith
68 ( const elem::DistData& data, bool constrain )
69 {
70 DEBUG_ONLY(CallStackEntry cse("GDM::AlignRowsWith"))
71 this->SetGrid( *data.grid );
72 this->SetRoot( data.root );
73 if( data.colDist == V || data.colDist == VPart )
74 this->AlignRows( data.colAlign, constrain );
75 else if( data.rowDist == V || data.rowDist == VPart )
76 this->AlignRows( data.rowAlign, constrain );
77 else if( data.colDist == VScat )
78 this->AlignRows( data.colAlign % this->RowStride(), constrain );
79 else if( data.rowDist == VScat )
80 this->AlignRows( data.rowAlign % this->RowStride(), constrain );
81 DEBUG_ONLY(
82 else if( V != VGath && data.colDist != VGath && data.rowDist != VGath )
83 LogicError("Nonsensical alignment");
84 )
85 }
86
87 template<typename T,Dist U,Dist V>
88 void
Translate(DistMatrix<T,U,V> & A) const89 GeneralDistMatrix<T,U,V>::Translate( DistMatrix<T,U,V>& A ) const
90 {
91 DEBUG_ONLY(CallStackEntry cse("GDM::Translate"))
92 const Grid& g = this->Grid();
93 const Int height = this->Height();
94 const Int width = this->Width();
95 const Int colAlign = this->ColAlign();
96 const Int rowAlign = this->RowAlign();
97 const Int root = this->Root();
98 A.SetGrid( g );
99 if( !A.RootConstrained() )
100 A.SetRoot( root );
101 if( !A.ColConstrained() )
102 A.AlignCols( colAlign, false );
103 if( !A.RowConstrained() )
104 A.AlignRows( rowAlign, false );
105 A.Resize( height, width );
106 if( !g.InGrid() )
107 return;
108
109 const bool aligned = colAlign == A.ColAlign() && rowAlign == A.RowAlign();
110 if( aligned && root == A.Root() )
111 {
112 A.matrix_ = this->matrix_;
113 }
114 else
115 {
116 #ifdef ELEM_UNALIGNED_WARNINGS
117 if( g.Rank() == 0 )
118 std::cerr << "Unaligned [U,V] <- [U,V]" << std::endl;
119 #endif
120 const Int colRank = this->ColRank();
121 const Int rowRank = this->RowRank();
122 const Int crossRank = this->CrossRank();
123 const Int colStride = this->ColStride();
124 const Int rowStride = this->RowStride();
125 const Int maxHeight = MaxLength( height, colStride );
126 const Int maxWidth = MaxLength( width, rowStride );
127 const Int pkgSize = mpi::Pad( maxHeight*maxWidth );
128 T* buffer=0;
129 if( crossRank == root || crossRank == A.Root() )
130 buffer = A.auxMemory_.Require( pkgSize );
131
132 const Int colAlignA = A.ColAlign();
133 const Int rowAlignA = A.RowAlign();
134 const Int localHeightA =
135 Length( height, colRank, colAlignA, colStride );
136 const Int localWidthA = Length( width, rowRank, rowAlignA, rowStride );
137 const Int recvSize = mpi::Pad( localHeightA*localWidthA );
138
139 if( crossRank == root )
140 {
141 // Pack the local data
142 const Int localHeight = this->LocalHeight();
143 const Int localWidth = this->LocalWidth();
144 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
145 MemCopy
146 ( &buffer[jLoc*localHeight], this->LockedBuffer(0,jLoc),
147 localHeight );
148
149 if( !aligned )
150 {
151 // If we were not aligned, then SendRecv over the DistComm
152 const Int toRow = Mod(colRank+colAlignA-colAlign,colStride);
153 const Int toCol = Mod(rowRank+rowAlignA-rowAlign,rowStride);
154 const Int fromRow = Mod(colRank+colAlign-colAlignA,colStride);
155 const Int fromCol = Mod(rowRank+rowAlign-rowAlignA,rowStride);
156 const Int toRank = toRow + toCol*colStride;
157 const Int fromRank = fromRow + fromCol*colStride;
158 mpi::SendRecv
159 ( buffer, pkgSize, toRank, fromRank, this->DistComm() );
160 }
161 }
162 if( root != A.Root() )
163 {
164 // Send to the correct new root over the cross communicator
165 if( crossRank == root )
166 mpi::Send( buffer, recvSize, A.Root(), A.CrossComm() );
167 else if( crossRank == A.Root() )
168 mpi::Recv( buffer, recvSize, root, A.CrossComm() );
169 }
170 // Unpack
171 if( crossRank == A.Root() )
172 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
173 MemCopy
174 ( A.Buffer(0,jLoc), &buffer[jLoc*localHeightA], localHeightA );
175 if( crossRank == root || crossRank == A.Root() )
176 A.auxMemory_.Release();
177 }
178 }
179
180 template<typename T,Dist U,Dist V>
181 void
AllGather(DistMatrix<T,UGath,VGath> & A) const182 GeneralDistMatrix<T,U,V>::AllGather( DistMatrix<T,UGath,VGath>& A ) const
183 {
184 DEBUG_ONLY(CallStackEntry cse("GDM::AllGather"))
185 const Int height = this->Height();
186 const Int width = this->Width();
187 A.SetGrid( this->Grid() );
188 A.Resize( height, width );
189
190 if( this->Participating() )
191 {
192 const Int colStride = this->ColStride();
193 const Int rowStride = this->RowStride();
194 const Int distStride = colStride*rowStride;
195
196 const Int thisLocalHeight = this->LocalHeight();
197 const Int thisLocalWidth = this->LocalWidth();
198 const Int maxLocalHeight = MaxLength(height,colStride);
199 const Int maxLocalWidth = MaxLength(width,rowStride);
200
201 const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
202 T* buffer = A.auxMemory_.Require( (distStride+1)*portionSize );
203 T* sendBuf = &buffer[0];
204 T* recvBuf = &buffer[portionSize];
205
206 // Pack
207 const Int ldim = this->LDim();
208 const T* thisBuf = this->LockedBuffer();
209 ELEM_PARALLEL_FOR
210 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
211 MemCopy
212 ( &sendBuf[jLoc*thisLocalHeight], &thisBuf[jLoc*ldim],
213 thisLocalHeight );
214
215 // Communicate
216 mpi::AllGather
217 ( sendBuf, portionSize, recvBuf, portionSize, this->DistComm() );
218
219 // Unpack
220 T* ABuf = A.Buffer();
221 const Int ALDim = A.LDim();
222 const Int colAlign = this->ColAlign();
223 const Int rowAlign = this->RowAlign();
224 ELEM_OUTER_PARALLEL_FOR
225 for( Int l=0; l<rowStride; ++l )
226 {
227 const Int rowShift = Shift_( l, rowAlign, rowStride );
228 const Int localWidth = Length_( width, rowShift, rowStride );
229 for( Int k=0; k<colStride; ++k )
230 {
231 const T* data = &recvBuf[(k+l*colStride)*portionSize];
232 const Int colShift = Shift_( k, colAlign, colStride );
233 const Int localHeight = Length_( height, colShift, colStride );
234 ELEM_INNER_PARALLEL_FOR
235 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
236 {
237 T* destCol =
238 &ABuf[colShift+(rowShift+jLoc*rowStride)*ALDim];
239 const T* sourceCol = &data[jLoc*localHeight];
240 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
241 destCol[iLoc*colStride] = sourceCol[iLoc];
242 }
243 }
244 }
245 A.auxMemory_.Release();
246 }
247 if( this->Grid().InGrid() && this->CrossComm() != mpi::COMM_SELF )
248 {
249 // Pack from the root
250 const Int localHeight = A.LocalHeight();
251 const Int localWidth = A.LocalWidth();
252 T* buf = A.auxMemory_.Require( localHeight*localWidth );
253 if( this->CrossRank() == this->Root() )
254 {
255 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
256 MemCopy
257 ( &buf[jLoc*localHeight], A.LockedBuffer(0,jLoc), localHeight );
258 }
259
260 // Broadcast from the root
261 mpi::Broadcast
262 ( buf, localHeight*localWidth, this->Root(), this->CrossComm() );
263
264 // Unpack if not the root
265 if( this->CrossRank() != this->Root() )
266 {
267 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
268 MemCopy
269 ( A.Buffer(0,jLoc), &buf[jLoc*localHeight], localHeight );
270 }
271 A.auxMemory_.Release();
272 }
273 }
274
275 template<typename T,Dist U,Dist V>
276 void
ColAllGather(DistMatrix<T,UGath,V> & A) const277 GeneralDistMatrix<T,U,V>::ColAllGather( DistMatrix<T,UGath,V>& A ) const
278 {
279 DEBUG_ONLY(
280 CallStackEntry cse("GDM::ColAllGather");
281 this->AssertSameGrid( A.Grid() );
282 )
283 const Int height = this->Height();
284 const Int width = this->Width();
285 #ifdef ELEM_CACHE_WARNINGS
286 if( height != 1 && this->Grid().Rank() == 0 )
287 {
288 std::cerr <<
289 "The matrix redistribution [* ,V] <- [U,V] potentially causes a "
290 "large amount of cache-thrashing. If possible, avoid it by "
291 "performing the redistribution with a (conjugate)-transpose"
292 << std::endl;
293 }
294 #endif
295 A.AlignRowsAndResize( this->RowAlign(), height, width, false, false );
296
297 if( this->Participating() )
298 {
299 if( this->RowAlign() == A.RowAlign() )
300 {
301 if( height == 1 )
302 {
303 const Int localWidthA = A.LocalWidth();
304 T* bcastBuf = A.auxMemory_.Require( localWidthA );
305
306 if( this->ColRank() == this->ColAlign() )
307 {
308 A.matrix_ = this->LockedMatrix();
309 // Pack
310 const T* ABuf = A.LockedBuffer();
311 const Int ALDim = A.LDim();
312 ELEM_PARALLEL_FOR
313 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
314 bcastBuf[jLoc] = ABuf[jLoc*ALDim];
315 }
316
317 // Broadcast within the column comm
318 mpi::Broadcast
319 ( bcastBuf, localWidthA, this->ColAlign(), this->ColComm() );
320
321 // Unpack
322 T* ABuf = A.Buffer();
323 const Int ALDim = A.LDim();
324 ELEM_PARALLEL_FOR
325 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
326 ABuf[jLoc*ALDim] = bcastBuf[jLoc];
327 A.auxMemory_.Release();
328 }
329 else
330 {
331 const Int colStride = this->ColStride();
332 const Int localWidth = this->LocalWidth();
333 const Int thisLocalHeight = this->LocalHeight();
334 const Int maxLocalHeight = MaxLength(height,colStride);
335 const Int portionSize = mpi::Pad( maxLocalHeight*localWidth );
336
337 T* buffer = A.auxMemory_.Require( (colStride+1)*portionSize );
338 T* sendBuf = &buffer[0];
339 T* recvBuf = &buffer[portionSize];
340
341 // Pack
342 const Int ldim = this->LDim();
343 const T* thisBuf = this->LockedBuffer();
344 ELEM_PARALLEL_FOR
345 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
346 {
347 const T* thisCol = &thisBuf[jLoc*ldim];
348 T* sendBufCol = &sendBuf[jLoc*thisLocalHeight];
349 MemCopy( sendBufCol, thisCol, thisLocalHeight );
350 }
351
352 // Communicate
353 mpi::AllGather
354 ( sendBuf, portionSize, recvBuf, portionSize, this->ColComm() );
355
356 // Unpack
357 T* ABuf = A.Buffer();
358 const Int ALDim = A.LDim();
359 const Int colAlign = this->ColAlign();
360 ELEM_OUTER_PARALLEL_FOR
361 for( Int k=0; k<colStride; ++k )
362 {
363 const T* data = &recvBuf[k*portionSize];
364 const Int colShift = Shift_( k, colAlign, colStride );
365 const Int localHeight =
366 Length_( height, colShift, colStride );
367 ELEM_INNER_PARALLEL_FOR
368 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
369 {
370 T* destCol = &ABuf[colShift+jLoc*ALDim];
371 const T* sourceCol = &data[jLoc*localHeight];
372 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
373 destCol[iLoc*colStride] = sourceCol[iLoc];
374 }
375 }
376 A.auxMemory_.Release();
377 }
378 }
379 else
380 {
381 #ifdef ELEM_UNALIGNED_WARNINGS
382 if( this->Grid().Rank() == 0 )
383 std::cerr << "Unaligned [U,V] -> [* ,V]." << std::endl;
384 #endif
385 const Int colStride = this->ColStride();
386 const Int rowStride = this->RowStride();
387 const Int rowRank = this->RowRank();
388
389 const Int rowAlign = this->RowAlign();
390 const Int rowAlignA = A.RowAlign();
391 const Int sendRowRank =
392 (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
393 const Int recvRowRank =
394 (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
395
396 if( height == 1 )
397 {
398 const Int localWidthA = A.LocalWidth();
399 T* bcastBuf;
400
401 if( this->ColRank() == this->ColAlign() )
402 {
403 const Int localWidth = this->LocalWidth();
404
405 T* buffer = A.auxMemory_.Require( localWidth+localWidthA );
406 T* sendBuf = &buffer[0];
407 bcastBuf = &buffer[localWidth];
408
409 // Pack
410 const T* thisBuf = this->LockedBuffer();
411 const Int ldim = this->LDim();
412 ELEM_PARALLEL_FOR
413 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
414 sendBuf[jLoc] = thisBuf[jLoc*ldim];
415
416 // Communicate
417 mpi::SendRecv
418 ( sendBuf, localWidth, sendRowRank,
419 bcastBuf, localWidthA, recvRowRank, this->RowComm() );
420 }
421 else
422 {
423 bcastBuf = A.auxMemory_.Require( localWidthA );
424 }
425
426 // Communicate
427 mpi::Broadcast
428 ( bcastBuf, localWidthA, this->ColAlign(), this->ColComm() );
429
430 // Unpack
431 T* ABuf = A.Buffer();
432 const Int ALDim = A.LDim();
433 ELEM_PARALLEL_FOR
434 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
435 ABuf[jLoc*ALDim] = bcastBuf[jLoc];
436 A.auxMemory_.Release();
437 }
438 else
439 {
440 const Int thisLocalWidth = this->LocalWidth();
441 const Int localWidthA = A.LocalWidth();
442 const Int thisLocalHeight = this->LocalHeight();
443 const Int maxLocalHeight = MaxLength(height,colStride);
444 const Int maxLocalWidth = MaxLength(width,rowStride);
445 const Int portionSize =
446 mpi::Pad( maxLocalHeight*maxLocalWidth );
447
448 T* buffer = A.auxMemory_.Require( (colStride+1)*portionSize );
449 T* firstBuf = &buffer[0];
450 T* secondBuf = &buffer[portionSize];
451
452 // Pack
453 const Int ldim = this->LDim();
454 const T* thisBuf = this->LockedBuffer();
455 ELEM_PARALLEL_FOR
456 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
457 {
458 const T* thisCol = &thisBuf[jLoc*ldim];
459 T* secondBufCol = &secondBuf[jLoc*thisLocalHeight];
460 MemCopy( secondBufCol, thisCol, thisLocalHeight );
461 }
462
463 // Realign
464 mpi::SendRecv
465 ( secondBuf, portionSize, sendRowRank,
466 firstBuf, portionSize, recvRowRank, this->RowComm() );
467
468 // AllGather the aligned data
469 mpi::AllGather
470 ( firstBuf, portionSize,
471 secondBuf, portionSize, this->ColComm() );
472
473 // Unpack the contents of each member of the column team
474 T* ABuf = A.Buffer();
475 const Int ALDim = A.LDim();
476 const Int colAlign = this->ColAlign();
477 ELEM_OUTER_PARALLEL_FOR
478 for( Int k=0; k<colStride; ++k )
479 {
480 const T* data = &secondBuf[k*portionSize];
481 const Int colShift = Shift_( k, colAlign, colStride );
482 const Int localHeight =
483 Length_( height, colShift, colStride );
484 ELEM_INNER_PARALLEL_FOR
485 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
486 {
487 T* destCol = &ABuf[colShift+jLoc*ALDim];
488 const T* sourceCol = &data[jLoc*localHeight];
489 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
490 destCol[iLoc*colStride] = sourceCol[iLoc];
491 }
492 }
493 A.auxMemory_.Release();
494 }
495 }
496 }
497 if( this->Grid().InGrid() && this->CrossComm() != mpi::COMM_SELF )
498 {
499 // Pack from the root
500 const Int localHeight = A.LocalHeight();
501 const Int localWidth = A.LocalWidth();
502 T* buf = A.auxMemory_.Require( localHeight*localWidth );
503 if( this->CrossRank() == this->Root() )
504 {
505 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
506 MemCopy
507 ( &buf[jLoc*localHeight], A.LockedBuffer(0,jLoc), localHeight );
508 }
509
510 // Broadcast from the root
511 mpi::Broadcast
512 ( buf, localHeight*localWidth, this->Root(), this->CrossComm() );
513
514 // Unpack if not the root
515 if( this->CrossRank() != this->Root() )
516 {
517 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
518 MemCopy
519 ( A.Buffer(0,jLoc), &buf[jLoc*localHeight], localHeight );
520 }
521 A.auxMemory_.Release();
522 }
523 }
524
525 template<typename T,Dist U,Dist V>
526 void
RowAllGather(DistMatrix<T,U,VGath> & A) const527 GeneralDistMatrix<T,U,V>::RowAllGather( DistMatrix<T,U,VGath>& A ) const
528 {
529 DEBUG_ONLY(
530 CallStackEntry cse("GDM::RowAllGather");
531 this->AssertSameGrid( A.Grid() );
532 )
533 const Int height = this->Height();
534 const Int width = this->Width();
535 A.AlignColsAndResize( this->ColAlign(), height, width, false, false );
536
537 if( this->Participating() )
538 {
539 if( this->ColAlign() == A.ColAlign() )
540 {
541 if( width == 1 )
542 {
543 if( this->RowRank() == this->RowAlign() )
544 A.matrix_ = this->LockedMatrix();
545 mpi::Broadcast
546 ( A.matrix_.Buffer(), A.LocalHeight(), this->RowAlign(),
547 this->RowComm() );
548 }
549 else
550 {
551 const Int rowStride = this->RowStride();
552 const Int thisLocalWidth = this->LocalWidth();
553 const Int localHeight = this->LocalHeight();
554 const Int maxLocalWidth = MaxLength(width,rowStride);
555
556 const Int portionSize = mpi::Pad( localHeight*maxLocalWidth );
557 T* buffer = A.auxMemory_.Require( (rowStride+1)*portionSize );
558 T* sendBuf = &buffer[0];
559 T* recvBuf = &buffer[portionSize];
560
561 // Pack
562 const Int ldim = this->LDim();
563 const T* thisBuf = this->LockedBuffer();
564 ELEM_PARALLEL_FOR
565 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
566 {
567 const T* thisCol = &thisBuf[jLoc*ldim];
568 T* sendBufCol = &sendBuf[jLoc*localHeight];
569 MemCopy( sendBufCol, thisCol, localHeight );
570 }
571
572 // Communicate
573 mpi::AllGather
574 ( sendBuf, portionSize, recvBuf, portionSize, this->RowComm() );
575
576 // Unpack
577 T* ABuf = A.Buffer();
578 const Int ALDim = A.LDim();
579 const Int rowAlign = this->RowAlign();
580 ELEM_OUTER_PARALLEL_FOR
581 for( Int k=0; k<rowStride; ++k )
582 {
583 const T* data = &recvBuf[k*portionSize];
584 const Int rowShift = Shift_( k, rowAlign, rowStride );
585 const Int localWidth =
586 Length_( width, rowShift, rowStride );
587 ELEM_INNER_PARALLEL_FOR
588 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
589 {
590 const T* dataCol = &data[jLoc*localHeight];
591 T* ACol = &ABuf[(rowShift+jLoc*rowStride)*ALDim];
592 MemCopy( ACol, dataCol, localHeight );
593 }
594 }
595 A.auxMemory_.Release();
596 }
597 }
598 else
599 {
600 #ifdef ELEM_UNALIGNED_WARNINGS
601 if( this->Grid().Rank() == 0 )
602 std::cerr << "Unaligned RowAllGather." << std::endl;
603 #endif
604 const Int colStride = this->ColStride();
605 const Int rowStride = this->RowStride();
606 const Int colRank = this->ColRank();
607
608 const Int colAlign = this->ColAlign();
609 const Int colAlignA = A.ColAlign();
610 const Int sendColRank =
611 (colRank+colStride+colAlignA-colAlign) % colStride;
612 const Int recvColRank =
613 (colRank+colStride+colAlign-colAlignA) % colStride;
614
615 if( width == 1 )
616 {
617 const Int localHeightA = A.LocalHeight();
618 if( this->RowRank() == this->RowAlign() )
619 {
620 const Int localHeight = this->LocalHeight();
621 T* buffer = A.auxMemory_.Require( localHeight );
622
623 // Pack
624 const T* thisCol = this->LockedBuffer();
625 MemCopy( buffer, thisCol, localHeight );
626
627 // Realign
628 mpi::SendRecv
629 ( buffer, localHeight, sendColRank,
630 A.matrix_.Buffer(), localHeightA, recvColRank,
631 this->ColComm() );
632 A.auxMemory_.Release();
633 }
634
635 // Perform the row broadcast
636 mpi::Broadcast
637 ( A.matrix_.Buffer(), localHeightA, this->RowAlign(),
638 this->RowComm() );
639 }
640 else
641 {
642 const Int localHeight = this->LocalHeight();
643 const Int thisLocalWidth = this->LocalWidth();
644 const Int localHeightA = A.LocalHeight();
645 const Int maxLocalHeight = MaxLength(height,colStride);
646 const Int maxLocalWidth = MaxLength(width,rowStride);
647
648 const Int portionSize =
649 mpi::Pad( maxLocalHeight*maxLocalWidth );
650 T* buffer = A.auxMemory_.Require( (rowStride+1)*portionSize );
651 T* firstBuf = &buffer[0];
652 T* secondBuf = &buffer[portionSize];
653
654 // Pack
655 const Int ldim = this->LDim();
656 const T* thisBuf = this->LockedBuffer();
657 ELEM_PARALLEL_FOR
658 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
659 {
660 const T* thisCol = &thisBuf[jLoc*ldim];
661 T* secondBufCol = &secondBuf[jLoc*localHeight];
662 MemCopy( secondBufCol, thisCol, localHeight );
663 }
664
665 // Realign
666 mpi::SendRecv
667 ( secondBuf, portionSize, sendColRank,
668 firstBuf, portionSize, recvColRank, this->ColComm() );
669
670 // Perform the row AllGather
671 mpi::AllGather
672 ( firstBuf, portionSize,
673 secondBuf, portionSize, this->RowComm() );
674
675 // Unpack
676 T* ABuf = A.Buffer();
677 const Int ALDim = A.LDim();
678 const Int rowAlign = this->RowAlign();
679 ELEM_OUTER_PARALLEL_FOR
680 for( Int k=0; k<rowStride; ++k )
681 {
682 const T* data = &secondBuf[k*portionSize];
683 const Int rowShift = Shift_( k, rowAlign, rowStride );
684 const Int localWidth =
685 Length_( width, rowShift, rowStride );
686 ELEM_INNER_PARALLEL_FOR
687 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
688 {
689 const T* dataCol = &data[jLoc*localHeightA];
690 T* ACol = &ABuf[(rowShift+jLoc*rowStride)*ALDim];
691 MemCopy( ACol, dataCol, localHeightA );
692 }
693 }
694 A.auxMemory_.Release();
695 }
696 }
697 }
698 if( this->Grid().InGrid() && this->CrossComm() != mpi::COMM_SELF )
699 {
700 // Pack from the root
701 const Int localHeight = A.LocalHeight();
702 const Int localWidth = A.LocalWidth();
703 T* buf = A.auxMemory_.Require( localHeight*localWidth );
704 if( this->CrossRank() == this->Root() )
705 {
706 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
707 MemCopy
708 ( &buf[jLoc*localHeight], A.LockedBuffer(0,jLoc), localHeight );
709 }
710
711 // Broadcast from the root
712 mpi::Broadcast
713 ( buf, localHeight*localWidth, this->Root(), this->CrossComm() );
714
715 // Unpack if not the root
716 if( this->CrossRank() != this->Root() )
717 {
718 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
719 MemCopy
720 ( A.Buffer(0,jLoc), &buf[jLoc*localHeight], localHeight );
721 }
722 A.auxMemory_.Release();
723 }
724 }
725
726 template<typename T,Dist U,Dist V>
727 void
PartialColAllGather(DistMatrix<T,UPart,V> & A) const728 GeneralDistMatrix<T,U,V>::PartialColAllGather( DistMatrix<T,UPart,V>& A ) const
729 {
730 DEBUG_ONLY(
731 CallStackEntry cse("GDM::PartialColAllGather");
732 this->AssertSameGrid( A.Grid() );
733 )
734 const Int height = this->Height();
735 const Int width = this->Width();
736 #ifdef ELEM_VECTOR_WARNINGS
737 if( width == 1 && this->Grid().Rank() == 0 )
738 {
739 std::cerr <<
740 "The vector version of PartialColAllGather is not yet written but "
741 "would only require modifying the vector version of "
742 "PartialRowAllGather" << std::endl;
743 }
744 #endif
745 #ifdef ELEM_CACHE_WARNINGS
746 if( width && this->Grid().Rank() == 0 )
747 {
748 std::cerr <<
749 "PartialColAllGather potentially causes a large amount of cache-"
750 "thrashing. If possible, avoid it by performing the redistribution"
751 "on the (conjugate-)transpose" << std::endl;
752 }
753 #endif
754 A.AlignColsAndResize
755 ( this->ColAlign()%A.ColStride(), height, width, false, false );
756 if( !this->Participating() )
757 return;
758
759 DEBUG_ONLY(
760 if( this->LocalWidth() != this->Width() )
761 LogicError("This routine assumes rows are not distributed");
762 )
763 const T* thisBuf = this->LockedBuffer();
764 const Int ldim = this->LDim();
765 T* ABuf = A.Buffer();
766 const Int ALDim = A.LDim();
767
768 const Int colAlign = this->ColAlign();
769 const Int colAlignA = A.ColAlign();
770 const Int colStride = this->ColStride();
771 const Int colStrideUnion = this->PartialUnionColStride();
772 const Int colStridePart = this->PartialColStride();
773 const Int colRankPart = this->PartialColRank();
774 const Int colShiftA = A.ColShift();
775
776 const Int thisLocalHeight = this->LocalHeight();
777 const Int maxLocalHeight = MaxLength(height,colStride);
778 const Int portionSize = mpi::Pad( maxLocalHeight*width );
779 T* buffer = A.auxMemory_.Require( (colStrideUnion+1)*portionSize );
780 T* firstBuf = &buffer[0];
781 T* secondBuf = &buffer[portionSize];
782
783 if( colAlignA == colAlign % colStridePart )
784 {
785 // Pack
786 ELEM_PARALLEL_FOR
787 for( Int j=0; j<width; ++j )
788 {
789 const T* thisCol = &thisBuf[j*ldim];
790 T* firstBufCol = &firstBuf[j*thisLocalHeight];
791 MemCopy( firstBufCol, thisCol, thisLocalHeight );
792 }
793
794 // Communicate
795 mpi::AllGather
796 ( firstBuf, portionSize, secondBuf, portionSize,
797 this->PartialUnionColComm() );
798
799 // Unpack
800 ELEM_OUTER_PARALLEL_FOR
801 for( Int k=0; k<colStrideUnion; ++k )
802 {
803 const T* data = &secondBuf[k*portionSize];
804 const Int colShift =
805 Shift_( colRankPart+k*colStridePart, colAlign, colStride );
806 const Int colOffset = (colShift-colShiftA) / colStridePart;
807 const Int localHeight = Length_( height, colShift, colStride );
808 ELEM_INNER_PARALLEL_FOR
809 for( Int j=0; j<width; ++j )
810 {
811 const T* dataCol = &data[j*localHeight];
812 T* ACol = &ABuf[colOffset+j*ALDim];
813 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
814 ACol[iLoc*colStrideUnion] = dataCol[iLoc];
815 }
816 }
817 }
818 else
819 {
820 #ifdef ELEM_UNALIGNED_WARNINGS
821 if( this->Grid().Rank() == 0 )
822 std::cerr << "Unaligned PartialColAllGather" << std::endl;
823 #endif
824 // Perform a SendRecv to match the row alignments
825 const Int colRank = this->ColRank();
826 const Int sendColRank =
827 (colRank+colStride+colAlignA-colAlign) % colStride;
828 const Int recvColRank =
829 (colRank+colStride+colAlign-colAlignA) % colStride;
830 ELEM_PARALLEL_FOR
831 for( Int j=0; j<width; ++j )
832 {
833 const T* thisCol = &thisBuf[j*ldim];
834 T* secondBufCol = &secondBuf[j*thisLocalHeight];
835 MemCopy( secondBufCol, thisCol, thisLocalHeight );
836 }
837 mpi::SendRecv
838 ( secondBuf, portionSize, sendColRank,
839 firstBuf, portionSize, recvColRank, this->ColComm() );
840
841 // Use the SendRecv as an input to the partial union AllGather
842 mpi::AllGather
843 ( firstBuf, portionSize,
844 secondBuf, portionSize, this->PartialUnionColComm() );
845
846 // Unpack
847 ELEM_OUTER_PARALLEL_FOR
848 for( Int k=0; k<colStrideUnion; ++k )
849 {
850 const T* data = &secondBuf[k*portionSize];
851 const Int colShift =
852 Shift_( colRankPart+colStridePart*k, colAlignA, colStride );
853 const Int colOffset = (colShift-colShiftA) / colStridePart;
854 const Int localHeight = Length_( height, colShift, colStride );
855 ELEM_INNER_PARALLEL_FOR
856 for( Int j=0; j<width; ++j )
857 {
858 const T* dataCol = &data[j*localHeight];
859 T* ACol = &ABuf[colOffset+j*ALDim];
860 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
861 ACol[iLoc*colStrideUnion] = dataCol[iLoc];
862 }
863 }
864 }
865 A.auxMemory_.Release();
866 }
867
868 template<typename T,Dist U,Dist V>
869 void
PartialRowAllGather(DistMatrix<T,U,VPart> & A) const870 GeneralDistMatrix<T,U,V>::PartialRowAllGather( DistMatrix<T,U,VPart>& A ) const
871 {
872 DEBUG_ONLY(
873 CallStackEntry cse("GDM::PartialRowAllGather");
874 this->AssertSameGrid( A.Grid() );
875 )
876 const Int height = this->Height();
877 const Int width = this->Width();
878 A.AlignRowsAndResize
879 ( this->RowAlign()%A.RowStride(), height, width, false, false );
880 if( !this->Participating() )
881 return;
882
883 DEBUG_ONLY(
884 if( this->LocalHeight() != this->Height() )
885 LogicError("This routine assumes columns are not distributed");
886 )
887 const T* thisBuf = this->LockedBuffer();
888 const Int ldim = this->LDim();
889 T* ABuf = A.Buffer();
890 const Int ALDim = A.LDim();
891
892 const Int rowAlign = this->RowAlign();
893 const Int rowAlignA = A.RowAlign();
894 const Int rowStride = this->RowStride();
895 const Int rowStrideUnion = this->PartialUnionRowStride();
896 const Int rowStridePart = this->PartialRowStride();
897 const Int rowRankPart = this->PartialRowRank();
898 const Int rowShiftA = A.RowShift();
899
900 const Int thisLocalWidth = this->LocalWidth();
901 const Int maxLocalWidth = MaxLength(width,rowStride);
902 const Int portionSize = mpi::Pad( height*maxLocalWidth );
903 T* buffer = A.auxMemory_.Require( (rowStrideUnion+1)*portionSize );
904 T* firstBuf = &buffer[0];
905 T* secondBuf = &buffer[portionSize];
906
907 if( rowAlignA == rowAlign % rowStridePart )
908 {
909 // Pack
910 ELEM_PARALLEL_FOR
911 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
912 {
913 const T* thisCol = &thisBuf[jLoc*ldim];
914 T* firstBufCol = &firstBuf[jLoc*height];
915 MemCopy( firstBufCol, thisCol, height );
916 }
917
918 // Communicate
919 mpi::AllGather
920 ( firstBuf, portionSize, secondBuf, portionSize,
921 this->PartialUnionRowComm() );
922
923 // Unpack
924 ELEM_OUTER_PARALLEL_FOR
925 for( Int k=0; k<rowStrideUnion; ++k )
926 {
927 const T* data = &secondBuf[k*portionSize];
928 const Int rowShift =
929 Shift_( rowRankPart+k*rowStridePart, rowAlign, rowStride );
930 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
931 const Int localWidth = Length_( width, rowShift, rowStride );
932 ELEM_INNER_PARALLEL_FOR
933 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
934 {
935 const T* dataCol = &data[jLoc*height];
936 T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
937 MemCopy( ACol, dataCol, height );
938 }
939 }
940 }
941 else
942 {
943 #ifdef ELEM_UNALIGNED_WARNINGS
944 if( this->Grid().Rank() == 0 )
945 std::cerr << "Unaligned PartialRowAllGather" << std::endl;
946 #endif
947 // Perform a SendRecv to match the row alignments
948 const Int rowRank = this->RowRank();
949 const Int sendRowRank =
950 (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
951 const Int recvRowRank =
952 (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
953 ELEM_PARALLEL_FOR
954 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
955 {
956 const T* thisCol = &thisBuf[jLoc*ldim];
957 T* secondBufCol = &secondBuf[jLoc*height];
958 MemCopy( secondBufCol, thisCol, height );
959 }
960 mpi::SendRecv
961 ( secondBuf, portionSize, sendRowRank,
962 firstBuf, portionSize, recvRowRank, this->RowComm() );
963
964 // Use the SendRecv as an input to the partial union AllGather
965 mpi::AllGather
966 ( firstBuf, portionSize,
967 secondBuf, portionSize, this->PartialUnionRowComm() );
968
969 // Unpack
970 ELEM_OUTER_PARALLEL_FOR
971 for( Int k=0; k<rowStrideUnion; ++k )
972 {
973 const T* data = &secondBuf[k*portionSize];
974 const Int rowShift =
975 Shift_( rowRankPart+rowStridePart*k, rowAlignA, rowStride );
976 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
977 const Int localWidth = Length_( width, rowShift, rowStride );
978 ELEM_INNER_PARALLEL_FOR
979 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
980 {
981 const T* dataCol = &data[jLoc*height];
982 T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
983 MemCopy( ACol, dataCol, height );
984 }
985 }
986 }
987 A.auxMemory_.Release();
988 }
989
990 template<typename T,Dist U,Dist V>
991 void
FilterFrom(const DistMatrix<T,UGath,VGath> & A)992 GeneralDistMatrix<T,U,V>::FilterFrom( const DistMatrix<T,UGath,VGath>& A )
993 {
994 DEBUG_ONLY(
995 CallStackEntry cse("GDM::FilterFrom");
996 this->AssertSameGrid( A.Grid() );
997 )
998 const Int height = A.Height();
999 const Int width = A.Width();
1000 this->Resize( height, width );
1001 if( !this->Participating() )
1002 return;
1003
1004 const Int colStride = this->ColStride();
1005 const Int rowStride = this->RowStride();
1006 const Int colShift = this->ColShift();
1007 const Int rowShift = this->RowShift();
1008
1009 const Int localHeight = this->LocalHeight();
1010 const Int localWidth = this->LocalWidth();
1011
1012 T* thisBuf = this->Buffer();
1013 const Int ldim = this->LDim();
1014 const T* ABuf = A.LockedBuffer();
1015 const Int ALDim = A.LDim();
1016 ELEM_PARALLEL_FOR
1017 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1018 {
1019 T* thisCol = &thisBuf[jLoc*ldim];
1020 const T* ACol = &ABuf[colShift+(rowShift+jLoc*rowStride)*ALDim];
1021 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1022 thisCol[iLoc] = ACol[iLoc*colStride];
1023 }
1024 }
1025
1026 template<typename T,Dist U,Dist V>
1027 void
ColFilterFrom(const DistMatrix<T,UGath,V> & A)1028 GeneralDistMatrix<T,U,V>::ColFilterFrom( const DistMatrix<T,UGath,V>& A )
1029 {
1030 DEBUG_ONLY(
1031 CallStackEntry cse("GDM::ColFilterFrom");
1032 this->AssertSameGrid( A.Grid() );
1033 )
1034 const Int height = A.Height();
1035 const Int width = A.Width();
1036 this->AlignRowsAndResize( A.RowAlign(), height, width, false, false );
1037 if( !this->Participating() )
1038 return;
1039
1040 const Int colStride = this->ColStride();
1041 const Int colShift = this->ColShift();
1042 const Int rowAlign = this->RowAlign();
1043 const Int rowAlignA = A.RowAlign();
1044
1045 const Int localHeight = this->LocalHeight();
1046 const Int localWidth = this->LocalWidth();
1047
1048 T* thisBuf = this->Buffer();
1049 const Int ldim = this->LDim();
1050 const T* ABuf = A.LockedBuffer();
1051 const Int ALDim = A.LDim();
1052
1053 if( rowAlign == rowAlignA )
1054 {
1055 ELEM_PARALLEL_FOR
1056 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1057 {
1058 T* thisCol = &thisBuf[jLoc*ldim];
1059 const T* ACol = &ABuf[colShift+jLoc*ALDim];
1060 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1061 thisCol[iLoc] = ACol[iLoc*colStride];
1062 }
1063 }
1064 else
1065 {
1066 #ifdef ELEM_UNALIGNED_WARNINGS
1067 if( this->Grid().Rank() == 0 )
1068 std::cerr << "Unaligned ColFilterFrom" << std::endl;
1069 #endif
1070 const Int rowStride = this->RowStride();
1071 const Int rowRank = this->RowRank();
1072 const Int sendRowRank =
1073 (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
1074 const Int recvRowRank =
1075 (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
1076 const Int localWidthA = A.LocalWidth();
1077 const Int sendSize = localHeight*localWidthA;
1078 const Int recvSize = localHeight*localWidth;
1079 T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1080 T* sendBuf = &buffer[0];
1081 T* recvBuf = &buffer[sendSize];
1082
1083 // Pack
1084 ELEM_PARALLEL_FOR
1085 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1086 {
1087 T* sendCol = &sendBuf[jLoc*localHeight];
1088 const T* ACol = &ABuf[colShift+jLoc*ALDim];
1089 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1090 sendCol[iLoc] = ACol[iLoc*colStride];
1091 }
1092
1093 // Realign
1094 mpi::SendRecv
1095 ( sendBuf, sendSize, sendRowRank,
1096 recvBuf, recvSize, recvRowRank, this->RowComm() );
1097
1098 // Unpack
1099 ELEM_PARALLEL_FOR
1100 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1101 MemCopy
1102 ( &thisBuf[jLoc*ldim], &recvBuf[jLoc*localHeight], localHeight );
1103 this->auxMemory_.Release();
1104 }
1105 }
1106
1107 template<typename T,Dist U,Dist V>
1108 void
RowFilterFrom(const DistMatrix<T,U,VGath> & A)1109 GeneralDistMatrix<T,U,V>::RowFilterFrom( const DistMatrix<T,U,VGath>& A )
1110 {
1111 DEBUG_ONLY(
1112 CallStackEntry cse("GDM::RowFilterFrom");
1113 this->AssertSameGrid( A.Grid() );
1114 )
1115 const Int height = A.Height();
1116 const Int width = A.Width();
1117 this->AlignColsAndResize( A.ColAlign(), height, width, false, false );
1118 if( !this->Participating() )
1119 return;
1120
1121 const Int colAlign = this->ColAlign();
1122 const Int colAlignA = A.ColAlign();
1123 const Int rowStride = this->RowStride();
1124 const Int rowShift = this->RowShift();
1125
1126 const Int localHeight = this->LocalHeight();
1127 const Int localWidth = this->LocalWidth();
1128
1129 T* thisBuf = this->Buffer();
1130 const Int ldim = this->LDim();
1131 const T* ABuf = A.LockedBuffer();
1132 const Int ALDim = A.LDim();
1133
1134 if( colAlign == colAlignA )
1135 {
1136 ELEM_PARALLEL_FOR
1137 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1138 {
1139 T* thisCol = &thisBuf[jLoc*ldim];
1140 const T* ACol = &ABuf[(rowShift+jLoc*rowStride)*ALDim];
1141 MemCopy( thisCol, ACol, localHeight );
1142 }
1143 }
1144 else
1145 {
1146 #ifdef ELEM_UNALIGNED_WARNINGS
1147 if( this->Grid().Rank() == 0 )
1148 std::cerr << "Unaligned RowFilterFrom" << std::endl;
1149 #endif
1150 const Int colRank = this->ColRank();
1151 const Int colStride = this->ColStride();
1152 const Int sendColRank =
1153 (colRank+colStride+colAlign-colAlignA) % colStride;
1154 const Int recvColRank =
1155 (colRank+colStride+colAlignA-colAlign) % colStride;
1156 const Int localHeightA = A.LocalHeight();
1157 const Int sendSize = localHeightA*localWidth;
1158 const Int recvSize = localHeight *localWidth;
1159
1160 T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1161 T* sendBuf = &buffer[0];
1162 T* recvBuf = &buffer[sendSize];
1163
1164 // Pack
1165 ELEM_PARALLEL_FOR
1166 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1167 MemCopy
1168 ( &sendBuf[jLoc*localHeightA],
1169 &ABuf[(rowShift+jLoc*rowStride)*ALDim], localHeightA );
1170
1171 // Realign
1172 mpi::SendRecv
1173 ( sendBuf, sendSize, sendColRank,
1174 recvBuf, recvSize, recvColRank, this->ColComm() );
1175
1176 // Unpack
1177 ELEM_PARALLEL_FOR
1178 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1179 MemCopy
1180 ( &thisBuf[jLoc*ldim], &recvBuf[jLoc*localHeight], localHeight );
1181 this->auxMemory_.Release();
1182 }
1183 }
1184
1185 template<typename T,Dist U,Dist V>
1186 void
PartialColFilterFrom(const DistMatrix<T,UPart,V> & A)1187 GeneralDistMatrix<T,U,V>::PartialColFilterFrom( const DistMatrix<T,UPart,V>& A )
1188 {
1189 DEBUG_ONLY(
1190 CallStackEntry cse("GDM::PartialColFilterFrom");
1191 this->AssertSameGrid( A.Grid() );
1192 )
1193 const Int height = A.Height();
1194 const Int width = A.Width();
1195 this->AlignColsAndResize( A.ColAlign(), height, width, false, false );
1196 if( !this->Participating() )
1197 return;
1198
1199 const Int colAlign = this->ColAlign();
1200 const Int colAlignA = A.ColAlign();
1201 const Int colStride = this->ColStride();
1202 const Int colStridePart = this->PartialColStride();
1203 const Int colStrideUnion = this->PartialUnionColStride();
1204 const Int colShiftA = A.ColShift();
1205
1206 const Int localHeight = this->LocalHeight();
1207
1208 T* thisBuf = this->Buffer();
1209 const Int ldim = this->LDim();
1210 const T* ABuf = A.LockedBuffer();
1211 const Int ALDim = A.LDim();
1212 if( colAlign % colStridePart == colAlignA )
1213 {
1214 const Int colShift = this->ColShift();
1215 const Int colOffset = (colShift-colShiftA) / colStridePart;
1216 ELEM_PARALLEL_FOR
1217 for( Int j=0; j<width; ++j )
1218 {
1219 T* thisCol = &thisBuf[j*ldim];
1220 const T* ACol = &ABuf[colOffset+j*ALDim];
1221 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1222 thisCol[iLoc] = ACol[iLoc*colStrideUnion];
1223 }
1224 }
1225 else
1226 {
1227 #ifdef ELEM_UNALIGNED_WARNINGS
1228 if( this->Grid().Rank() == 0 )
1229 std::cerr << "Unaligned PartialColFilterFrom" << std::endl;
1230 #endif
1231 const Int colRankPart = this->PartialColRank();
1232 const Int colRankUnion = this->PartialUnionColRank();
1233 const Int colShiftA = A.ColShift();
1234
1235 // Realign
1236 // -------
1237 const Int sendColRankPart =
1238 (colRankPart+colStridePart+(colAlign%colStridePart)-colAlignA) %
1239 colStridePart;
1240 const Int recvColRankPart =
1241 (colRankPart+colStridePart+colAlignA-(colAlign%colStridePart)) %
1242 colStridePart;
1243 const Int sendColRank = sendColRankPart + colStridePart*colRankUnion;
1244 const Int sendColShift = Shift( sendColRank, colAlign, colStride );
1245 const Int sendColOffset = (sendColShift-colShiftA) / colStridePart;
1246 const Int localHeightSend = Length( height, sendColShift, colStride );
1247 const Int sendSize = localHeightSend*width;
1248 const Int recvSize = localHeight *width;
1249 T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1250 T* sendBuf = &buffer[0];
1251 T* recvBuf = &buffer[sendSize];
1252 // Pack
1253 ELEM_PARALLEL_FOR
1254 for( Int j=0; j<width; ++j )
1255 {
1256 T* sendCol = &sendBuf[j*localHeightSend];
1257 const T* ACol = &ABuf[sendColOffset+j*ALDim];
1258 for( Int iLoc=0; iLoc<localHeightSend; ++iLoc )
1259 sendCol[iLoc] = ACol[iLoc*colStrideUnion];
1260 }
1261 // Change the column alignment
1262 mpi::SendRecv
1263 ( sendBuf, sendSize, sendColRankPart,
1264 recvBuf, recvSize, recvColRankPart, this->PartialColComm() );
1265
1266 // Unpack
1267 // ------
1268 ELEM_PARALLEL_FOR
1269 for( Int j=0; j<width; ++j )
1270 {
1271 const T* recvCol = &recvBuf[j*localHeight];
1272 T* thisCol = &thisBuf[j*ldim];
1273 MemCopy( thisCol, recvCol, localHeight );
1274 }
1275 this->auxMemory_.Release();
1276 }
1277 }
1278
1279 template<typename T,Dist U,Dist V>
1280 void
PartialRowFilterFrom(const DistMatrix<T,U,VPart> & A)1281 GeneralDistMatrix<T,U,V>::PartialRowFilterFrom( const DistMatrix<T,U,VPart>& A )
1282 {
1283 DEBUG_ONLY(
1284 CallStackEntry cse("GDM::PartialRowFilterFrom");
1285 this->AssertSameGrid( A.Grid() );
1286 )
1287 const Int height = A.Height();
1288 const Int width = A.Width();
1289 this->AlignRowsAndResize( A.RowAlign(), height, width, false, false );
1290 if( !this->Participating() )
1291 return;
1292
1293 const Int rowAlign = this->RowAlign();
1294 const Int rowAlignA = A.RowAlign();
1295 const Int rowStride = this->RowStride();
1296 const Int rowStridePart = this->PartialRowStride();
1297 const Int rowStrideUnion = this->PartialUnionRowStride();
1298 const Int rowShiftA = A.RowShift();
1299
1300 const Int localWidth = this->LocalWidth();
1301
1302 T* thisBuf = this->Buffer();
1303 const Int ldim = this->LDim();
1304 const T* ABuf = A.LockedBuffer();
1305 const Int ALDim = A.LDim();
1306 if( rowAlign % rowStridePart == rowAlignA )
1307 {
1308 const Int rowShift = this->RowShift();
1309 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1310 ELEM_PARALLEL_FOR
1311 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1312 {
1313 T* thisCol = &thisBuf[jLoc*ldim];
1314 const T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
1315 MemCopy( thisCol, ACol, height );
1316 }
1317 }
1318 else
1319 {
1320 #ifdef ELEM_UNALIGNED_WARNINGS
1321 if( this->Grid().Rank() == 0 )
1322 std::cerr << "Unaligned PartialRowFilterFrom" << std::endl;
1323 #endif
1324 const Int rowRankPart = this->PartialRowRank();
1325 const Int rowRankUnion = this->PartialUnionRowRank();
1326 const Int rowShiftA = A.RowShift();
1327
1328 // Realign
1329 // -------
1330 const Int sendRowRankPart =
1331 (rowRankPart+rowStridePart+(rowAlign%rowStridePart)-rowAlignA) %
1332 rowStridePart;
1333 const Int recvRowRankPart =
1334 (rowRankPart+rowStridePart+rowAlignA-(rowAlign%rowStridePart)) %
1335 rowStridePart;
1336 const Int sendRowRank = sendRowRankPart + rowStridePart*rowRankUnion;
1337 const Int sendRowShift = Shift( sendRowRank, rowAlign, rowStride );
1338 const Int sendRowOffset = (sendRowShift-rowShiftA) / rowStridePart;
1339 const Int localWidthSend = Length( width, sendRowShift, rowStride );
1340 const Int sendSize = height*localWidthSend;
1341 const Int recvSize = height*localWidth;
1342 T* buffer = this->auxMemory_.Require( sendSize+recvSize );
1343 T* sendBuf = &buffer[0];
1344 T* recvBuf = &buffer[sendSize];
1345 // Pack
1346 ELEM_PARALLEL_FOR
1347 for( Int jLoc=0; jLoc<localWidthSend; ++jLoc )
1348 {
1349 T* sendCol = &sendBuf[jLoc*height];
1350 const T* ACol = &ABuf[(sendRowOffset+jLoc*rowStrideUnion)*ALDim];
1351 MemCopy( sendCol, ACol, height );
1352 }
1353 // Change the column alignment
1354 mpi::SendRecv
1355 ( sendBuf, sendSize, sendRowRankPart,
1356 recvBuf, recvSize, recvRowRankPart, this->PartialRowComm() );
1357
1358 // Unpack
1359 // ------
1360 ELEM_PARALLEL_FOR
1361 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1362 {
1363 const T* recvCol = &recvBuf[jLoc*height];
1364 T* thisCol = &thisBuf[jLoc*ldim];
1365 MemCopy( thisCol, recvCol, height );
1366 }
1367 this->auxMemory_.Release();
1368 }
1369 }
1370
1371 template<typename T,Dist U,Dist V>
1372 void
PartialColAllToAllFrom(const DistMatrix<T,UPart,VScat> & A)1373 GeneralDistMatrix<T,U,V>::PartialColAllToAllFrom
1374 ( const DistMatrix<T,UPart,VScat>& A )
1375 {
1376 DEBUG_ONLY(
1377 CallStackEntry cse("GDM::PartialColAllToAllFrom");
1378 this->AssertSameGrid( A.Grid() );
1379 )
1380 const Int height = A.Height();
1381 const Int width = A.Width();
1382 this->AlignColsAndResize( A.ColAlign(), height, width, false, false );
1383 if( !this->Participating() )
1384 return;
1385
1386 const Int colAlign = this->ColAlign();
1387 const Int colAlignA = A.ColAlign();
1388 const Int rowAlignA = A.RowAlign();
1389
1390 const Int colStride = this->ColStride();
1391 const Int colStridePart = this->PartialColStride();
1392 const Int colStrideUnion = this->PartialUnionColStride();
1393 const Int colRankPart = this->PartialColRank();
1394
1395 const Int colShiftA = A.ColShift();
1396
1397 const Int thisLocalHeight = this->LocalHeight();
1398 const Int localWidthA = A.LocalWidth();
1399 const Int maxLocalHeight = MaxLength(height,colStride);
1400 const Int maxLocalWidth = MaxLength(width,colStrideUnion);
1401 const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1402
1403 T* thisBuf = this->Buffer();
1404 const Int ldim = this->LDim();
1405 const T* ABuf = A.LockedBuffer();
1406 const Int ALDim = A.LDim();
1407
1408 T* buffer = this->auxMemory_.Require( 2*colStrideUnion*portionSize );
1409 T* firstBuf = &buffer[0];
1410 T* secondBuf = &buffer[colStrideUnion*portionSize];
1411
1412 if( colAlign % colStridePart == colAlignA )
1413 {
1414 // Pack
1415 ELEM_OUTER_PARALLEL_FOR
1416 for( Int k=0; k<colStrideUnion; ++k )
1417 {
1418 T* data = &firstBuf[k*portionSize];
1419 const Int colRank = colRankPart + k*colStridePart;
1420 const Int colShift = Shift_( colRank, colAlign, colStride );
1421 const Int colOffset = (colShift-colShiftA) / colStridePart;
1422 const Int localHeight = Length_( height, colShift, colStride );
1423 ELEM_INNER_PARALLEL_FOR
1424 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1425 {
1426 T* dataCol = &data[jLoc*localHeight];
1427 const T* ACol = &ABuf[colOffset+jLoc*ALDim];
1428 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1429 dataCol[iLoc] = ACol[iLoc*colStrideUnion];
1430 }
1431 }
1432
1433 // Simultaneously Scatter in columns and Gather in rows
1434 mpi::AllToAll
1435 ( firstBuf, portionSize,
1436 secondBuf, portionSize, this->PartialUnionColComm() );
1437
1438 // Unpack
1439 ELEM_OUTER_PARALLEL_FOR
1440 for( Int k=0; k<colStrideUnion; ++k )
1441 {
1442 const T* data = &secondBuf[k*portionSize];
1443 const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1444 const Int localWidth = Length_( width, rowShift, colStrideUnion );
1445 ELEM_INNER_PARALLEL_FOR
1446 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1447 {
1448 const T* dataCol = &data[jLoc*thisLocalHeight];
1449 T* thisCol = &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim];
1450 MemCopy( thisCol, dataCol, thisLocalHeight );
1451 }
1452 }
1453 }
1454 else
1455 {
1456 #ifdef ELEM_UNALIGNED_WARNINGS
1457 if( this->Grid().Rank() == 0 )
1458 std::cerr << "Unaligned PartialColAllToAllFrom" << std::endl;
1459 #endif
1460 const Int sendColRankPart =
1461 (colRankPart+colStridePart+(colAlign%colStridePart)-colAlignA) %
1462 colStridePart;
1463 const Int recvColRankPart =
1464 (colRankPart+colStridePart+colAlignA-(colAlign%colStridePart)) %
1465 colStridePart;
1466
1467 // Pack
1468 ELEM_OUTER_PARALLEL_FOR
1469 for( Int k=0; k<colStrideUnion; ++k )
1470 {
1471 T* data = &secondBuf[k*portionSize];
1472 const Int colRank = sendColRankPart + k*colStridePart;
1473 const Int colShift = Shift_( colRank, colAlign, colStride );
1474 const Int colOffset = (colShift-colShiftA) / colStridePart;
1475 const Int localHeight = Length_( height, colShift, colStride );
1476 ELEM_INNER_PARALLEL_FOR
1477 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1478 {
1479 T* dataCol = &data[jLoc*localHeight];
1480 const T* ACol = &ABuf[colOffset+jLoc*ALDim];
1481 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1482 dataCol[iLoc] = ACol[iLoc*colStrideUnion];
1483 }
1484 }
1485
1486 // Simultaneously Scatter in columns and Gather in rows
1487 mpi::AllToAll
1488 ( secondBuf, portionSize,
1489 firstBuf, portionSize, this->PartialUnionColComm() );
1490
1491 // Realign the result
1492 mpi::SendRecv
1493 ( firstBuf, colStrideUnion*portionSize, sendColRankPart,
1494 secondBuf, colStrideUnion*portionSize, recvColRankPart,
1495 this->PartialColComm() );
1496
1497 // Unpack
1498 ELEM_OUTER_PARALLEL_FOR
1499 for( Int k=0; k<colStrideUnion; ++k )
1500 {
1501 const T* data = &secondBuf[k*portionSize];
1502 const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1503 const Int localWidth = Length_( width, rowShift, colStrideUnion );
1504 ELEM_INNER_PARALLEL_FOR
1505 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1506 {
1507 const T* dataCol = &data[jLoc*thisLocalHeight];
1508 T* thisCol = &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim];
1509 MemCopy( thisCol, dataCol, thisLocalHeight );
1510 }
1511 }
1512 }
1513 this->auxMemory_.Release();
1514 }
1515
1516 template<typename T,Dist U,Dist V>
1517 void
PartialRowAllToAllFrom(const DistMatrix<T,UScat,VPart> & A)1518 GeneralDistMatrix<T,U,V>::PartialRowAllToAllFrom
1519 ( const DistMatrix<T,UScat,VPart>& A )
1520 {
1521 DEBUG_ONLY(
1522 CallStackEntry cse("GDM::PartialRowAllToAllFrom");
1523 this->AssertSameGrid( A.Grid() );
1524 )
1525 const Int height = A.Height();
1526 const Int width = A.Width();
1527 this->AlignRowsAndResize( A.RowAlign(), height, width, false, false );
1528 if( !this->Participating() )
1529 return;
1530
1531 const Int rowAlign = this->RowAlign();
1532 const Int rowAlignA = A.RowAlign();
1533 const Int colAlignA = A.ColAlign();
1534
1535 const Int rowStride = this->RowStride();
1536 const Int rowStridePart = this->PartialRowStride();
1537 const Int rowStrideUnion = this->PartialUnionRowStride();
1538 const Int rowRankPart = this->PartialRowRank();
1539
1540 const Int rowShiftA = A.RowShift();
1541
1542 const Int thisLocalWidth = this->LocalWidth();
1543 const Int localHeightA = A.LocalHeight();
1544 const Int maxLocalHeight = MaxLength(height,rowStrideUnion);
1545 const Int maxLocalWidth = MaxLength(width,rowStride);
1546 const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1547
1548 T* thisBuf = this->Buffer();
1549 const Int ldim = this->LDim();
1550 const T* ABuf = A.LockedBuffer();
1551 const Int ALDim = A.LDim();
1552
1553 T* buffer = this->auxMemory_.Require( 2*rowStrideUnion*portionSize );
1554 T* firstBuf = &buffer[0];
1555 T* secondBuf = &buffer[rowStrideUnion*portionSize];
1556
1557 if( rowAlign % rowStridePart == rowAlignA )
1558 {
1559 // Pack
1560 ELEM_OUTER_PARALLEL_FOR
1561 for( Int k=0; k<rowStrideUnion; ++k )
1562 {
1563 T* data = &firstBuf[k*portionSize];
1564 const Int rowRank = rowRankPart + k*rowStridePart;
1565 const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1566 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1567 const Int localWidth = Length_( width, rowShift, rowStride );
1568 ELEM_INNER_PARALLEL_FOR
1569 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1570 {
1571 T* dataCol = &data[jLoc*localHeightA];
1572 const T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
1573 MemCopy( dataCol, ACol, localHeightA );
1574 }
1575 }
1576
1577 // Simultaneously Scatter in rows and Gather in columns
1578 mpi::AllToAll
1579 ( firstBuf, portionSize,
1580 secondBuf, portionSize, this->PartialUnionRowComm() );
1581
1582 // Unpack
1583 ELEM_OUTER_PARALLEL_FOR
1584 for( Int k=0; k<rowStrideUnion; ++k )
1585 {
1586 const T* data = &secondBuf[k*portionSize];
1587 const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1588 const Int localHeight = Length_( height, colShift, rowStrideUnion );
1589 ELEM_INNER_PARALLEL_FOR
1590 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1591 {
1592 const T* dataCol = &data[jLoc*localHeight];
1593 T* thisCol = &thisBuf[colShift+jLoc*ldim];
1594 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1595 thisCol[iLoc*rowStrideUnion] = dataCol[iLoc];
1596 }
1597 }
1598 }
1599 else
1600 {
1601 #ifdef ELEM_UNALIGNED_WARNINGS
1602 if( this->Grid().Rank() == 0 )
1603 std::cerr << "Unaligned PartialRowAllToAllFrom" << std::endl;
1604 #endif
1605 const Int sendRowRankPart =
1606 (rowRankPart+rowStridePart+(rowAlign%rowStridePart)-rowAlignA) %
1607 rowStridePart;
1608 const Int recvRowRankPart =
1609 (rowRankPart+rowStridePart+rowAlignA-(rowAlign%rowStridePart)) %
1610 rowStridePart;
1611
1612 // Pack
1613 ELEM_OUTER_PARALLEL_FOR
1614 for( Int k=0; k<rowStrideUnion; ++k )
1615 {
1616 T* data = &secondBuf[k*portionSize];
1617 const Int rowRank = sendRowRankPart + k*rowStridePart;
1618 const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1619 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1620 const Int localWidth = Length_( width, rowShift, rowStride );
1621 ELEM_INNER_PARALLEL_FOR
1622 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1623 {
1624 T* dataCol = &data[jLoc*localHeightA];
1625 const T* ACol = &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim];
1626 MemCopy( dataCol, ACol, localHeightA );
1627 }
1628 }
1629
1630 // Simultaneously Scatter in rows and Gather in columns
1631 mpi::AllToAll
1632 ( secondBuf, portionSize,
1633 firstBuf, portionSize, this->PartialUnionRowComm() );
1634
1635 // Realign the result
1636 mpi::SendRecv
1637 ( firstBuf, rowStrideUnion*portionSize, sendRowRankPart,
1638 secondBuf, rowStrideUnion*portionSize, recvRowRankPart,
1639 this->PartialRowComm() );
1640
1641 // Unpack
1642 ELEM_OUTER_PARALLEL_FOR
1643 for( Int k=0; k<rowStrideUnion; ++k )
1644 {
1645 const T* data = &secondBuf[k*portionSize];
1646 const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1647 const Int localHeight = Length_( height, colShift, rowStrideUnion );
1648 ELEM_INNER_PARALLEL_FOR
1649 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1650 {
1651 const T* dataCol = &data[jLoc*localHeight];
1652 T* thisCol = &thisBuf[colShift+jLoc*ldim];
1653 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1654 thisCol[iLoc*rowStrideUnion] = dataCol[iLoc];
1655 }
1656 }
1657 }
1658 this->auxMemory_.Release();
1659 }
1660
1661 template<typename T,Dist U,Dist V>
1662 void
PartialColAllToAll(DistMatrix<T,UPart,VScat> & A) const1663 GeneralDistMatrix<T,U,V>::PartialColAllToAll
1664 ( DistMatrix<T,UPart,VScat>& A ) const
1665 {
1666 DEBUG_ONLY(
1667 CallStackEntry cse("GDM::PartialColAllToAll");
1668 this->AssertSameGrid( A.Grid() );
1669 )
1670 const Int height = this->Height();
1671 const Int width = this->Width();
1672 A.AlignColsAndResize
1673 ( this->ColAlign()%A.ColStride(), height, width, false, false );
1674 if( !A.Participating() )
1675 return;
1676
1677 const Int colAlign = this->ColAlign();
1678 const Int colAlignA = A.ColAlign();
1679 const Int rowAlignA = A.RowAlign();
1680
1681 const Int colStride = this->ColStride();
1682 const Int colStridePart = this->PartialColStride();
1683 const Int colStrideUnion = this->PartialUnionColStride();
1684 const Int colRankPart = this->PartialColRank();
1685
1686 const Int colShiftA = A.ColShift();
1687
1688 const Int thisLocalHeight = this->LocalHeight();
1689 const Int localWidthA = A.LocalWidth();
1690 const Int maxLocalHeight = MaxLength(height,colStride);
1691 const Int maxLocalWidth = MaxLength(width,colStrideUnion);
1692 const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1693
1694 const T* thisBuf = this->LockedBuffer();
1695 const Int ldim = this->LDim();
1696 T* ABuf = A.Buffer();
1697 const Int ALDim = A.LDim();
1698
1699 T* buffer = A.auxMemory_.Require( 2*colStrideUnion*portionSize );
1700 T* firstBuf = &buffer[0];
1701 T* secondBuf = &buffer[colStrideUnion*portionSize];
1702
1703 if( colAlignA == colAlign % colStridePart )
1704 {
1705 // Pack
1706 ELEM_OUTER_PARALLEL_FOR
1707 for( Int k=0; k<colStrideUnion; ++k )
1708 {
1709 T* data = &firstBuf[k*portionSize];
1710 const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1711 const Int localWidth = Length_( width, rowShift, colStrideUnion );
1712 ELEM_INNER_PARALLEL_FOR
1713 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1714 MemCopy
1715 ( &data[jLoc*thisLocalHeight],
1716 &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim],
1717 thisLocalHeight );
1718 }
1719
1720 // Simultaneously Gather in columns and Scatter in rows
1721 mpi::AllToAll
1722 ( firstBuf, portionSize,
1723 secondBuf, portionSize, this->PartialUnionColComm() );
1724
1725 // Unpack
1726 ELEM_OUTER_PARALLEL_FOR
1727 for( Int k=0; k<colStrideUnion; ++k )
1728 {
1729 const T* data = &secondBuf[k*portionSize];
1730 const Int colRank = colRankPart + k*colStridePart;
1731 const Int colShift = Shift_( colRank, colAlign, colStride );
1732 const Int colOffset = (colShift-colShiftA) / colStridePart;
1733 const Int localHeight = Length_( height, colShift, colStride );
1734 ELEM_INNER_PARALLEL_FOR
1735 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1736 {
1737 T* ACol = &ABuf[colOffset+jLoc*ALDim];
1738 const T* dataCol = &data[jLoc*localHeight];
1739 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1740 ACol[iLoc*colStrideUnion] = dataCol[iLoc];
1741 }
1742 }
1743 }
1744 else
1745 {
1746 #ifdef ELEM_UNALIGNED_WARNINGS
1747 if( this->Grid().Rank() == 0 )
1748 std::cerr << "Unaligned PartialColAllToAll" << std::endl;
1749 #endif
1750 const Int colAlignDiff = colAlignA - (colAlign%colStridePart);
1751 const Int sendColRankPart =
1752 (colRankPart+colStridePart+colAlignDiff) % colStridePart;
1753 const Int recvColRankPart =
1754 (colRankPart+colStridePart-colAlignDiff) % colStridePart;
1755
1756 // Pack
1757 ELEM_OUTER_PARALLEL_FOR
1758 for( Int k=0; k<colStrideUnion; ++k )
1759 {
1760 T* data = &secondBuf[k*portionSize];
1761 const Int rowShift = Shift_( k, rowAlignA, colStrideUnion );
1762 const Int localWidth = Length_( width, rowShift, colStrideUnion );
1763 ELEM_INNER_PARALLEL_FOR
1764 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1765 MemCopy
1766 ( &data[jLoc*thisLocalHeight],
1767 &thisBuf[(rowShift+jLoc*colStrideUnion)*ldim],
1768 thisLocalHeight );
1769 }
1770
1771 // Realign the input
1772 mpi::SendRecv
1773 ( secondBuf, colStrideUnion*portionSize, sendColRankPart,
1774 firstBuf, colStrideUnion*portionSize, recvColRankPart,
1775 this->PartialColComm() );
1776
1777 // Simultaneously Scatter in columns and Gather in rows
1778 mpi::AllToAll
1779 ( firstBuf, portionSize,
1780 secondBuf, portionSize, this->PartialUnionColComm() );
1781
1782 // Unpack
1783 ELEM_OUTER_PARALLEL_FOR
1784 for( Int k=0; k<colStrideUnion; ++k )
1785 {
1786 const T* data = &secondBuf[k*portionSize];
1787 const Int colRank = recvColRankPart + k*colStridePart;
1788 const Int colShift = Shift_( colRank, colAlign, colStride );
1789 const Int colOffset = (colShift-colShiftA) / colStridePart;
1790 const Int localHeight = Length_( height, colShift, colStride );
1791 ELEM_INNER_PARALLEL_FOR
1792 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
1793 {
1794 T* ACol = &ABuf[colOffset+jLoc*ALDim];
1795 const T* dataCol = &data[jLoc*localHeight];
1796 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1797 ACol[iLoc*colStrideUnion] = dataCol[iLoc];
1798 }
1799 }
1800 }
1801 A.auxMemory_.Release();
1802 }
1803
1804 template<typename T,Dist U,Dist V>
1805 void
PartialRowAllToAll(DistMatrix<T,UScat,VPart> & A) const1806 GeneralDistMatrix<T,U,V>::PartialRowAllToAll
1807 ( DistMatrix<T,UScat,VPart>& A ) const
1808 {
1809 DEBUG_ONLY(
1810 CallStackEntry cse("GDM::PartialRowAllToAll");
1811 this->AssertSameGrid( A.Grid() );
1812 )
1813 const Int height = this->Height();
1814 const Int width = this->Width();
1815 A.AlignRowsAndResize
1816 ( this->RowAlign()%A.RowStride(), height, width, false, false );
1817 if( !A.Participating() )
1818 return;
1819
1820 const Int colAlignA = A.ColAlign();
1821 const Int rowAlign = this->RowAlign();
1822 const Int rowAlignA = A.RowAlign();
1823
1824 const Int rowStride = this->RowStride();
1825 const Int rowStridePart = this->PartialRowStride();
1826 const Int rowStrideUnion = this->PartialUnionRowStride();
1827 const Int rowRankPart = this->PartialRowRank();
1828
1829 const Int rowShiftA = A.RowShift();
1830
1831 const Int thisLocalWidth = this->LocalWidth();
1832 const Int localHeightA = A.LocalHeight();
1833 const Int maxLocalWidth = MaxLength(width,rowStride);
1834 const Int maxLocalHeight = MaxLength(height,rowStrideUnion);
1835 const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
1836
1837 const T* thisBuf = this->LockedBuffer();
1838 const Int ldim = this->LDim();
1839 T* ABuf = A.Buffer();
1840 const Int ALDim = A.LDim();
1841
1842 T* buffer = A.auxMemory_.Require( 2*rowStrideUnion*portionSize );
1843 T* firstBuf = &buffer[0];
1844 T* secondBuf = &buffer[rowStrideUnion*portionSize];
1845
1846 if( rowAlignA == rowAlign % rowStridePart )
1847 {
1848 // Pack
1849 ELEM_OUTER_PARALLEL_FOR
1850 for( Int k=0; k<rowStrideUnion; ++k )
1851 {
1852 T* data = &firstBuf[k*portionSize];
1853 const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1854 const Int localHeight = Length_( height, colShift, rowStrideUnion );
1855 ELEM_INNER_PARALLEL_FOR
1856 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1857 {
1858 T* dataCol = &data[jLoc*localHeight];
1859 const T* thisCol = &thisBuf[colShift+jLoc*ldim];
1860 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1861 dataCol[iLoc] = thisCol[iLoc*rowStrideUnion];
1862 }
1863 }
1864
1865 // Simultaneously Gather in rows and Scatter in columns
1866 mpi::AllToAll
1867 ( firstBuf, portionSize,
1868 secondBuf, portionSize, this->PartialUnionRowComm() );
1869
1870 // Unpack
1871 ELEM_OUTER_PARALLEL_FOR
1872 for( Int k=0; k<rowStrideUnion; ++k )
1873 {
1874 const T* data = &secondBuf[k*portionSize];
1875 const Int rowRank = rowRankPart + k*rowStridePart;
1876 const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1877 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1878 const Int localWidth = Length_( width, rowShift, rowStride );
1879 ELEM_INNER_PARALLEL_FOR
1880 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1881 MemCopy
1882 ( &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim],
1883 &data[jLoc*localHeightA], localHeightA );
1884 }
1885 }
1886 else
1887 {
1888 #ifdef ELEM_UNALIGNED_WARNINGS
1889 if( this->Grid().Rank() == 0 )
1890 std::cerr << "Unaligned PartialRowAllToAll" << std::endl;
1891 #endif
1892 const Int rowAlignDiff = rowAlignA - (rowAlign%rowStridePart);
1893 const Int sendRowRankPart =
1894 (rowRankPart+rowStridePart+rowAlignDiff) % rowStridePart;
1895 const Int recvRowRankPart =
1896 (rowRankPart+rowStridePart-rowAlignDiff) % rowStridePart;
1897
1898 // Pack
1899 ELEM_OUTER_PARALLEL_FOR
1900 for( Int k=0; k<rowStrideUnion; ++k )
1901 {
1902 T* data = &secondBuf[k*portionSize];
1903 const Int colShift = Shift_( k, colAlignA, rowStrideUnion );
1904 const Int localHeight = Length_( height, colShift, rowStrideUnion );
1905 ELEM_INNER_PARALLEL_FOR
1906 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
1907 {
1908 T* dataCol = &data[jLoc*localHeight];
1909 const T* sourceCol = &thisBuf[colShift+jLoc*ldim];
1910 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
1911 dataCol[iLoc] = sourceCol[iLoc*rowStrideUnion];
1912 }
1913 }
1914
1915 // Realign the input
1916 mpi::SendRecv
1917 ( secondBuf, rowStrideUnion*portionSize, sendRowRankPart,
1918 firstBuf, rowStrideUnion*portionSize, recvRowRankPart,
1919 this->PartialRowComm() );
1920
1921 // Simultaneously Scatter in rows and Gather in columns
1922 mpi::AllToAll
1923 ( firstBuf, portionSize,
1924 secondBuf, portionSize, this->PartialUnionRowComm() );
1925
1926 // Unpack
1927 ELEM_OUTER_PARALLEL_FOR
1928 for( Int k=0; k<rowStrideUnion; ++k )
1929 {
1930 const T* data = &secondBuf[k*portionSize];
1931 const Int rowRank = recvRowRankPart + k*rowStridePart;
1932 const Int rowShift = Shift_( rowRank, rowAlign, rowStride );
1933 const Int rowOffset = (rowShift-rowShiftA) / rowStridePart;
1934 const Int localWidth = Length_( width, rowShift, rowStride );
1935 ELEM_INNER_PARALLEL_FOR
1936 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
1937 MemCopy
1938 ( &ABuf[(rowOffset+jLoc*rowStrideUnion)*ALDim],
1939 &data[jLoc*localHeightA], localHeightA );
1940 }
1941 }
1942 A.auxMemory_.Release();
1943 }
1944
1945 template<typename T,Dist U,Dist V>
1946 void
RowSumScatterFrom(const DistMatrix<T,U,VGath> & A)1947 GeneralDistMatrix<T,U,V>::RowSumScatterFrom( const DistMatrix<T,U,VGath>& A )
1948 {
1949 DEBUG_ONLY(
1950 CallStackEntry cse("GDM::RowSumScatterFrom");
1951 this->AssertSameGrid( A.Grid() );
1952 )
1953 this->AlignColsAndResize
1954 ( A.ColAlign(), A.Height(), A.Width(), false, false );
1955 // NOTE: This will be *slightly* slower than necessary due to the result
1956 // of the MPI operations being added rather than just copied
1957 Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
1958 this->RowSumScatterUpdate( T(1), A );
1959 }
1960
1961 template<typename T,Dist U,Dist V>
1962 void
ColSumScatterFrom(const DistMatrix<T,UGath,V> & A)1963 GeneralDistMatrix<T,U,V>::ColSumScatterFrom( const DistMatrix<T,UGath,V>& A )
1964 {
1965 DEBUG_ONLY(
1966 CallStackEntry cse("GDM::ColSumScatterFrom");
1967 this->AssertSameGrid( A.Grid() );
1968 )
1969 this->AlignRowsAndResize
1970 ( A.RowAlign(), A.Height(), A.Width(), false, false );
1971 // NOTE: This will be *slightly* slower than necessary due to the result
1972 // of the MPI operations being added rather than just copied
1973 Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
1974 this->ColSumScatterUpdate( T(1), A );
1975 }
1976
1977 template<typename T,Dist U,Dist V>
1978 void
SumScatterFrom(const DistMatrix<T,UGath,VGath> & A)1979 GeneralDistMatrix<T,U,V>::SumScatterFrom( const DistMatrix<T,UGath,VGath>& A )
1980 {
1981 DEBUG_ONLY(
1982 CallStackEntry cse("GDM::SumScatterFrom");
1983 this->AssertSameGrid( A.Grid() );
1984 )
1985 this->Resize( A.Height(), A.Width() );
1986 // NOTE: This will be *slightly* slower than necessary due to the result
1987 // of the MPI operations being added rather than just copied
1988 Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
1989 this->SumScatterUpdate( T(1), A );
1990 }
1991
1992 template<typename T,Dist U,Dist V>
1993 void
PartialRowSumScatterFrom(const DistMatrix<T,U,VPart> & A)1994 GeneralDistMatrix<T,U,V>::PartialRowSumScatterFrom
1995 ( const DistMatrix<T,U,VPart>& A )
1996 {
1997 DEBUG_ONLY(
1998 CallStackEntry cse("GDM::PartialRowSumScatterFrom");
1999 this->AssertSameGrid( A.Grid() );
2000 )
2001 this->AlignAndResize
2002 ( A.ColAlign(), A.RowAlign(), A.Height(), A.Width(), false, false );
2003 // NOTE: This will be *slightly* slower than necessary due to the result
2004 // of the MPI operations being added rather than just copied
2005 Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
2006 this->PartialRowSumScatterUpdate( T(1), A );
2007 }
2008
2009 template<typename T,Dist U,Dist V>
2010 void
PartialColSumScatterFrom(const DistMatrix<T,UPart,V> & A)2011 GeneralDistMatrix<T,U,V>::PartialColSumScatterFrom
2012 ( const DistMatrix<T,UPart,V>& A )
2013 {
2014 DEBUG_ONLY(
2015 CallStackEntry cse("GDM::PartialColSumScatterFrom");
2016 this->AssertSameGrid( A.Grid() );
2017 )
2018 this->AlignAndResize
2019 ( A.ColAlign(), A.RowAlign(), A.Height(), A.Width(), false, false );
2020 // NOTE: This will be *slightly* slower than necessary due to the result
2021 // of the MPI operations being added rather than just copied
2022 Zeros( this->Matrix(), this->LocalHeight(), this->LocalWidth() );
2023 this->PartialColSumScatterUpdate( T(1), A );
2024 }
2025
2026 template<typename T,Dist U,Dist V>
2027 void
RowSumScatterUpdate(T alpha,const DistMatrix<T,U,VGath> & A)2028 GeneralDistMatrix<T,U,V>::RowSumScatterUpdate
2029 ( T alpha, const DistMatrix<T,U,VGath>& A )
2030 {
2031 DEBUG_ONLY(
2032 CallStackEntry cse("GDM::RowSumScatterUpdate");
2033 this->AssertNotLocked();
2034 this->AssertSameGrid( A.Grid() );
2035 this->AssertSameSize( A.Height(), A.Width() );
2036 )
2037 if( !this->Participating() )
2038 return;
2039
2040 if( this->ColAlign() == A.ColAlign() )
2041 {
2042 if( this->Width() == 1 )
2043 {
2044 const Int rowAlign = this->RowAlign();
2045 const Int rowRank = this->RowRank();
2046
2047 const Int localHeight = this->LocalHeight();
2048 const Int portionSize = mpi::Pad( localHeight );
2049 T* buffer = this->auxMemory_.Require( 2*portionSize );
2050 T* sendBuf = &buffer[0];
2051 T* recvBuf = &buffer[portionSize];
2052
2053 // Pack
2054 const T* ACol = A.LockedBuffer();
2055 MemCopy( sendBuf, ACol, localHeight );
2056
2057 // Reduce to rowAlign
2058 mpi::Reduce
2059 ( sendBuf, recvBuf, portionSize, rowAlign, this->RowComm() );
2060
2061 if( rowRank == rowAlign )
2062 {
2063 T* thisCol = this->Buffer();
2064 ELEM_FMA_PARALLEL_FOR
2065 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2066 thisCol[iLoc] += alpha*recvBuf[iLoc];
2067 }
2068
2069 this->auxMemory_.Release();
2070 }
2071 else
2072 {
2073 const Int rowStride = this->RowStride();
2074 const Int rowAlign = this->RowAlign();
2075
2076 const Int width = this->Width();
2077 const Int localHeight = this->LocalHeight();
2078 const Int localWidth = this->LocalWidth();
2079 const Int maxLocalWidth = MaxLength(width,rowStride);
2080
2081 const Int portionSize = mpi::Pad( localHeight*maxLocalWidth );
2082 const Int sendSize = rowStride*portionSize;
2083
2084 // Pack
2085 const Int ALDim = A.LDim();
2086 const T* ABuffer = A.LockedBuffer();
2087 T* buffer = this->auxMemory_.Require( sendSize );
2088 ELEM_OUTER_PARALLEL_FOR
2089 for( Int k=0; k<rowStride; ++k )
2090 {
2091 T* data = &buffer[k*portionSize];
2092 const Int thisRowShift = Shift_( k, rowAlign, rowStride );
2093 const Int thisLocalWidth =
2094 Length_(width,thisRowShift,rowStride);
2095 ELEM_INNER_PARALLEL_FOR
2096 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2097 {
2098 const T* ACol =
2099 &ABuffer[(thisRowShift+jLoc*rowStride)*ALDim];
2100 T* dataCol = &data[jLoc*localHeight];
2101 MemCopy( dataCol, ACol, localHeight );
2102 }
2103 }
2104 // Communicate
2105 mpi::ReduceScatter( buffer, portionSize, this->RowComm() );
2106
2107 // Update with our received data
2108 T* thisBuffer = this->Buffer();
2109 const Int thisLDim = this->LDim();
2110 ELEM_PARALLEL_FOR
2111 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2112 {
2113 const T* bufferCol = &buffer[jLoc*localHeight];
2114 T* thisCol = &thisBuffer[jLoc*thisLDim];
2115 blas::Axpy( localHeight, alpha, bufferCol, 1, thisCol, 1 );
2116 }
2117 this->auxMemory_.Release();
2118 }
2119 }
2120 else
2121 {
2122 #ifdef ELEM_UNALIGNED_WARNINGS
2123 if( this->Grid().Rank() == 0 )
2124 std::cerr << "Unaligned RowSumScatterUpdate" << std::endl;
2125 #endif
2126 if( this->Width() == 1 )
2127 {
2128 const Int colStride = this->ColStride();
2129 const Int rowAlign = this->RowAlign();
2130 const Int colRank = this->ColRank();
2131 const Int rowRank = this->RowRank();
2132
2133 const Int height = this->Height();
2134 const Int localHeight = this->LocalHeight();
2135 const Int localHeightA = A.LocalHeight();
2136 const Int maxLocalHeight = MaxLength(height,colStride);
2137 const Int portionSize = mpi::Pad( maxLocalHeight );
2138
2139 const Int colAlign = this->ColAlign();
2140 const Int colAlignA = A.ColAlign();
2141 const Int sendRow =
2142 (colRank+colStride+colAlign-colAlignA) % colStride;
2143 const Int recvRow =
2144 (colRank+colStride+colAlignA-colAlign) % colStride;
2145
2146 T* buffer = this->auxMemory_.Require( 2*portionSize );
2147 T* sendBuf = &buffer[0];
2148 T* recvBuf = &buffer[portionSize];
2149
2150 // Pack
2151 const T* ACol = A.LockedBuffer();
2152 MemCopy( sendBuf, ACol, localHeightA );
2153
2154 // Reduce to rowAlign
2155 mpi::Reduce
2156 ( sendBuf, recvBuf, portionSize, rowAlign, this->RowComm() );
2157
2158 if( rowRank == rowAlign )
2159 {
2160 // Perform the realignment
2161 mpi::SendRecv
2162 ( recvBuf, portionSize, sendRow,
2163 sendBuf, portionSize, recvRow, this->ColComm() );
2164
2165 T* thisCol = this->Buffer();
2166 ELEM_FMA_PARALLEL_FOR
2167 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2168 thisCol[iLoc] += alpha*sendBuf[iLoc];
2169 }
2170 this->auxMemory_.Release();
2171 }
2172 else
2173 {
2174 const Int colStride = this->ColStride();
2175 const Int rowStride = this->RowStride();
2176 const Int colRank = this->ColRank();
2177
2178 const Int colAlign = this->ColAlign();
2179 const Int rowAlign = this->RowAlign();
2180 const Int colAlignA = A.ColAlign();
2181 const Int sendRow =
2182 (colRank+colStride+colAlign-colAlignA) % colStride;
2183 const Int recvRow =
2184 (colRank+colStride+colAlignA-colAlign) % colStride;
2185
2186 const Int width = this->Width();
2187 const Int localHeight = this->LocalHeight();
2188 const Int localWidth = this->LocalWidth();
2189 const Int localHeightA = A.LocalHeight();
2190 const Int maxLocalWidth = MaxLength(width,rowStride);
2191
2192 const Int recvSize_RS = mpi::Pad( localHeightA*maxLocalWidth );
2193 const Int sendSize_RS = rowStride * recvSize_RS;
2194 const Int recvSize_SR = localHeight * localWidth;
2195
2196 T* buffer = this->auxMemory_.Require
2197 ( recvSize_RS + std::max(sendSize_RS,recvSize_SR) );
2198 T* firstBuf = &buffer[0];
2199 T* secondBuf = &buffer[recvSize_RS];
2200
2201 // Pack
2202 const T* ABuffer = A.LockedBuffer();
2203 const Int ALDim = A.LDim();
2204 ELEM_OUTER_PARALLEL_FOR
2205 for( Int k=0; k<rowStride; ++k )
2206 {
2207 T* data = &secondBuf[k*recvSize_RS];
2208 const Int thisRowShift = Shift_( k, rowAlign, rowStride );
2209 const Int thisLocalWidth =
2210 Length_(width,thisRowShift,rowStride);
2211 ELEM_INNER_PARALLEL_FOR
2212 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2213 {
2214 const T* ACol =
2215 &ABuffer[(thisRowShift+jLoc*rowStride)*ALDim];
2216 T* dataCol = &data[jLoc*localHeightA];
2217 MemCopy( dataCol, ACol, localHeightA );
2218 }
2219 }
2220
2221 // Reduce-scatter over each process row
2222 mpi::ReduceScatter
2223 ( secondBuf, firstBuf, recvSize_RS, this->RowComm() );
2224
2225 // Trade reduced data with the appropriate process row
2226 mpi::SendRecv
2227 ( firstBuf, localHeightA*localWidth, sendRow,
2228 secondBuf, localHeight*localWidth, recvRow, this->ColComm() );
2229
2230 // Update with our received data
2231 T* thisBuffer = this->Buffer();
2232 const Int thisLDim = this->LDim();
2233 ELEM_FMA_PARALLEL_FOR
2234 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2235 {
2236 const T* secondBufCol = &secondBuf[jLoc*localHeight];
2237 T* thisCol = &thisBuffer[jLoc*thisLDim];
2238 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2239 thisCol[iLoc] += alpha*secondBufCol[iLoc];
2240 }
2241 this->auxMemory_.Release();
2242 }
2243 }
2244 }
2245
2246 template<typename T,Dist U,Dist V>
2247 void
ColSumScatterUpdate(T alpha,const DistMatrix<T,UGath,V> & A)2248 GeneralDistMatrix<T,U,V>::ColSumScatterUpdate
2249 ( T alpha, const DistMatrix<T,UGath,V>& A )
2250 {
2251 DEBUG_ONLY(
2252 CallStackEntry cse("GDM::ColSumScatterUpdate");
2253 this->AssertNotLocked();
2254 this->AssertSameGrid( A.Grid() );
2255 this->AssertSameSize( A.Height(), A.Width() );
2256 )
2257 #ifdef ELEM_VECTOR_WARNINGS
2258 if( A.Width() == 1 && this->Grid().Rank() == 0 )
2259 {
2260 std::cerr <<
2261 "The vector version of ColSumScatterUpdate does not"
2262 " yet have a vector version implemented, but it would only "
2263 "require a modification of the vector version of RowSumScatterUpdate"
2264 << std::endl;
2265 }
2266 #endif
2267 #ifdef ELEM_CACHE_WARNINGS
2268 if( A.Width() != 1 && this->Grid().Rank() == 0 )
2269 {
2270 std::cerr <<
2271 "ColSumScatterUpdate potentially causes a large "
2272 "amount of cache-thrashing. If possible, avoid it by forming the "
2273 "(conjugate-)transpose of the [* ,V] matrix instead." << std::endl;
2274 }
2275 #endif
2276 if( !this->Participating() )
2277 return;
2278
2279 if( this->RowAlign() == A.RowAlign() )
2280 {
2281 const Int colStride = this->ColStride();
2282 const Int colAlign = this->ColAlign();
2283 const Int height = this->Height();
2284 const Int localHeight = this->LocalHeight();
2285 const Int localWidth = this->LocalWidth();
2286 const Int maxLocalHeight = MaxLength(height,colStride);
2287
2288 const Int recvSize = mpi::Pad( maxLocalHeight*localWidth );
2289 const Int sendSize = colStride*recvSize;
2290
2291 // Pack
2292 const T* ABuffer = A.LockedBuffer();
2293 const Int ALDim = A.LDim();
2294 T* buffer = this->auxMemory_.Require( sendSize );
2295 ELEM_OUTER_PARALLEL_FOR
2296 for( Int k=0; k<colStride; ++k )
2297 {
2298 T* data = &buffer[k*recvSize];
2299 const Int thisColShift = Shift_( k, colAlign, colStride );
2300 const Int thisLocalHeight = Length_(height,thisColShift,colStride);
2301 ELEM_INNER_PARALLEL_FOR
2302 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2303 {
2304 T* destCol = &data[jLoc*thisLocalHeight];
2305 const T* sourceCol = &ABuffer[thisColShift+jLoc*ALDim];
2306 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2307 destCol[iLoc] = sourceCol[iLoc*colStride];
2308 }
2309 }
2310
2311 // Communicate
2312 mpi::ReduceScatter( buffer, recvSize, this->ColComm() );
2313
2314 // Update with our received data
2315 T* thisBuffer = this->Buffer();
2316 const Int thisLDim = this->LDim();
2317 ELEM_FMA_PARALLEL_FOR
2318 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2319 {
2320 const T* bufferCol = &buffer[jLoc*localHeight];
2321 T* thisCol = &thisBuffer[jLoc*thisLDim];
2322 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2323 thisCol[iLoc] += alpha*bufferCol[iLoc];
2324 }
2325 this->auxMemory_.Release();
2326 }
2327 else
2328 {
2329 #ifdef ELEM_UNALIGNED_WARNINGS
2330 if( this->Grid().Rank() == 0 )
2331 std::cerr << "Unaligned ColSumScatterUpdate" << std::endl;
2332 #endif
2333 const Int colStride = this->ColStride();
2334 const Int rowStride = this->RowStride();
2335 const Int rowRank = this->RowRank();
2336
2337 const Int colAlign = this->ColAlign();
2338 const Int rowAlign = this->RowAlign();
2339 const Int rowAlignA = A.RowAlign();
2340 const Int sendCol = (rowRank+rowStride+rowAlign-rowAlignA) % rowStride;
2341 const Int recvCol = (rowRank+rowStride+rowAlignA-rowAlign) % rowStride;
2342
2343 const Int height = this->Height();
2344 const Int localHeight = this->LocalHeight();
2345 const Int localWidth = this->LocalWidth();
2346 const Int localWidthA = A.LocalWidth();
2347 const Int maxLocalHeight = MaxLength(height,colStride);
2348
2349 const Int recvSize_RS = mpi::Pad( maxLocalHeight*localWidthA );
2350 const Int sendSize_RS = colStride * recvSize_RS;
2351 const Int recvSize_SR = localHeight * localWidth;
2352
2353 T* buffer = this->auxMemory_.Require
2354 ( recvSize_RS + std::max(sendSize_RS,recvSize_SR) );
2355 T* firstBuf = &buffer[0];
2356 T* secondBuf = &buffer[recvSize_RS];
2357
2358 // Pack
2359 const T* ABuffer = A.LockedBuffer();
2360 const Int ALDim = A.LDim();
2361 ELEM_OUTER_PARALLEL_FOR
2362 for( Int k=0; k<colStride; ++k )
2363 {
2364 T* data = &secondBuf[k*recvSize_RS];
2365 const Int thisColShift = Shift_( k, colAlign, colStride );
2366 const Int thisLocalHeight = Length_(height,thisColShift,colStride);
2367 ELEM_INNER_PARALLEL_FOR
2368 for( Int jLoc=0; jLoc<localWidthA; ++jLoc )
2369 {
2370 T* destCol = &data[jLoc*thisLocalHeight];
2371 const T* sourceCol = &ABuffer[thisColShift+jLoc*ALDim];
2372 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2373 destCol[iLoc] = sourceCol[iLoc*colStride];
2374 }
2375 }
2376
2377 // Reduce-scatter over each col
2378 mpi::ReduceScatter( secondBuf, firstBuf, recvSize_RS, this->ColComm() );
2379
2380 // Trade reduced data with the appropriate col
2381 mpi::SendRecv
2382 ( firstBuf, localHeight*localWidthA, sendCol,
2383 secondBuf, localHeight*localWidth, recvCol, this->RowComm() );
2384
2385 // Update with our received data
2386 T* thisBuffer = this->Buffer();
2387 const Int thisLDim = this->LDim();
2388 ELEM_FMA_PARALLEL_FOR
2389 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2390 {
2391 const T* secondBufCol = &secondBuf[jLoc*localHeight];
2392 T* thisCol = &thisBuffer[jLoc*thisLDim];
2393 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2394 thisCol[iLoc] += alpha*secondBufCol[iLoc];
2395 }
2396 this->auxMemory_.Release();
2397 }
2398 }
2399
2400 template<typename T,Dist U,Dist V>
2401 void
SumScatterUpdate(T alpha,const DistMatrix<T,UGath,VGath> & A)2402 GeneralDistMatrix<T,U,V>::SumScatterUpdate
2403 ( T alpha, const DistMatrix<T,UGath,VGath>& A )
2404 {
2405 DEBUG_ONLY(
2406 CallStackEntry cse("GDM::SumScatterUpdate");
2407 this->AssertNotLocked();
2408 this->AssertSameGrid( A.Grid() );
2409 this->AssertSameSize( A.Height(), A.Width() );
2410 )
2411 if( !this->Participating() )
2412 return;
2413
2414 const Int colStride = this->ColStride();
2415 const Int rowStride = this->RowStride();
2416 const Int colAlign = this->ColAlign();
2417 const Int rowAlign = this->RowAlign();
2418
2419 const Int height = this->Height();
2420 const Int width = this->Width();
2421 const Int localHeight = this->LocalHeight();
2422 const Int localWidth = this->LocalWidth();
2423 const Int maxLocalHeight = MaxLength(height,colStride);
2424 const Int maxLocalWidth = MaxLength(width,rowStride);
2425
2426 const Int recvSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
2427 const Int sendSize = colStride*rowStride*recvSize;
2428
2429 // Pack
2430 const T* ABuffer = A.LockedBuffer();
2431 const Int ALDim = A.LDim();
2432 T* buffer = this->auxMemory_.Require( sendSize );
2433 ELEM_OUTER_PARALLEL_FOR
2434 for( Int l=0; l<rowStride; ++l )
2435 {
2436 const Int thisRowShift = Shift_( l, rowAlign, rowStride );
2437 const Int thisLocalWidth = Length_( width, thisRowShift, rowStride );
2438 for( Int k=0; k<colStride; ++k )
2439 {
2440 T* data = &buffer[(k+l*colStride)*recvSize];
2441 const Int thisColShift = Shift_( k, colAlign, colStride );
2442 const Int thisLocalHeight = Length_(height,thisColShift,colStride);
2443 ELEM_INNER_PARALLEL_FOR
2444 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2445 {
2446 T* destCol = &data[jLoc*thisLocalHeight];
2447 const T* sourceCol =
2448 &ABuffer[thisColShift+(thisRowShift+jLoc*rowStride)*ALDim];
2449 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2450 destCol[iLoc] = sourceCol[iLoc*colStride];
2451 }
2452 }
2453 }
2454
2455 // Communicate
2456 mpi::ReduceScatter( buffer, recvSize, this->DistComm() );
2457
2458 // Unpack our received data
2459 T* thisBuffer = this->Buffer();
2460 const Int thisLDim = this->LDim();
2461 ELEM_FMA_PARALLEL_FOR
2462 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2463 {
2464 const T* bufferCol = &buffer[jLoc*localHeight];
2465 T* thisCol = &thisBuffer[jLoc*thisLDim];
2466 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2467 thisCol[iLoc] += alpha*bufferCol[iLoc];
2468 }
2469 this->auxMemory_.Release();
2470 }
2471
2472 template<typename T,Dist U,Dist V>
2473 void
PartialRowSumScatterUpdate(T alpha,const DistMatrix<T,U,VPart> & A)2474 GeneralDistMatrix<T,U,V>::PartialRowSumScatterUpdate
2475 ( T alpha, const DistMatrix<T,U,VPart>& A )
2476 {
2477 DEBUG_ONLY(
2478 CallStackEntry cse("GDM::PartialRowSumScatterUpdate");
2479 this->AssertNotLocked();
2480 this->AssertSameGrid( A.Grid() );
2481 this->AssertSameSize( A.Height(), A.Width() );
2482 )
2483 if( !this->Participating() )
2484 return;
2485
2486 if( this->RowAlign() % A.RowStride() == A.RowAlign() )
2487 {
2488 const Int rowStride = this->RowStride();
2489 const Int rowStridePart = this->PartialRowStride();
2490 const Int rowStrideUnion = this->PartialUnionRowStride();
2491 const Int rowRankPart = this->PartialRowRank();
2492 const Int rowAlign = this->RowAlign();
2493 const Int rowShiftOfA = A.RowShift();
2494
2495 const Int height = this->Height();
2496 const Int width = this->Width();
2497 const Int localWidth = this->LocalWidth();
2498 const Int maxLocalWidth = MaxLength( width, rowStride );
2499 const Int recvSize = mpi::Pad( height*maxLocalWidth );
2500 const Int sendSize = rowStrideUnion*recvSize;
2501
2502 // Pack
2503 const T* ABuf = A.LockedBuffer();
2504 const Int ALDim = A.LDim();
2505 T* buffer = this->auxMemory_.Require( sendSize );
2506 ELEM_OUTER_PARALLEL_FOR
2507 for( Int k=0; k<rowStrideUnion; ++k )
2508 {
2509 T* data = &buffer[k*recvSize];
2510 const Int thisRank = rowRankPart+k*rowStridePart;
2511 const Int thisRowShift = Shift_( thisRank, rowAlign, rowStride );
2512 const Int thisRowOffset =
2513 (thisRowShift-rowShiftOfA) / rowStridePart;
2514 const Int thisLocalWidth =
2515 Length_( width, thisRowShift, rowStride );
2516 ELEM_INNER_PARALLEL_FOR
2517 for( Int jLoc=0; jLoc<thisLocalWidth; ++jLoc )
2518 {
2519 const T* ACol =
2520 &ABuf[(thisRowOffset+jLoc*rowStrideUnion)*ALDim];
2521 T* dataCol = &data[jLoc*height];
2522 MemCopy( dataCol, ACol, height );
2523 }
2524 }
2525
2526 // Communicate
2527 mpi::ReduceScatter( buffer, recvSize, this->PartialUnionRowComm() );
2528
2529 // Unpack our received data
2530 T* thisBuf = this->Buffer();
2531 const Int thisLDim = this->LDim();
2532 ELEM_PARALLEL_FOR
2533 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2534 {
2535 const T* bufferCol = &buffer[jLoc*height];
2536 T* thisCol = &thisBuf[jLoc*thisLDim];
2537 for( Int i=0; i<height; ++i )
2538 thisCol[i] += alpha*bufferCol[i];
2539 }
2540 this->auxMemory_.Release();
2541 }
2542 else
2543 {
2544 LogicError("Unaligned PartialRowSumScatterUpdate not implemented");
2545 }
2546 }
2547
2548 template<typename T,Dist U,Dist V>
2549 void
PartialColSumScatterUpdate(T alpha,const DistMatrix<T,UPart,V> & A)2550 GeneralDistMatrix<T,U,V>::PartialColSumScatterUpdate
2551 ( T alpha, const DistMatrix<T,UPart,V>& A )
2552 {
2553 DEBUG_ONLY(
2554 CallStackEntry cse("GDM::PartialColSumScatterUpdate");
2555 this->AssertNotLocked();
2556 this->AssertSameGrid( A.Grid() );
2557 this->AssertSameSize( A.Height(), A.Width() );
2558 )
2559 if( !this->Participating() )
2560 return;
2561
2562 #ifdef ELEM_CACHE_WARNINGS
2563 if( A.Width() != 1 && A.Grid().Rank() == 0 )
2564 {
2565 std::cerr <<
2566 "PartialColSumScatterUpdate potentially causes a large amount"
2567 " of cache-thrashing. If possible, avoid it by forming the "
2568 "(conjugate-)transpose of the [UGath,* ] matrix instead."
2569 << std::endl;
2570 }
2571 #endif
2572 if( this->ColAlign() % A.ColStride() == A.ColAlign() )
2573 {
2574 const Int colStride = this->ColStride();
2575 const Int colStridePart = this->PartialColStride();
2576 const Int colStrideUnion = this->PartialUnionColStride();
2577 const Int colRankPart = this->PartialColRank();
2578 const Int colAlign = this->ColAlign();
2579 const Int colShiftOfA = A.ColShift();
2580
2581 const Int height = this->Height();
2582 const Int width = this->Width();
2583 const Int localHeight = this->LocalHeight();
2584 const Int maxLocalHeight = MaxLength( height, colStride );
2585 const Int recvSize = mpi::Pad( maxLocalHeight*width );
2586 const Int sendSize = colStrideUnion*recvSize;
2587
2588 T* buffer = this->auxMemory_.Require( sendSize );
2589
2590 // Pack
2591 const Int ALDim = A.LDim();
2592 const T* ABuf = A.LockedBuffer();
2593 ELEM_OUTER_PARALLEL_FOR
2594 for( Int k=0; k<colStrideUnion; ++k )
2595 {
2596 T* data = &buffer[k*recvSize];
2597 const Int thisRank = colRankPart+k*colStridePart;
2598 const Int thisColShift = Shift_( thisRank, colAlign, colStride );
2599 const Int thisColOffset =
2600 (thisColShift-colShiftOfA) / colStridePart;
2601 const Int thisLocalHeight =
2602 Length_( height, thisColShift, colStride );
2603 ELEM_INNER_PARALLEL_FOR
2604 for( Int j=0; j<width; ++j )
2605 {
2606 T* destCol = &data[j*thisLocalHeight];
2607 const T* sourceCol = &ABuf[thisColOffset+j*ALDim];
2608 for( Int iLoc=0; iLoc<thisLocalHeight; ++iLoc )
2609 destCol[iLoc] = sourceCol[iLoc*colStrideUnion];
2610 }
2611 }
2612
2613 // Communicate
2614 mpi::ReduceScatter( buffer, recvSize, this->PartialUnionColComm() );
2615
2616 // Unpack our received data
2617 T* thisBuf = this->Buffer();
2618 const Int thisLDim = this->LDim();
2619 ELEM_PARALLEL_FOR
2620 for( Int j=0; j<width; ++j )
2621 {
2622 const T* bufferCol = &buffer[j*localHeight];
2623 T* thisCol = &thisBuf[j*thisLDim];
2624 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2625 thisCol[iLoc] += alpha*bufferCol[iLoc];
2626 }
2627 this->auxMemory_.Release();
2628 }
2629 else
2630 {
2631 LogicError("Unaligned PartialColSumScatterUpdate not implemented");
2632 }
2633 }
2634
2635 template<typename T,Dist U,Dist V>
2636 void
TransposeColAllGather(DistMatrix<T,V,UGath> & A,bool conjugate) const2637 GeneralDistMatrix<T,U,V>::TransposeColAllGather
2638 ( DistMatrix<T,V,UGath>& A, bool conjugate ) const
2639 {
2640 DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColAllGather"))
2641 DistMatrix<T,V,U> ATrans( this->Grid() );
2642 ATrans.AlignWith( *this );
2643 ATrans.Resize( this->Width(), this->Height() );
2644 Transpose( this->LockedMatrix(), ATrans.Matrix(), conjugate );
2645 ATrans.RowAllGather( A );
2646 }
2647
2648 template<typename T,Dist U,Dist V>
2649 void
TransposePartialColAllGather(DistMatrix<T,V,UPart> & A,bool conjugate) const2650 GeneralDistMatrix<T,U,V>::TransposePartialColAllGather
2651 ( DistMatrix<T,V,UPart>& A, bool conjugate ) const
2652 {
2653 DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColAllGather"))
2654 DistMatrix<T,V,U> ATrans( this->Grid() );
2655 ATrans.AlignWith( *this );
2656 ATrans.Resize( this->Width(), this->Height() );
2657 Transpose( this->LockedMatrix(), ATrans.Matrix(), conjugate );
2658 ATrans.PartialRowAllGather( A );
2659 }
2660
2661 template<typename T,Dist U,Dist V>
2662 void
AdjointColAllGather(DistMatrix<T,V,UGath> & A) const2663 GeneralDistMatrix<T,U,V>::AdjointColAllGather( DistMatrix<T,V,UGath>& A ) const
2664 {
2665 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointRowAllGather"))
2666 this->TransposeColAllGather( A, true );
2667 }
2668
2669 template<typename T,Dist U,Dist V>
2670 void
AdjointPartialColAllGather(DistMatrix<T,V,UPart> & A) const2671 GeneralDistMatrix<T,U,V>::AdjointPartialColAllGather
2672 ( DistMatrix<T,V,UPart>& A ) const
2673 {
2674 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColAllGather"))
2675 this->TransposePartialColAllGather( A, true );
2676 }
2677
2678 template<typename T,Dist U,Dist V>
2679 void
TransposeColFilterFrom(const DistMatrix<T,V,UGath> & A,bool conjugate)2680 GeneralDistMatrix<T,U,V>::TransposeColFilterFrom
2681 ( const DistMatrix<T,V,UGath>& A, bool conjugate )
2682 {
2683 DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColFilterFrom"))
2684 DistMatrix<T,V,U> AFilt( A.Grid() );
2685 if( this->ColConstrained() )
2686 AFilt.AlignRowsWith( *this, false );
2687 if( this->RowConstrained() )
2688 AFilt.AlignColsWith( *this, false );
2689 AFilt.RowFilterFrom( A );
2690 if( !this->ColConstrained() )
2691 this->AlignColsWith( AFilt, false );
2692 if( !this->RowConstrained() )
2693 this->AlignRowsWith( AFilt, false );
2694 this->Resize( A.Width(), A.Height() );
2695 Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2696 }
2697
2698 template<typename T,Dist U,Dist V>
2699 void
TransposeRowFilterFrom(const DistMatrix<T,VGath,U> & A,bool conjugate)2700 GeneralDistMatrix<T,U,V>::TransposeRowFilterFrom
2701 ( const DistMatrix<T,VGath,U>& A, bool conjugate )
2702 {
2703 DEBUG_ONLY(CallStackEntry cse("GDM::TransposeRowFilterFrom"))
2704 DistMatrix<T,V,U> AFilt( A.Grid() );
2705 if( this->ColConstrained() )
2706 AFilt.AlignRowsWith( *this, false );
2707 if( this->RowConstrained() )
2708 AFilt.AlignColsWith( *this, false );
2709 AFilt.ColFilterFrom( A );
2710 if( !this->ColConstrained() )
2711 this->AlignColsWith( AFilt, false );
2712 if( !this->RowConstrained() )
2713 this->AlignRowsWith( AFilt, false );
2714 this->Resize( A.Width(), A.Height() );
2715 Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2716 }
2717
2718 template<typename T,Dist U,Dist V>
2719 void
TransposePartialColFilterFrom(const DistMatrix<T,V,UPart> & A,bool conjugate)2720 GeneralDistMatrix<T,U,V>::TransposePartialColFilterFrom
2721 ( const DistMatrix<T,V,UPart>& A, bool conjugate )
2722 {
2723 DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColFilterFrom"))
2724 DistMatrix<T,V,U> AFilt( A.Grid() );
2725 if( this->ColConstrained() )
2726 AFilt.AlignRowsWith( *this, false );
2727 if( this->RowConstrained() )
2728 AFilt.AlignColsWith( *this, false );
2729 AFilt.PartialRowFilterFrom( A );
2730 if( !this->ColConstrained() )
2731 this->AlignColsWith( AFilt, false );
2732 if( !this->RowConstrained() )
2733 this->AlignRowsWith( AFilt, false );
2734 this->Resize( A.Width(), A.Height() );
2735 Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2736 }
2737
2738 template<typename T,Dist U,Dist V>
2739 void
TransposePartialRowFilterFrom(const DistMatrix<T,VPart,U> & A,bool conjugate)2740 GeneralDistMatrix<T,U,V>::TransposePartialRowFilterFrom
2741 ( const DistMatrix<T,VPart,U>& A, bool conjugate )
2742 {
2743 DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialRowFilterFrom"))
2744 DistMatrix<T,V,U> AFilt( A.Grid() );
2745 if( this->ColConstrained() )
2746 AFilt.AlignRowsWith( *this, false );
2747 if( this->RowConstrained() )
2748 AFilt.AlignColsWith( *this, false );
2749 AFilt.PartialColFilterFrom( A );
2750 if( !this->ColConstrained() )
2751 this->AlignColsWith( AFilt, false );
2752 if( !this->RowConstrained() )
2753 this->AlignRowsWith( AFilt, false );
2754 this->Resize( A.Width(), A.Height() );
2755 Transpose( AFilt.LockedMatrix(), this->Matrix(), conjugate );
2756 }
2757
2758 template<typename T,Dist U,Dist V>
2759 void
AdjointColFilterFrom(const DistMatrix<T,V,UGath> & A)2760 GeneralDistMatrix<T,U,V>::AdjointColFilterFrom( const DistMatrix<T,V,UGath>& A )
2761 {
2762 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointColFilterFrom"))
2763 this->TransposeColFilterFrom( A, true );
2764 }
2765
2766 template<typename T,Dist U,Dist V>
2767 void
AdjointRowFilterFrom(const DistMatrix<T,VGath,U> & A)2768 GeneralDistMatrix<T,U,V>::AdjointRowFilterFrom( const DistMatrix<T,VGath,U>& A )
2769 {
2770 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointRowFilterFrom"))
2771 this->TransposeRowFilterFrom( A, true );
2772 }
2773
2774 template<typename T,Dist U,Dist V>
2775 void
AdjointPartialColFilterFrom(const DistMatrix<T,V,UPart> & A)2776 GeneralDistMatrix<T,U,V>::AdjointPartialColFilterFrom
2777 ( const DistMatrix<T,V,UPart>& A )
2778 {
2779 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColFilterFrom"))
2780 this->TransposePartialColFilterFrom( A, true );
2781 }
2782
2783 template<typename T,Dist U,Dist V>
2784 void
AdjointPartialRowFilterFrom(const DistMatrix<T,VPart,U> & A)2785 GeneralDistMatrix<T,U,V>::AdjointPartialRowFilterFrom
2786 ( const DistMatrix<T,VPart,U>& A )
2787 {
2788 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialRowFilterFrom"))
2789 this->TransposePartialRowFilterFrom( A, true );
2790 }
2791
2792 template<typename T,Dist U,Dist V>
2793 void
TransposeColSumScatterFrom(const DistMatrix<T,V,UGath> & A,bool conjugate)2794 GeneralDistMatrix<T,U,V>::TransposeColSumScatterFrom
2795 ( const DistMatrix<T,V,UGath>& A, bool conjugate )
2796 {
2797 DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColSumScatterFrom"))
2798 DistMatrix<T,V,U> ASumFilt( A.Grid() );
2799 if( this->ColConstrained() )
2800 ASumFilt.AlignRowsWith( *this, false );
2801 if( this->RowConstrained() )
2802 ASumFilt.AlignColsWith( *this, false );
2803 ASumFilt.RowSumScatterFrom( A );
2804 if( !this->ColConstrained() )
2805 this->AlignColsWith( ASumFilt, false );
2806 if( !this->RowConstrained() )
2807 this->AlignRowsWith( ASumFilt, false );
2808 this->Resize( A.Width(), A.Height() );
2809 Transpose( ASumFilt.LockedMatrix(), this->Matrix(), conjugate );
2810 }
2811
2812 template<typename T,Dist U,Dist V>
2813 void
TransposePartialColSumScatterFrom(const DistMatrix<T,V,UPart> & A,bool conjugate)2814 GeneralDistMatrix<T,U,V>::TransposePartialColSumScatterFrom
2815 ( const DistMatrix<T,V,UPart>& A, bool conjugate )
2816 {
2817 DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColSumScatterFrom"))
2818 DistMatrix<T,V,U> ASumFilt( A.Grid() );
2819 if( this->ColConstrained() )
2820 ASumFilt.AlignRowsWith( *this, false );
2821 if( this->RowConstrained() )
2822 ASumFilt.AlignColsWith( *this, false );
2823 ASumFilt.PartialRowSumScatterFrom( A );
2824 if( !this->ColConstrained() )
2825 this->AlignColsWith( ASumFilt, false );
2826 if( !this->RowConstrained() )
2827 this->AlignRowsWith( ASumFilt, false );
2828 this->Resize( A.Width(), A.Height() );
2829 Transpose( ASumFilt.LockedMatrix(), this->Matrix(), conjugate );
2830 }
2831
2832 template<typename T,Dist U,Dist V>
2833 void
AdjointColSumScatterFrom(const DistMatrix<T,V,UGath> & A)2834 GeneralDistMatrix<T,U,V>::AdjointColSumScatterFrom
2835 ( const DistMatrix<T,V,UGath>& A )
2836 {
2837 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointColSumScatterFrom"))
2838 this->TransposeColSumScatterFrom( A, true );
2839 }
2840
2841 template<typename T,Dist U,Dist V>
2842 void
AdjointPartialColSumScatterFrom(const DistMatrix<T,V,UPart> & A)2843 GeneralDistMatrix<T,U,V>::AdjointPartialColSumScatterFrom
2844 ( const DistMatrix<T,V,UPart>& A )
2845 {
2846 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColSumScatterFrom"))
2847 this->TransposePartialColSumScatterFrom( A, true );
2848 }
2849
2850 template<typename T,Dist U,Dist V>
2851 void
TransposeColSumScatterUpdate(T alpha,const DistMatrix<T,V,UGath> & A,bool conjugate)2852 GeneralDistMatrix<T,U,V>::TransposeColSumScatterUpdate
2853 ( T alpha, const DistMatrix<T,V,UGath>& A, bool conjugate )
2854 {
2855 DEBUG_ONLY(CallStackEntry cse("GDM::TransposeColSumScatterUpdate"))
2856 DistMatrix<T,V,U> ASumFilt( A.Grid() );
2857 if( this->ColConstrained() )
2858 ASumFilt.AlignRowsWith( *this, false );
2859 if( this->RowConstrained() )
2860 ASumFilt.AlignColsWith( *this, false );
2861 ASumFilt.RowSumScatterFrom( A );
2862 if( !this->ColConstrained() )
2863 this->AlignColsWith( ASumFilt, false );
2864 if( !this->RowConstrained() )
2865 this->AlignRowsWith( ASumFilt, false );
2866 // ALoc += alpha ASumFiltLoc'
2867 elem::Matrix<T>& ALoc = this->Matrix();
2868 const elem::Matrix<T>& BLoc = ASumFilt.LockedMatrix();
2869 const Int localHeight = ALoc.Height();
2870 const Int localWidth = ALoc.Width();
2871 if( conjugate )
2872 {
2873 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2874 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2875 ALoc.Update( iLoc, jLoc, alpha*Conj(BLoc.Get(jLoc,iLoc)) );
2876 }
2877 else
2878 {
2879 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2880 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2881 ALoc.Update( iLoc, jLoc, alpha*BLoc.Get(jLoc,iLoc) );
2882 }
2883 }
2884
2885 template<typename T,Dist U,Dist V>
2886 void
TransposePartialColSumScatterUpdate(T alpha,const DistMatrix<T,V,UPart> & A,bool conjugate)2887 GeneralDistMatrix<T,U,V>::TransposePartialColSumScatterUpdate
2888 ( T alpha, const DistMatrix<T,V,UPart>& A, bool conjugate )
2889 {
2890 DEBUG_ONLY(CallStackEntry cse("GDM::TransposePartialColSumScatterUpdate"))
2891 DistMatrix<T,V,U> ASumFilt( A.Grid() );
2892 if( this->ColConstrained() )
2893 ASumFilt.AlignRowsWith( *this, false );
2894 if( this->RowConstrained() )
2895 ASumFilt.AlignColsWith( *this, false );
2896 ASumFilt.PartialRowSumScatterFrom( A );
2897 if( !this->ColConstrained() )
2898 this->AlignColsWith( ASumFilt, false );
2899 if( !this->RowConstrained() )
2900 this->AlignRowsWith( ASumFilt, false );
2901 // ALoc += alpha ASumFiltLoc'
2902 elem::Matrix<T>& ALoc = this->Matrix();
2903 const elem::Matrix<T>& BLoc = ASumFilt.LockedMatrix();
2904 const Int localHeight = ALoc.Height();
2905 const Int localWidth = ALoc.Width();
2906 if( conjugate )
2907 {
2908 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2909 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2910 ALoc.Update( iLoc, jLoc, alpha*Conj(BLoc.Get(jLoc,iLoc)) );
2911 }
2912 else
2913 {
2914 for( Int jLoc=0; jLoc<localWidth; ++jLoc )
2915 for( Int iLoc=0; iLoc<localHeight; ++iLoc )
2916 ALoc.Update( iLoc, jLoc, alpha*BLoc.Get(jLoc,iLoc) );
2917 }
2918 }
2919
2920 template<typename T,Dist U,Dist V>
2921 void
AdjointColSumScatterUpdate(T alpha,const DistMatrix<T,V,UGath> & A)2922 GeneralDistMatrix<T,U,V>::AdjointColSumScatterUpdate
2923 ( T alpha, const DistMatrix<T,V,UGath>& A )
2924 {
2925 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointColSumScatterUpdate"))
2926 this->TransposeColSumScatterUpdate( alpha, A, true );
2927 }
2928
2929 template<typename T,Dist U,Dist V>
2930 void
AdjointPartialColSumScatterUpdate(T alpha,const DistMatrix<T,V,UPart> & A)2931 GeneralDistMatrix<T,U,V>::AdjointPartialColSumScatterUpdate
2932 ( T alpha, const DistMatrix<T,V,UPart>& A )
2933 {
2934 DEBUG_ONLY(CallStackEntry cse("GDM::AdjointPartialColSumScatterUpdate"))
2935 this->TransposePartialColSumScatterUpdate( alpha, A, true );
2936 }
2937
2938 // Diagonal manipulation
2939 // =====================
2940 template<typename T,Dist U,Dist V>
2941 bool
DiagonalAlignedWith(const elem::DistData & d,Int offset) const2942 GeneralDistMatrix<T,U,V>::DiagonalAlignedWith
2943 ( const elem::DistData& d, Int offset ) const
2944 {
2945 DEBUG_ONLY(CallStackEntry cse("GDM::DiagonalAlignedWith"))
2946 if( this->Grid() != *d.grid )
2947 return false;
2948
2949 const Int diagRoot = this->DiagonalRoot(offset);
2950 if( diagRoot != d.root )
2951 return false;
2952
2953 const Int diagAlign = this->DiagonalAlign(offset);
2954 if( d.colDist == UDiag && d.rowDist == VDiag )
2955 return d.colAlign == diagAlign;
2956 else if( d.colDist == VDiag && d.rowDist == UDiag )
2957 return d.rowAlign == diagAlign;
2958 else
2959 return false;
2960 }
2961
2962 template<typename T,Dist U,Dist V>
2963 Int
DiagonalRoot(Int offset) const2964 GeneralDistMatrix<T,U,V>::DiagonalRoot( Int offset ) const
2965 {
2966 DEBUG_ONLY(CallStackEntry cse("GDM::DiagonalRoot"))
2967 const elem::Grid& grid = this->Grid();
2968
2969 if( U == MC && V == MR )
2970 {
2971 // Result is an [MD,* ] or [* ,MD]
2972 Int owner;
2973 if( offset >= 0 )
2974 {
2975 const Int procRow = this->ColAlign();
2976 const Int procCol = (this->RowAlign()+offset) % this->RowStride();
2977 owner = procRow + this->ColStride()*procCol;
2978 }
2979 else
2980 {
2981 const Int procRow = (this->ColAlign()-offset) % this->ColStride();
2982 const Int procCol = this->RowAlign();
2983 owner = procRow + this->ColStride()*procCol;
2984 }
2985 return grid.DiagPath(owner);
2986 }
2987 else if( U == MR && V == MC )
2988 {
2989 // Result is an [MD,* ] or [* ,MD]
2990 Int owner;
2991 if( offset >= 0 )
2992 {
2993 const Int procCol = this->ColAlign();
2994 const Int procRow = (this->RowAlign()+offset) % this->RowStride();
2995 owner = procRow + this->ColStride()*procCol;
2996 }
2997 else
2998 {
2999 const Int procCol = (this->ColAlign()-offset) % this->ColStride();
3000 const Int procRow = this->RowAlign();
3001 owner = procRow + this->ColStride()*procCol;
3002 }
3003 return grid.DiagPath(owner);
3004 }
3005 else
3006 return this->Root();
3007 }
3008
3009 template<typename T,Dist U,Dist V>
3010 Int
DiagonalAlign(Int offset) const3011 GeneralDistMatrix<T,U,V>::DiagonalAlign( Int offset ) const
3012 {
3013 DEBUG_ONLY(CallStackEntry cse("GDM::DiagonalAlign"))
3014 const elem::Grid& grid = this->Grid();
3015
3016 if( U == MC && V == MR )
3017 {
3018 // Result is an [MD,* ] or [* ,MD]
3019 Int owner;
3020 if( offset >= 0 )
3021 {
3022 const Int procRow = this->ColAlign();
3023 const Int procCol = (this->RowAlign()+offset) % this->RowStride();
3024 owner = procRow + this->ColStride()*procCol;
3025 }
3026 else
3027 {
3028 const Int procRow = (this->ColAlign()-offset) % this->ColStride();
3029 const Int procCol = this->RowAlign();
3030 owner = procRow + this->ColStride()*procCol;
3031 }
3032 return grid.DiagPathRank(owner);
3033 }
3034 else if( U == MR && V == MC )
3035 {
3036 // Result is an [MD,* ] or [* ,MD]
3037 Int owner;
3038 if( offset >= 0 )
3039 {
3040 const Int procCol = this->ColAlign();
3041 const Int procRow = (this->RowAlign()+offset) % this->RowStride();
3042 owner = procRow + this->ColStride()*procCol;
3043 }
3044 else
3045 {
3046 const Int procCol = (this->ColAlign()-offset) % this->ColStride();
3047 const Int procRow = this->RowAlign();
3048 owner = procRow + this->ColStride()*procCol;
3049 }
3050 return grid.DiagPathRank(owner);
3051 }
3052 else if( U == STAR )
3053 {
3054 // Result is a [V,* ] or [* ,V]
3055 if( offset >= 0 )
3056 return (this->RowAlign()+offset) % this->RowStride();
3057 else
3058 return this->RowAlign();
3059 }
3060 else
3061 {
3062 // Result is [U,V] or [V,U], where V is either STAR or CIRC
3063 if( offset >= 0 )
3064 return this->ColAlign();
3065 else
3066 return (this->ColAlign()-offset) % this->ColStride();
3067 }
3068 }
3069
3070 template<typename T,Dist U,Dist V>
3071 void
GetDiagonal(DistMatrix<T,UDiag,VDiag> & d,Int offset) const3072 GeneralDistMatrix<T,U,V>::GetDiagonal
3073 ( DistMatrix<T,UDiag,VDiag>& d, Int offset ) const
3074 {
3075 DEBUG_ONLY(CallStackEntry cse("GDM::GetDiagonal"))
3076 this->GetDiagonalHelper
3077 ( d, offset, []( T& alpha, T beta ) { alpha = beta; } );
3078 }
3079
3080 template<typename T,Dist U,Dist V>
3081 void
GetRealPartOfDiagonal(DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset) const3082 GeneralDistMatrix<T,U,V>::GetRealPartOfDiagonal
3083 ( DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset ) const
3084 {
3085 DEBUG_ONLY(CallStackEntry cse("GDM::GetRealPartOfDiagonal"))
3086 this->GetDiagonalHelper
3087 ( d, offset, []( Base<T>& alpha, T beta ) { alpha = RealPart(beta); } );
3088 }
3089
3090 template<typename T,Dist U,Dist V>
3091 void
GetImagPartOfDiagonal(DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset) const3092 GeneralDistMatrix<T,U,V>::GetImagPartOfDiagonal
3093 ( DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset ) const
3094 {
3095 DEBUG_ONLY(CallStackEntry cse("GDM::GetImagPartOfDiagonal"))
3096 this->GetDiagonalHelper
3097 ( d, offset, []( Base<T>& alpha, T beta ) { alpha = ImagPart(beta); } );
3098 }
3099
3100 template<typename T,Dist U,Dist V>
3101 auto
GetDiagonal(Int offset) const3102 GeneralDistMatrix<T,U,V>::GetDiagonal( Int offset ) const
3103 -> DistMatrix<T,UDiag,VDiag>
3104 {
3105 DistMatrix<T,UDiag,VDiag> d( this->Grid() );
3106 GetDiagonal( d, offset );
3107 return d;
3108 }
3109
3110 template<typename T,Dist U,Dist V>
3111 auto
GetRealPartOfDiagonal(Int offset) const3112 GeneralDistMatrix<T,U,V>::GetRealPartOfDiagonal( Int offset ) const
3113 -> DistMatrix<Base<T>,UDiag,VDiag>
3114 {
3115 DistMatrix<Base<T>,UDiag,VDiag> d( this->Grid() );
3116 GetRealPartOfDiagonal( d, offset );
3117 return d;
3118 }
3119
3120 template<typename T,Dist U,Dist V>
3121 auto
GetImagPartOfDiagonal(Int offset) const3122 GeneralDistMatrix<T,U,V>::GetImagPartOfDiagonal( Int offset ) const
3123 -> DistMatrix<Base<T>,UDiag,VDiag>
3124 {
3125 DistMatrix<Base<T>,UDiag,VDiag> d( this->Grid() );
3126 GetImagPartOfDiagonal( d, offset );
3127 return d;
3128 }
3129
3130 template<typename T,Dist U,Dist V>
3131 void
SetDiagonal(const DistMatrix<T,UDiag,VDiag> & d,Int offset)3132 GeneralDistMatrix<T,U,V>::SetDiagonal
3133 ( const DistMatrix<T,UDiag,VDiag>& d, Int offset )
3134 {
3135 DEBUG_ONLY(CallStackEntry cse("GDM::SetDiagonal"))
3136 this->SetDiagonalHelper
3137 ( d, offset, []( T& alpha, T beta ) { alpha = beta; } );
3138 }
3139
3140 template<typename T,Dist U,Dist V>
3141 void
SetRealPartOfDiagonal(const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3142 GeneralDistMatrix<T,U,V>::SetRealPartOfDiagonal
3143 ( const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3144 {
3145 DEBUG_ONLY(CallStackEntry cse("GDM::SetRealPartOfDiagonal"))
3146 this->SetDiagonalHelper
3147 ( d, offset,
3148 []( T& alpha, Base<T> beta ) { elem::SetRealPart(alpha,beta); } );
3149 }
3150
3151 template<typename T,Dist U,Dist V>
3152 void
SetImagPartOfDiagonal(const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3153 GeneralDistMatrix<T,U,V>::SetImagPartOfDiagonal
3154 ( const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3155 {
3156 DEBUG_ONLY(CallStackEntry cse("GDM::SetImagPartOfDiagonal"))
3157 this->SetDiagonalHelper
3158 ( d, offset,
3159 []( T& alpha, Base<T> beta ) { elem::SetImagPart(alpha,beta); } );
3160 }
3161
3162 template<typename T,Dist U,Dist V>
3163 void
UpdateDiagonal(T gamma,const DistMatrix<T,UDiag,VDiag> & d,Int offset)3164 GeneralDistMatrix<T,U,V>::UpdateDiagonal
3165 ( T gamma, const DistMatrix<T,UDiag,VDiag>& d, Int offset )
3166 {
3167 DEBUG_ONLY(CallStackEntry cse("GDM::UpdateDiagonal"))
3168 this->SetDiagonalHelper
3169 ( d, offset, [gamma]( T& alpha, T beta ) { alpha += gamma*beta; } );
3170 }
3171
3172 template<typename T,Dist U,Dist V>
3173 void
UpdateRealPartOfDiagonal(Base<T> gamma,const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3174 GeneralDistMatrix<T,U,V>::UpdateRealPartOfDiagonal
3175 ( Base<T> gamma, const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3176 {
3177 DEBUG_ONLY(CallStackEntry cse("GDM::UpdateRealPartOfDiagonal"))
3178 this->SetDiagonalHelper
3179 ( d, offset,
3180 [gamma]( T& alpha, Base<T> beta )
3181 { elem::UpdateRealPart(alpha,gamma*beta); } );
3182 }
3183
3184 template<typename T,Dist U,Dist V>
3185 void
UpdateImagPartOfDiagonal(Base<T> gamma,const DistMatrix<Base<T>,UDiag,VDiag> & d,Int offset)3186 GeneralDistMatrix<T,U,V>::UpdateImagPartOfDiagonal
3187 ( Base<T> gamma, const DistMatrix<Base<T>,UDiag,VDiag>& d, Int offset )
3188 {
3189 DEBUG_ONLY(CallStackEntry cse("GDM::UpdateImagPartOfDiagonal"))
3190 this->SetDiagonalHelper
3191 ( d, offset,
3192 [gamma]( T& alpha, Base<T> beta )
3193 { elem::UpdateImagPart(alpha,gamma*beta); } );
3194 }
3195
3196 // Private section
3197 // ###############
3198
3199 // Diagonal helper functions
3200 // =========================
3201 template<typename T,Dist U,Dist V>
3202 template<typename S,class Function>
3203 void
GetDiagonalHelper(DistMatrix<S,UDiag,VDiag> & d,Int offset,Function func) const3204 GeneralDistMatrix<T,U,V>::GetDiagonalHelper
3205 ( DistMatrix<S,UDiag,VDiag>& d, Int offset, Function func ) const
3206 {
3207 DEBUG_ONLY(CallStackEntry cse("GDM::GetDiagonalHelper"))
3208 d.SetGrid( this->Grid() );
3209 d.SetRoot( this->DiagonalRoot(offset) );
3210 d.AlignCols( this->DiagonalAlign(offset), false );
3211 d.Resize( this->DiagonalLength(offset), 1 );
3212 if( !d.Participating() )
3213 return;
3214
3215 const Int diagShift = d.ColShift();
3216 const Int diagStride = d.ColStride();
3217 const Int iStart = ( offset>=0 ? diagShift : diagShift-offset );
3218 const Int jStart = ( offset>=0 ? diagShift+offset : diagShift );
3219
3220 const Int colStride = this->ColStride();
3221 const Int rowStride = this->RowStride();
3222 const Int iLocStart = (iStart-this->ColShift()) / colStride;
3223 const Int jLocStart = (jStart-this->RowShift()) / rowStride;
3224
3225 const Int localDiagLength = d.LocalHeight();
3226 S* dBuf = d.Buffer();
3227 const T* buffer = this->LockedBuffer();
3228 const Int ldim = this->LDim();
3229
3230 ELEM_PARALLEL_FOR
3231 for( Int k=0; k<localDiagLength; ++k )
3232 {
3233 const Int iLoc = iLocStart + k*(diagStride/colStride);
3234 const Int jLoc = jLocStart + k*(diagStride/rowStride);
3235 func( dBuf[k], buffer[iLoc+jLoc*ldim] );
3236 }
3237 }
3238
3239 template<typename T,Dist U,Dist V>
3240 template<typename S,class Function>
3241 void
SetDiagonalHelper(const DistMatrix<S,UDiag,VDiag> & d,Int offset,Function func)3242 GeneralDistMatrix<T,U,V>::SetDiagonalHelper
3243 ( const DistMatrix<S,UDiag,VDiag>& d, Int offset, Function func )
3244 {
3245 DEBUG_ONLY(
3246 CallStackEntry cse("GDM::SetDiagonalHelper");
3247 if( !this->DiagonalAlignedWith( d, offset ) )
3248 LogicError("Invalid diagonal alignment");
3249 )
3250 if( !d.Participating() )
3251 return;
3252
3253 const Int diagShift = d.ColShift();
3254 const Int diagStride = d.ColStride();
3255 const Int iStart = ( offset>=0 ? diagShift : diagShift-offset );
3256 const Int jStart = ( offset>=0 ? diagShift+offset : diagShift );
3257
3258 const Int colStride = this->ColStride();
3259 const Int rowStride = this->RowStride();
3260 const Int iLocStart = (iStart-this->ColShift()) / colStride;
3261 const Int jLocStart = (jStart-this->RowShift()) / rowStride;
3262
3263 const Int localDiagLength = d.LocalHeight();
3264 const S* dBuf = d.LockedBuffer();
3265 T* buffer = this->Buffer();
3266 const Int ldim = this->LDim();
3267
3268 ELEM_PARALLEL_FOR
3269 for( Int k=0; k<localDiagLength; ++k )
3270 {
3271 const Int iLoc = iLocStart + k*(diagStride/colStride);
3272 const Int jLoc = jLocStart + k*(diagStride/rowStride);
3273 func( buffer[iLoc+jLoc*ldim], dBuf[k] );
3274 }
3275 }
3276
3277 // Instantiations for {Int,Real,Complex<Real>} for each Real in {float,double}
3278 // ###########################################################################
3279
3280 #define DISTPROTO(T,U,V) template class GeneralDistMatrix<T,U,V>
3281
3282 #define PROTO(T)\
3283 DISTPROTO(T,CIRC,CIRC);\
3284 DISTPROTO(T,MC, MR );\
3285 DISTPROTO(T,MC, STAR);\
3286 DISTPROTO(T,MD, STAR);\
3287 DISTPROTO(T,MR, MC );\
3288 DISTPROTO(T,MR, STAR);\
3289 DISTPROTO(T,STAR,MC );\
3290 DISTPROTO(T,STAR,MD );\
3291 DISTPROTO(T,STAR,MR );\
3292 DISTPROTO(T,STAR,STAR);\
3293 DISTPROTO(T,STAR,VC );\
3294 DISTPROTO(T,STAR,VR );\
3295 DISTPROTO(T,VC, STAR);\
3296 DISTPROTO(T,VR, STAR);
3297
3298 #ifndef ELEM_DISABLE_COMPLEX
3299 #ifndef ELEM_DISABLE_FLOAT
3300 PROTO(Int);
3301 PROTO(float);
3302 PROTO(double);
3303 PROTO(Complex<float>);
3304 PROTO(Complex<double>);
3305 #else // ifndef ELEM_DISABLE_FLOAT
3306 PROTO(Int);
3307 PROTO(double);
3308 PROTO(Complex<double>);
3309 #endif // ifndef ELEM_DISABLE_FLOAT
3310 #else // ifndef ELEM_DISABLE_COMPLEX
3311 #ifndef ELEM_DISABLE_FLOAT
3312 PROTO(Int);
3313 PROTO(float);
3314 PROTO(double);
3315 #else // ifndef ELEM_DISABLE_FLOAT
3316 PROTO(Int);
3317 PROTO(double);
3318 #endif // ifndef ELEM_DISABLE_FLOAT
3319 #endif // ifndef ELEM_DISABLE_COMPLEX
3320
3321 } // namespace elem
3322