1//////////////////////////////////////////////////////////////////////////////////////
2// This file is distributed under the University of Illinois/NCSA Open Source License.
3// See LICENSE file in top directory for details.
4//
5// Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6//
7// File developed by: Miguel Morales, moralessilva2@llnl.gov, Lawrence Livermore National Laboratory
8//
9// File created by: Miguel Morales, moralessilva2@llnl.gov, Lawrence Livermore National Laboratory
10//////////////////////////////////////////////////////////////////////////////////////
11
12#ifndef QMCPLUSPLUS_AFQMC_SHAREDWALKERSET_ICC
13#define QMCPLUSPLUS_AFQMC_SHAREDWALKERSET_ICC
14
15#include <cassert>
16#include <cstdlib>
17
18namespace qmcplusplus
19{
20namespace afqmc
21{
22template<class Alloc, typename Ptr>
23void WalkerSetBase<Alloc, Ptr>::parse(xmlNodePtr cur)
24{
25  if (cur == NULL)
26    APP_ABORT(" Error: Empty Walker xml-node pointer. \n");
27
28  app_log() << "\n****************************************************\n";
29  app_log() << "           Initializing Walker Set \n";
30  app_log() << "****************************************************\n";
31
32  xmlNodePtr curRoot = cur;
33  OhmmsAttributeSet oAttrib;
34  oAttrib.add(name, "name");
35  oAttrib.put(cur);
36
37  std::string type              = "collinear";
38  std::string load_balance_type = "async";
39  std::string pop_control_type  = "pair";
40
41  ParameterSet m_param;
42  m_param.add(max_weight, "max_weight");
43  m_param.add(min_weight, "min_weight");
44  m_param.add(type, "walker_type");
45  m_param.add(load_balance_type, "load_balance");
46  m_param.add(pop_control_type, "pop_control");
47  //    m_param.add(nback_prop,"back_propagation_steps");
48  m_param.put(cur);
49
50  std::for_each(type.begin(), type.end(), [](char& c) { c = ::tolower(c); });
51  if (type.find("closed") != std::string::npos)
52  {
53    app_log() << " Using a closed-shell (closed-shell RHF) walker. \n";
54    walkerType = CLOSED;
55  }
56  else if (type.find("non-collinear") != std::string::npos)
57  {
58    app_log() << " Using a non-collinear (GHF) walker. \n";
59    walkerType = NONCOLLINEAR;
60  }
61  else if (type.find("noncollinear") != std::string::npos)
62  {
63    app_log() << " Using a non-collinear (GHF) walker. \n";
64    walkerType = NONCOLLINEAR;
65  }
66  else if (type.find("collinear") != std::string::npos)
67  {
68    app_log() << " Using a collinear (UHF/ROHF) walker. \n";
69    walkerType = COLLINEAR;
70  }
71  else
72  {
73    app_error() << " Error: Unknown walker type: " << type << std::endl;
74    APP_ABORT("");
75  }
76
77  std::for_each(load_balance_type.begin(), load_balance_type.end(), [](char& c) { c = ::tolower(c); });
78  if (load_balance_type.find("simple") != std::string::npos)
79  {
80    app_log() << " Using blocking (1-1) swap load balancing algorithm. "
81              << "\n";
82    load_balance = SIMPLE;
83  }
84  else if (load_balance_type.find("async") != std::string::npos)
85  {
86    app_log() << " Using asynchronous non-blocking swap load balancing algorithm. "
87              << "\n";
88    load_balance = ASYNC;
89  }
90  else
91  {
92    app_error() << " Error: Unknown load balancing algorithm: " << load_balance_type << " \n";
93    APP_ABORT("");
94  }
95
96  std::for_each(pop_control_type.begin(), pop_control_type.end(), [](char& c) { c = ::tolower(c); });
97  if (pop_control_type.find("pair") != std::string::npos)
98  {
99    app_log() << " Using population control algorithm based on paired walker branching ( a la QWalk). \n";
100    pop_control = PAIR;
101  }
102  else if (pop_control_type.find("serial_comb") != std::string::npos)
103  {
104    app_log() << " Using population control algorithm based on comb method (See Booth, Gubernatis, PRE 2009). \n";
105    pop_control = SERIAL_COMB;
106  }
107  else if (pop_control_type.find("comb") != std::string::npos)
108  {
109    app_log() << " Using population control algorithm based on comb method (See Booth, Gubernatis, PRE 2009). \n";
110    pop_control = COMB;
111  }
112  else if (pop_control_type.find("min") != std::string::npos)
113  {
114    app_log() << " Using population control algorithm based on minimum reconfiguration (Caffarel et al., 2000). \n";
115    pop_control = MIN_BRANCH;
116  }
117  else
118  {
119    app_error() << " Error: Unknown population control algorithm: " << pop_control_type << "\n";
120    APP_ABORT("");
121  }
122
123  cur = curRoot->children;
124  while (cur != NULL)
125  {
126    std::string cname((const char*)(cur->name));
127    if (cname == "something") {}
128    cur = cur->next;
129  }
130  app_log() << std::endl;
131}
132
133template<class Alloc, typename Ptr>
134void WalkerSetBase<Alloc, Ptr>::setup()
135{
136  TimerNameList_t<WalkerSetBaseTimers> WalkerSetBaseTimerNames = {{LoadBalance_t, "WalkerSetBase::loadBalance"},
137                                                                  {Branching_t, "WalkerSetBase::branching"}};
138
139  setup_timers(Timers, WalkerSetBaseTimerNames, timer_level_coarse);
140
141  // careful! These are only used to calculate memory needs and access points/partitionings
142  int ncol = NAEA;
143  int nrow = NMO;
144  // wlk_descriptor: {nmo, naea, naeb, nback_prop, nCV, nRefs, nHist}
145  if (walkerType == CLOSED)
146  {
147    wlk_desc = {NMO, NAEA, 0, 0, 0, 0, 0};
148  }
149  else if (walkerType == COLLINEAR)
150  {
151    wlk_desc = {NMO, NAEA, NAEB, 0, 0, 0, 0};
152    ncol += NAEB;
153  }
154  else if (walkerType == NONCOLLINEAR)
155  {
156    wlk_desc = {2 * NMO, NAEA + NAEB, 0, 0, 0, 0, 0};
157    nrow += NMO;
158    ncol += NAEB;
159  }
160  else
161  {
162    app_error() << " Error: Incorrect walker_type on WalkerSetBase::setup \n";
163    APP_ABORT("");
164  }
165
166  //   - SlaterMatrix:         NCOL*NROW
167  //   - weight:               1
168  //   - phase:                1
169  //   - pseudo energy:        1
170  //   - E1:                   1
171  //   - EXX:                  1
172  //   - EJ:                   1
173  //   - overlap:              1
174  //   - SlaterMatrixN:        Same size as Slater Matrix
175  //   - SlaterMatrixAux:        Same size as Slater Matrix
176  //   Total: 7+2*NROW*NCOL+BP_SIZE+2*NBACK_PROP
177  int cnt        = 0;
178  data_displ[SM] = cnt;
179  cnt += nrow * ncol;
180  data_displ[WEIGHT] = cnt;
181  cnt += 1; // weight
182  data_displ[PHASE] = cnt;
183  cnt += 1; // phase
184  data_displ[PSEUDO_ELOC_] = cnt;
185  cnt += 1; // pseudo energy
186  data_displ[E1_] = cnt;
187  cnt += 1; // E1
188  data_displ[EXX_] = cnt;
189  cnt += 1; // EXX
190  data_displ[EJ_] = cnt;
191  cnt += 1; // EJ
192  data_displ[OVLP] = cnt;
193  cnt += 1; // overlap
194  walker_size                = cnt;
195  walker_memory_usage        = walker_size * sizeof(ComplexType);
196  data_displ[SMN]            = -1;
197  data_displ[SM_AUX]         = -1;
198  data_displ[FIELDS]         = -1;
199  data_displ[WEIGHT_FAC]     = -1;
200  data_displ[WEIGHT_HISTORY] = -1;
201  bp_walker_size             = 0;
202  bp_walker_memory_usage     = bp_walker_size * sizeof(ComplexType);
203
204  tot_num_walkers = 0;
205
206  min_weight = std::max(std::abs(min_weight), 1e-2);
207}
208
209template<class Alloc, typename Ptr>
210bool WalkerSetBase<Alloc, Ptr>::clean()
211{
212  walker_buffer.reextent({0, walker_size});
213  bp_buffer.reextent({bp_walker_size, 0});
214  tot_num_walkers = targetN = targetN_per_TG = 0;
215  return true;
216}
217
218/*
219 * Increases the capacity of the containers to n.
220 */
221template<class Alloc, typename Ptr>
222void WalkerSetBase<Alloc, Ptr>::reserve(int n)
223{
224  if (walker_buffer.size(0) < n || walker_buffer.size(1) != walker_size)
225    walker_buffer.reextent({n, walker_size});
226  if (bp_buffer.size(1) < n || bp_buffer.size(0) != bp_walker_size)
227  {
228    bp_buffer.reextent({bp_walker_size, n});
229    using std::fill_n;
230    fill_n(bp_buffer.origin(), bp_buffer.num_elements(), bp_element(0));
231  }
232}
233
234/*
235 * Adds/removes the number of walkers in the set to match the requested value.
236 * Walkers are removed from the end of the set
237 *     and buffer capacity remains unchanged in this case.
238 * New walkers are initialized from already existing walkers in a round-robin fashion.
239 * If the set is empty, calling this routine will abort.
240 * Capacity is increased if necessary.
241 * Target Populations are set to n.
242 */
243template<class Alloc, typename Ptr>
244void WalkerSetBase<Alloc, Ptr>::resize(int n)
245{
246  if (tot_num_walkers == 0)
247    APP_ABORT("error: empty set in resize(n).\n");
248
249  reserve(n);
250  if (n > tot_num_walkers)
251  {
252    if (TG.TG_local().root())
253    {
254      auto pos = tot_num_walkers;
255      auto i0  = 0;
256      while (pos < n)
257      {
258        walker_buffer[pos++] = walker_buffer[i0];
259        i0                   = (i0 + 1) % tot_num_walkers;
260      }
261    }
262  }
263  tot_num_walkers = n;
264  targetN_per_TG  = tot_num_walkers;
265  targetN         = GlobalPopulation();
266  if (targetN != targetN_per_TG * TG.getNumberOfTGs())
267  {
268    app_error() << " targetN, targetN_per_TG, # of TGs: " << targetN << " " << targetN_per_TG << " "
269                << TG.getNumberOfTGs() << std::endl;
270    APP_ABORT("Error in WalkerSetBase::resize(n).\n");
271  }
272}
273
274//  curData:
275//  0: factor used to rescale the weights
276//  1: sum_i w_i * Eloc_i   (where w_i is the unnormalized weight)
277//  2: sum_i w_i            (where w_i is the unnormalized weight)
278//  3: sum_i abs(w_i)       (where w_i is the unnormalized weight)
279//  4: sum_i abs(<psi_T|phi_i>)
280//  5: total number of walkers
281//  6: total number of "healthy" walkers (those with weight > 1e-6, ovlp>1e-8, etc)
282template<class Alloc, typename Ptr>
283void WalkerSetBase<Alloc, Ptr>::popControl(std::vector<ComplexType>& curData)
284{
285  Timers[Branching_t].get().start();
286  ComplexType minus = ComplexType(-1.0, 0.0);
287
288  curData.resize(7);
289  using std::fill;
290  fill(curData.begin(), curData.begin() + 7, ComplexType(0));
291
292  // safety check
293  if (tot_num_walkers != targetN_per_TG)
294    APP_ABORT("Error: tot_num_walkers!=targetN_per_TG");
295
296  // gather data and walker information
297  if (TG.TG_local().root())
298  {
299    afqmc::BasicWalkerData(*this, curData, TG.TG_heads());
300    RealType scl = 1.0 / curData[0].real();
301    scaleWeight(scl, true);
302  }
303  if (TG.TG_local().size() > 1)
304    TG.TG_local().broadcast_n(curData.data(), curData.size());
305  // by default, LogOverlapFactor is set to the walker mean in popControl
306  // this comes at no extra cost and keeps the values of overlaps stable
307  adjustLogOverlapFactor(std::log(std::abs(curData[4])));
308
309  // matrix to hold walkers beyond targetN_per_TG
310  // doing this to avoid resizing SHMBuffer, instead use local memory
311  // will be resized later
312  boost::multi::array<ComplexType, 2> Wexcess({0, walker_size + (wlk_desc[3] > 0 ? bp_walker_size : 0)});
313
314  if (TG.TG_local().root())
315  {
316    nwalk_counts_new.resize(TG.TG_heads().size());
317    std::fill(nwalk_counts_new.begin(), nwalk_counts_new.end(), targetN_per_TG);
318  }
319
320  // population control on master node
321  if (pop_control == PAIR || pop_control == SERIAL_COMB || pop_control == MIN_BRANCH)
322  {
323    if (TG.TG_local().root())
324      SerialBranching(*this, pop_control, min_weight, max_weight, nwalk_counts_old, Wexcess, *rng, TG.TG_heads());
325
326    // distributed routines from here
327  }
328  else if (pop_control == COMB)
329  {
330    APP_ABORT(" Error: Distributed comb not implemented yet. \n\n\n");
331    //afqmc::DistCombBranching(*this,rng_heads,nwalk_counts_old);
332  }
333  Timers[Branching_t].get().stop();
334
335  Timers[LoadBalance_t].get().start();
336  // load balance after population control events
337  loadBalance(Wexcess);
338  Timers[LoadBalance_t].get().stop();
339
340  if (tot_num_walkers != targetN_per_TG)
341    APP_ABORT(" Error: tot_num_walkers != targetN_per_TG");
342}
343
344template<class Alloc, typename Ptr>
345void WalkerSetBase<Alloc, Ptr>::benchmark(std::string& blist, int maxnW, int delnW, int repeat)
346{
347  if (blist.find("comm") != std::string::npos)
348  {
349    app_log() << " Testing communication times in WalkerHandler. This should be done using a single TG per node, to "
350                 "avoid timing communication between cores on the same node. \n";
351    std::ofstream out;
352    if (TG.getGlobalRank() == 0)
353      out.open("benchmark.icomm.dat");
354
355    std::vector<std::string> tags(3);
356    tags[0] = "M1";
357    tags[1] = "M2";
358    tags[2] = "M3";
359
360    //    for( std::string& str: tags) Timer.reset(str);
361
362    int nw = 1;
363    while (nw <= maxnW)
364    {
365      if (TG.TG_local().root() && (TG.TG_heads().rank() == 0 || TG.TG_heads().rank() == 1))
366      {
367        int sz = nw * walker_size;
368        std::vector<ComplexType> Cbuff(sz);
369        MPI_Request req;
370        MPI_Status st;
371        TG.TG_heads().barrier();
372        for (int i = 0; i < repeat; i++)
373        {
374          if (TG.TG_heads().rank() == 0)
375          {
376            //            Timer.start("M1");
377            MPI_Isend(Cbuff.data(), 2 * Cbuff.size(), MPI_DOUBLE, 1, 999, TG.TG_heads().get(), &req);
378            MPI_Wait(&req, &st);
379            //            Timer.stop("M1");
380          }
381          else
382          {
383            MPI_Irecv(Cbuff.data(), 2 * Cbuff.size(), MPI_DOUBLE, 0, 999, TG.TG_heads().get(), &req);
384            MPI_Wait(&req, &st);
385          }
386        }
387
388        if (TG.TG_heads().rank() == 0)
389        {
390          out << nw << " ";
391          //          for( std::string& str: tags) out<<Timer.total(str)/double(repeat) <<" ";
392          out << std::endl;
393        }
394      }
395      else if (TG.TG_local().root())
396      {
397        TG.TG_heads().barrier();
398      }
399
400      if (delnW <= 0)
401        nw *= 2;
402      else
403        nw += delnW;
404    }
405  }
406  else if (blist.find("comm") != std::string::npos)
407  {
408    std::ofstream out;
409    if (TG.getGlobalRank() == 0)
410      out.open("benchmark.comm.dat");
411  }
412}
413
414} // namespace afqmc
415
416} // namespace qmcplusplus
417
418#endif
419