1//////////////////////////////////////////////////////////////////////
2// This file is distributed under the University of Illinois/NCSA Open Source
3// License.  See LICENSE file in top directory for details.
4//
5// Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6//
7// File developed by:
8// Miguel A. Morales, moralessilva2@llnl.gov
9//    Lawrence Livermore National Laboratory
10//
11// File created by:
12// Miguel A. Morales, moralessilva2@llnl.gov
13//    Lawrence Livermore National Laboratory
14////////////////////////////////////////////////////////////////////////////////
15
16#include <vector>
17#include <map>
18#include <string>
19#include <iostream>
20#include <tuple>
21
22#include "AFQMC/config.h"
23#include "AFQMC/Utilities/Utils.hpp"
24#include "AFQMC/Propagators/WalkerSetUpdate.hpp"
25#include "AFQMC/Walkers/WalkerConfig.hpp"
26#include "AFQMC/Numerics/ma_blas.hpp"
27
28#include "Utilities/Timer.h"
29
30namespace qmcplusplus
31{
32namespace afqmc
33{
34/*
35 * Propagates the walker population nsteps forward with a fixed vbias (from the initial
36 * configuration).
37 */
38template<class WlkSet>
39void AFQMCDistributedPropagatorDistCV::step(int nsteps_, WlkSet& wset, RealType Eshift, RealType dt)
40{
41  using ma::axpy;
42  using std::copy_n;
43  using std::fill_n;
44  AFQMCTimers[setup_timer].get().start();
45  const SPComplexType one(1.), zero(0.);
46  auto walker_type        = wset.getWalkerType();
47  int nsteps              = nsteps_;
48  int nwalk               = wset.size();
49  RealType sqrtdt         = std::sqrt(dt);
50  long Gsize              = wfn.size_of_G_for_vbias();
51  const int globalnCV     = wfn.global_number_of_cholesky_vectors();
52  const int localnCV      = wfn.local_number_of_cholesky_vectors();
53  const int global_origin = wfn.global_origin_cholesky_vector();
54  const int nnodes        = TG.getNGroupsPerTG();
55  const int node_number   = TG.getLocalGroupNumber();
56  // if transposed_XXX_=true  --> XXX[nwalk][...],
57  // if transposed_XXX_=false --> XXX[...][nwalk]
58  auto vhs_ext   = iextensions<2u>{NMO * NMO, nwalk * nsteps};
59  auto vhs3d_ext = iextensions<3u>{NMO, NMO, nwalk * nsteps};
60  auto G_ext     = iextensions<2u>{Gsize, nwalk};
61  if (transposed_vHS_)
62  {
63    vhs_ext   = iextensions<2u>{nwalk * nsteps, NMO * NMO};
64    vhs3d_ext = iextensions<3u>{nwalk * nsteps, NMO, NMO};
65  }
66  if (transposed_G_)
67    G_ext = iextensions<2u>{nwalk, Gsize};
68
69  if (MFfactor.size(0) != nsteps || MFfactor.size(1) != nwalk)
70    MFfactor = CMatrix({long(nsteps), long(nwalk)});
71  if (hybrid_weight.size(0) != nsteps || hybrid_weight.size(1) != nwalk)
72    hybrid_weight = CMatrix({long(nsteps), long(nwalk)});
73  if (new_overlaps.size(0) != nwalk)
74    new_overlaps = CVector(iextensions<1u>{nwalk});
75  if (new_energies.size(0) != nwalk || new_energies.size(1) != 3)
76    new_energies = CMatrix({long(nwalk), 3});
77
78  //  Temporary memory usage summary:
79  //  G_for_vbias:     [ Gsize * nwalk ] (2 copies)
80  //  vbias:           [ localnCV * nwalk ]
81  //  X:               [ localnCV * nwalk * nstep ]
82  //  vHS:             [ NMO*NMO * nwalk * nstep ] (3 copies)
83  // memory_needs: nwalk * ( 2*nsteps + Gsize + localnCV*(nstep+1) + NMO*NMO*nstep )
84
85  // if timestep changed, recalculate one body propagator
86  if (std::abs(dt - old_dt) > 1e-6)
87    generateP1(dt, walker_type);
88  TG.local_barrier();
89
90  StaticMatrix vrecv_buff(vhs_ext, buffer_manager.get_generator().template get_allocator<ComplexType>());
91  SPCMatrix_ref vrecv(sp_pointer(make_device_ptr(vrecv_buff.origin())), vhs_ext);
92
93  { // using scope to control lifetime of StaticArrays, avoiding unnecesary buffer space
94
95    Static3Tensor globalMFfactor({nnodes, nsteps, nwalk},
96                                 buffer_manager.get_generator().template get_allocator<ComplexType>());
97    Static3Tensor globalhybrid_weight({nnodes, nsteps, nwalk},
98                                      buffer_manager.get_generator().template get_allocator<ComplexType>());
99    StaticSPMatrix Gwork(G_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
100
101    // 1. Calculate Green function for all (local) walkers
102    AFQMCTimers[G_for_vbias_timer].get().start();
103#if defined(MIXED_PRECISION)
104    { // control scope of Gc
105      int Gak0, GakN;
106      std::tie(Gak0, GakN) = FairDivideBoundary(TG.getLocalTGRank(), int(Gwork.num_elements()), TG.getNCoresPerTG());
107      StaticMatrix Gc(G_ext, buffer_manager.get_generator().template get_allocator<ComplexType>());
108      wfn.MixedDensityMatrix_for_vbias(wset, Gc);
109      copy_n_cast(make_device_ptr(Gc.origin()) + Gak0, GakN - Gak0, make_device_ptr(Gwork.origin()) + Gak0);
110    }
111    TG.local_barrier();
112#else
113    wfn.MixedDensityMatrix_for_vbias(wset, Gwork);
114#endif
115    AFQMCTimers[G_for_vbias_timer].get().stop();
116
117
118    StaticSPMatrix Grecv(G_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
119    StaticSPMatrix vbias({long(localnCV), long(nwalk)},
120                         buffer_manager.get_generator().template get_allocator<SPComplexType>());
121    StaticSPMatrix X({long(localnCV), long(nwalk * nsteps)},
122                     buffer_manager.get_generator().template get_allocator<SPComplexType>());
123#if defined(MIXED_PRECISION)
124    // in MIXED_PRECISION, use second half of vrecv_buff for vsend
125    SPCMatrix_ref vsend(sp_pointer(make_device_ptr(vrecv_buff.origin())) + vrecv_buff.num_elements(), vhs_ext);
126#else
127    StaticSPMatrix vsend(vhs_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
128#endif
129    StaticSPMatrix vHS(vhs_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
130
131    // partition G and v for communications: all cores communicate a piece of the matrix
132    int vak0, vakN;
133    int Gak0, GakN;
134    std::tie(Gak0, GakN) = FairDivideBoundary(TG.getLocalTGRank(), int(Gwork.num_elements()), TG.getNCoresPerTG());
135    std::tie(vak0, vakN) = FairDivideBoundary(TG.getLocalTGRank(), int(vHS.num_elements()), TG.getNCoresPerTG());
136    MPI_Send_init(to_address(Gwork.origin()) + Gak0, (GakN - Gak0) * sizeof(SPComplexType), MPI_CHAR, TG.prev_core(),
137                  3456, TG.TG().get(), &req_Gsend);
138    MPI_Recv_init(to_address(Grecv.origin()) + Gak0, (GakN - Gak0) * sizeof(SPComplexType), MPI_CHAR, TG.next_core(),
139                  3456, TG.TG().get(), &req_Grecv);
140    MPI_Send_init(to_address(vsend.origin()) + vak0, (vakN - vak0) * sizeof(SPComplexType), MPI_CHAR, TG.prev_core(),
141                  5678, TG.TG().get(), &req_vsend);
142    MPI_Recv_init(to_address(vrecv.origin()) + vak0, (vakN - vak0) * sizeof(SPComplexType), MPI_CHAR, TG.next_core(),
143                  5678, TG.TG().get(), &req_vrecv);
144
145    fill_n(make_device_ptr(vsend.origin()) + vak0, (vakN - vak0), zero);
146
147    // are we back propagating?
148    int bp_step = wset.getBPPos(), bp_max = wset.NumBackProp();
149    bool realloc(false);
150    int xx(0);
151    if (bp_step >= 0 && bp_step < bp_max)
152    {
153      xx = 1;
154      size_t m_(globalnCV * nwalk * nsteps * 2);
155      if (bpX.num_elements() < m_)
156      {
157        realloc = true;
158        bpX     = mpi3SPCVector(iextensions<1u>{m_}, shared_allocator<SPComplexType>{TG.TG_local()});
159        if (TG.TG_local().root())
160          fill_n(bpX.origin(), bpX.num_elements(), SPComplexType(0.0));
161      }
162    }
163    stdSPCMatrix_ref Xsend(to_address(bpX.origin()), {long(globalnCV * xx), nwalk * nsteps});
164    stdSPCMatrix_ref Xrecv(Xsend.origin() + Xsend.num_elements(), {long(globalnCV * xx), nwalk * nsteps});
165    int Xg0, XgN;
166    std::tie(Xg0, XgN) = FairDivideBoundary(TG.getLocalTGRank(), globalnCV * nwalk * nsteps, TG.getNCoresPerTG());
167    int Xl0, XlN;
168    std::tie(Xl0, XlN) = FairDivideBoundary(TG.getLocalTGRank(), localnCV * nwalk * nsteps, TG.getNCoresPerTG());
169    int cv0, cvN;
170    std::tie(cv0, cvN) = FairDivideBoundary(TG.getLocalTGRank(), localnCV, TG.getNCoresPerTG());
171
172    if (bp_step >= 0 && bp_step < bp_max)
173    {
174      MPI_Send_init(Xsend.origin() + Xg0, (XgN - Xg0) * sizeof(SPComplexType), MPI_CHAR, TG.prev_core(), 3456, TG.TG().get(),
175                    &req_X2send);
176      MPI_Recv_init(Xrecv.origin() + Xg0, (XgN - Xg0) * sizeof(SPComplexType), MPI_CHAR, TG.next_core(), 3456, TG.TG().get(),
177                    &req_X2recv);
178    }
179
180    TG.local_barrier();
181    AFQMCTimers[setup_timer].get().stop();
182
183    MPI_Status st;
184
185    for (int k = 0; k < nnodes; ++k)
186    {
187      // 2. wait for communication of previous step
188      AFQMCTimers[vHS_comm_overhead_timer].get().start();
189      if (k > 0)
190      {
191        MPI_Wait(&req_Grecv, &st);
192        MPI_Wait(&req_Gsend, &st); // need to wait for Gsend in order to overwrite Gwork
193        copy_n(make_device_ptr(Grecv.origin()) + Gak0, GakN - Gak0, make_device_ptr(Gwork.origin()) + Gak0);
194        TG.local_barrier();
195      }
196
197      // 3. setup next communication
198      if (k < nnodes - 1)
199      {
200        MPI_Start(&req_Gsend);
201        MPI_Start(&req_Grecv);
202      }
203      AFQMCTimers[vHS_comm_overhead_timer].get().stop();
204
205      // calculate vHS contribution from this node
206      // 4a. Calculate vbias for initial configuration
207      AFQMCTimers[vbias_timer].get().start();
208      wfn.vbias(Gwork, vbias, sqrtdt);
209      AFQMCTimers[vbias_timer].get().stop();
210
211      // 4b. Assemble X(nCV,nsteps,nwalk)
212      AFQMCTimers[assemble_X_timer].get().start();
213      int q = (node_number + k) % nnodes;
214      assemble_X(nsteps, nwalk, sqrtdt, X, vbias, MFfactor, hybrid_weight);
215      copy_n(make_device_ptr(MFfactor.origin()), MFfactor.num_elements(), make_device_ptr(globalMFfactor[q].origin()));
216      copy_n(make_device_ptr(hybrid_weight.origin()), hybrid_weight.num_elements(),
217             make_device_ptr(globalhybrid_weight[q].origin()));
218      TG.local_barrier();
219      AFQMCTimers[assemble_X_timer].get().stop();
220      if (bp_step >= 0 && bp_step < bp_max)
221      {
222        // receive X
223        if (k > 0)
224        {
225          MPI_Wait(&req_X2recv, &st);
226          MPI_Wait(&req_X2send, &st);
227          copy_n(Xrecv.origin() + Xg0, XgN - Xg0, Xsend.origin() + Xg0);
228          TG.local_barrier();
229        }
230        // accumulate
231        copy_n(make_device_ptr(X[cv0].origin()), nwalk * nsteps * (cvN - cv0), Xsend[global_origin + cv0].origin());
232        TG.local_barrier();
233        // start X communication
234        MPI_Start(&req_X2send);
235        MPI_Start(&req_X2recv);
236      }
237
238      // 4c. Calculate vHS(M*M,nsteps,nwalk)
239      AFQMCTimers[vHS_timer].get().start();
240      wfn.vHS(X, vHS, sqrtdt);
241      TG.local_barrier();
242      AFQMCTimers[vHS_timer].get().stop();
243
244      AFQMCTimers[vHS_comm_overhead_timer].get().start();
245      // 5. receive v
246      if (k > 0)
247      {
248        MPI_Wait(&req_vrecv, &st);
249        MPI_Wait(&req_vsend, &st);
250        copy_n(make_device_ptr(vrecv.origin()) + vak0, vakN - vak0, make_device_ptr(vsend.origin()) + vak0);
251      }
252
253      // 6. add local contribution to vsend
254      axpy(vakN - vak0, one, make_device_ptr(vHS.origin()) + vak0, 1, make_device_ptr(vsend.origin()) + vak0, 1);
255
256      // 7. start v communication
257      MPI_Start(&req_vsend);
258      MPI_Start(&req_vrecv);
259      TG.local_barrier();
260      AFQMCTimers[vHS_comm_overhead_timer].get().stop();
261    }
262
263    // after the wait, vrecv ( and by extention vHS3D ) has the final vHS for the local walkers
264    AFQMCTimers[vHS_comm_overhead_timer].get().start();
265    MPI_Wait(&req_vrecv, &st);
266    MPI_Wait(&req_vsend, &st);
267    MPI_Wait(&req_X2recv, &st);
268    MPI_Wait(&req_X2send, &st);
269
270    MPI_Request_free(&req_Grecv);
271    MPI_Request_free(&req_Gsend);
272    MPI_Request_free(&req_vrecv);
273    MPI_Request_free(&req_vsend);
274
275    // store fields in walker
276    if (bp_step >= 0 && bp_step < bp_max)
277    {
278      MPI_Request_free(&req_X2recv);
279      MPI_Request_free(&req_X2send);
280
281      int cvg0, cvgN;
282      std::tie(cvg0, cvgN) = FairDivideBoundary(TG.getLocalTGRank(), globalnCV, TG.getNCoresPerTG());
283      for (int ni = 0; ni < nsteps; ni++)
284      {
285        if (bp_step < bp_max)
286        {
287          auto&& V(*wset.getFields(bp_step));
288          if (nsteps == 1)
289          {
290            copy_n(Xrecv[cvg0].origin(), nwalk * (cvgN - cvg0), V[cvg0].origin());
291            ma::scal(SPComplexType(sqrtdt), V.sliced(cvg0, cvgN));
292          }
293          else
294          {
295            ma::add(SPComplexType(0.0), V.sliced(cvg0, cvgN), SPComplexType(sqrtdt),
296                    Xrecv({cvg0, cvgN}, {ni * nwalk, (ni + 1) * nwalk}), V.sliced(cvg0, cvgN));
297          }
298          bp_step++;
299        }
300      }
301      TG.TG_local().barrier();
302    }
303    // reduce MF and HWs
304    if (TG.TG().size() > 1)
305    {
306      TG.TG().all_reduce_in_place_n(to_address(globalMFfactor.origin()), globalMFfactor.num_elements(), std::plus<>());
307      TG.TG().all_reduce_in_place_n(to_address(globalhybrid_weight.origin()), globalhybrid_weight.num_elements(),
308                                    std::plus<>());
309    }
310
311    // copy from global to local array
312    copy_n(make_device_ptr(globalMFfactor[node_number].origin()), MFfactor.num_elements(),
313           make_device_ptr(MFfactor.origin()));
314    copy_n(make_device_ptr(globalhybrid_weight[node_number].origin()), hybrid_weight.num_elements(),
315           make_device_ptr(hybrid_weight.origin()));
316
317    TG.local_barrier();
318    AFQMCTimers[vHS_comm_overhead_timer].get().stop();
319  }
320
321#if defined(MIXED_PRECISION)
322  // is this clever or dirsty? seems to work well and saves memory!
323  TG.local_barrier();
324  using qmcplusplus::afqmc::inplace_cast;
325  if (TG.TG_local().root())
326    inplace_cast<SPComplexType, ComplexType>(vrecv.origin(), vrecv.num_elements());
327  TG.local_barrier();
328#endif
329  C3Tensor_ref vHS3D(make_device_ptr(vrecv_buff.origin()), vhs3d_ext);
330
331  // From here on is similar to Shared
332  int nx = 1;
333  if (walker_type == COLLINEAR)
334    nx = 2;
335
336  // from now on, individual work on each walker/step
337  const int ntasks_per_core     = int(nx * nwalk) / TG.getNCoresPerTG();
338  const int ntasks_total_serial = ntasks_per_core * TG.getNCoresPerTG();
339  const int nextra              = int(nx * nwalk) - ntasks_total_serial;
340
341  // each processor does ntasks_percore_serial overlaps serially
342  const int tk0 = TG.getLocalTGRank() * ntasks_per_core;
343  const int tkN = (TG.getLocalTGRank() + 1) * ntasks_per_core;
344
345  // make new communicator if nextra changed from last setting
346  reset_nextra(nextra);
347
348  for (int ni = 0; ni < nsteps_; ni++)
349  {
350    // 5. Propagate walkers
351    AFQMCTimers[propagate_timer].get().start();
352    if (nbatched_propagation != 0)
353    {
354      apply_propagators_batched('N', wset, ni, vHS3D);
355    }
356    else
357    {
358      apply_propagators('N', wset, ni, tk0, tkN, ntasks_total_serial, vHS3D);
359    }
360    AFQMCTimers[propagate_timer].get().stop();
361
362    // 6. Calculate local energy/overlap
363    AFQMCTimers[pseudo_energy_timer].get().start();
364    if (hybrid)
365    {
366      wfn.Overlap(wset, new_overlaps);
367    }
368    else
369    {
370      wfn.Energy(wset, new_energies, new_overlaps);
371    }
372    TG.local_barrier();
373    AFQMCTimers[pseudo_energy_timer].get().stop();
374
375    // 7. update weights/energy/etc, apply constrains/bounds/etc
376    AFQMCTimers[extra_timer].get().start();
377    if (TG.TG_local().root())
378    {
379      if (free_projection)
380      {
381        free_projection_walker_update(wset, dt, new_overlaps, MFfactor[ni], Eshift, hybrid_weight[ni], work);
382      }
383      else if (hybrid)
384      {
385        hybrid_walker_update(wset, dt, apply_constrain, importance_sampling, Eshift, new_overlaps, MFfactor[ni],
386                             hybrid_weight[ni], work);
387      }
388      else
389      {
390        local_energy_walker_update(wset, dt, apply_constrain, Eshift, new_overlaps, new_energies, MFfactor[ni],
391                                   hybrid_weight[ni], work);
392      }
393      if (wset.getBPPos() >= 0 && wset.getBPPos() < wset.NumBackProp())
394        wset.advanceBPPos();
395      if (wset.getBPPos() >= 0)
396        wset.advanceHistoryPos();
397    }
398    TG.local_barrier();
399    AFQMCTimers[extra_timer].get().stop();
400  }
401}
402
403// Distributed propagation based on collective communication
404/*
405 * Propagates the walker population nsteps forward with a fixed vbias (from the initial
406 * configuration).
407 */
408template<class WlkSet>
409void AFQMCDistributedPropagatorDistCV::step_collective(int nsteps_, WlkSet& wset, RealType Eshift, RealType dt)
410{
411  using ma::axpy;
412  using std::copy_n;
413  using std::fill_n;
414  AFQMCTimers[setup_timer].get().start();
415  const SPComplexType one(1.), zero(0.);
416  auto walker_type        = wset.getWalkerType();
417  int nsteps              = nsteps_;
418  int nwalk               = wset.size();
419  RealType sqrtdt         = std::sqrt(dt);
420  long Gsize              = wfn.size_of_G_for_vbias();
421  const int globalnCV     = wfn.global_number_of_cholesky_vectors();
422  const int localnCV      = wfn.local_number_of_cholesky_vectors();
423  const int global_origin = wfn.global_origin_cholesky_vector();
424  const int nnodes        = TG.getNGroupsPerTG();
425  const int ncores        = TG.getNCoresPerTG();
426  const int node_number   = TG.getLocalGroupNumber();
427  // if transposed_XXX_=true  --> XXX[nwalk][...],
428  // if transposed_XXX_=false --> XXX[...][nwalk]
429  auto vhs_ext   = iextensions<2u>{NMO * NMO, nwalk * nsteps};
430  auto vhs3d_ext = iextensions<3u>{NMO, NMO, nwalk * nsteps};
431  auto G_ext     = iextensions<2u>{Gsize, nwalk};
432  if (transposed_vHS_)
433  {
434    vhs_ext   = iextensions<2u>{nwalk * nsteps, NMO * NMO};
435    vhs3d_ext = iextensions<3u>{nwalk * nsteps, NMO, NMO};
436  }
437  if (transposed_G_)
438    G_ext = iextensions<2u>{nwalk, Gsize};
439
440  if (MFfactor.size(0) != nsteps || MFfactor.size(1) != nwalk)
441    MFfactor = CMatrix({long(nsteps), long(nwalk)});
442  if (hybrid_weight.size(0) != nsteps || hybrid_weight.size(1) != nwalk)
443    hybrid_weight = CMatrix({long(nsteps), long(nwalk)});
444  if (new_overlaps.size(0) != nwalk)
445    new_overlaps = CVector(iextensions<1u>{nwalk});
446  if (new_energies.size(0) != nwalk || new_energies.size(1) != 3)
447    new_energies = CMatrix({long(nwalk), 3});
448
449  //  Temporary memory usage summary:
450  //  G_for_vbias:     [ Gsize * nwalk ] (2 copies)
451  //  vbias:           [ localnCV * nwalk ]
452  //  X:               [ localnCV * nwalk * nstep ]
453  //  vHS:             [ NMO*NMO * nwalk * nstep ] (2 copies)
454  // memory_needs: nwalk * ( 2*nsteps + Gsize + localnCV*(nstep+1) + NMO*NMO*nstep )
455
456  // if timestep changed, recalculate one body propagator
457  if (std::abs(dt - old_dt) > 1e-6)
458    generateP1(dt, walker_type);
459  TG.local_barrier();
460
461  StaticMatrix vrecv_buff(vhs_ext, buffer_manager.get_generator().template get_allocator<ComplexType>());
462  C3Tensor_ref vHS3D(make_device_ptr(vrecv_buff.origin()), vhs3d_ext);
463  SPCMatrix_ref vrecv(sp_pointer(make_device_ptr(vrecv_buff.origin())), vhs_ext);
464
465  // scope controlling lifetime of temporary arrays
466  {
467    Static3Tensor globalMFfactor({nnodes, nsteps, nwalk},
468                                 buffer_manager.get_generator().template get_allocator<ComplexType>());
469    Static3Tensor globalhybrid_weight({nnodes, nsteps, nwalk},
470                                      buffer_manager.get_generator().template get_allocator<ComplexType>());
471    StaticSPMatrix Gstore(G_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
472
473    int Gak0, GakN;
474    std::tie(Gak0, GakN) = FairDivideBoundary(TG.getLocalTGRank(), int(Gstore.num_elements()), TG.getNCoresPerTG());
475
476    // 1. Calculate Green function for all (local) walkers
477    AFQMCTimers[G_for_vbias_timer].get().start();
478#if defined(MIXED_PRECISION)
479    {
480      StaticMatrix Gstore_(G_ext, buffer_manager.get_generator().template get_allocator<ComplexType>());
481      wfn.MixedDensityMatrix_for_vbias(wset, Gstore_);
482      copy_n_cast(make_device_ptr(Gstore_.origin()) + Gak0, GakN - Gak0, make_device_ptr(Gstore.origin()) + Gak0);
483    }
484#else
485    wfn.MixedDensityMatrix_for_vbias(wset, Gstore);
486#endif
487    TG.local_barrier();
488    AFQMCTimers[G_for_vbias_timer].get().stop();
489
490    StaticSPMatrix Gwork(G_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
491    StaticSPMatrix vbias({long(localnCV), long(nwalk)},
492                         buffer_manager.get_generator().template get_allocator<SPComplexType>());
493    StaticSPMatrix X({long(localnCV), long(nwalk * nsteps)},
494                     buffer_manager.get_generator().template get_allocator<SPComplexType>());
495// reusing second half of vrecv buffer reinterpreted as a SPComplex Array
496#if defined(MIXED_PRECISION)
497    SPCMatrix_ref vHS(sp_pointer(make_device_ptr(vrecv_buff.origin())) + vrecv_buff.num_elements(), vhs_ext);
498#else
499    StaticMatrix vHS(vhs_ext, buffer_manager.get_generator().template get_allocator<ComplexType>());
500#endif
501
502    // partition G and v for communications: all cores communicate a piece of the matrix
503    int vak0, vakN;
504    std::tie(vak0, vakN) = FairDivideBoundary(TG.getLocalTGRank(), int(vHS.num_elements()), TG.getNCoresPerTG());
505
506    // are we back propagating?
507    int bp_step = wset.getBPPos(), bp_max = wset.NumBackProp();
508    bool realloc(false);
509    int xx(0);
510    if (bp_step >= 0 && bp_step < bp_max)
511    {
512      xx = 1;
513      size_t m_(globalnCV * nwalk * nsteps);
514      if (bpX.num_elements() < m_)
515      {
516        realloc = true;
517        bpX     = mpi3SPCVector(iextensions<1u>{m_}, shared_allocator<SPComplexType>{TG.TG_local()});
518        if (TG.TG_local().root())
519          fill_n(bpX.origin(), bpX.num_elements(), SPComplexType(0.0));
520      }
521    }
522    stdSPCMatrix_ref Xrecv(to_address(bpX.origin()), {long(globalnCV * xx), nwalk * nsteps});
523    int Xg0, XgN;
524    std::tie(Xg0, XgN) = FairDivideBoundary(TG.getLocalTGRank(), globalnCV * nwalk * nsteps, TG.getNCoresPerTG());
525    int Xl0, XlN;
526    std::tie(Xl0, XlN) = FairDivideBoundary(TG.getLocalTGRank(), localnCV * nwalk * nsteps, TG.getNCoresPerTG());
527    int cv0, cvN;
528    std::tie(cv0, cvN) = FairDivideBoundary(TG.getLocalTGRank(), localnCV, TG.getNCoresPerTG());
529
530    // setup counts, it is possible to keep old one if nwalk*nsteps*(cvN-cv0) doesn't change.
531    // what can I assume?
532    if (bp_step >= 0 && bp_step < bp_max)
533    {
534      bpx_counts = TG.TG().all_gather_value(int(nwalk * nsteps * (cvN - cv0)));
535      if (TG.TG_local().root())
536      {
537        bpx_displ.reserve(bpx_counts.size());
538        bpx_displ.clear();
539        for (int i = 0, s = 0; i < bpx_counts.size(); i++)
540        {
541          bpx_displ.push_back(s);
542          s += bpx_counts[i];
543        }
544      }
545    }
546
547    TG.local_barrier();
548    AFQMCTimers[setup_timer].get().stop();
549
550    MPI_Status st;
551
552    for (int k = 0; k < nnodes; ++k)
553    {
554      // 2. bcast G
555      AFQMCTimers[vHS_comm_overhead_timer].get().start();
556      if (k == node_number)
557        copy_n(make_device_ptr(Gstore.origin()) + Gak0, GakN - Gak0, make_device_ptr(Gwork.origin()) + Gak0);
558#ifdef BUILD_AFQMC_WITH_NCCL
559#ifdef ENABLE_CUDA
560#if defined(MIXED_PRECISION)
561      NCCLCHECK(
562          ncclBcast(to_address(Gwork.origin() + Gak0), 2 * (GakN - Gak0), ncclFloat, k, TG.ncclTG(), TG.ncclStream()));
563#else
564      NCCLCHECK(
565          ncclBcast(to_address(Gwork.origin() + Gak0), 2 * (GakN - Gak0), ncclDouble, k, TG.ncclTG(), TG.ncclStream()));
566#endif
567      qmc_cuda::cuda_check(cudaStreamSynchronize(TG.ncclStream()), "cudaStreamSynchronize(s)");
568#else
569#error "BUILD_AFQMC_WITH_NCCL only with ENABLE_CUDA"
570#endif
571#else
572      TG.TG_Cores().broadcast_n(to_address(Gwork.origin()) + Gak0, GakN - Gak0, k);
573#endif
574      TG.local_barrier();
575      AFQMCTimers[vHS_comm_overhead_timer].get().stop();
576
577      // calculate vHS contribution from this node
578      // 3a. Calculate vbias for initial configuration
579      AFQMCTimers[vbias_timer].get().start();
580      wfn.vbias(Gwork, vbias, sqrtdt);
581      AFQMCTimers[vbias_timer].get().stop();
582
583      // 3b. Assemble X(nCV,nsteps,nwalk)
584      AFQMCTimers[assemble_X_timer].get().start();
585      assemble_X(nsteps, nwalk, sqrtdt, X, vbias, MFfactor, hybrid_weight);
586      copy_n(make_device_ptr(MFfactor.origin()), MFfactor.num_elements(), make_device_ptr(globalMFfactor[k].origin()));
587      copy_n(make_device_ptr(hybrid_weight.origin()), hybrid_weight.num_elements(),
588             make_device_ptr(globalhybrid_weight[k].origin()));
589      TG.local_barrier();
590      AFQMCTimers[assemble_X_timer].get().stop();
591      if (bp_step >= 0 && bp_step < bp_max)
592      {
593        APP_ABORT("Finish");
594        // accumulate
595        //copy_n(X[cv0].origin(),nwalk*nsteps*(cvN-cv0),Xsend[global_origin+cv0].origin());
596        // how do I know if these change???
597        // possible if 2 execute blocks use different ncores
598        // maybe store last ncores and change if this number changes!!!
599        TG.TG().gatherv_n(to_address(X[cv0].origin()), nwalk * nsteps * (cvN - cv0), to_address(Xrecv.origin()),
600                          bpx_counts.begin(), bpx_displ.begin(), k * ncores);
601        TG.local_barrier();
602      }
603
604      // 3c. Calculate vHS(M*M,nsteps,nwalk)
605      AFQMCTimers[vHS_timer].get().start();
606      wfn.vHS(X, vHS, sqrtdt);
607      TG.local_barrier();
608      AFQMCTimers[vHS_timer].get().stop();
609
610      AFQMCTimers[vHS_comm_overhead_timer].get().start();
611      // 4. Reduce vHS
612#ifdef BUILD_AFQMC_WITH_NCCL
613#ifdef ENABLE_CUDA
614#if defined(MIXED_PRECISION)
615      NCCLCHECK(ncclReduce((const void*)to_address(vHS.origin() + vak0), (void*)to_address(vrecv.origin() + vak0),
616                           2 * (vakN - vak0), ncclFloat, ncclSum, k, TG.ncclTG(), TG.ncclStream()));
617#else
618      NCCLCHECK(ncclReduce((const void*)to_address(vHS.origin() + vak0), (void*)to_address(vrecv.origin() + vak0),
619                           2 * (vakN - vak0), ncclDouble, ncclSum, k, TG.ncclTG(), TG.ncclStream()));
620#endif
621      qmc_cuda::cuda_check(cudaStreamSynchronize(TG.ncclStream()), "cudaStreamSynchronize(s)");
622#else
623#error "BUILD_AFQMC_WITH_NCCL only with ENABLE_CUDA"
624#endif
625#else
626      TG.TG_Cores().reduce_n(to_address(vHS.origin()) + vak0, vakN - vak0, to_address(vrecv.origin()) + vak0,
627                             std::plus<>(), k);
628#endif
629
630      TG.local_barrier();
631      AFQMCTimers[vHS_comm_overhead_timer].get().stop();
632    }
633
634    // after the wait, vrecv ( and by extention vHS3D ) has the final vHS for the local walkers
635    AFQMCTimers[vHS_comm_overhead_timer].get().start();
636
637    // store fields in walker
638    if (bp_step >= 0 && bp_step < bp_max)
639    {
640      int cvg0, cvgN;
641      std::tie(cvg0, cvgN) = FairDivideBoundary(TG.getLocalTGRank(), globalnCV, TG.getNCoresPerTG());
642      for (int ni = 0; ni < nsteps; ni++)
643      {
644        if (bp_step < bp_max)
645        {
646          auto&& V(*wset.getFields(bp_step));
647          if (nsteps == 1)
648          {
649            copy_n(Xrecv[cvg0].origin(), nwalk * (cvgN - cvg0), V[cvg0].origin());
650            ma::scal(sqrtdt, V.sliced(cvg0, cvgN));
651          }
652          else
653          {
654            ma::add(SPComplexType(0.0), V.sliced(cvg0, cvgN), SPComplexType(sqrtdt),
655                    Xrecv({cvg0, cvgN}, {ni * nwalk, (ni + 1) * nwalk}), V.sliced(cvg0, cvgN));
656          }
657          bp_step++;
658        }
659      }
660      TG.TG_local().barrier();
661    }
662    // reduce MF and HWs, always in DP
663    if (TG.TG().size() > 1)
664    {
665#ifdef BUILD_AFQMC_WITH_NCCL
666#ifdef ENABLE_CUDA
667      NCCLCHECK(ncclAllReduce((const void*)to_address(globalMFfactor.origin()),
668                              (void*)to_address(globalMFfactor.origin()), 2 * globalMFfactor.num_elements(), ncclDouble,
669                              ncclSum, TG.ncclTG(), TG.ncclStream()));
670      NCCLCHECK(ncclAllReduce((const void*)to_address(globalhybrid_weight.origin()),
671                              (void*)to_address(globalhybrid_weight.origin()), 2 * globalhybrid_weight.num_elements(),
672                              ncclDouble, ncclSum, TG.ncclTG(), TG.ncclStream()));
673      qmc_cuda::cuda_check(cudaStreamSynchronize(TG.ncclStream()), "cudaStreamSynchronize(s)");
674#else
675#error "BUILD_AFQMC_WITH_NCCL only with ENABLE_CUDA"
676#endif
677#else
678      TG.TG().all_reduce_in_place_n(to_address(globalMFfactor.origin()), globalMFfactor.num_elements(), std::plus<>());
679      TG.TG().all_reduce_in_place_n(to_address(globalhybrid_weight.origin()), globalhybrid_weight.num_elements(),
680                                    std::plus<>());
681#endif
682    }
683    TG.local_barrier();
684
685    // copy from global to local array
686    copy_n(make_device_ptr(globalMFfactor[node_number].origin()), MFfactor.num_elements(),
687           make_device_ptr(MFfactor.origin()));
688    copy_n(make_device_ptr(globalhybrid_weight[node_number].origin()), hybrid_weight.num_elements(),
689           make_device_ptr(hybrid_weight.origin()));
690
691    AFQMCTimers[vHS_comm_overhead_timer].get().stop();
692  } // scope controlling lifetime of temporary arrays
693
694#if defined(MIXED_PRECISION)
695  TG.local_barrier();
696  using qmcplusplus::afqmc::inplace_cast;
697  if (TG.TG_local().root())
698    inplace_cast<SPComplexType, ComplexType>(make_device_ptr(vrecv.origin()), vrecv.num_elements());
699  TG.local_barrier();
700#endif
701
702  // From here on is similar to Shared
703  int nx = 1;
704  if (walker_type == COLLINEAR)
705    nx = 2;
706
707  // from now on, individual work on each walker/step
708  const int ntasks_per_core     = int(nx * nwalk) / TG.getNCoresPerTG();
709  const int ntasks_total_serial = ntasks_per_core * TG.getNCoresPerTG();
710  const int nextra              = int(nx * nwalk) - ntasks_total_serial;
711
712  // each processor does ntasks_percore_serial overlaps serially
713  const int tk0 = TG.getLocalTGRank() * ntasks_per_core;
714  const int tkN = (TG.getLocalTGRank() + 1) * ntasks_per_core;
715
716  // make new communicator if nextra changed from last setting
717  reset_nextra(nextra);
718
719  for (int ni = 0; ni < nsteps_; ni++)
720  {
721    // 5. Propagate walkers
722    AFQMCTimers[propagate_timer].get().start();
723    if (nbatched_propagation != 0)
724    {
725      apply_propagators_batched('N', wset, ni, vHS3D);
726    }
727    else
728    {
729      apply_propagators('N', wset, ni, tk0, tkN, ntasks_total_serial, vHS3D);
730    }
731    AFQMCTimers[propagate_timer].get().stop();
732
733    // 6. Calculate local energy/overlap
734    AFQMCTimers[pseudo_energy_timer].get().start();
735    if (hybrid)
736    {
737      wfn.Overlap(wset, new_overlaps);
738    }
739    else
740    {
741      wfn.Energy(wset, new_energies, new_overlaps);
742    }
743    TG.local_barrier();
744    AFQMCTimers[pseudo_energy_timer].get().stop();
745
746    // 7. update weights/energy/etc, apply constrains/bounds/etc
747    AFQMCTimers[extra_timer].get().start();
748    if (TG.TG_local().root())
749    {
750      if (free_projection)
751      {
752        free_projection_walker_update(wset, dt, new_overlaps, MFfactor[ni], Eshift, hybrid_weight[ni], work);
753      }
754      else if (hybrid)
755      {
756        hybrid_walker_update(wset, dt, apply_constrain, importance_sampling, Eshift, new_overlaps, MFfactor[ni],
757                             hybrid_weight[ni], work);
758      }
759      else
760      {
761        local_energy_walker_update(wset, dt, apply_constrain, Eshift, new_overlaps, new_energies, MFfactor[ni],
762                                   hybrid_weight[ni], work);
763      }
764      if (wset.getBPPos() >= 0 && wset.getBPPos() < wset.NumBackProp())
765        wset.advanceBPPos();
766      if (wset.getBPPos() >= 0)
767        wset.advanceHistoryPos();
768    }
769    TG.local_barrier();
770    AFQMCTimers[extra_timer].get().stop();
771  }
772}
773
774/*
775 * This routine assumes that the 1 body propagator does not need updating
776 */
777template<class WlkSet, class CTens, class CMat>
778void AFQMCDistributedPropagatorDistCV::BackPropagate(int nbpsteps,
779                                                     int nStabalize,
780                                                     WlkSet& wset,
781                                                     CTens&& Refs,
782                                                     CMat&& detR)
783{
784  using std::copy_n;
785  using std::fill_n;
786  const SPComplexType one(1.), zero(0.);
787  auto walker_type        = wset.getWalkerType();
788  const int nwalk         = wset.size();
789  const int globalnCV     = wfn.global_number_of_cholesky_vectors();
790  const int localnCV      = wfn.local_number_of_cholesky_vectors();
791  const int global_origin = wfn.global_origin_cholesky_vector();
792  const int nnodes        = TG.getNGroupsPerTG();
793
794  auto vhs_ext   = iextensions<2u>{NMO * NMO, nwalk};
795  auto vhs3d_ext = iextensions<3u>{NMO, NMO, nwalk};
796  if (transposed_vHS_)
797  {
798    vhs_ext   = iextensions<2u>{nwalk, NMO * NMO};
799    vhs3d_ext = iextensions<3u>{nwalk, NMO, NMO};
800  }
801
802  //  Shared buffer used for:
803  //  X:               [ (localnCV + 2*globalnCV) * nwalk ]
804  //  vHS:             [ NMO*NMO * nwalk ] (3 copies)
805  // memory_needs: nwalk * ( localnCV + NMO*NMO )
806
807  StaticSPMatrix X({long(localnCV), long(nwalk)},
808                   buffer_manager.get_generator().template get_allocator<SPComplexType>());
809  StaticSPMatrix Xsend({long(globalnCV), long(nwalk)},
810                       buffer_manager.get_generator().template get_allocator<SPComplexType>());
811  StaticSPMatrix Xrecv({long(globalnCV), long(nwalk)},
812                       buffer_manager.get_generator().template get_allocator<SPComplexType>());
813  StaticMatrix vrecv_buff(vhs_ext, buffer_manager.get_generator().template get_allocator<ComplexType>());
814  SPCMatrix_ref vrecv(sp_pointer(make_device_ptr(vrecv_buff.origin())), vhs_ext);
815#if defined(MIXED_PRECISION)
816  SPCMatrix_ref vsend(sp_pointer(make_device_ptr(vrecv_buff.origin())) + vrecv_buff.num_elements(), vhs_ext);
817#else
818  StaticSPMatrix vsend(vhs_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
819#endif
820  StaticSPMatrix vHS(vhs_ext, buffer_manager.get_generator().template get_allocator<SPComplexType>());
821
822  // partition G and v for communications: all cores communicate a piece of the matrix
823  int vak0, vakN;
824  int X0, XN;
825  std::tie(X0, XN)     = FairDivideBoundary(TG.getLocalTGRank(), int(Xsend.num_elements()), TG.getNCoresPerTG());
826  std::tie(vak0, vakN) = FairDivideBoundary(TG.getLocalTGRank(), int(vHS.num_elements()), TG.getNCoresPerTG());
827  MPI_Send_init(to_address(Xsend.origin()) + X0, (XN - X0) * sizeof(SPComplexType), MPI_CHAR, TG.prev_core(), 2345,
828                TG.TG().get(), &req_Xsend);
829  MPI_Recv_init(to_address(Xrecv.origin()) + X0, (XN - X0) * sizeof(SPComplexType), MPI_CHAR, TG.next_core(), 2345,
830                TG.TG().get(), &req_Xrecv);
831  MPI_Send_init(to_address(vsend.origin()) + vak0, (vakN - vak0) * sizeof(SPComplexType), MPI_CHAR, TG.prev_core(),
832                6789, TG.TG().get(), &req_bpvsend);
833  MPI_Recv_init(to_address(vrecv.origin()) + vak0, (vakN - vak0) * sizeof(SPComplexType), MPI_CHAR, TG.next_core(),
834                6789, TG.TG().get(), &req_bpvrecv);
835  TG.local_barrier();
836
837  auto&& Fields(*wset.getFields());
838  assert(Fields.size(0) >= nbpsteps);
839  assert(Fields.size(1) == globalnCV);
840  assert(Fields.size(2) == nwalk);
841
842  int nrow(NMO * ((walker_type == NONCOLLINEAR) ? 2 : 1));
843  int ncol(NAEA + ((walker_type == CLOSED) ? 0 : NAEB));
844  assert(Refs.size(0) == nwalk);
845  int nrefs = Refs.size(1);
846  assert(Refs.size(2) == nrow * ncol);
847
848  int cv0, cvN;
849  std::tie(cv0, cvN) = FairDivideBoundary(TG.getLocalTGRank(), localnCV, TG.getNCoresPerTG());
850  int r0, rN;
851  std::tie(r0, rN) = FairDivideBoundary(TG.getLocalTGRank(), nrow * ncol, TG.getNCoresPerTG());
852
853  MPI_Status st;
854
855  int nx = 1;
856  if (walker_type == COLLINEAR)
857    nx = 2;
858
859  assert(detR.size(0) == nwalk);
860  assert(detR.size(1) == nrefs * nx);
861  std::fill_n(detR.origin(), detR.num_elements(), SPComplexType(1.0, 0.0));
862
863  // from now on, individual work on each walker/step
864  const int ntasks_per_core     = int(nx * nwalk) / TG.getNCoresPerTG();
865  const int ntasks_total_serial = ntasks_per_core * TG.getNCoresPerTG();
866  const int nextra              = int(nx * nwalk) - ntasks_total_serial;
867
868  // each processor does ntasks_percore_serial overlaps serially
869  const int tk0 = TG.getLocalTGRank() * ntasks_per_core;
870  const int tkN = (TG.getLocalTGRank() + 1) * ntasks_per_core;
871
872  // make new communicator if nextra changed from last setting
873  reset_nextra(nextra);
874
875  // to avoid having to modify the existing routines,
876  // I'm storing the walkers SlaterMatrix on SlaterMatrixAux
877  // and copying the back propagated references into SlaterMatrix
878  // 0. copy SlaterMatrix to SlaterMatrixAux
879  for (int i = 0; i < nwalk; i++)
880  {
881    copy_n((*wset[i].SlaterMatrix(Alpha)).origin() + r0, rN - r0, (*wset[i].SlaterMatrixAux(Alpha)).origin() + r0);
882    // optimize for the single reference case
883    if (nrefs == 1)
884      copy_n(Refs[i][0].origin() + r0, rN - r0, (*wset[i].SlaterMatrix(Alpha)).origin() + r0);
885  }
886  TG.TG_local().barrier();
887
888  for (int ni = nbpsteps - 1; ni >= 0; --ni)
889  {
890    // 1. Get X(nCV,nwalk) from wset
891    fill_n(make_device_ptr(vsend.origin()) + vak0, (vakN - vak0), zero);
892    copy_n(Fields[ni].origin() + X0, (XN - X0), make_device_ptr(Xsend.origin()) + X0);
893    TG.TG_local().barrier();
894    copy_n(make_device_ptr(Xsend[global_origin + cv0].origin()), nwalk * (cvN - cv0), make_device_ptr(X[cv0].origin()));
895    TG.TG_local().barrier();
896    // 2. Calculate vHS(M*M,nwalk)/vHS(nwalk,M*M) using distributed algorithm
897    for (int k = 0; k < nnodes; ++k)
898    {
899      // 2.1 wait for communication of previous step
900      if (k > 0)
901      {
902        MPI_Wait(&req_Xrecv, &st);
903        MPI_Wait(&req_Xsend, &st); // need to wait for Gsend in order to overwrite Gwork
904        copy_n(make_device_ptr(Xrecv.origin()) + X0, XN - X0, make_device_ptr(Xsend.origin()) + X0);
905        TG.local_barrier();
906        copy_n(make_device_ptr(Xsend[global_origin + cv0].origin()), nwalk * (cvN - cv0),
907               make_device_ptr(X[cv0].origin()));
908        TG.local_barrier();
909      }
910
911      // 2.2 setup next communication
912      if (k < nnodes - 1)
913      {
914        MPI_Start(&req_Xsend);
915        MPI_Start(&req_Xrecv);
916      }
917
918      // 2.3 Calculate vHS
919      //std::cout<<" k, X: " <<TG.Global().rank() <<" " <<k <<" " <<ma::dot(X(X.extension(0),0),X(X.extension(0),0)) <<"\n\n" <<std::endl;
920      wfn.vHS(X, vHS);
921      //std::cout<<" k, vHS: " <<TG.Global().rank() <<" " <<k <<" " <<ma::dot(vHS[0],vHS[0]) <<"\n\n" <<std::endl;
922
923      // 2.4 receive v
924      if (k > 0)
925      {
926        MPI_Wait(&req_bpvrecv, &st);
927        MPI_Wait(&req_bpvsend, &st);
928        copy_n(make_device_ptr(vrecv.origin()) + vak0, vakN - vak0, make_device_ptr(vsend.origin()) + vak0);
929      }
930      TG.local_barrier();
931
932      // 2.5 add local contribution to vsend
933      using ma::axpy;
934      axpy(vakN - vak0, one, make_device_ptr(vHS.origin()) + vak0, 1, make_device_ptr(vsend.origin()) + vak0, 1);
935      //std::cout<<" k vsend: " <<TG.Global().rank() <<" " <<k <<" " <<ma::dot(vsend[0],vsend[0]) <<"\n\n" <<std::endl;
936
937      // 2.6 start v communication
938      MPI_Start(&req_bpvsend);
939      MPI_Start(&req_bpvrecv);
940      TG.local_barrier();
941    }
942
943    MPI_Wait(&req_bpvrecv, &st);
944    MPI_Wait(&req_bpvsend, &st);
945    TG.local_barrier();
946    //std::cout<<" vrecv: " <<TG.Global().rank() <<" " <<ma::dot(vrecv[0],vrecv[0]) <<"\n\n" <<std::endl;
947
948#if defined(MIXED_PRECISION)
949    TG.local_barrier();
950    if (TG.TG_local().root())
951      inplace_cast<SPComplexType, ComplexType>(vrecv.origin(), vrecv.num_elements());
952    TG.local_barrier();
953#endif
954    C3Tensor_ref vHS3D(make_device_ptr(vrecv_buff.origin()), vhs3d_ext);
955
956    for (int nr = 0; nr < nrefs; ++nr)
957    {
958      // 3. copy reference to SlaterMatrix
959      if (nrefs > 1)
960        for (int i = 0; i < nwalk; i++)
961          copy_n(Refs[i][nr].origin() + r0, rN - r0, (*wset[i].SlaterMatrix(Alpha)).origin() + r0);
962      TG.TG_local().barrier();
963
964      // 4. Propagate walkers
965      if (nbatched_propagation != 0)
966        apply_propagators_batched('H', wset, 0, vHS3D);
967      else
968        apply_propagators('H', wset, 0, tk0, tkN, ntasks_total_serial, vHS3D);
969      TG.local_barrier();
970
971      // always end (ni==0) with orthogonalization
972      if (ni == 0 || ni % nStabalize == 0)
973      {
974        // orthogonalize
975        if (nbatched_qr != 0)
976        {
977          if (walker_type != COLLINEAR)
978            Orthogonalize_batched(wset, detR(detR.extension(0), {nr, nr + 1}));
979          else
980            Orthogonalize_batched(wset, detR(detR.extension(0), {2 * nr, 2 * nr + 2}));
981        }
982        else
983        {
984          if (walker_type != COLLINEAR)
985            Orthogonalize_shared(wset, detR(detR.extension(0), {nr, nr + 1}));
986          else
987            Orthogonalize_shared(wset, detR(detR.extension(0), {2 * nr, 2 * nr + 2}));
988        }
989      }
990
991      // 5. copy reference to back
992      if (nrefs > 1)
993        for (int i = 0; i < nwalk; i++)
994          copy_n((*wset[i].SlaterMatrix(Alpha)).origin() + r0, rN - r0, Refs[i][nr].origin() + r0);
995      TG.TG_local().barrier();
996    }
997  }
998
999  // 6. restore the Slater Matrix
1000  for (int i = 0; i < nwalk; i++)
1001  {
1002    if (nrefs == 1)
1003      copy_n((*wset[i].SlaterMatrix(Alpha)).origin() + r0, rN - r0, Refs[i][0].origin() + r0);
1004    copy_n((*wset[i].SlaterMatrixAux(Alpha)).origin() + r0, rN - r0, (*wset[i].SlaterMatrix(Alpha)).origin() + r0);
1005  }
1006  MPI_Request_free(&req_Xrecv);
1007  MPI_Request_free(&req_Xsend);
1008  MPI_Request_free(&req_bpvrecv);
1009  MPI_Request_free(&req_bpvsend);
1010  TG.TG_local().barrier();
1011}
1012
1013
1014} // namespace afqmc
1015
1016} // namespace qmcplusplus
1017