1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_HYPERBOLICREFLECTOR_ROW_HPP
11 #define ELEM_HYPERBOLICREFLECTOR_ROW_HPP
12
13 #include ELEM_NRM2_INC
14 #include ELEM_SCALE_INC
15 #include ELEM_ZERO_INC
16
17 namespace elem {
18 namespace hyp_reflector {
19
20 template<typename F,Dist U,Dist V>
21 inline F
Row(DistMatrix<F,U,V> & chi,DistMatrix<F,U,V> & x)22 Row( DistMatrix<F,U,V>& chi, DistMatrix<F,U,V>& x )
23 {
24 DEBUG_ONLY(
25 CallStackEntry cse("hyp_reflector::Row");
26 if( chi.Grid() != x.Grid() )
27 LogicError("chi and x must be distributed over the same grid");
28 if( chi.Height() != 1 || chi.Width() != 1 )
29 LogicError("chi must be a scalar");
30 if( x.Height() != 1 )
31 LogicError("x must be a row vector");
32 if( chi.ColRank() != chi.ColAlign() || x.ColRank() != x.ColAlign() )
33 LogicError("Reflecting from incorrect process");
34 )
35 typedef Base<F> Real;
36 mpi::Comm rowComm = x.RowComm();
37 const Int rowRank = x.RowRank();
38 const Int rowStride = x.RowStride();
39 const Int rowAlign = chi.RowAlign();
40
41 std::vector<Real> localNorms(rowStride);
42 Real localNorm = Nrm2( x.LockedMatrix() );
43 mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
44 Real norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
45
46 Real alpha;
47 if( rowRank == rowAlign )
48 {
49 if( ImagPart(chi.GetLocal(0,0)) != Real(0) )
50 LogicError("chi is assumed to be real");
51 alpha = chi.GetLocalRealPart(0,0);
52 }
53 mpi::Broadcast( alpha, rowAlign, rowComm );
54 const Real delta = alpha*alpha - norm*norm;
55 if( delta < Real(0) )
56 LogicError("Attempted to square-root a negative number");
57 const Real lambda = ( alpha>=0 ? Sqrt(delta) : -Sqrt(delta) );
58 if( rowRank == rowAlign )
59 chi.SetLocal(0,0,-lambda);
60
61 const Real kappa = alpha + lambda;
62 if( kappa == Real(0) )
63 {
64 Zero( x );
65 return Real(1);
66 }
67 else
68 {
69 Scale( Real(1)/kappa, x );
70 Conjugate( x );
71 return (delta+alpha*lambda)/(kappa*kappa);
72 }
73 }
74
75 template<typename F,Dist U,Dist V>
76 inline F
Row(F & chi,DistMatrix<F,U,V> & x)77 Row( F& chi, DistMatrix<F,U,V>& x )
78 {
79 DEBUG_ONLY(
80 CallStackEntry cse("hyp_reflector::Row");
81 if( x.Height() != 1 )
82 LogicError("x must be a row vector");
83 if( x.ColRank() != x.ColAlign() )
84 LogicError("Reflecting from incorrect process");
85 if( ImagPart(chi) != Base<F>(0) )
86 LogicError("chi is assumed to be real");
87 )
88 typedef Base<F> Real;
89 mpi::Comm rowComm = x.RowComm();
90 const Int rowStride = x.RowStride();
91
92 std::vector<Real> localNorms(rowStride);
93 Real localNorm = Nrm2( x.LockedMatrix() );
94 mpi::AllGather( &localNorm, 1, localNorms.data(), 1, rowComm );
95 Real norm = blas::Nrm2( rowStride, localNorms.data(), 1 );
96
97 const Real alpha = RealPart(chi);
98 const Real delta = alpha*alpha - norm*norm;
99 if( delta < Real(0) )
100 LogicError("Attempted to square-root a negative number");
101 const Real lambda = ( alpha>=0 ? Sqrt(delta) : -Sqrt(delta) );
102 chi = -lambda;
103
104 const Real kappa = alpha + lambda;
105 if( kappa == Real(0) )
106 {
107 Zero( x );
108 return Real(1);
109 }
110 else
111 {
112 Scale( Real(1)/kappa, x );
113 Conjugate( x );
114 return (delta+alpha*lambda)/(kappa*kappa);
115 }
116 }
117
118 } // namespace hyp_reflector
119 } // namespace elem
120
121 #endif // ifndef ELEM_HYPERBOLICREFLECTOR_ROW_HPP
122