1 /* 2 Copyright (c) 2009-2014, Jack Poulson 3 All rights reserved. 4 5 This file is part of Elemental and is under the BSD 2-Clause License, 6 which can be found in the LICENSE file in the root directory, or at 7 http://opensource.org/licenses/BSD-2-Clause 8 */ 9 #pragma once 10 #ifndef ELEM_BLOCKDISTMATRIX_ABSTRACT_DECL_HPP 11 #define ELEM_BLOCKDISTMATRIX_ABSTRACT_DECL_HPP 12 13 namespace elem { 14 15 template<typename T> 16 class AbstractBlockDistMatrix 17 { 18 public: 19 // Typedefs 20 // ======== 21 typedef AbstractBlockDistMatrix<T> type; 22 23 // Constructors and destructors 24 // ============================ 25 // Move constructor 26 AbstractBlockDistMatrix( type&& A ) ELEM_NOEXCEPT; 27 28 virtual ~AbstractBlockDistMatrix(); 29 30 // Assignment and reconfiguration 31 // ============================== 32 // Move assignment 33 type& operator=( type&& A ); 34 35 void Empty(); 36 void EmptyData(); 37 void SetGrid( const elem::Grid& grid ); 38 void Resize( Int height, Int width ); 39 void Resize( Int height, Int width, Int ldim ); 40 void MakeConsistent( bool includingViewers=false ); 41 void MakeSizeConsistent( bool includingViewers=false ); 42 43 // Realignment 44 // ----------- 45 void Align 46 ( Int blockHeight, Int blockWidth, 47 Int colAlign, Int rowAlign, Int colCut=0, Int rowCut=0, 48 bool constrain=true ); 49 void AlignCols 50 ( Int blockHeight, Int colAlign, Int colCut=0, bool constrain=true ); 51 void AlignRows 52 ( Int blockWidth, Int rowAlign, Int rowCut=0, bool constrain=true ); 53 void FreeAlignments(); 54 void SetRoot( Int root, bool constrain=true ); 55 void AlignWith( const elem::BlockDistData& data, bool constrain=true ); 56 virtual void AlignColsWith 57 ( const elem::BlockDistData& data, bool constrain=true ); 58 virtual void AlignRowsWith 59 ( const elem::BlockDistData& data, bool constrain=true ); 60 // TODO: The interface for these routines could be improved 61 void AlignAndResize 62 ( Int blockHeight, Int blockWidth, 63 Int colAlign, Int rowAlign, Int colCut, Int rowCut, 64 Int height, Int width, bool force=false, bool constrain=true ); 65 void AlignColsAndResize 66 ( Int blockHeight, Int colAlign, Int colCut, Int height, Int width, 67 bool force=false, bool constrain=true ); 68 void AlignRowsAndResize 69 ( Int blockWidth, Int rowAlign, Int rowCut, Int height, Int width, 70 bool force=false, bool constrain=true ); 71 72 // Buffer attachment 73 // ----------------- 74 // (Immutable) view of a distributed matrix's buffer 75 void Attach 76 ( Int height, Int width, const elem::Grid& g, 77 Int blockHeight, Int blockWidth, 78 Int colAlign, Int rowAlign, Int colCut, Int rowCut, 79 T* buffer, Int ldim, Int root=0 ); 80 void LockedAttach 81 ( Int height, Int width, const elem::Grid& g, 82 Int blockHeight, Int blockWidth, 83 Int colAlign, Int rowAlign, Int colCut, Int rowCut, 84 const T* buffer, Int ldim, Int root=0 ); 85 void Attach 86 ( Int height, Int width, const elem::Grid& g, 87 Int blockHeight, Int blockWidth, 88 Int colAlign, Int rowAlign, Int colCut, Int rowCut, 89 elem::Matrix<T>& A, Int root=0 ); 90 void LockedAttach 91 ( Int height, Int width, const elem::Grid& g, 92 Int blockHeight, Int blockWidth, 93 Int colAlign, Int rowAlign, Int colCut, Int rowCut, 94 const elem::Matrix<T>& A, Int root=0 ); 95 96 // Basic queries 97 // ============= 98 99 // Global matrix information 100 // ------------------------- 101 Int Height() const; 102 Int Width() const; 103 Int DiagonalLength( Int offset=0 ) const; 104 bool Viewing() const; 105 bool Locked() const; 106 107 // Local matrix information 108 // ------------------------ 109 Int LocalHeight() const; 110 Int LocalWidth() const; 111 Int LDim() const; 112 elem::Matrix<T>& Matrix(); 113 const elem::Matrix<T>& LockedMatrix() const; 114 size_t AllocatedMemory() const; 115 T* Buffer(); 116 T* Buffer( Int iLoc, Int jLoc ); 117 const T* LockedBuffer() const; 118 const T* LockedBuffer( Int iLoc, Int jLoc ) const; 119 120 // Distribution information 121 // ------------------------ 122 const elem::Grid& Grid() const; 123 bool ColConstrained() const; 124 bool RowConstrained() const; 125 bool RootConstrained() const; 126 Int BlockHeight() const; 127 Int BlockWidth() const; 128 Int ColAlign() const; 129 Int RowAlign() const; 130 Int ColCut() const; 131 Int RowCut() const; 132 Int ColShift() const; 133 Int RowShift() const; 134 Int ColRank() const; 135 Int RowRank() const; 136 Int PartialColRank() const; 137 Int PartialRowRank() const; 138 Int PartialUnionColRank() const; 139 Int PartialUnionRowRank() const; 140 Int DistRank() const; 141 Int CrossRank() const; 142 Int RedundantRank() const; 143 Int Root() const; 144 bool Participating() const; 145 Int RowOwner( Int i ) const; // rank in ColComm 146 Int ColOwner( Int j ) const; // rank in RowComm 147 Int Owner( Int i, Int j ) const; // rank in DistComm 148 Int LocalRow( Int i ) const; // debug throws if row i is not locally owned 149 Int LocalCol( Int j ) const; // debug throws if col j is not locally owned 150 Int LocalRowOffset( Int i ) const; // number of local rows before row i 151 Int LocalColOffset( Int j ) const; // number of local cols before col j 152 Int GlobalRow( Int iLoc ) const; 153 Int GlobalCol( Int jLoc ) const; 154 bool IsLocalRow( Int i ) const; 155 bool IsLocalCol( Int j ) const; 156 bool IsLocal( Int i, Int j ) const; 157 // Must be overridden 158 // ^^^^^^^^^^^^^^^^^^ 159 virtual elem::BlockDistData DistData() const = 0; 160 virtual mpi::Comm DistComm() const = 0; 161 virtual mpi::Comm CrossComm() const = 0; 162 virtual mpi::Comm RedundantComm() const = 0; 163 virtual mpi::Comm ColComm() const = 0; 164 virtual mpi::Comm RowComm() const = 0; 165 virtual mpi::Comm PartialColComm() const; 166 virtual mpi::Comm PartialRowComm() const; 167 virtual mpi::Comm PartialUnionColComm() const; 168 virtual mpi::Comm PartialUnionRowComm() const; 169 virtual Int ColStride() const = 0; 170 virtual Int RowStride() const = 0; 171 virtual Int PartialColStride() const; 172 virtual Int PartialRowStride() const; 173 virtual Int PartialUnionColStride() const; 174 virtual Int PartialUnionRowStride() const; 175 virtual Int DistSize() const = 0; 176 virtual Int CrossSize() const = 0; 177 virtual Int RedundantSize() const = 0; 178 179 // Single-entry manipulation 180 // ========================= 181 182 // Global entry manipulation 183 // ------------------------- 184 // NOTE: Local entry manipulation is often much faster and should be 185 // preferred in most circumstances where performance matters. 186 T Get( Int i, Int j ) const; 187 Base<T> GetRealPart( Int i, Int j ) const; 188 Base<T> GetImagPart( Int i, Int j ) const; 189 void Set( Int i, Int j, T alpha ); 190 void SetRealPart( Int i, Int j, Base<T> alpha ); 191 void SetImagPart( Int i, Int j, Base<T> alpha ); 192 void Update( Int i, Int j, T alpha ); 193 void UpdateRealPart( Int i, Int j, Base<T> alpha ); 194 void UpdateImagPart( Int i, Int j, Base<T> alpha ); 195 void MakeReal( Int i, Int j ); 196 void Conjugate( Int i, Int j ); 197 198 // Local entry manipulation 199 // ------------------------ 200 T GetLocal( Int iLoc, Int jLoc ) const; 201 Base<T> GetLocalRealPart( Int iLoc, Int jLoc ) const; 202 Base<T> GetLocalImagPart( Int iLoc, Int jLoc ) const; 203 void SetLocal( Int iLoc, Int jLoc, T alpha ); 204 void SetLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha ); 205 void SetLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha ); 206 void UpdateLocal( Int iLoc, Int jLoc, T alpha ); 207 void UpdateLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha ); 208 void UpdateLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha ); 209 void MakeLocalReal( Int iLoc, Int jLoc ); 210 void ConjugateLocal( Int iLoc, Int jLoc ); 211 212 // Diagonal manipulation 213 // ===================== 214 void MakeDiagonalReal( Int offset=0 ); 215 void ConjugateDiagonal( Int offset=0 ); 216 217 // Arbitrary-submatrix manipulation 218 // ================================ 219 220 // Global submatrix manipulation 221 // ----------------------------- 222 void GetSubmatrix 223 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 224 DistMatrix<T,STAR,STAR>& ASub ) const; 225 void GetRealPartOfSubmatrix 226 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 227 DistMatrix<Base<T>,STAR,STAR>& ASub ) const; 228 void GetImagPartOfSubmatrix 229 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 230 DistMatrix<Base<T>,STAR,STAR>& ASub ) const; 231 DistMatrix<T,STAR,STAR> GetSubmatrix 232 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const; 233 DistMatrix<Base<T>,STAR,STAR> GetRealPartOfSubmatrix 234 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const; 235 DistMatrix<Base<T>,STAR,STAR> GetImagPartOfSubmatrix 236 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const; 237 238 void SetSubmatrix 239 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 240 const DistMatrix<T,STAR,STAR>& ASub ); 241 void SetRealPartOfSubmatrix 242 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 243 const DistMatrix<Base<T>,STAR,STAR>& ASub ); 244 void SetImagPartOfSubmatrix 245 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 246 const DistMatrix<Base<T>,STAR,STAR>& ASub ); 247 248 void UpdateSubmatrix 249 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 250 T alpha, const DistMatrix<T,STAR,STAR>& ASub ); 251 void UpdateRealPartOfSubmatrix 252 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 253 Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub ); 254 void UpdateImagPartOfSubmatrix 255 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 256 Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub ); 257 258 void MakeSubmatrixReal 259 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ); 260 void ConjugateSubmatrix 261 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ); 262 263 // Local submatrix manipulation 264 // ---------------------------- 265 void GetLocalSubmatrix 266 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 267 elem::Matrix<T>& ASub ) const; 268 void GetRealPartOfLocalSubmatrix 269 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 270 elem::Matrix<Base<T>>& ASub ) const; 271 void GetImagPartOfLocalSubmatrix 272 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 273 elem::Matrix<Base<T>>& ASub ) const; 274 elem::Matrix<T> GetLocalSubmatrix 275 ( const std::vector<Int>& rowIndLoc, 276 const std::vector<Int>& colIndLoc ) const; 277 elem::Matrix<Base<T>> GetRealPartOfLocalSubmatrix 278 ( const std::vector<Int>& rowIndLoc, 279 const std::vector<Int>& colIndLoc ) const; 280 elem::Matrix<Base<T>> GetImagPartOfLocalSubmatrix 281 ( const std::vector<Int>& rowIndLoc, 282 const std::vector<Int>& colIndLoc ) const; 283 284 void SetLocalSubmatrix 285 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 286 const elem::Matrix<T>& ASub ); 287 void SetRealPartOfLocalSubmatrix 288 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 289 const elem::Matrix<Base<T>>& ASub ); 290 void SetImagPartOfLocalSubmatrix 291 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 292 const elem::Matrix<Base<T>>& ASub ); 293 294 void UpdateLocalSubmatrix 295 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 296 T alpha, const elem::Matrix<T>& ASub ); 297 void UpdateRealPartOfLocalSubmatrix 298 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 299 Base<T> alpha, const elem::Matrix<Base<T>>& ASub ); 300 void UpdateImagPartOfLocalSubmatrix 301 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 302 Base<T> alpha, const elem::Matrix<Base<T>>& ASub ); 303 304 void MakeLocalSubmatrixReal 305 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc ); 306 void ConjugateLocalSubmatrix 307 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc ); 308 309 // Sum over a specified communicator 310 // ================================= 311 void SumOver( mpi::Comm comm ); 312 313 // Assertions 314 // ========== 315 void ComplainIfReal() const; 316 void AssertNotLocked() const; 317 void AssertNotStoringData() const; 318 void AssertValidEntry( Int i, Int j ) const; 319 void AssertValidSubmatrix( Int i, Int j, Int height, Int width ) const; 320 void AssertSameGrid( const elem::Grid& grid ) const; 321 void AssertSameSize( Int height, Int width ) const; 322 323 protected: 324 // Member variables 325 // ================ 326 327 // Global and local matrix information 328 // ----------------------------------- 329 ViewType viewType_; 330 Int height_, width_; 331 Memory<T> auxMemory_; 332 elem::Matrix<T> matrix_; 333 334 // Process grid and distribution metadata 335 // -------------------------------------- 336 bool colConstrained_, rowConstrained_, rootConstrained_; 337 Int blockHeight_, blockWidth_; 338 Int colAlign_, rowAlign_, 339 colCut_, rowCut_, 340 colShift_, rowShift_; 341 Int root_; 342 const elem::Grid* grid_; 343 344 // Private constructors 345 // ==================== 346 AbstractBlockDistMatrix( const elem::Grid& g=DefaultGrid(), Int root=0 ); 347 AbstractBlockDistMatrix 348 ( const elem::Grid& g, Int blockHeight, Int blockWidth, Int root=0 ); 349 350 // Exchange metadata with another matrix 351 // ===================================== 352 virtual void ShallowSwap( type& A ); 353 354 // Modify the distribution metadata 355 // ================================ 356 void SetShifts(); 357 void SetColShift(); 358 void SetRowShift(); 359 void SetGrid(); 360 361 // Friend declarations 362 // =================== 363 template<typename S,Dist J,Dist K> friend class GeneralDistMatrix; 364 template<typename S,Dist J,Dist K> friend class DistMatrix; 365 template<typename S,Dist J,Dist K> friend class GeneralBlockDistMatrix; 366 template<typename S,Dist J,Dist K> friend class BlockDistMatrix; 367 }; 368 369 template<typename T> 370 void AssertConforming1x2 371 ( const AbstractBlockDistMatrix<T>& AL, const AbstractBlockDistMatrix<T>& AR ); 372 373 template<typename T> 374 void AssertConforming2x1 375 ( const AbstractBlockDistMatrix<T>& AT, const AbstractBlockDistMatrix<T>& AB ); 376 377 template<typename T> 378 void AssertConforming2x2 379 ( const AbstractBlockDistMatrix<T>& ATL, 380 const AbstractBlockDistMatrix<T>& ATR, 381 const AbstractBlockDistMatrix<T>& ABL, 382 const AbstractBlockDistMatrix<T>& ABR ); 383 384 } // namespace elem 385 386 #endif // ifndef ELEM_BLOCKDISTMATRIX_ABSTRACT_DECL_HPP 387