1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_LU_PANEL_HPP
11 #define ELEM_LU_PANEL_HPP
12 
13 #include ELEM_SCALE_INC
14 #include ELEM_SWAP_INC
15 #include ELEM_MAXABS_INC
16 #include ELEM_GERU_INC
17 
18 namespace elem {
19 namespace lu {
20 
21 template<typename F>
22 inline void
Panel(Matrix<F> & A,Matrix<Int> & pivots)23 Panel( Matrix<F>& A, Matrix<Int>& pivots )
24 {
25     DEBUG_ONLY(CallStackEntry cse("lu::Panel"))
26     const Int m = A.Height();
27     const Int n = A.Width();
28     DEBUG_ONLY(
29         if( m < n )
30             LogicError("Must be a column panel");
31     )
32     pivots.Resize( n, 1 );
33 
34     for( Int k=0; k<n; ++k )
35     {
36         auto alpha11 = ViewRange( A, k,   k,   k+1, k+1 );
37         auto a12     = ViewRange( A, k,   k+1, k+1, n   );
38         auto a21     = ViewRange( A, k+1, k,   m,   k+1 );
39         auto A22     = ViewRange( A, k+1, k+1, m,   n   );
40 
41         // Find the index and value of the pivot candidate
42         auto pivot = VectorMaxAbs( ViewRange(A,k,k,m,k+1) );
43         const Int iPiv = pivot.index + k;
44         pivots.Set( k, 0, iPiv );
45 
46         // Swap the pivot row and current row
47         if( iPiv != k )
48         {
49             auto aCurRow = ViewRange( A, k,    0, k+1,    n );
50             auto aPivRow = ViewRange( A, iPiv, 0, iPiv+1, n );
51             Swap( NORMAL, aCurRow, aPivRow );
52         }
53 
54         // Now we can perform the update of the current panel
55         const F alpha = alpha11.Get(0,0);
56         if( alpha == F(0) )
57             throw SingularMatrixException();
58         const F alpha11Inv = F(1) / alpha;
59         Scale( alpha11Inv, a21 );
60         Geru( F(-1), a21, a12, A22 );
61     }
62 }
63 
64 template<typename F>
65 inline void
Panel(DistMatrix<F,STAR,STAR> & A,DistMatrix<F,MC,STAR> & B,DistMatrix<Int,STAR,STAR> & pivots)66 Panel
67 ( DistMatrix<F,  STAR,STAR>& A,
68   DistMatrix<F,  MC,  STAR>& B,
69   DistMatrix<Int,STAR,STAR>& pivots )
70 {
71     DEBUG_ONLY(
72         CallStackEntry cse("lu::Panel");
73         if( A.Grid() != pivots.Grid() || pivots.Grid() != B.Grid() )
74             LogicError("Matrices must be distributed over the same grid");
75         if( A.Width() != B.Width() )
76             LogicError("A and B must be the same width");
77     )
78     typedef Base<F> Real;
79 
80     // For packing rows of data for pivoting
81     const Int n = A.Width();
82     const Int mB = B.Height();
83     const Int nB = B.Width();
84     std::vector<F> pivotBuffer( n );
85 
86     pivots.Resize( n, 1 );
87 
88     for( Int k=0; k<n; ++k )
89     {
90         auto alpha11 = ViewRange( A, k,   k,   k+1, k+1 );
91         auto a12     = ViewRange( A, k,   k+1, k+1, n   );
92         auto a21     = ViewRange( A, k+1, k,   n,   k+1 );
93         auto A22     = ViewRange( A, k+1, k+1, n,   n   );
94         auto b1      = ViewRange( B, 0,   k,   mB,  k+1 );
95         auto B2      = ViewRange( B, 0,   k+1, mB,  nB  );
96 
97         // Store the index/value of the local pivot candidate
98         ValueInt<Real> localPivot;
99         localPivot.value = FastAbs(alpha11.GetLocal(0,0));
100         localPivot.index = k;
101         for( Int i=0; i<a21.Height(); ++i )
102         {
103             const Real value = FastAbs(a21.GetLocal(i,0));
104             if( value > localPivot.value )
105             {
106                 localPivot.value = value;
107                 localPivot.index = k + i + 1;
108             }
109         }
110         for( Int iLoc=0; iLoc<B.LocalHeight(); ++iLoc )
111         {
112             const Real value = FastAbs(b1.GetLocal(iLoc,0));
113             if( value > localPivot.value )
114             {
115                 localPivot.value = value;
116                 localPivot.index = n + B.GlobalRow(iLoc);
117             }
118         }
119 
120         // Compute and store the location of the new pivot
121         const ValueInt<Real> pivot =
122             mpi::AllReduce( localPivot, mpi::MaxLocOp<Real>(), B.ColComm() );
123         const Int iPiv = pivot.index;
124         pivots.SetLocal( k, 0, iPiv );
125 
126         // Perform the pivot within this panel
127         if( iPiv < n )
128         {
129             // Pack pivot into temporary
130             for( Int j=0; j<n; ++j )
131                 pivotBuffer[j] = A.GetLocal( iPiv, j );
132             // Replace pivot with current
133             for( Int j=0; j<n; ++j )
134                 A.SetLocal( iPiv, j, A.GetLocal(k,j) );
135         }
136         else
137         {
138             // The owning row of the pivot row packs it into the row buffer
139             // and then overwrites with the current row
140             const Int relIndex = iPiv - n;
141             const Int ownerRow = B.RowOwner(relIndex);
142             if( B.IsLocalRow(relIndex) )
143             {
144                 const Int iLoc = B.LocalRow(relIndex);
145                 for( Int j=0; j<n; ++j )
146                     pivotBuffer[j] = B.GetLocal( iLoc, j );
147                 for( Int j=0; j<n; ++j )
148                     B.SetLocal( iLoc, j, A.GetLocal(k,j) );
149             }
150             // The owning row broadcasts within process columns
151             mpi::Broadcast( pivotBuffer.data(), n, ownerRow, B.ColComm() );
152         }
153         // Overwrite the current row with the pivot row
154         for( Int j=0; j<n; ++j )
155             A.SetLocal( k, j, pivotBuffer[j] );
156 
157         // Now we can perform the update of the current panel
158         const F alpha = alpha11.GetLocal(0,0);
159         if( alpha == F(0) )
160             throw SingularMatrixException();
161         const F alpha11Inv = F(1) / alpha;
162         Scale( alpha11Inv, a21 );
163         Scale( alpha11Inv, b1  );
164         Geru( F(-1), a21.Matrix(), a12.Matrix(), A22.Matrix() );
165         Geru( F(-1), b1.Matrix(), a12.Matrix(), B2.Matrix() );
166     }
167 }
168 
169 } // namespace lu
170 } // namespace elem
171 
172 #endif // ifndef ELEM_LU_PANEL_HPP
173