1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_DISTMATRIX_ABSTRACT_DECL_HPP
11 #define ELEM_DISTMATRIX_ABSTRACT_DECL_HPP
12 
13 namespace elem {
14 
15 template<typename T>
16 class AbstractDistMatrix
17 {
18 public:
19     // Typedefs
20     // ========
21     typedef AbstractDistMatrix<T> type;
22 
23     // Constructors and destructors
24     // ============================
25     // Move constructor
26     AbstractDistMatrix( type&& A ) ELEM_NOEXCEPT;
27 
28     virtual ~AbstractDistMatrix();
29 
30     // Assignment and reconfiguration
31     // ==============================
32     // Move assignment
33     type& operator=( type&& A );
34 
35     void Empty();
36     void EmptyData();
37     void SetGrid( const elem::Grid& grid );
38     void Resize( Int height, Int width );
39     void Resize( Int height, Int width, Int ldim );
40     void MakeConsistent( bool includingViewers=false );
41     void MakeSizeConsistent( bool includingViewers=false );
42 
43     // Realignment
44     // -----------
45     void Align( Int colAlign, Int rowAlign, bool constrain=true );
46     void AlignCols( Int colAlign, bool constrain=true );
47     void AlignRows( Int rowAlign, bool constrain=true );
48     void FreeAlignments();
49     void SetRoot( Int root, bool constrain=true );
50     void AlignWith( const elem::DistData& data, bool constrain=true );
51     virtual void AlignColsWith
52     ( const elem::DistData& data, bool constrain=true );
53     virtual void AlignRowsWith
54     ( const elem::DistData& data, bool constrain=true );
55     void AlignAndResize
56     ( Int colAlign, Int rowAlign, Int height, Int width,
57       bool force=false, bool constrain=true );
58     void AlignColsAndResize
59     ( Int colAlign, Int height, Int width,
60       bool force=false, bool constrain=true );
61     void AlignRowsAndResize
62     ( Int rowAlign, Int height, Int width,
63       bool force=false, bool constrain=true );
64 
65     // Buffer attachment
66     // -----------------
67     // (Immutable) view of a distributed matrix's buffer
68     void Attach
69     ( Int height, Int width, const elem::Grid& grid,
70       Int colAlign, Int rowAlign, T* buffer, Int ldim, Int root=0 );
71     void LockedAttach
72     ( Int height, Int width, const elem::Grid& grid,
73       Int colAlign, Int rowAlign, const T* buffer, Int ldim, Int root=0 );
74     void Attach
75     ( Int height, Int width, const elem::Grid& grid,
76       Int colAlign, Int rowAlign, elem::Matrix<T>& A, Int root=0 );
77     void LockedAttach
78     ( Int height, Int width, const elem::Grid& grid,
79       Int colAlign, Int rowAlign, const elem::Matrix<T>& A, Int root=0 );
80 
81     // Basic queries
82     // =============
83 
84     // Global matrix information
85     // -------------------------
86     Int Height() const;
87     Int Width() const;
88     Int DiagonalLength( Int offset=0 ) const;
89     bool Viewing() const;
90     bool Locked()  const;
91 
92     // Local matrix information
93     // ------------------------
94     Int LocalHeight() const;
95     Int LocalWidth() const;
96     Int LDim() const;
97           elem::Matrix<T>& Matrix();
98     const elem::Matrix<T>& LockedMatrix() const;
99     size_t AllocatedMemory() const;
100           T* Buffer();
101           T* Buffer( Int iLoc, Int jLoc );
102     const T* LockedBuffer() const;
103     const T* LockedBuffer( Int iLoc, Int jLoc ) const;
104 
105     // Distribution information
106     // ------------------------
107     const elem::Grid& Grid() const;
108     bool ColConstrained() const;
109     bool RowConstrained() const;
110     bool RootConstrained() const;
111     Int ColAlign() const;
112     Int RowAlign() const;
113     Int ColShift() const;
114     Int RowShift() const;
115     Int ColRank() const;
116     Int RowRank() const;
117     Int PartialColRank() const;
118     Int PartialRowRank() const;
119     Int PartialUnionColRank() const;
120     Int PartialUnionRowRank() const;
121     Int DistRank() const;
122     Int CrossRank() const;
123     Int RedundantRank() const;
124     Int Root() const;
125     bool Participating() const;
126     Int RowOwner( Int i ) const;     // rank in ColComm
127     Int ColOwner( Int j ) const;     // rank in RowComm
128     Int Owner( Int i, Int j ) const; // rank in DistComm
129     Int LocalRow( Int i ) const; // debug throws if row i is not locally owned
130     Int LocalCol( Int j ) const; // debug throws if col j is not locally owned
131     Int LocalRowOffset( Int i ) const; // number of local rows before row i
132     Int LocalColOffset( Int j ) const; // number of local cols before col j
133     Int GlobalRow( Int iLoc ) const;
134     Int GlobalCol( Int jLoc ) const;
135     bool IsLocalRow( Int i ) const;
136     bool IsLocalCol( Int j ) const;
137     bool IsLocal( Int i, Int j ) const;
138     // Must be overridden
139     // ^^^^^^^^^^^^^^^^^^
140     virtual elem::DistData DistData() const = 0;
141     virtual mpi::Comm DistComm() const = 0;
142     virtual mpi::Comm CrossComm() const = 0;
143     virtual mpi::Comm RedundantComm() const = 0;
144     virtual mpi::Comm ColComm() const = 0;
145     virtual mpi::Comm RowComm() const = 0;
146     virtual mpi::Comm PartialColComm() const;
147     virtual mpi::Comm PartialRowComm() const;
148     virtual mpi::Comm PartialUnionColComm() const;
149     virtual mpi::Comm PartialUnionRowComm() const;
150     virtual Int ColStride() const = 0;
151     virtual Int RowStride() const = 0;
152     virtual Int PartialColStride() const;
153     virtual Int PartialRowStride() const;
154     virtual Int PartialUnionColStride() const;
155     virtual Int PartialUnionRowStride() const;
156     virtual Int DistSize() const = 0;
157     virtual Int CrossSize() const = 0;
158     virtual Int RedundantSize() const = 0;
159 
160     // Single-entry manipulation
161     // =========================
162 
163     // Global entry manipulation
164     // -------------------------
165     // NOTE: Local entry manipulation is often much faster and should be
166     //       preferred in most circumstances where performance matters.
167     T Get( Int i, Int j ) const;
168     Base<T> GetRealPart( Int i, Int j ) const;
169     Base<T> GetImagPart( Int i, Int j ) const;
170     void Set( Int i, Int j, T alpha );
171     void SetRealPart( Int i, Int j, Base<T> alpha );
172     void SetImagPart( Int i, Int j, Base<T> alpha );
173     void Update( Int i, Int j, T alpha );
174     void UpdateRealPart( Int i, Int j, Base<T> alpha );
175     void UpdateImagPart( Int i, Int j, Base<T> alpha );
176     void MakeReal( Int i, Int j );
177     void Conjugate( Int i, Int j );
178 
179     // Local entry manipulation
180     // ------------------------
181     T GetLocal( Int iLoc, Int jLoc ) const;
182     Base<T> GetLocalRealPart( Int iLoc, Int jLoc ) const;
183     Base<T> GetLocalImagPart( Int iLoc, Int jLoc ) const;
184     void SetLocal( Int iLoc, Int jLoc, T alpha );
185     void SetLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha );
186     void SetLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha );
187     void UpdateLocal( Int iLoc, Int jLoc, T alpha );
188     void UpdateLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha );
189     void UpdateLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha );
190     void MakeLocalReal( Int iLoc, Int jLoc );
191     void ConjugateLocal( Int iLoc, Int jLoc );
192 
193     // Diagonal manipulation
194     // =====================
195     void MakeDiagonalReal( Int offset=0 );
196     void ConjugateDiagonal( Int offset=0 );
197 
198     // Arbitrary-submatrix manipulation
199     // ================================
200 
201     // Global submatrix manipulation
202     // -----------------------------
203     void GetSubmatrix
204     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
205       DistMatrix<T,STAR,STAR>& ASub ) const;
206     void GetRealPartOfSubmatrix
207     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
208       DistMatrix<Base<T>,STAR,STAR>& ASub ) const;
209     void GetImagPartOfSubmatrix
210     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
211       DistMatrix<Base<T>,STAR,STAR>& ASub ) const;
212 
213     DistMatrix<T,STAR,STAR> GetSubmatrix
214     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const;
215     DistMatrix<Base<T>,STAR,STAR> GetRealPartOfSubmatrix
216     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const;
217     DistMatrix<Base<T>,STAR,STAR> GetImagPartOfSubmatrix
218     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const;
219 
220     void SetSubmatrix
221     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
222       const DistMatrix<T,STAR,STAR>& ASub );
223     void SetRealPartOfSubmatrix
224     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
225       const DistMatrix<Base<T>,STAR,STAR>& ASub );
226     void SetImagPartOfSubmatrix
227     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
228       const DistMatrix<Base<T>,STAR,STAR>& ASub );
229 
230     void UpdateSubmatrix
231     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
232       T alpha, const DistMatrix<T,STAR,STAR>& ASub );
233     void UpdateRealPartOfSubmatrix
234     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
235       Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub );
236     void UpdateImagPartOfSubmatrix
237     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
238       Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub );
239 
240     void MakeSubmatrixReal
241     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd );
242     void ConjugateSubmatrix
243     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd );
244 
245     // Local submatrix manipulation
246     // ----------------------------
247     void GetLocalSubmatrix
248     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
249       elem::Matrix<T>& ASub ) const;
250     void GetRealPartOfLocalSubmatrix
251     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
252       elem::Matrix<Base<T>>& ASub ) const;
253     void GetImagPartOfLocalSubmatrix
254     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
255       elem::Matrix<Base<T>>& ASub ) const;
256 
257     elem::Matrix<T> GetLocalSubmatrix
258     ( const std::vector<Int>& rowIndLoc,
259       const std::vector<Int>& colIndLoc ) const;
260     elem::Matrix<Base<T>> GetRealPartOfLocalSubmatrix
261     ( const std::vector<Int>& rowIndLoc,
262       const std::vector<Int>& colIndLoc ) const;
263     elem::Matrix<Base<T>> GetImagPartOfLocalSubmatrix
264     ( const std::vector<Int>& rowIndLoc,
265       const std::vector<Int>& colIndLoc ) const;
266 
267     void SetLocalSubmatrix
268     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
269       const elem::Matrix<T>& ASub );
270     void SetRealPartOfLocalSubmatrix
271     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
272       const elem::Matrix<Base<T>>& ASub );
273     void SetImagPartOfLocalSubmatrix
274     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
275       const elem::Matrix<Base<T>>& ASub );
276 
277     void UpdateLocalSubmatrix
278     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
279       T alpha, const elem::Matrix<T>& ASub );
280     void UpdateRealPartOfLocalSubmatrix
281     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
282       Base<T> alpha, const elem::Matrix<Base<T>>& ASub );
283     void UpdateImagPartOfLocalSubmatrix
284     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
285       Base<T> alpha, const elem::Matrix<Base<T>>& ASub );
286 
287     void MakeLocalSubmatrixReal
288     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc );
289     void ConjugateLocalSubmatrix
290     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc );
291 
292     // Sum over a specified communicator
293     // =================================
294     void SumOver( mpi::Comm comm );
295 
296     // Assertions
297     // ==========
298     void ComplainIfReal() const;
299     void AssertNotLocked() const;
300     void AssertNotStoringData() const;
301     void AssertValidEntry( Int i, Int j ) const;
302     void AssertValidSubmatrix( Int i, Int j, Int height, Int width ) const;
303     void AssertSameGrid( const elem::Grid& grid ) const;
304     void AssertSameSize( Int height, Int width ) const;
305 
306 protected:
307     // Member variables
308     // ================
309 
310     // Global and local matrix information
311     // -----------------------------------
312     ViewType viewType_;
313     Int height_, width_;
314     Memory<T> auxMemory_;
315     elem::Matrix<T> matrix_;
316 
317     // Process grid and distribution metadata
318     // --------------------------------------
319     bool colConstrained_, rowConstrained_, rootConstrained_;
320     Int colAlign_, rowAlign_,
321         colShift_, rowShift_;
322     Int root_;
323     const elem::Grid* grid_;
324 
325     // Private constructors
326     // ====================
327     // Create a 0 x 0 distributed matrix
328     AbstractDistMatrix( const elem::Grid& g=DefaultGrid(), Int root=0 );
329 
330     // Exchange metadata with another matrix
331     // =====================================
332     virtual void ShallowSwap( type& A );
333 
334     // Modify the distribution metadata
335     // ================================
336     void SetShifts();
337     void SetColShift();
338     void SetRowShift();
339     void SetGrid();
340 
341     // Friend declarations
342     // ===================
343     template<typename S,Dist J,Dist K> friend class GeneralDistMatrix;
344     template<typename S,Dist J,Dist K> friend class DistMatrix;
345 
346     template<typename S,Dist J,Dist K> friend class GeneralBlockDistMatrix;
347     template<typename S,Dist J,Dist K> friend class BlockDistMatrix;
348 };
349 
350 template<typename T>
351 void AssertConforming1x2
352 ( const AbstractDistMatrix<T>& AL, const AbstractDistMatrix<T>& AR );
353 
354 template<typename T>
355 void AssertConforming2x1
356 ( const AbstractDistMatrix<T>& AT, const AbstractDistMatrix<T>& AB );
357 
358 template<typename T>
359 void AssertConforming2x2
360 ( const AbstractDistMatrix<T>& ATL, const AbstractDistMatrix<T>& ATR,
361   const AbstractDistMatrix<T>& ABL, const AbstractDistMatrix<T>& ABR );
362 
363 } // namespace elem
364 
365 #endif // ifndef ELEM_DISTMATRIX_ABSTRACT_DECL_HPP
366