1 /*
2    Copyright (c) 1992-2008 The University of Tennessee.
3    All rights reserved.
4 
5    Copyright (c) 2009-2014, Jack Poulson
6    All rights reserved.
7 
8    This file is loosely based upon the LAPACK routines dlarfg.f and zlarfg.f.
9 
10    This file is part of Elemental and is under the BSD 2-Clause License,
11    which can be found in the LICENSE file in the root directory, or at
12    http://opensource.org/licenses/BSD-2-Clause
13 */
14 #pragma once
15 #ifndef ELEM_REFLECTOR_ROW_HPP
16 #define ELEM_REFLECTOR_ROW_HPP
17 
18 #include ELEM_NRM2_INC
19 #include ELEM_SCALE_INC
20 
21 namespace elem {
22 namespace reflector {
23 
24 template<typename F,Dist U,Dist V>
25 inline F
Row(DistMatrix<F,U,V> & chi,DistMatrix<F,U,V> & x)26 Row( DistMatrix<F,U,V>& chi, DistMatrix<F,U,V>& x )
27 {
28     DEBUG_ONLY(
29         CallStackEntry cse("reflector::Row");
30         if( chi.Grid() != x.Grid() )
31             LogicError("chi and x must be distributed over the same grid");
32         if( chi.Height() != 1 || chi.Width() != 1 )
33             LogicError("chi must be a scalar");
34         if( x.Height() != 1 )
35             LogicError("x must be a row vector");
36         if( chi.ColRank() != chi.ColAlign() || x.ColRank() != x.ColAlign() )
37             LogicError("Reflecting from incorrect process");
38     )
39     typedef Base<F> Real;
40     mpi::Comm rowComm = x.RowComm();
41     const Int rowRank = x.RowRank();
42     const Int rowStride = x.RowStride();
43     const Int rowAlign = chi.RowAlign();
44 
45     std::vector<Real> localNorms(rowStride);
46     Real localNorm = Nrm2( x.LockedMatrix() );
47     mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
48     Real norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
49 
50     F alpha;
51     if( rowRank == rowAlign )
52         alpha = chi.GetLocal(0,0);
53     mpi::Broadcast( alpha, rowAlign, rowComm );
54 
55     if( norm == Real(0) && ImagPart(alpha) == Real(0) )
56     {
57         if( rowRank == rowAlign )
58             chi.SetLocal(0,0,-chi.GetLocal(0,0));
59         return F(2);
60     }
61 
62     Real beta;
63     if( RealPart(alpha) <= 0 )
64         beta = lapack::SafeNorm( alpha, norm );
65     else
66         beta = -lapack::SafeNorm( alpha, norm );
67 
68     // Rescale if the vector is too small
69     const Real safeMin = lapack::MachineSafeMin<Real>();
70     const Real epsilon = lapack::MachineEpsilon<Real>();
71     const Real safeInv = safeMin/epsilon;
72     Int count = 0;
73     if( Abs(beta) < safeInv )
74     {
75         Real invOfSafeInv = Real(1)/safeInv;
76         do
77         {
78             ++count;
79             Scale( invOfSafeInv, x );
80             alpha *= invOfSafeInv;
81             beta *= invOfSafeInv;
82         } while( Abs(beta) < safeInv );
83 
84         localNorm = Nrm2( x.LockedMatrix() );
85         mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
86         norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
87         if( RealPart(alpha) <= 0 )
88             beta = lapack::SafeNorm( alpha, norm );
89         else
90             beta = -lapack::SafeNorm( alpha, norm );
91     }
92 
93     F tau = (beta-Conj(alpha)) / beta;
94     Scale( Real(1)/(alpha-beta), x );
95 
96     // Undo the scaling
97     for( Int j=0; j<count; ++j )
98         beta *= safeInv;
99 
100     if( rowRank == rowAlign )
101         chi.SetLocal(0,0,beta);
102 
103     // This is to make this a reflector meant to be applied from the right;
104     // there is no need to conjugate chi, as it is real
105     Conjugate( x );
106 
107     return tau;
108 }
109 
110 template<typename F,Dist U,Dist V>
111 inline F
Row(F & chi,DistMatrix<F,U,V> & x)112 Row( F& chi, DistMatrix<F,U,V>& x )
113 {
114     DEBUG_ONLY(
115         CallStackEntry cse("reflector::Row");
116         if( x.Height() != 1 )
117             LogicError("x must be a row vector");
118         if( x.ColRank() != x.ColAlign() )
119             LogicError("Reflecting from incorrect process");
120     )
121     typedef Base<F> Real;
122     mpi::Comm rowComm = x.RowComm();
123     const Int rowStride = x.RowStride();
124 
125     std::vector<Real> localNorms(rowStride);
126     Real localNorm = Nrm2( x.LockedMatrix() );
127     mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
128     Real norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
129 
130     F alpha = chi;
131     if( norm == Real(0) && ImagPart(alpha) == Real(0) )
132     {
133         chi = -chi;
134         return F(2);
135     }
136 
137     Real beta;
138     if( RealPart(alpha) <= 0 )
139         beta = lapack::SafeNorm( alpha, norm );
140     else
141         beta = -lapack::SafeNorm( alpha, norm );
142 
143     // Rescale if the vector is too small
144     const Real safeMin = lapack::MachineSafeMin<Real>();
145     const Real epsilon = lapack::MachineEpsilon<Real>();
146     const Real safeInv = safeMin/epsilon;
147     Int count = 0;
148     if( Abs(beta) < safeInv )
149     {
150         Real invOfSafeInv = Real(1)/safeInv;
151         do
152         {
153             ++count;
154             Scale( invOfSafeInv, x );
155             alpha *= invOfSafeInv;
156             beta *= invOfSafeInv;
157         } while( Abs(beta) < safeInv );
158 
159         localNorm = Nrm2( x.LockedMatrix() );
160         mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
161         norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
162         if( RealPart(alpha) <= 0 )
163             beta = lapack::SafeNorm( alpha, norm );
164         else
165             beta = -lapack::SafeNorm( alpha, norm );
166     }
167 
168     F tau = (beta-Conj(alpha)) / beta;
169     Scale( Real(1)/(alpha-beta), x );
170 
171     // Undo the scaling
172     for( Int j=0; j<count; ++j )
173         beta *= safeInv;
174 
175     chi = beta;
176 
177     // This is to make this a reflector meant to be applied from the right;
178     // there is no need to conjugate chi, as it is real
179     Conjugate( x );
180 
181     return tau;
182 }
183 
184 } // namespace reflector
185 } // namespace elem
186 
187 #endif // ifndef ELEM_REFLECTOR_ROW_HPP
188