1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_BLOCKDISTMATRIX_ABSTRACT_DECL_HPP
11 #define ELEM_BLOCKDISTMATRIX_ABSTRACT_DECL_HPP
12 
13 namespace elem {
14 
15 template<typename T>
16 class AbstractBlockDistMatrix
17 {
18 public:
19     // Typedefs
20     // ========
21     typedef AbstractBlockDistMatrix<T> type;
22 
23     // Constructors and destructors
24     // ============================
25     // Move constructor
26     AbstractBlockDistMatrix( type&& A ) ELEM_NOEXCEPT;
27 
28     virtual ~AbstractBlockDistMatrix();
29 
30     // Assignment and reconfiguration
31     // ==============================
32     // Move assignment
33     type& operator=( type&& A );
34 
35     void Empty();
36     void EmptyData();
37     void SetGrid( const elem::Grid& grid );
38     void Resize( Int height, Int width );
39     void Resize( Int height, Int width, Int ldim );
40     void MakeConsistent( bool includingViewers=false );
41     void MakeSizeConsistent( bool includingViewers=false );
42 
43     // Realignment
44     // -----------
45     void Align
46     ( Int blockHeight, Int blockWidth,
47       Int colAlign, Int rowAlign, Int colCut=0, Int rowCut=0,
48       bool constrain=true );
49     void AlignCols
50     ( Int blockHeight, Int colAlign, Int colCut=0, bool constrain=true );
51     void AlignRows
52     ( Int blockWidth, Int rowAlign, Int rowCut=0, bool constrain=true );
53     void FreeAlignments();
54     void SetRoot( Int root, bool constrain=true );
55     void AlignWith( const elem::BlockDistData& data, bool constrain=true );
56     virtual void AlignColsWith
57     ( const elem::BlockDistData& data, bool constrain=true );
58     virtual void AlignRowsWith
59     ( const elem::BlockDistData& data, bool constrain=true );
60     // TODO: The interface for these routines could be improved
61     void AlignAndResize
62     ( Int blockHeight, Int blockWidth,
63       Int colAlign, Int rowAlign, Int colCut, Int rowCut,
64       Int height, Int width, bool force=false, bool constrain=true );
65     void AlignColsAndResize
66     ( Int blockHeight, Int colAlign, Int colCut, Int height, Int width,
67       bool force=false, bool constrain=true );
68     void AlignRowsAndResize
69     ( Int blockWidth, Int rowAlign, Int rowCut, Int height, Int width,
70       bool force=false, bool constrain=true );
71 
72     // Buffer attachment
73     // -----------------
74     // (Immutable) view of a distributed matrix's buffer
75     void Attach
76     ( Int height, Int width, const elem::Grid& g,
77       Int blockHeight, Int blockWidth,
78       Int colAlign, Int rowAlign, Int colCut, Int rowCut,
79       T* buffer, Int ldim, Int root=0 );
80     void LockedAttach
81     ( Int height, Int width, const elem::Grid& g,
82       Int blockHeight, Int blockWidth,
83       Int colAlign, Int rowAlign, Int colCut, Int rowCut,
84       const T* buffer, Int ldim, Int root=0 );
85     void Attach
86     ( Int height, Int width, const elem::Grid& g,
87       Int blockHeight, Int blockWidth,
88       Int colAlign, Int rowAlign, Int colCut, Int rowCut,
89       elem::Matrix<T>& A, Int root=0 );
90     void LockedAttach
91     ( Int height, Int width, const elem::Grid& g,
92       Int blockHeight, Int blockWidth,
93       Int colAlign, Int rowAlign, Int colCut, Int rowCut,
94       const elem::Matrix<T>& A, Int root=0 );
95 
96     // Basic queries
97     // =============
98 
99     // Global matrix information
100     // -------------------------
101     Int Height() const;
102     Int Width() const;
103     Int DiagonalLength( Int offset=0 ) const;
104     bool Viewing() const;
105     bool Locked()  const;
106 
107     // Local matrix information
108     // ------------------------
109     Int LocalHeight() const;
110     Int LocalWidth() const;
111     Int LDim() const;
112           elem::Matrix<T>& Matrix();
113     const elem::Matrix<T>& LockedMatrix() const;
114     size_t AllocatedMemory() const;
115           T* Buffer();
116           T* Buffer( Int iLoc, Int jLoc );
117     const T* LockedBuffer() const;
118     const T* LockedBuffer( Int iLoc, Int jLoc ) const;
119 
120     // Distribution information
121     // ------------------------
122     const elem::Grid& Grid() const;
123     bool ColConstrained() const;
124     bool RowConstrained() const;
125     bool RootConstrained() const;
126     Int BlockHeight() const;
127     Int BlockWidth() const;
128     Int ColAlign() const;
129     Int RowAlign() const;
130     Int ColCut() const;
131     Int RowCut() const;
132     Int ColShift() const;
133     Int RowShift() const;
134     Int ColRank() const;
135     Int RowRank() const;
136     Int PartialColRank() const;
137     Int PartialRowRank() const;
138     Int PartialUnionColRank() const;
139     Int PartialUnionRowRank() const;
140     Int DistRank() const;
141     Int CrossRank() const;
142     Int RedundantRank() const;
143     Int Root() const;
144     bool Participating() const;
145     Int RowOwner( Int i ) const;     // rank in ColComm
146     Int ColOwner( Int j ) const;     // rank in RowComm
147     Int Owner( Int i, Int j ) const; // rank in DistComm
148     Int LocalRow( Int i ) const; // debug throws if row i is not locally owned
149     Int LocalCol( Int j ) const; // debug throws if col j is not locally owned
150     Int LocalRowOffset( Int i ) const; // number of local rows before row i
151     Int LocalColOffset( Int j ) const; // number of local cols before col j
152     Int GlobalRow( Int iLoc ) const;
153     Int GlobalCol( Int jLoc ) const;
154     bool IsLocalRow( Int i ) const;
155     bool IsLocalCol( Int j ) const;
156     bool IsLocal( Int i, Int j ) const;
157     // Must be overridden
158     // ^^^^^^^^^^^^^^^^^^
159     virtual elem::BlockDistData DistData() const = 0;
160     virtual mpi::Comm DistComm() const = 0;
161     virtual mpi::Comm CrossComm() const = 0;
162     virtual mpi::Comm RedundantComm() const = 0;
163     virtual mpi::Comm ColComm() const = 0;
164     virtual mpi::Comm RowComm() const = 0;
165     virtual mpi::Comm PartialColComm() const;
166     virtual mpi::Comm PartialRowComm() const;
167     virtual mpi::Comm PartialUnionColComm() const;
168     virtual mpi::Comm PartialUnionRowComm() const;
169     virtual Int ColStride() const = 0;
170     virtual Int RowStride() const = 0;
171     virtual Int PartialColStride() const;
172     virtual Int PartialRowStride() const;
173     virtual Int PartialUnionColStride() const;
174     virtual Int PartialUnionRowStride() const;
175     virtual Int DistSize() const = 0;
176     virtual Int CrossSize() const = 0;
177     virtual Int RedundantSize() const = 0;
178 
179     // Single-entry manipulation
180     // =========================
181 
182     // Global entry manipulation
183     // -------------------------
184     // NOTE: Local entry manipulation is often much faster and should be
185     //       preferred in most circumstances where performance matters.
186     T Get( Int i, Int j ) const;
187     Base<T> GetRealPart( Int i, Int j ) const;
188     Base<T> GetImagPart( Int i, Int j ) const;
189     void Set( Int i, Int j, T alpha );
190     void SetRealPart( Int i, Int j, Base<T> alpha );
191     void SetImagPart( Int i, Int j, Base<T> alpha );
192     void Update( Int i, Int j, T alpha );
193     void UpdateRealPart( Int i, Int j, Base<T> alpha );
194     void UpdateImagPart( Int i, Int j, Base<T> alpha );
195     void MakeReal( Int i, Int j );
196     void Conjugate( Int i, Int j );
197 
198     // Local entry manipulation
199     // ------------------------
200     T GetLocal( Int iLoc, Int jLoc ) const;
201     Base<T> GetLocalRealPart( Int iLoc, Int jLoc ) const;
202     Base<T> GetLocalImagPart( Int iLoc, Int jLoc ) const;
203     void SetLocal( Int iLoc, Int jLoc, T alpha );
204     void SetLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha );
205     void SetLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha );
206     void UpdateLocal( Int iLoc, Int jLoc, T alpha );
207     void UpdateLocalRealPart( Int iLoc, Int jLoc, Base<T> alpha );
208     void UpdateLocalImagPart( Int iLoc, Int jLoc, Base<T> alpha );
209     void MakeLocalReal( Int iLoc, Int jLoc );
210     void ConjugateLocal( Int iLoc, Int jLoc );
211 
212     // Diagonal manipulation
213     // =====================
214     void MakeDiagonalReal( Int offset=0 );
215     void ConjugateDiagonal( Int offset=0 );
216 
217     // Arbitrary-submatrix manipulation
218     // ================================
219 
220     // Global submatrix manipulation
221     // -----------------------------
222     void GetSubmatrix
223     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
224       DistMatrix<T,STAR,STAR>& ASub ) const;
225     void GetRealPartOfSubmatrix
226     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
227       DistMatrix<Base<T>,STAR,STAR>& ASub ) const;
228     void GetImagPartOfSubmatrix
229     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
230       DistMatrix<Base<T>,STAR,STAR>& ASub ) const;
231     DistMatrix<T,STAR,STAR> GetSubmatrix
232     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const;
233     DistMatrix<Base<T>,STAR,STAR> GetRealPartOfSubmatrix
234     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const;
235     DistMatrix<Base<T>,STAR,STAR> GetImagPartOfSubmatrix
236     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd ) const;
237 
238     void SetSubmatrix
239     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
240       const DistMatrix<T,STAR,STAR>& ASub );
241     void SetRealPartOfSubmatrix
242     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
243       const DistMatrix<Base<T>,STAR,STAR>& ASub );
244     void SetImagPartOfSubmatrix
245     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
246       const DistMatrix<Base<T>,STAR,STAR>& ASub );
247 
248     void UpdateSubmatrix
249     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
250       T alpha, const DistMatrix<T,STAR,STAR>& ASub );
251     void UpdateRealPartOfSubmatrix
252     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
253       Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub );
254     void UpdateImagPartOfSubmatrix
255     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd,
256       Base<T> alpha, const DistMatrix<Base<T>,STAR,STAR>& ASub );
257 
258     void MakeSubmatrixReal
259     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd );
260     void ConjugateSubmatrix
261     ( const std::vector<Int>& rowInd, const std::vector<Int>& colInd );
262 
263     // Local submatrix manipulation
264     // ----------------------------
265     void GetLocalSubmatrix
266     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
267       elem::Matrix<T>& ASub ) const;
268     void GetRealPartOfLocalSubmatrix
269     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
270       elem::Matrix<Base<T>>& ASub ) const;
271     void GetImagPartOfLocalSubmatrix
272     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
273       elem::Matrix<Base<T>>& ASub ) const;
274     elem::Matrix<T> GetLocalSubmatrix
275     ( const std::vector<Int>& rowIndLoc,
276       const std::vector<Int>& colIndLoc ) const;
277     elem::Matrix<Base<T>> GetRealPartOfLocalSubmatrix
278     ( const std::vector<Int>& rowIndLoc,
279       const std::vector<Int>& colIndLoc ) const;
280     elem::Matrix<Base<T>> GetImagPartOfLocalSubmatrix
281     ( const std::vector<Int>& rowIndLoc,
282       const std::vector<Int>& colIndLoc ) const;
283 
284     void SetLocalSubmatrix
285     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
286       const elem::Matrix<T>& ASub );
287     void SetRealPartOfLocalSubmatrix
288     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
289       const elem::Matrix<Base<T>>& ASub );
290     void SetImagPartOfLocalSubmatrix
291     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
292       const elem::Matrix<Base<T>>& ASub );
293 
294     void UpdateLocalSubmatrix
295     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
296       T alpha, const elem::Matrix<T>& ASub );
297     void UpdateRealPartOfLocalSubmatrix
298     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
299       Base<T> alpha, const elem::Matrix<Base<T>>& ASub );
300     void UpdateImagPartOfLocalSubmatrix
301     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc,
302       Base<T> alpha, const elem::Matrix<Base<T>>& ASub );
303 
304     void MakeLocalSubmatrixReal
305     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc );
306     void ConjugateLocalSubmatrix
307     ( const std::vector<Int>& rowIndLoc, const std::vector<Int>& colIndLoc );
308 
309     // Sum over a specified communicator
310     // =================================
311     void SumOver( mpi::Comm comm );
312 
313     // Assertions
314     // ==========
315     void ComplainIfReal() const;
316     void AssertNotLocked() const;
317     void AssertNotStoringData() const;
318     void AssertValidEntry( Int i, Int j ) const;
319     void AssertValidSubmatrix( Int i, Int j, Int height, Int width ) const;
320     void AssertSameGrid( const elem::Grid& grid ) const;
321     void AssertSameSize( Int height, Int width ) const;
322 
323 protected:
324     // Member variables
325     // ================
326 
327     // Global and local matrix information
328     // -----------------------------------
329     ViewType viewType_;
330     Int height_, width_;
331     Memory<T> auxMemory_;
332     elem::Matrix<T> matrix_;
333 
334     // Process grid and distribution metadata
335     // --------------------------------------
336     bool colConstrained_, rowConstrained_, rootConstrained_;
337     Int blockHeight_, blockWidth_;
338     Int colAlign_, rowAlign_,
339         colCut_,   rowCut_,
340         colShift_, rowShift_;
341     Int root_;
342     const elem::Grid* grid_;
343 
344     // Private constructors
345     // ====================
346     AbstractBlockDistMatrix( const elem::Grid& g=DefaultGrid(),  Int root=0 );
347     AbstractBlockDistMatrix
348     ( const elem::Grid& g, Int blockHeight, Int blockWidth, Int root=0 );
349 
350     // Exchange metadata with another matrix
351     // =====================================
352     virtual void ShallowSwap( type& A );
353 
354     // Modify the distribution metadata
355     // ================================
356     void SetShifts();
357     void SetColShift();
358     void SetRowShift();
359     void SetGrid();
360 
361     // Friend declarations
362     // ===================
363     template<typename S,Dist J,Dist K> friend class GeneralDistMatrix;
364     template<typename S,Dist J,Dist K> friend class DistMatrix;
365     template<typename S,Dist J,Dist K> friend class GeneralBlockDistMatrix;
366     template<typename S,Dist J,Dist K> friend class BlockDistMatrix;
367 };
368 
369 template<typename T>
370 void AssertConforming1x2
371 ( const AbstractBlockDistMatrix<T>& AL, const AbstractBlockDistMatrix<T>& AR );
372 
373 template<typename T>
374 void AssertConforming2x1
375 ( const AbstractBlockDistMatrix<T>& AT, const AbstractBlockDistMatrix<T>& AB );
376 
377 template<typename T>
378 void AssertConforming2x2
379 ( const AbstractBlockDistMatrix<T>& ATL,
380   const AbstractBlockDistMatrix<T>& ATR,
381   const AbstractBlockDistMatrix<T>& ABL,
382   const AbstractBlockDistMatrix<T>& ABR );
383 
384 } // namespace elem
385 
386 #endif // ifndef ELEM_BLOCKDISTMATRIX_ABSTRACT_DECL_HPP
387