1 /*
2   This file is part of MADNESS.
3 
4   Copyright (C) 2007,2010 Oak Ridge National Laboratory
5 
6   This program is free software; you can redistribute it and/or modify
7   it under the terms of the GNU General Public License as published by
8   the Free Software Foundation; either version 2 of the License, or
9   (at your option) any later version.
10 
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14   GNU General Public License for more details.
15 
16   You should have received a copy of the GNU General Public License
17   along with this program; if not, write to the Free Software
18   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 
20   For more information please contact:
21 
22   Robert J. Harrison
23   Oak Ridge National Laboratory
24   One Bethel Valley Road
25   P.O. Box 2008, MS-6367
26 
27   email: harrisonrj@ornl.gov
28   tel:   865-241-3937
29   fax:   865-572-0680
30 
31   $Id$
32 */
33 /// \file tdse/tdse4.cc
34 /// \brief Evolves the hydrogen molecular ion in 4D ... 3 electron + 1 nuclear degree of freedom
35 
36 
37 //#define WORLD_INSTANTIATE_STATIC_TEMPLATES
38 #include <madness/mra/mra.h>
39 #include <madness/mra/funcimpl.h>
40 #include <madness/mra/qmprop.h>
41 #include <madness/mra/operator.h>
42 #include <madness/constants.h>
43 #include <madness/tensor/vmath.h>
44 
45 #include <madness/mra/lbdeux.h>
46 
47 using namespace madness;
48 
49 
50 template <typename T, int NDIM>
51 struct lbcost {
52     double leaf_value;
53     double parent_value;
lbcostlbcost54     lbcost(double leaf_value=1.0, double parent_value=1.0) : leaf_value(leaf_value), parent_value(parent_value) {}
operator ()lbcost55     double operator()(const Key<NDIM>& key, const FunctionNode<T,NDIM>& node) const {
56         if (node.is_leaf()) {
57             return leaf_value;
58         }
59         else {
60             return parent_value;
61         }
62         //return key.level()+1.0;
63     }
64 };
65 
66 
67 // typedefs to make life less verbose
68 typedef Vector<double,4> coordT;
69 typedef std::shared_ptr< FunctionFunctorInterface<double,4> > functorT;
70 typedef Function<double,4> functionT;
71 typedef FunctionFactory<double,4> factoryT;
72 typedef SeparatedConvolution<double,4> operatorT;
73 typedef std::shared_ptr< FunctionFunctorInterface<double_complex,4> > complex_functorT;
74 typedef Function<double_complex,4> complex_functionT;
75 typedef FunctionFactory<double_complex,4> complex_factoryT;
76 typedef Convolution1D<double_complex> complex_operatorT;
77 typedef std::shared_ptr< WorldDCPmapInterface< Key<4> > > pmapT;
78 
real(double a)79 double real(double a) {return a;}
80 
81 static const double reduced_mass = 0.5*constants::proton_electron_mass_ratio;
82 static const double sqrtmu = sqrt(reduced_mass);
83 static const double R0 = 2.04; // Effective center of nuclear wave function
84 static const double s0 = sqrtmu*R0;
85 static const double Z=1.0;
86 
87 struct InputParameters {
88   static const int MAXNATOM=99;
89 
90     // IF YOU ADD A NEW PARAMETER DON'T FORGET TO INCLUDE IT IN
91     // a) read()
92     // b) serialize()
93     // c) operator<<()
94 
95   double L;           // Box size for the simulation
96   double F;           // Laser field strength
97   double omega;       // Laser frequency
98   double ncycle;      // Number of laser cycles in envelope
99   int k;              // wavelet order
100   double thresh;      // precision for truncating wave function
101   double safety;      // additional precision (thresh*safety) for operators and potential
102   double cut;         // smoothing parameter for 1/r (same for all atoms for now)
103   std::string prefix;      // Prefix for filenames
104   int ndump;          // dump wave function to disk every ndump steps
105   int nprint;         // print stats every nprint steps
106   int nloadbal;       // load balance every nloadbal steps
107   int nio;            // Number of IO nodes
108 
109   double tScale;      // Scaling parameter for optimization
110 
111   double target_time;// Target end-time for the simulation
112 
readInputParameters113   void read(const char* filename) {
114     std::ifstream f(filename);
115     std::string tag;
116     printf("\n");
117     printf("       Simulation parameters\n");
118     printf("       ---------------------\n");
119     printf("             Z = %.1f\n", Z);
120     printf("            R0 = %.6f\n", R0);
121     printf("            mu = %.6f\n", reduced_mass);
122     printf("      sqrt(mu) = %.6f\n", sqrtmu);
123     while(f >> tag) {
124         if (tag[0] == '#') {
125             char ch;
126             printf("    comment  %s ",tag.c_str());
127             while (f.get(ch)) {
128                 printf("%c",ch);
129                 if (ch == '\n') break;
130             }
131         }
132         else if (tag == "L") {
133             f >> L;
134             printf("             L = %.1f\n", L);
135         }
136         else if (tag == "F") {
137             f >> F;
138             printf("             F = %.6f\n", F);
139         }
140         else if (tag == "omega") {
141             f >> omega;
142             printf("         omega = %.6f\n", omega);
143         }
144         else if (tag == "ncycle") {
145             f >> ncycle;
146             printf("         ncycle = %.6f\n", ncycle);
147         }
148         else if (tag == "k") {
149             f >> k;
150             printf("             k = %d\n", k);
151         }
152         else if (tag == "thresh") {
153             f >> thresh;
154             printf("        thresh = %.1e\n", thresh);
155         }
156         else if (tag == "safety") {
157             f >> safety;
158             printf("        safety = %.1e\n", safety);
159         }
160         else if (tag == "cut") {
161             f >> cut;
162             printf("           cut = %.2f\n", cut);
163         }
164         else if (tag == "prefix") {
165             f >> prefix;
166             printf("        prefix = %s\n", prefix.c_str());
167         }
168         else if (tag == "ndump") {
169             f >> ndump;
170             printf("         ndump = %d\n", ndump);
171         }
172         else if (tag == "nprint") {
173             f >> nprint;
174             printf("         nprint = %d\n", nprint);
175         }
176         else if (tag == "nloadbal") {
177             f >> nloadbal;
178             printf("         nloadbal = %d\n", nloadbal);
179         }
180         else if (tag == "nio") {
181             f >> nio;
182             printf("           nio = %d\n", nio);
183         }
184         else if (tag == "target_time") {
185             f >> target_time;
186             printf("   target_time = %.3f\n", target_time);
187         }
188         else if (tag == "tScale") {
189             f >> tScale;
190             printf("           tScale = %.5f\n", tScale);
191         }
192         else {
193             MADNESS_EXCEPTION("unknown input option", 0);
194         }
195     }
196   }
197 
198   template <typename Archive>
serializeInputParameters199   void serialize(Archive & ar) {
200       ar & L & F & omega & ncycle;
201       ar & k & thresh & safety & cut & prefix & ndump & nprint & nloadbal & nio & target_time & tScale;
202   }
203 };
204 
operator <<(std::ostream & s,const InputParameters & p)205 std::ostream& operator<<(std::ostream& s, const InputParameters& p) {
206     s << p.L<< " " << p.F<< " " << p.omega << " " <<
207         p.ncycle << " " << p.k<< " " <<
208         p.thresh<< " " << p.cut<< " " << p.prefix<< " " << p.ndump<< " " <<
209         p.nprint << " "  << p.nloadbal << " " << p.nio << p.tScale << std::endl;
210 return s;
211 }
212 
213 InputParameters param;
214 
215 static double zero_field_time;      // Laser actually switches on after this time (set by propagate)
216                                     // Delay provides for several steps with no field before start
217 
218 // Smoothed 1/r potential.
219 
220 // Invoke as \c u(r/c)/c where \c c is the radius of the smoothed volume.
smoothed_potential(double r)221 static double smoothed_potential(double r) {
222     double r2 = r*r, pot;
223     if (r > 6.5){
224         pot = 1.0/r;
225     } else if (r > 1e-2) {
226         pot = erf(r)/r + exp(-r2)*0.56418958354775630;
227     } else{
228         pot = 1.6925687506432689-r2*(0.94031597257959381-r2*(0.39493270848342941-0.12089776790309064*r2));
229     }
230 
231     return pot;
232 }
233 
234 // Potential - nuclear-nuclear repulsion
Vn(const coordT & r)235 static double Vn(const coordT& r) {
236     double s=r[3];
237     double R = R0 + s/sqrtmu;
238     if (R < 0.0) R = 0.0;// Do something vaguely sensible for non-physical bond length
239 
240     double cut = 0.5;
241     return smoothed_potential(R/cut)/cut;
242 }
243 
244 // Potential - electron-nuclear attraction
Ve(const coordT & r)245 static double Ve(const coordT& r) {
246     const double x=r[0], y=r[1], z=r[2], s=r[3];
247     double R = R0 + s/sqrtmu;
248     if (R < 0.0) R = 0.0;
249 
250     double zz = z-R*0.5;
251     double rr = sqrt(x*x+y*y+zz*zz);
252     double Va = -Z*smoothed_potential(rr/param.cut)/param.cut;
253 
254     zz = z+R*0.5;
255     rr = sqrt(x*x+y*y+zz*zz);
256     double Vb = -Z*smoothed_potential(rr/param.cut)/param.cut;
257 
258     return Va + Vb;
259 }
260 
261 // Initial guess wave function using symmetric superposition of 1s orbital on atoms and harmonic oscillator
guess(const coordT & r)262 static double guess(const coordT& r) {
263     const double x=r[0], y=r[1], z=r[2], s=r[3];
264     const double R = R0 + s/sqrtmu;
265 
266     // These from fitting to Predrag's exact function form for psinuc in BO approx
267     static const double a = 4.42162;
268     static const double alpha = 1.28164;
269     static const double beta = -1.06379;
270 
271     static const double empirical_norm = 43.0;
272 
273     // Screen on size of nuclear wave function
274     static const double Rmax = R0 + sqrt(46.0/a);
275     if (R-R0 > Rmax) return 0.0;
276 
277     // Note in electronic part we are using R0 not R ... OR SHOULD THIS BE R? TRY BOTH!
278     static const double face = sqrt(3.14*Z*Z*Z);
279     double zz = z-R0*0.5;
280     double rr = sqrt(x*x+y*y+zz*zz);
281     double psia = face*exp(-Z*rr);
282     zz = z+R0*0.5;
283     rr = sqrt(x*x+y*y+zz*zz);
284     double psib = face*exp(-Z*rr);
285 
286     // Nuclear part
287     double R2 = (R-R0)*(R-R0);
288     double psinuc = alpha*exp(-a*R2) + beta*(R-R0)*exp(-1.5*a*R2);
289 
290     return (psia + psib)*psinuc / empirical_norm;;
291 }
292 
293 // x-dipole electronic
xdipole(const coordT & r)294 double xdipole(const coordT& r) {
295     return r[0];
296 }
297 
298 // y-dipole electronic
ydipole(const coordT & r)299 double ydipole(const coordT& r) {
300     return r[1];
301 }
302 
303 // z-dipole electronic
zdipole(const coordT & r)304 double zdipole(const coordT& r) {
305     return r[2];
306 }
307 
bond_length(const coordT & r)308 double bond_length(const coordT& r) {
309     return R0 + r[3]/sqrtmu;
310 }
311 
312 // Strength of the laser field at time t
laser(double t)313 double laser(double t) {
314     double omegat = param.omega*t;
315 
316     if (omegat < 0.0 || omegat/(2*param.ncycle) > constants::pi) return 0.0;
317 
318     double envelope = sin(omegat/(2*param.ncycle));
319     envelope *= envelope;
320     return param.F*envelope*sin(omegat);
321 }
322 
myreal(double t)323 double myreal(double t) {return t;}
324 
myreal(const double_complex & t)325 double myreal(const double_complex& t) {return real(t);}
326 
327 // Given psi and V evaluate the energy ... leaves psi compressed, potn reconstructed
328 template <typename T>
energy(World & world,const Function<T,4> & psi,const functionT & pote,const functionT & potn,const functionT & potf)329 double energy(World& world, const Function<T,4>& psi, const functionT& pote, const functionT& potn, const functionT& potf) {
330     // First do all work in the scaling function basis
331     //bool DOFENCE = false;
332     bool DOFENCE = true;
333     psi.reconstruct();
334     Derivative<T,4> Dx(world,0), Dy(world,1), Dz(world,2), Ds(world,3);
335     Function<T,4> dx = Dx(psi,DOFENCE);
336     Function<T,4> dy = Dy(psi,DOFENCE);
337     Function<T,4> dz = Dz(psi,DOFENCE);
338     Function<T,4> ds = Ds(psi,DOFENCE);
339     Function<T,4> Vepsi = psi*pote;
340     Function<T,4> Vnpsi = psi*potn;
341     Function<T,4> Vfpsi = psi*potf;
342 
343     // Now do all work in the wavelet basis
344     psi.compress(DOFENCE);
345     Vepsi.compress(DOFENCE);
346     Vnpsi.compress(DOFENCE);
347     Vfpsi.compress(DOFENCE);
348     dx.compress(DOFENCE);
349     dy.compress(DOFENCE);
350     dz.compress(DOFENCE);
351     ds.compress(true);
352     double S = real(psi.inner(psi));
353     double PEe = real(psi.inner(Vepsi))/S;
354     double PEn = real(psi.inner(Vnpsi))/S;
355     double PEf = real(psi.inner(Vfpsi))/S;
356     double KEe = real(0.5*(inner(dx,dx) + inner(dy,dy) + inner(dz,dz)))/S;
357     double KEn = real(0.5*inner(ds,ds))/S;
358     double E = (KEe + KEn + PEe + PEn + PEf);
359 
360     dx.clear(); dy.clear(); dz.clear(); ds.clear(); Vepsi.clear(); Vepsi.clear(); Vfpsi.clear(); // To free memory on return
361     world.gop.fence();
362     if (world.rank() == 0) {
363         printf("energy=%.6f overlap=%.6f KEe=%.6f KEn=%.6f PEe=%.6f PEn=%.6f PEf=%.6f\n", E, S, KEe, KEn, PEe, PEn, PEf);
364      }
365 
366     return myreal(E);
367 }
368 
fred(const coordT & r)369 double fred(const coordT& r) {
370     static const double a = 0.1;
371     double rsq = r[0]*r[0]+r[1]*r[1]+r[2]*r[2]+r[3]*r[3];
372     return 10.0*exp(-a*rsq);
373 }
374 
delsqfred(const coordT & r)375 double delsqfred(const coordT& r) {
376     static const double a = 0.1;
377     double rsq = r[0]*r[0]+r[1]*r[1]+r[2]*r[2]+r[3]*r[3];
378     return (4.0*a*a*rsq - 4.0*2.0*a)*10.0*exp(-a*rsq);
379 }
380 
381 
testbsh(World & world)382 void testbsh(World& world) {
383     double mu = 0.3;
384     functionT test = factoryT(world).f(fred); test.truncate();
385     functionT rhs = (mu*mu)*test - functionT(factoryT(world).f(delsqfred));
386     rhs.truncate();
387     operatorT op = BSHOperator<4>(world, mu, 1e-2, FunctionDefaults<4>::get_thresh());
388 
389     functionT result = apply(op,rhs);
390 
391     double err = (result - test).norm2();
392     print("ERR", err);
393     result.reconstruct();
394     for (int i=-100; i<=100; i++) {
395         coordT r(i*0.01*5.0);
396         double ok = fred(r);
397         double bad = result(r);
398         print(r[0], ok, bad, ok-bad, ok?bad/ok:0.0);
399     }
400 }
401 
converge(World & world,functionT & potn,functionT & pote,functionT & pot,functionT & psi,double & eps)402 void converge(World& world, functionT& potn, functionT& pote, functionT& pot, functionT& psi, double& eps) {
403     functionT zero = factoryT(world);
404     for (int iter=0; iter<205; iter++) {
405         if (world.rank() == 0) print("beginning iter", iter, wall_time());
406 
407         functionT Vpsi = pot*psi;// - 0.5*psi; // TRY SHIFTING POTENTIAL AND ENERGY DOWN
408         //eps -= 0.5;
409 
410         operatorT op = BSHOperator<4>(world, sqrt(-2*eps), param.cut, param.thresh);
411         if (world.rank() == 0) print("made V*psi", wall_time());
412         Vpsi.scale(-2.0).truncate();
413         if (world.rank() == 0) print("tryuncated V*psi", wall_time());
414         functionT tmp = apply(op,Vpsi).truncate(param.thresh);
415         if (world.rank() == 0) print("applied operator", wall_time());
416         double norm = tmp.norm2();
417         functionT r = tmp-psi;
418         double rnorm = r.norm2();
419         double eps_new = eps - 0.5*inner(Vpsi,r)/(norm*norm);
420         if (world.rank() == 0) {
421             print("norm=",norm," eps=",eps," err(psi)=",rnorm," err(eps)=",eps_new-eps);
422         }
423 
424         tmp.scale(1.0/norm);
425 
426         double d = 0.3;
427         psi = tmp*d + psi*(1.0-d);
428 
429         psi.scale(1.0/psi.norm2());
430 
431         eps = eps_new;
432         energy(world, psi, pote, potn, zero);
433 
434         if (rnorm < std::max(1e-5,param.thresh)) break;
435     }
436 }
437 
APPLY(const complex_operatorT * Ge,const complex_operatorT * Gn,const complex_functionT & psi)438 complex_functionT APPLY(const complex_operatorT* Ge, const complex_operatorT* Gn, const complex_functionT& psi) {
439     complex_functionT r = psi;  // Shallow copy violates constness !!!!!!!!!!!!!!!!!
440 
441     r.reconstruct();
442     r.broaden();
443     r.broaden();
444     r.broaden();
445     r.broaden();
446 
447     r = apply_1d_realspace_push(*Gn, r, 3); r.sum_down();
448     r = apply_1d_realspace_push(*Ge, r, 2); r.sum_down();
449     r = apply_1d_realspace_push(*Ge, r, 1); r.sum_down();
450     r = apply_1d_realspace_push(*Ge, r, 0); r.sum_down();
451 
452     return r;
453 }
454 
trotter(World & world,const complex_functionT & expV,const complex_operatorT * Ge,const complex_operatorT * Gn,const complex_functionT & psi0)455 complex_functionT trotter(World& world,
456                           const complex_functionT& expV,
457                           const complex_operatorT* Ge,
458                           const complex_operatorT* Gn,
459                           const complex_functionT& psi0) {
460     //    psi(t) = exp(-i*T*t/2) exp(-i*V(t/2)*t) exp(-i*T*t/2) psi(0)
461 
462     complex_functionT psi1;
463 
464     unsigned long size = psi0.size();
465     if (world.rank() == 0) print("APPLYING G", size);
466     psi1 = APPLY(Ge,Gn,psi0);  psi1.truncate();  size = psi1.size();
467     if (world.rank() == 0) print("APPLYING expV", size);
468     psi1 = expV*psi1;      psi1.truncate();  size = psi1.size();
469 
470     pmapT oldpmap = FunctionDefaults<4>::get_pmap();
471     LoadBalanceDeux<4> lb(world);
472     lb.add_tree(psi1, lbcost<double_complex,4>(1.0,1.0));
473     FunctionDefaults<4>::set_pmap(lb.load_balance());
474     psi1 = copy(psi1, FunctionDefaults<4>::get_pmap(), true);
475 
476     if (world.rank() == 0) print("APPLYING G again", size);
477     psi1 = APPLY(Ge,Gn,psi1);  psi1.truncate(param.thresh);  size = psi1.size();
478     if (world.rank() == 0) print("DONE", size);
479 
480     FunctionDefaults<4>::set_pmap(oldpmap);
481     psi1 = copy(psi1, oldpmap, true);
482 
483     return psi1;
484 }
485 
486 template<typename T, int NDIM>
487 struct unaryexp {};
488 
489 
490 template<int NDIM>
491 struct unaryexp<double_complex,NDIM> {
operator ()unaryexp492     void operator()(const Key<NDIM>& key, Tensor<double_complex>& t) const {
493         //vzExp(t.size, t.ptr(), t.ptr());
494         UNARY_OPTIMIZED_ITERATOR(double_complex, t, *_p0 = exp(*_p0););
495     }
496     template <typename Archive>
serializeunaryexp497     void serialize(Archive& ar) {}
498 };
499 
500 
501 // Returns exp(-I*t*V)
make_exp(double t,const functionT & v)502 complex_functionT make_exp(double t, const functionT& v) {
503     v.reconstruct();
504     complex_functionT expV = double_complex(0.0,-t)*v;
505     expV.unaryop(unaryexp<double_complex,4>());
506     return expV;
507 }
508 
print_stats_header(World & world)509 void print_stats_header(World& world) {
510     if (world.rank() == 0) {
511         printf("  step       time            field           energy            norm           overlap0         x-dipole         y-dipole         z-dipole           <R>         wall-time(s)\n");
512         printf("------- ---------------- ---------------- ---------------- ---------------- ---------------- ---------------- ---------------- ---------------- ---------------- ------------\n");
513     }
514 }
515 
print_stats(World & world,int step,double t,const functionT & pote,const functionT & potn,const functionT & potf,const functionT & x,const functionT & y,const functionT & z,const functionT & R,const complex_functionT & psi0,const complex_functionT & psi)516 void print_stats(World& world, int step, double t, const functionT& pote,  const functionT& potn, const functionT& potf,
517                  const functionT& x, const functionT& y, const functionT& z, const functionT& R,
518                  const complex_functionT& psi0, const complex_functionT& psi) {
519     double start = wall_time();
520     double norm = psi.norm2();
521     double current_energy = energy(world, psi, pote, potn, potf);
522     double xdip = real(inner(psi, x*psi))/(norm*norm);
523     double ydip = real(inner(psi, y*psi))/(norm*norm);
524     double zdip = real(inner(psi, z*psi))/(norm*norm);
525     double Ravg = real(inner(psi, R*psi))/(norm*norm);
526     double overlap0 = std::abs(psi.inner(psi0))/norm;
527     if (world.rank() == 0) {
528         printf("%7d %16.8e %16.8e %16.8e %16.8e %16.8e %16.8e %16.8e %16.8e %16.8e %9.1f\n", step, t, laser(t), current_energy, norm, overlap0, xdip, ydip, zdip, Ravg, wall_time());
529         printf("printing used %.1f\n", wall_time() - start);
530     }
531 }
532 
wave_function_filename(int step)533 const char* wave_function_filename(int step) {
534     static char fname[1024];
535     sprintf(fname, "%s-%5.5d", param.prefix.c_str(), step);
536     return fname;
537 }
538 
wave_function_small_plot_filename(int step)539 const char* wave_function_small_plot_filename(int step) {
540     static char fname[1024];
541     sprintf(fname, "%s-%5.5dS.dx", param.prefix.c_str(), step);
542     return fname;
543 }
544 
wave_function_large_plot_filename(int step)545 const char* wave_function_large_plot_filename(int step) {
546     static char fname[1024];
547     sprintf(fname, "%s-%5.5dL.dx", param.prefix.c_str(), step);
548     return fname;
549 }
550 
wave_function_load(World & world,int step)551 complex_functionT wave_function_load(World& world, int step) {
552     complex_functionT psi;
553     archive::ParallelInputArchive ar(world, wave_function_filename(step));
554     ar & psi;
555     return psi;
556 }
557 
wave_function_store(World & world,int step,const complex_functionT & psi)558 void wave_function_store(World& world, int step, const complex_functionT& psi) {
559     archive::ParallelOutputArchive ar(world, wave_function_filename(step), param.nio);
560     ar & psi;
561 }
562 
wave_function_exists(World & world,int step)563 bool wave_function_exists(World& world, int step) {
564     return archive::ParallelInputArchive::exists(world, wave_function_filename(step));
565 }
566 
567 
loadbal(World & world,functionT & pote,functionT & potn,functionT & pot,functionT & vt,complex_functionT & psi,complex_functionT & psi0,functionT & x,functionT & y,functionT & z,functionT & R)568 void loadbal(World& world,
569              functionT& pote, functionT& potn, functionT& pot, functionT& vt,
570              complex_functionT& psi, complex_functionT& psi0,
571              functionT& x, functionT& y, functionT& z, functionT& R) {
572     if (world.size() < 2) return;
573     if (world.rank() == 0) print("starting LB");
574     LoadBalanceDeux<4> lb(world);
575     lb.add_tree(vt, lbcost<double,4>(1.0,1.0));
576     lb.add_tree(psi, lbcost<double_complex,4>(10.0,5.0));
577     FunctionDefaults<4>::redistribute(world,lb.load_balance(2.0,false));
578     world.gop.fence();
579 }
580 
581 template <typename T>
initial_loadbal(World & world,functionT & pote,functionT & potn,functionT & pot,Function<T,4> & psi)582 void initial_loadbal(World& world,
583                      functionT& pote, functionT& potn, functionT& pot,
584                      Function<T,4>& psi) {
585     if (world.size() < 2) return;
586     if (world.rank() == 0) print("starting initial LB");
587     LoadBalanceDeux<4> lb(world);
588     lb.add_tree(pote, lbcost<double,4>(1.0,1.0));
589     lb.add_tree(potn, lbcost<double,4>(1.0,1.0));
590     lb.add_tree(psi, lbcost<T,4>(10.0,5.0));
591     FunctionDefaults<4>::redistribute(world,lb.load_balance(2.0,false));
592     world.gop.fence();
593 }
594 
595 
596 // Evolve the wave function in real time starting from given time step on disk
propagate(World & world,functionT & pote,functionT & potn,functionT & pot,int step0)597 void propagate(World& world, functionT& pote, functionT& potn, functionT& pot, int step0) {
598     double ctarget = 5.0/param.cut;
599     double c = 1.72*ctarget;   // This for 10^4 steps
600     double tcrit = 2*constants::pi/(c*c);
601 
602     double time_step = tcrit * param.tScale;
603 
604     zero_field_time = 20.0*time_step;
605 
606     int nstep = int((param.target_time + zero_field_time)/time_step + 1);
607 
608     // Ensure everyone has the same data
609     world.gop.broadcast(c);
610     world.gop.broadcast(time_step);
611     world.gop.broadcast(nstep);
612 
613     // Free particle propagator
614     complex_operatorT* Ge = qm_1d_free_particle_propagator(param.k, c, 0.5*time_step, 2.0*param.L);
615     complex_operatorT* Gn = qm_1d_free_particle_propagator(param.k, c, 0.5*time_step,  s0+param.L);
616 
617 
618     // Dipole moment functions for laser field and for printing statistics
619     functionT x = factoryT(world).f(xdipole);
620     functionT y = factoryT(world).f(ydipole);
621     functionT z = factoryT(world).f(zdipole);
622     functionT R = factoryT(world).f(bond_length);
623 
624     // Wave function at time t=0 for printing statistics
625     complex_functionT psi0 = wave_function_load(world, 0);
626     initial_loadbal(world, pote, potn, pot, psi0);
627 
628     int step = step0;  // The current step
629     double t = step0 * time_step - zero_field_time;        // The current time
630     complex_functionT psi = wave_function_load(world, step); // The wave function at time t
631     functionT vt = pot+laser(t)*x; // The total potential at time t
632 
633     if (world.rank() == 0) {
634         printf("\n");
635         printf("        Evolution parameters\n");
636         printf("       --------------------\n");
637         printf("     bandlimit = %.2f\n", ctarget);
638         printf(" eff-bandlimit = %.2f\n", c);
639         printf("         tcrit = %.6f\n", tcrit);
640         printf("     time step = %.6f\n", time_step);
641         printf(" no field time = %.6f\n", zero_field_time);
642         printf("   target time = %.2f\n", param.target_time);
643         printf("         nstep = %d\n", nstep);
644         printf("\n");
645         printf("  restart step = %d\n", step0);
646         printf("  restart time = %.6f\n", t);
647         printf("\n");
648     }
649 
650     print_stats_header(world);
651     print_stats(world, step, t, pote, potn, laser(t)*x, x, y, z, R, psi0, psi0);
652     world.gop.fence();
653 
654     psi.truncate();
655 
656     while (step < nstep) {
657         if (step < 2 || (step%param.nloadbal) == 0)
658             loadbal(world, pote, potn, pot, vt, psi, psi0, x, y, z, R);
659 
660         long depth = psi.max_depth(); long size=psi.size();
661         if (world.rank() == 0) print("depth size", depth, size);
662 
663         // Make the potential at time t + step/2
664         functionT vhalf = pot + laser(t+0.5*time_step)*x;
665 
666         // Apply Trotter to advance from time t to time t+step
667         complex_functionT expV = make_exp(time_step, vhalf);
668         psi = trotter(world, expV, Ge, Gn, psi);
669 
670         // Update counters, print info, dump/plot as necessary
671         step++;
672         t += time_step;
673         vt = pot+laser(t)*x;
674 
675         if ((step%param.nprint)==0 || step==nstep)
676         print_stats(world, step, t, pote, potn, laser(t)*x, x, y, z, R, psi0, psi);
677 
678         if ((step%param.ndump) == 0 || step==nstep) {
679             double start = wall_time();
680             wave_function_store(world, step, psi);
681             // Update the restart file for automatic restarting
682             if (world.rank() == 0) {
683                 std::ofstream("restart4") << step << std::endl;
684                 print("dumping took", wall_time()-start);
685             }
686             world.gop.fence();
687         }
688     }
689 }
690 
doit(World & world)691 void doit(World& world) {
692     std::cout.precision(8);
693 
694     if (world.rank() == 0) param.read("input4");
695     world.gop.broadcast_serializable(param, 0);
696 
697     FunctionDefaults<4>::set_k(param.k);                        // Wavelet order
698     FunctionDefaults<4>::set_thresh(param.thresh*param.safety);       // Accuracy
699     FunctionDefaults<4>::set_initial_level(4);
700 
701     real_tensor cell(4,2);
702     cell(0,0)=-param.L; cell(0,1)=param.L;
703     cell(1,0)=-param.L; cell(1,1)=param.L;
704     cell(2,0)=-param.L; cell(2,1)=param.L;
705     cell(3,0)=-s0;      cell(3,1)=param.L;
706     FunctionDefaults<4>::set_cell(cell);
707     //FunctionDefaults<4>::set_cubic_cell(-param.L,param.L);
708     FunctionDefaults<4>::set_apply_randomize(true);
709     FunctionDefaults<4>::set_autorefine(false);
710     FunctionDefaults<4>::set_truncate_mode(1);
711     FunctionDefaults<4>::set_truncate_on_project(true);
712     FunctionDefaults<4>::set_pmap(pmapT(new SimplePmap< Key<4> >(world)));
713 
714     // Read restart information
715     int step0;               // Initial time step ... filenames are <prefix>-<step0>
716     if (world.rank() == 0) std::ifstream("restart4") >> step0;
717     world.gop.broadcast(step0);
718 
719     bool exists = wave_function_exists(world, step0);
720 
721     if (world.rank() == 0) {
722         print("EXISTS",exists,"STEP0",step0);
723         std::ofstream out("plot.dat");
724         for (int i=-100; i<=100; i++) {
725             double x = i*0.01*param.L;
726             coordT rn(0.0), re(0.0);
727             rn[3]=x;
728             re[2]=x;
729             double vn = Vn(rn);
730             double ve = Ve(re);
731             double pn = guess(rn);
732             double pe = guess(re);
733             out << x << " " << vn << " " << ve << " " << pn << " " << pe << std::endl;
734         }
735         out.close();
736     }
737 
738     // Make the potential
739     if (world.rank() == 0) print("COMPRESSING Vn",wall_time());
740     functionT potn = factoryT(world).f(Vn);  potn.truncate();
741     if (world.rank() == 0) print("COMPRESSING Ve",wall_time());
742     functionT pote = factoryT(world).f(Ve);  pote.truncate();
743     functionT pot = potn + pote;
744 
745     //LoadBalanceDeux<4> lb(world);
746     //lb.add_tree(pot, lbcost<double,4>());
747     //FunctionDefaults<4>::redistribute(world,lb.load_balance(2.0,false));
748     //world.gop.fence();
749 
750     if (!exists) {
751         if (step0 == 0) {
752             if (world.rank() == 0) print("Computing initial ground state wavefunction", wall_time());
753             functionT psi = factoryT(world).f(guess);
754             double norm0 = psi.norm2();
755             psi.scale(1.0/norm0);
756             psi.truncate();
757             if (world.rank() == 0) print("computed norm", norm0, "at", wall_time());
758             norm0 = psi.norm2();
759             psi.scale(1.0/norm0);
760 
761             initial_loadbal(world, pote, potn, pot, psi);
762 
763             double eps = energy(world, psi, pote, potn, functionT(factoryT(world)));
764             if (world.rank() == 0) print("guess energy", eps, wall_time());
765             converge(world, potn, pote, pot, psi, eps);
766 
767             psi.truncate(param.thresh);
768 
769             complex_functionT psic = double_complex(1.0,0.0)*psi;
770             wave_function_store(world, 0, psic);
771         }
772         else {
773             if (world.rank() == 0) {
774                 print("The requested restart was not found ---", step0);
775                 error("restart failed", 0);
776             }
777             world.gop.fence();
778         }
779     }
780 
781     propagate(world, pote, potn, pot, step0);
782 }
783 
main(int argc,char ** argv)784 int main(int argc, char** argv) {
785     initialize(argc,argv);
786     World world(SafeMPI::COMM_WORLD);
787 
788     startup(world,argc,argv);
789 
790     try {
791         doit(world);
792     } catch (const SafeMPI::Exception& e) {
793         print(e); std::cout.flush();
794         error("caught an MPI exception");
795     } catch (const madness::MadnessException& e) {
796         print(e); std::cout.flush();
797         error("caught a MADNESS exception");
798     } catch (const madness::TensorException& e) {
799         print(e); std::cout.flush();
800         error("caught a Tensor exception");
801     } catch (const char* s) {
802         print(s); std::cout.flush();
803         error("caught a c-string exception");
804     } catch (const std::string& s) {
805         print(s); std::cout.flush();
806         error("caught a string (class) exception");
807     } catch (const std::exception& e) {
808         print(e.what()); std::cout.flush();
809         error("caught an STL exception");
810     } catch (...) {
811         error("caught unhandled exception");
812     }
813 
814 
815     world.gop.fence();
816 
817     print_stats(world);
818     finalize();
819     return 0;
820 }
821 
822 
823