1 /* 2 Copyright (c) 2009-2014, Jack Poulson 3 All rights reserved. 4 5 This file is part of Elemental and is under the BSD 2-Clause License, 6 which can be found in the LICENSE file in the root directory, or at 7 http://opensource.org/licenses/BSD-2-Clause 8 */ 9 #pragma once 10 #ifndef ELEM_DISTMATRIX_ABSTRACT_DECL_HPP 11 #define ELEM_DISTMATRIX_ABSTRACT_DECL_HPP 12 13 namespace elem { 14 15 template<typename T> 16 class AbstractDistMatrix 17 { 18 public: 19 // Typedefs 20 // ======== 21 typedef AbstractDistMatrix<T> type; 22 23 // Constructors and destructors 24 // ============================ 25 // Move constructor 26 AbstractDistMatrix( type&& A ) ELEM_NOEXCEPT; 27 28 virtual ~AbstractDistMatrix(); 29 30 // Assignment and reconfiguration 31 // ============================== 32 // Move assignment 33 type& operator=( type&& A ); 34 35 void Empty(); 36 void EmptyData(); 37 void SetGrid( const elem::Grid& grid ); 38 void Resize( Int height, Int width ); 39 void Resize( Int height, Int width, Int ldim ); 40 void MakeConsistent( bool includingViewers=false ); 41 void MakeSizeConsistent( bool includingViewers=false ); 42 43 // Realignment 44 // ----------- 45 void Align( Int colAlign, Int rowAlign, bool constrain=true ); 46 void AlignCols( Int colAlign, bool constrain=true ); 47 void AlignRows( Int rowAlign, bool constrain=true ); 48 void FreeAlignments(); 49 void SetRoot( Int root, bool constrain=true ); 50 void AlignWith( const elem::DistData& data, bool constrain=true ); 51 virtual void AlignColsWith 52 ( const elem::DistData& data, bool constrain=true ); 53 virtual void AlignRowsWith 54 ( const elem::DistData& data, bool constrain=true ); 55 void AlignAndResize 56 ( Int colAlign, Int rowAlign, Int height, Int width, 57 bool force=false, bool constrain=true ); 58 void AlignColsAndResize 59 ( Int colAlign, Int height, Int width, 60 bool force=false, bool constrain=true ); 61 void AlignRowsAndResize 62 ( Int rowAlign, Int height, Int width, 63 bool force=false, bool constrain=true ); 64 65 // Buffer attachment 66 // ----------------- 67 // (Immutable) view of a distributed matrix's buffer 68 void Attach 69 ( Int height, Int width, const elem::Grid& grid, 70 Int colAlign, Int rowAlign, T* buffer, Int ldim, Int root=0 ); 71 void LockedAttach 72 ( Int height, Int width, const elem::Grid& grid, 73 Int colAlign, Int rowAlign, const T* buffer, Int ldim, Int root=0 ); 74 void Attach 75 ( Int height, Int width, const elem::Grid& grid, 76 Int colAlign, Int rowAlign, elem::Matrix<T>& A, Int root=0 ); 77 void LockedAttach 78 ( Int height, Int width, const elem::Grid& grid, 79 Int colAlign, Int rowAlign, const elem::Matrix<T>& A, Int root=0 ); 80 81 // Basic queries 82 // ============= 83 84 // Global matrix information 85 // ------------------------- 86 Int Height() const; 87 Int Width() const; 88 Int DiagonalLength( Int offset=0 ) const; 89 bool Viewing() const; 90 bool Locked() const; 91 92 // Local matrix information 93 // ------------------------ 94 Int LocalHeight() const; 95 Int LocalWidth() const; 96 Int LDim() const; 97 elem::Matrix<T>& Matrix(); 98 const elem::Matrix<T>& LockedMatrix() const; 99 size_t AllocatedMemory() const; 100 T* Buffer(); 101 T* Buffer( Int iLoc, Int jLoc ); 102 const T* LockedBuffer() const; 103 const T* LockedBuffer( Int iLoc, Int jLoc ) const; 104 105 // Distribution information 106 // ------------------------ 107 const elem::Grid& Grid() const; 108 bool ColConstrained() const; 109 bool RowConstrained() const; 110 bool RootConstrained() const; 111 Int ColAlign() const; 112 Int RowAlign() const; 113 Int ColShift() const; 114 Int RowShift() const; 115 Int ColRank() const; 116 Int RowRank() const; 117 Int PartialColRank() const; 118 Int PartialRowRank() const; 119 Int PartialUnionColRank() const; 120 Int PartialUnionRowRank() const; 121 Int DistRank() const; 122 Int CrossRank() const; 123 Int RedundantRank() const; 124 Int Root() const; 125 bool Participating() const; 126 Int RowOwner( Int i ) const; // rank in ColComm 127 Int ColOwner( Int j ) const; // rank in RowComm 128 Int Owner( Int i, Int j ) const; // rank in DistComm 129 Int LocalRow( Int i ) const; // debug throws if row i is not locally owned 130 Int LocalCol( Int j ) const; // debug throws if col j is not locally owned 131 Int LocalRowOffset( Int i ) const; // number of local rows before row i 132 Int LocalColOffset( Int j ) const; // number of local cols before col j 133 Int GlobalRow( Int iLoc ) const; 134 Int GlobalCol( Int jLoc ) const; 135 bool IsLocalRow( Int i ) const; 136 bool IsLocalCol( Int j ) const; 137 bool IsLocal( Int i, Int j ) const; 138 // Must be overridden 139 // ^^^^^^^^^^^^^^^^^^ 140 virtual elem::DistData DistData() const = 0; 141 virtual mpi::Comm DistComm() const = 0; 142 virtual mpi::Comm CrossComm() const = 0; 143 virtual mpi::Comm RedundantComm() const = 0; 144 virtual mpi::Comm ColComm() const = 0; 145 virtual mpi::Comm RowComm() const = 0; 146 virtual mpi::Comm PartialColComm() const; 147 virtual mpi::Comm PartialRowComm() const; 148 virtual mpi::Comm PartialUnionColComm() const; 149 virtual mpi::Comm PartialUnionRowComm() const; 150 virtual Int ColStride() const = 0; 151 virtual Int RowStride() const = 0; 152 virtual Int PartialColStride() const; 153 virtual Int PartialRowStride() const; 154 virtual Int PartialUnionColStride() const; 155 virtual Int PartialUnionRowStride() const; 156 virtual Int DistSize() const = 0; 157 virtual Int CrossSize() const = 0; 158 virtual Int RedundantSize() const = 0; 159 160 // Single-entry manipulation 161 // ========================= 162 163 // Global entry manipulation 164 // ------------------------- 165 // NOTE: Local entry manipulation is often much faster and should be 166 // preferred in most circumstances where performance matters. 167 T Get( Int i, Int j ) const; 168 Base<T> GetRealPart( Int i, Int j ) const; 169 Base<T> GetImagPart( Int i, Int j ) const; 170 void Set( Int i, Int j, T alpha ); 171 void SetRealPart( Int i, Int j, Base<T> alpha ); 172 void SetImagPart( Int i, Int j, Base<T> alpha ); 173 void Update( Int i, Int j, T alpha ); 174 void UpdateRealPart( Int i, Int j, Base<T> alpha ); 175 void UpdateImagPart( Int i, Int j, Base<T> alpha ); 176 void MakeReal( Int i, Int j ); 177 void Conjugate( Int i, Int j ); 178 179 // Local entry manipulation 180 // ------------------------ 181 T GetLocal( Int iLoc, Int jLoc ) const; 182 Base<T> GetLocalRealPart( Int iLoc, Int jLoc ) const; 183 Base<T> GetLocalImagPart( Int iLoc, Int jLoc ) const; 184 void SetLocal( Int iLoc, Int jLoc, T alpha ); 185 void SetLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha ); 186 void SetLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha ); 187 void UpdateLocal( Int iLoc, Int jLoc, T alpha ); 188 void UpdateLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha ); 189 void UpdateLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha ); 190 void MakeLocalReal( Int iLoc, Int jLoc ); 191 void ConjugateLocal( Int iLoc, Int jLoc ); 192 193 // Diagonal manipulation 194 // ===================== 195 void MakeDiagonalReal( Int offset=0 ); 196 void ConjugateDiagonal( Int offset=0 ); 197 198 // Arbitrary-submatrix manipulation 199 // ================================ 200 201 // Global submatrix manipulation 202 // ----------------------------- 203 void GetSubmatrix 204 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 205 DistMatrix<T,STAR,STAR>& ASub ) const; 206 void GetRealPartOfSubmatrix 207 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 208 DistMatrix<Base<T>,STAR,STAR>& ASub ) const; 209 void GetImagPartOfSubmatrix 210 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 211 DistMatrix<Base<T>,STAR,STAR>& ASub ) const; 212 213 DistMatrix<T,STAR,STAR> GetSubmatrix 214 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const; 215 DistMatrix<Base<T>,STAR,STAR> GetRealPartOfSubmatrix 216 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const; 217 DistMatrix<Base<T>,STAR,STAR> GetImagPartOfSubmatrix 218 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const; 219 220 void SetSubmatrix 221 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 222 const DistMatrix<T,STAR,STAR>& ASub ); 223 void SetRealPartOfSubmatrix 224 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 225 const DistMatrix<Base<T>,STAR,STAR>& ASub ); 226 void SetImagPartOfSubmatrix 227 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 228 const DistMatrix<Base<T>,STAR,STAR>& ASub ); 229 230 void UpdateSubmatrix 231 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 232 T alpha, const DistMatrix<T,STAR,STAR>& ASub ); 233 void UpdateRealPartOfSubmatrix 234 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 235 Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub ); 236 void UpdateImagPartOfSubmatrix 237 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd, 238 Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub ); 239 240 void MakeSubmatrixReal 241 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ); 242 void ConjugateSubmatrix 243 ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ); 244 245 // Local submatrix manipulation 246 // ---------------------------- 247 void GetLocalSubmatrix 248 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 249 elem::Matrix<T>& ASub ) const; 250 void GetRealPartOfLocalSubmatrix 251 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 252 elem::Matrix<Base<T>>& ASub ) const; 253 void GetImagPartOfLocalSubmatrix 254 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 255 elem::Matrix<Base<T>>& ASub ) const; 256 257 elem::Matrix<T> GetLocalSubmatrix 258 ( const std::vector<Int>& rowIndLoc, 259 const std::vector<Int>& colIndLoc ) const; 260 elem::Matrix<Base<T>> GetRealPartOfLocalSubmatrix 261 ( const std::vector<Int>& rowIndLoc, 262 const std::vector<Int>& colIndLoc ) const; 263 elem::Matrix<Base<T>> GetImagPartOfLocalSubmatrix 264 ( const std::vector<Int>& rowIndLoc, 265 const std::vector<Int>& colIndLoc ) const; 266 267 void SetLocalSubmatrix 268 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 269 const elem::Matrix<T>& ASub ); 270 void SetRealPartOfLocalSubmatrix 271 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 272 const elem::Matrix<Base<T>>& ASub ); 273 void SetImagPartOfLocalSubmatrix 274 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 275 const elem::Matrix<Base<T>>& ASub ); 276 277 void UpdateLocalSubmatrix 278 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 279 T alpha, const elem::Matrix<T>& ASub ); 280 void UpdateRealPartOfLocalSubmatrix 281 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 282 Base<T> alpha, const elem::Matrix<Base<T>>& ASub ); 283 void UpdateImagPartOfLocalSubmatrix 284 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc, 285 Base<T> alpha, const elem::Matrix<Base<T>>& ASub ); 286 287 void MakeLocalSubmatrixReal 288 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc ); 289 void ConjugateLocalSubmatrix 290 ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc ); 291 292 // Sum over a specified communicator 293 // ================================= 294 void SumOver( mpi::Comm comm ); 295 296 // Assertions 297 // ========== 298 void ComplainIfReal() const; 299 void AssertNotLocked() const; 300 void AssertNotStoringData() const; 301 void AssertValidEntry( Int i, Int j ) const; 302 void AssertValidSubmatrix( Int i, Int j, Int height, Int width ) const; 303 void AssertSameGrid( const elem::Grid& grid ) const; 304 void AssertSameSize( Int height, Int width ) const; 305 306 protected: 307 // Member variables 308 // ================ 309 310 // Global and local matrix information 311 // ----------------------------------- 312 ViewType viewType_; 313 Int height_, width_; 314 Memory<T> auxMemory_; 315 elem::Matrix<T> matrix_; 316 317 // Process grid and distribution metadata 318 // -------------------------------------- 319 bool colConstrained_, rowConstrained_, rootConstrained_; 320 Int colAlign_, rowAlign_, 321 colShift_, rowShift_; 322 Int root_; 323 const elem::Grid* grid_; 324 325 // Private constructors 326 // ==================== 327 // Create a 0 x 0 distributed matrix 328 AbstractDistMatrix( const elem::Grid& g=DefaultGrid(), Int root=0 ); 329 330 // Exchange metadata with another matrix 331 // ===================================== 332 virtual void ShallowSwap( type& A ); 333 334 // Modify the distribution metadata 335 // ================================ 336 void SetShifts(); 337 void SetColShift(); 338 void SetRowShift(); 339 void SetGrid(); 340 341 // Friend declarations 342 // =================== 343 template<typename S,Dist J,Dist K> friend class GeneralDistMatrix; 344 template<typename S,Dist J,Dist K> friend class DistMatrix; 345 346 template<typename S,Dist J,Dist K> friend class GeneralBlockDistMatrix; 347 template<typename S,Dist J,Dist K> friend class BlockDistMatrix; 348 }; 349 350 template<typename T> 351 void AssertConforming1x2 352 ( const AbstractDistMatrix<T>& AL, const AbstractDistMatrix<T>& AR ); 353 354 template<typename T> 355 void AssertConforming2x1 356 ( const AbstractDistMatrix<T>& AT, const AbstractDistMatrix<T>& AB ); 357 358 template<typename T> 359 void AssertConforming2x2 360 ( const AbstractDistMatrix<T>& ATL, const AbstractDistMatrix<T>& ATR, 361 const AbstractDistMatrix<T>& ABL, const AbstractDistMatrix<T>& ABR ); 362 363 } // namespace elem 364 365 #endif // ifndef ELEM_DISTMATRIX_ABSTRACT_DECL_HPP 366