1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_HYPERBOLICREFLECTOR_ROW_HPP
11 #define ELEM_HYPERBOLICREFLECTOR_ROW_HPP
12 
13 #include ELEM_NRM2_INC
14 #include ELEM_SCALE_INC
15 #include ELEM_ZERO_INC
16 
17 namespace elem {
18 namespace hyp_reflector {
19 
20 template<typename F,Dist U,Dist V>
21 inline F
Row(DistMatrix<F,U,V> & chi,DistMatrix<F,U,V> & x)22 Row( DistMatrix<F,U,V>& chi, DistMatrix<F,U,V>& x )
23 {
24     DEBUG_ONLY(
25         CallStackEntry cse("hyp_reflector::Row");
26         if( chi.Grid() != x.Grid() )
27             LogicError("chi and x must be distributed over the same grid");
28         if( chi.Height() != 1 || chi.Width() != 1 )
29             LogicError("chi must be a scalar");
30         if( x.Height() != 1 )
31             LogicError("x must be a row vector");
32         if( chi.ColRank() != chi.ColAlign() || x.ColRank() != x.ColAlign() )
33             LogicError("Reflecting from incorrect process");
34     )
35     typedef Base<F> Real;
36     mpi::Comm rowComm = x.RowComm();
37     const Int rowRank = x.RowRank();
38     const Int rowStride = x.RowStride();
39     const Int rowAlign = chi.RowAlign();
40 
41     std::vector<Real> localNorms(rowStride);
42     Real localNorm = Nrm2( x.LockedMatrix() );
43     mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
44     Real norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
45 
46     Real alpha;
47     if( rowRank == rowAlign )
48     {
49         if( ImagPart(chi.GetLocal(0,0)) != Real(0) )
50             LogicError("chi is assumed to be real");
51         alpha = chi.GetLocalRealPart(0,0);
52     }
53     mpi::Broadcast( alpha, rowAlign, rowComm );
54     const Real delta = alpha*alpha - norm*norm;
55     if( delta < Real(0) )
56         LogicError("Attempted to square-root a negative number");
57     const Real lambda = ( alpha>=0 ? Sqrt(delta) : -Sqrt(delta) );
58     if( rowRank == rowAlign )
59         chi.SetLocal(0,0,-lambda);
60 
61     const Real kappa = alpha + lambda;
62     if( kappa == Real(0) )
63     {
64         Zero( x );
65         return Real(1);
66     }
67     else
68     {
69         Scale( Real(1)/kappa, x );
70         Conjugate( x );
71         return (delta+alpha*lambda)/(kappa*kappa);
72     }
73 }
74 
75 template<typename F,Dist U,Dist V>
76 inline F
Row(F & chi,DistMatrix<F,U,V> & x)77 Row( F& chi, DistMatrix<F,U,V>& x )
78 {
79     DEBUG_ONLY(
80         CallStackEntry cse("hyp_reflector::Row");
81         if( x.Height() != 1 )
82             LogicError("x must be a row vector");
83         if( x.ColRank() != x.ColAlign() )
84             LogicError("Reflecting from incorrect process");
85         if( ImagPart(chi) != Base<F>(0) )
86             LogicError("chi is assumed to be real");
87     )
88     typedef Base<F> Real;
89     mpi::Comm rowComm = x.RowComm();
90     const Int rowStride = x.RowStride();
91 
92     std::vector<Real> localNorms(rowStride);
93     Real localNorm = Nrm2( x.LockedMatrix() );
94     mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
95     Real norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
96 
97     const Real alpha = RealPart(chi);
98     const Real delta = alpha*alpha - norm*norm;
99     if( delta < Real(0) )
100         LogicError("Attempted to square-root a negative number");
101     const Real lambda = ( alpha>=0 ? Sqrt(delta) : -Sqrt(delta) );
102     chi = -lambda;
103 
104     const Real kappa = alpha + lambda;
105     if( kappa == Real(0) )
106     {
107         Zero( x );
108         return Real(1);
109     }
110     else
111     {
112         Scale( Real(1)/kappa, x );
113         Conjugate( x );
114         return (delta+alpha*lambda)/(kappa*kappa);
115     }
116 }
117 
118 } // namespace hyp_reflector
119 } // namespace elem
120 
121 #endif // ifndef ELEM_HYPERBOLICREFLECTOR_ROW_HPP
122