1 ////////////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (c) 2008 The Regents of the University of California
4 //
5 // This file is part of Qbox
6 //
7 // Qbox is distributed under the terms of the GNU General Public License
8 // as published by the Free Software Foundation, either version 2 of
9 // the License, or (at your option) any later version.
10 // See the file COPYING in the root directory of this distribution
11 // or <http://www.gnu.org/licenses/>.
12 //
13 ////////////////////////////////////////////////////////////////////////////////
14 //
15 // PSDAWavefunctionStepper.cpp
16 //
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 #include "PSDAWavefunctionStepper.h"
20 #include "Wavefunction.h"
21 #include "SlaterDet.h"
22 #include "Preconditioner.h"
23 #include <iostream>
24 using namespace std;
25 
26 ////////////////////////////////////////////////////////////////////////////////
PSDAWavefunctionStepper(Wavefunction & wf,Preconditioner & prec,TimerMap & tmap)27 PSDAWavefunctionStepper::PSDAWavefunctionStepper(Wavefunction& wf,
28   Preconditioner& prec, TimerMap& tmap) : prec_(prec),
29   WavefunctionStepper(wf,tmap), wf_last_(wf), dwf_last_(wf),
30   extrapolate_(false)
31 {
32   tmap_["psda_residual"].reset();
33   tmap_["psda_prec"].reset();
34   tmap_["psda_update_wf"].reset();
35   tmap_["gram"].reset();
36   tmap_["lowdin"].reset();
37   tmap_["ortho_align"].reset();
38   tmap_["riccati"].reset();
39 }
40 
41 ////////////////////////////////////////////////////////////////////////////////
~PSDAWavefunctionStepper(void)42 PSDAWavefunctionStepper::~PSDAWavefunctionStepper(void)
43 {}
44 
45 ////////////////////////////////////////////////////////////////////////////////
update(Wavefunction & dwf)46 void PSDAWavefunctionStepper::update(Wavefunction& dwf)
47 {
48   tmap_["psda_residual"].start();
49   for ( int isp_loc = 0; isp_loc < wf_.nsp_loc(); ++isp_loc )
50   {
51     for ( int ikp_loc = 0; ikp_loc < wf_.nkp_loc(); ++ikp_loc )
52     {
53       // compute A = V^T H V  and descent direction HV - VA
54       SlaterDet* sd = wf_.sd(isp_loc,ikp_loc);
55       SlaterDet* dsd = dwf.sd(isp_loc,ikp_loc);
56       if ( sd->basis().real() )
57       {
58         // proxy real matrices c, cp
59         DoubleMatrix c_proxy(sd->c());
60         DoubleMatrix cp_proxy(dsd->c());
61         DoubleMatrix a(c_proxy.context(),c_proxy.n(),c_proxy.n(),
62           c_proxy.nb(),c_proxy.nb());
63 
64         // factor 2.0 in next line: G and -G
65         a.gemm('t','n',2.0,c_proxy,cp_proxy,0.0);
66         // rank-1 update correction
67         a.ger(-1.0,c_proxy,0,cp_proxy,0);
68 
69         // cp = cp - c * a
70         cp_proxy.gemm('n','n',-1.0,c_proxy,a,1.0);
71       }
72       else
73       {
74         ComplexMatrix& c = sd->c();
75         ComplexMatrix& cp = dsd->c();
76         ComplexMatrix a(c.context(),c.n(),c.n(),c.nb(),c.nb());
77         a.gemm('c','n',1.0,c,cp,0.0);
78         // cp = cp - c * a
79         cp.gemm('n','n',-1.0,c,a,1.0);
80       }
81     }
82   }
83   tmap_["psda_residual"].stop();
84 
85   // dwf.sd->c() now contains the descent direction (HV-VA) (residual)
86   // update the preconditioner
87   prec_.update(wf_);
88 
89   for ( int isp_loc = 0; isp_loc < wf_.nsp_loc(); ++isp_loc )
90   {
91     for ( int ikp_loc = 0; ikp_loc < wf_.nkp_loc(); ++ikp_loc )
92     {
93       tmap_["psda_prec"].start();
94       SlaterDet* sd = wf_.sd(isp_loc,ikp_loc);
95       SlaterDet* dsd = dwf.sd(isp_loc,ikp_loc);
96       SlaterDet* sd_last = wf_last_.sd(isp_loc,ikp_loc);
97       SlaterDet* dsd_last = dwf_last_.sd(isp_loc,ikp_loc);
98       // Apply preconditioner K and store -K(HV-VA) in dwf
99       double* c = (double*) sd->c().valptr();
100       double* c_last = (double*) sd_last->c().valptr();
101       double* dc = (double*) dsd->c().valptr();
102       double* dc_last = (double*) dsd_last->c().valptr();
103       const int mloc = sd->c().mloc();
104       const int ngwl = sd->basis().localsize();
105       const int nloc = sd->c().nloc();
106 
107       for ( int n = 0; n < nloc; n++ )
108       {
109         // note: double mloc length for complex<double> indices
110         double* dcn = &dc[2*mloc*n];
111         for ( int i = 0; i < ngwl; i++ )
112         {
113           const double fac = prec_.diag(isp_loc,ikp_loc,n,i);
114           const double f0 = -fac * dcn[2*i];
115           const double f1 = -fac * dcn[2*i+1];
116           dcn[2*i] = f0;
117           dcn[2*i+1] = f1;
118         }
119       }
120       tmap_["psda_prec"].stop();
121 
122       // dwf now contains the preconditioned descent
123       // direction -K(HV-VA)
124 
125       tmap_["psda_update_wf"].start();
126       // Anderson extrapolation
127       if ( extrapolate_ )
128       {
129         double theta = 0.0;
130         double a = 0.0, b = 0.0;
131         for ( int i = 0; i < 2*mloc*nloc; i++ )
132         {
133           const double f = dc[i];
134           const double delta_f = f - dc_last[i];
135 
136           // accumulate partial sums of a and b
137           // a = delta_F * F
138 
139           a += f * delta_f;
140           b += delta_f * delta_f;
141         }
142 
143         if ( sd->basis().real() )
144         {
145           // correct for double counting of asum and bsum on first row
146           // factor 2.0: G and -G
147           a *= 2.0;
148           b *= 2.0;
149           if ( wf_.sd_context().myrow() == 0 )
150           {
151             for ( int n = 0; n < nloc; n++ )
152             {
153               const int i = 2*mloc*n;
154               const double f0 = dc[i];
155               const double f1 = dc[i+1];
156               const double delta_f0 = f0 - dc_last[i];
157               const double delta_f1 = f1 - dc_last[i+1];
158               a -= f0 * delta_f0 + f1 * delta_f1;
159               b -= delta_f0 * delta_f0 + delta_f1 * delta_f1;
160             }
161           }
162         }
163 
164         // a and b contain the partial sums of a and b
165         double tmpvec[2] = { a, b };
166         wf_.sd_context().dsum(2,1,&tmpvec[0],2);
167         a = tmpvec[0];
168         b = tmpvec[1];
169 
170         // compute theta = - a / b
171         if ( b != 0.0 )
172           theta = - a / b;
173 
174         if ( theta < -1.0 )
175         {
176           theta = 0.0;
177         }
178 
179         theta = min(2.0,theta);
180 
181         // extrapolation
182         for ( int i = 0; i < 2*mloc*nloc; i++ )
183         {
184           // x_bar = x_ + theta * ( x_ - xlast_ ) (store in x_)
185           const double x = c[i];
186           const double xlast = c_last[i];
187           const double xbar = x + theta * ( x - xlast );
188 
189           // f_bar = f + theta * ( f - flast ) (store in f)
190           const double f = dc[i];
191           const double flast = dc_last[i];
192           const double fbar = f + theta * ( f - flast );
193 
194           c[i] = xbar + fbar;
195           c_last[i] = x;
196           dc_last[i] = f;
197         }
198       }
199       else
200       {
201         // no extrapolation
202         for ( int i = 0; i < 2*mloc*nloc; i++ )
203         {
204           // x_ = x_ + f_
205           const double x = c[i];
206           const double f = dc[i];
207 
208           c[i] = x + f;
209           c_last[i] = x;
210           dc_last[i] = f;
211         }
212       }
213       tmap_["psda_update_wf"].stop();
214 
215       enum ortho_type { GRAM, LOWDIN, ORTHO_ALIGN, RICCATI };
216       //const ortho_type ortho = GRAM;
217       //const ortho_type ortho = LOWDIN;
218       const ortho_type ortho = ORTHO_ALIGN;
219 
220       switch ( ortho )
221       {
222         case GRAM:
223           tmap_["gram"].start();
224           sd->gram();
225           tmap_["gram"].stop();
226           break;
227 
228         case LOWDIN:
229           tmap_["lowdin"].start();
230           sd->lowdin();
231           tmap_["lowdin"].stop();
232           break;
233 
234         case ORTHO_ALIGN:
235           tmap_["ortho_align"].start();
236           sd->ortho_align(*sd_last);
237           tmap_["ortho_align"].stop();
238           break;
239 
240         case RICCATI:
241           tmap_["riccati"].start();
242           sd->riccati(*sd_last);
243           tmap_["riccati"].stop();
244           break;
245       }
246     } // ikp_loc
247   } // isp_loc
248   extrapolate_ = true;
249 }
250