1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by: Miguel Morales, moralessilva2@llnl.gov, Lawrence Livermore National Laboratory
8 //
9 // File created by: Miguel Morales, moralessilva2@llnl.gov, Lawrence Livermore National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 #ifndef QMCPLUSPLUS_AFQMC_WALKERCONTROL_HPP
13 #define QMCPLUSPLUS_AFQMC_WALKERCONTROL_HPP
14 
15 
16 #include <tuple>
17 #include <cassert>
18 #include <memory>
19 #include <stack>
20 #include <mpi.h>
21 #include "AFQMC/config.h"
22 #include "Utilities/FairDivide.h"
23 
24 #include "AFQMC/Walkers/WalkerConfig.hpp"
25 #include "AFQMC/Walkers/WalkerUtilities.hpp"
26 
27 #include "mpi3/communicator.hpp"
28 #include "mpi3/request.hpp"
29 
30 namespace qmcplusplus
31 {
32 namespace afqmc
33 {
34 /** swap Walkers with Recv/Send
35  *
36  * The algorithm ensures that the load per node can differ only by one walker.
37  * The communication is one-dimensional.
38  * Wexcess is an object with multi::array concept which contains walkers beyond the expected
39  * pupolation target.
40  */
41 template<class WlkBucket, class Mat, class IVec = std::vector<int>>
swapWalkersSimple(WlkBucket & wset,Mat && Wexcess,IVec & CurrNumPerNode,IVec & NewNumPerNode,communicator & comm)42 inline int swapWalkersSimple(WlkBucket& wset,
43                              Mat&& Wexcess,
44                              IVec& CurrNumPerNode,
45                              IVec& NewNumPerNode,
46                              communicator& comm)
47 {
48   int wlk_size = wset.single_walker_size() + wset.single_walker_bp_size();
49   int NumContexts, MyContext;
50   NumContexts = comm.size();
51   MyContext   = comm.rank();
52   static_assert(std::decay<Mat>::type::dimensionality == 2, "Wrong dimensionality");
53   if (wlk_size != Wexcess.size(1))
54     throw std::runtime_error("Array dimension error in swapWalkersSimple().");
55   if (1 != Wexcess.stride(1))
56     throw std::runtime_error("Array shape error in swapWalkersSimple().");
57   if (CurrNumPerNode.size() < NumContexts || NewNumPerNode.size() < NumContexts)
58     throw std::runtime_error("Array dimension error in swapWalkersSimple().");
59   if (wset.capacity() < NewNumPerNode[MyContext])
60     throw std::runtime_error("Insufficient capacity in swapWalkersSimple().");
61   std::vector<int> minus, plus;
62   int deltaN = 0;
63   for (int ip = 0; ip < NumContexts; ip++)
64   {
65     int dn = CurrNumPerNode[ip] - NewNumPerNode[ip];
66     if (ip == MyContext)
67       deltaN = dn;
68     if (dn > 0)
69     {
70       plus.insert(plus.end(), dn, ip);
71     }
72     else if (dn < 0)
73     {
74       minus.insert(minus.end(), -dn, ip);
75     }
76   }
77   int nswap = std::min(plus.size(), minus.size());
78   int nsend = 0;
79   if (deltaN <= 0 && wset.size() != CurrNumPerNode[MyContext])
80     throw std::runtime_error("error in swapWalkersSimple().");
81   if (deltaN > 0 && (wset.size() != NewNumPerNode[MyContext] || int(Wexcess.size(0)) != deltaN))
82     throw std::runtime_error("error in swapWalkersSimple().");
83   std::vector<ComplexType> buff;
84   if (deltaN < 0)
85     buff.resize(wlk_size);
86   for (int ic = 0; ic < nswap; ic++)
87   {
88     if (plus[ic] == MyContext)
89     {
90       comm.send_n(Wexcess[nsend].origin(), Wexcess[nsend].size(), minus[ic], plus[ic] + 999);
91       ++nsend;
92     }
93     if (minus[ic] == MyContext)
94     {
95       comm.receive_n(buff.data(), buff.size(), plus[ic], plus[ic] + 999);
96       wset.push_walkers(boost::multi::array_ref<ComplexType, 2>(buff.data(), {1, wlk_size}));
97     }
98   }
99   return nswap;
100 }
101 
102 /** swap Walkers with Irecv/Send
103  *
104  * The algorithm ensures that the load per node can differ only by one walker.
105  * The communication is one-dimensional.
106  */
107 template<class WlkBucket, class Mat, class IVec = std::vector<int>>
108 // eventually generalize MPI_Comm to a MPI wrapper
swapWalkersAsync(WlkBucket & wset,Mat && Wexcess,IVec & CurrNumPerNode,IVec & NewNumPerNode,communicator & comm)109 inline int swapWalkersAsync(WlkBucket& wset,
110                             Mat&& Wexcess,
111                             IVec& CurrNumPerNode,
112                             IVec& NewNumPerNode,
113                             communicator& comm)
114 {
115   int wlk_size = wset.single_walker_size() + wset.single_walker_bp_size();
116   int NumContexts, MyContext;
117   NumContexts = comm.size();
118   MyContext   = comm.rank();
119   static_assert(std::decay<Mat>::type::dimensionality == 2, "Wrong dimensionality");
120   if (wlk_size != Wexcess.size(1))
121     throw std::runtime_error("Array dimension error in swapWalkersAsync().");
122   if (1 != Wexcess.stride(1) || (Wexcess.size(0) > 0 && Wexcess.size(1) != Wexcess.stride(0)))
123     throw std::runtime_error("Array shape error in swapWalkersAsync().");
124   if (CurrNumPerNode.size() < NumContexts || NewNumPerNode.size() < NumContexts)
125     throw std::runtime_error("Array dimension error in swapWalkersAsync().");
126   if (wset.capacity() < NewNumPerNode[MyContext])
127     throw std::runtime_error("Insufficient capacity in swapWalkersAsync().");
128   std::vector<int> minus, plus;
129   int deltaN = 0;
130   for (int ip = 0; ip < NumContexts; ip++)
131   {
132     int dn = CurrNumPerNode[ip] - NewNumPerNode[ip];
133     if (ip == MyContext)
134       deltaN = dn;
135     if (dn > 0)
136     {
137       plus.insert(plus.end(), dn, ip);
138     }
139     else if (dn < 0)
140     {
141       minus.insert(minus.end(), -dn, ip);
142     }
143   }
144   int nswap     = std::min(plus.size(), minus.size());
145   int nsend     = 0;
146   int countSend = 1;
147   if (deltaN <= 0 && wset.size() != CurrNumPerNode[MyContext])
148     throw std::runtime_error("error(1) in swapWalkersAsync().");
149   if (deltaN > 0 && (wset.size() != NewNumPerNode[MyContext] || int(Wexcess.size(0)) != deltaN))
150     throw std::runtime_error("error(2) in swapWalkersAsync().");
151   std::vector<ComplexType*> buffers;
152   std::vector<boost::mpi3::request> requests;
153   std::vector<int> recvCounts;
154   for (int ic = 0; ic < nswap; ic++)
155   {
156     if (plus[ic] == MyContext)
157     {
158       if ((ic < nswap - 1) && (plus[ic] == plus[ic + 1]) && (minus[ic] == minus[ic + 1]))
159       {
160         countSend++;
161       }
162       else
163       {
164         requests.emplace_back(comm.isend(Wexcess[nsend].origin(), Wexcess[nsend].origin() + countSend * Wexcess.size(1),
165                                          minus[ic], plus[ic] + 1999));
166         nsend += countSend;
167         countSend = 1;
168       }
169     }
170     if (minus[ic] == MyContext)
171     {
172       if ((ic < nswap - 1) && (plus[ic] == plus[ic + 1]) && (minus[ic] == minus[ic + 1]))
173       {
174         countSend++;
175       }
176       else
177       {
178         ComplexType* bf = new ComplexType[countSend * wlk_size];
179         buffers.push_back(bf);
180         recvCounts.push_back(countSend);
181         requests.emplace_back(comm.ireceive_n(bf, countSend * wlk_size, plus[ic], plus[ic] + 1999));
182         countSend = 1;
183       }
184     }
185   }
186   if (deltaN < 0)
187   {
188     // receiving nodes
189     for (int ip = 0; ip < requests.size(); ++ip)
190     {
191       requests[ip].wait();
192       wset.push_walkers(boost::multi::array_ref<ComplexType, 2>(buffers[ip], {recvCounts[ip], wlk_size}));
193       delete[] buffers[ip];
194     }
195   }
196   else
197   {
198     // sending nodes
199     for (int ip = 0; ip < requests.size(); ++ip)
200       requests[ip].wait();
201   }
202   return nswap;
203 }
204 
205 
206 /**
207  * Implements Cafarrel's minimum branching algorithm.
208  *   - buff: array of walker info (weight,num).
209  */
210 template<class Random>
min_branch(std::vector<std::pair<double,int>> & buff,Random & rng,double max_c,double min_c)211 inline void min_branch(std::vector<std::pair<double, int>>& buff, Random& rng, double max_c, double min_c)
212 {
213   APP_ABORT(" Error: min_branch not implemented yet. \n\n\n");
214 }
215 
216 /**
217  * Implements Cafarrel's minimum branching algorithm.
218  *   - buff: array of walker info (weight,num).
219  */
220 template<class Random>
serial_comb(std::vector<std::pair<double,int>> & buff,Random & rng)221 inline void serial_comb(std::vector<std::pair<double, int>>& buff, Random& rng)
222 {
223   APP_ABORT(" Error: serial_comb not implemented yet. \n\n\n");
224 }
225 
226 /**
227  * Implements the paired branching algorithm on a popultion of walkers,
228  * given a list of walker weights. For each walker in the list, returns the weight
229  * and number of times the walker should appear in the new list.
230  *   - buff: array of walker info (weight,num).
231  */
232 template<class Random>
pair_branch(std::vector<std::pair<double,int>> & buff,Random & rng,double max_c,double min_c)233 inline void pair_branch(std::vector<std::pair<double, int>>& buff, Random& rng, double max_c, double min_c)
234 {
235   typedef std::tuple<double, int, int> tp;
236   typedef std::vector<tp>::iterator tp_it;
237   // slow for now, not efficient!!!
238   int nw = buff.size();
239   std::vector<tp> wlks(nw);
240   for (int i = 0; i < nw; i++)
241     wlks[i] = tp{buff[i].first, 1, i};
242 
243   std::sort(wlks.begin(), wlks.end(), [](const tp& a, const tp& b) { return std::get<0>(a) < std::get<0>(b); });
244 
245   tp_it it_s = wlks.begin();
246   tp_it it_l = wlks.end() - 1;
247 
248   while (it_s < it_l)
249   {
250     if (std::abs(std::get<0>(*it_s)) < min_c || std::abs(std::get<0>(*it_l)) > max_c)
251     {
252       double w12 = std::get<0>(*it_s) + std::get<0>(*it_l);
253       if (rng() < std::get<0>(*it_l) / w12)
254       {
255         std::get<0>(*it_l) = 0.5 * w12;
256         std::get<0>(*it_s) = 0.0;
257         std::get<1>(*it_l) = 2;
258         std::get<1>(*it_s) = 0;
259       }
260       else
261       {
262         std::get<0>(*it_s) = 0.5 * w12;
263         std::get<0>(*it_l) = 0.0;
264         std::get<1>(*it_s) = 2;
265         std::get<1>(*it_l) = 0;
266       }
267       it_s++;
268       it_l--;
269     }
270     else
271       break;
272   }
273 
274   int nnew  = 0;
275   int nzero = 0;
276   for (auto& w : wlks)
277   {
278     buff[std::get<2>(w)] = {std::get<0>(w), std::get<1>(w)};
279     nnew += std::get<1>(w);
280     if (std::get<1>(w) > 0 && std::abs(std::get<0>(w)) < 1e-7)
281       nzero++;
282   }
283   if (nzero > 0)
284   {
285     app_error() << " Error in pair_branch: nzero>0: " << nzero << std::endl;
286     app_error() << " Found walkers with zero weight after branch.\n"
287                 << " Try reducing subSteps or reducing the time step." << std::endl;
288     APP_ABORT("Error in pair_branch.");
289   }
290   if (nw != nnew)
291     APP_ABORT("Error: Problems with pair_branching.\n");
292 }
293 
294 /**
295  * Implements the serial branching algorithm on the set of walkers.
296  * Serial branch involves gathering the list of weights on the root node
297  * and making the decisions locally. The new list of walker weights is then bcasted.
298  * This implementation requires contiguous walkers and fixed population walker sets.
299  */
300 template<class WalkerSet,
301          class Mat,
302          class Random,
303          typename = typename std::enable_if<(WalkerSet::contiguous_walker)>::type,
304          typename = typename std::enable_if<(WalkerSet::fixed_population)>::type>
SerialBranching(WalkerSet & wset,BRANCHING_ALGORITHM type,double min_,double max_,std::vector<int> & wlk_counts,Mat & Wexcess,Random & rng,communicator & comm)305 inline void SerialBranching(WalkerSet& wset,
306                             BRANCHING_ALGORITHM type,
307                             double min_,
308                             double max_,
309                             std::vector<int>& wlk_counts,
310                             Mat& Wexcess,
311                             Random& rng,
312                             communicator& comm)
313 {
314   std::vector<std::pair<double, int>> buffer(wset.get_global_target_population());
315 
316   // assemble list of weights
317   getGlobalListOfWalkerWeights(wset, buffer, comm);
318 
319   // using global weight list, use pair branching algorithm
320   if (comm.root())
321   {
322     if (type == PAIR)
323       pair_branch(buffer, rng, max_, min_);
324     else if (type == MIN_BRANCH)
325       min_branch(buffer, rng, max_, min_);
326     else if (type == SERIAL_COMB)
327       serial_comb(buffer, rng);
328     else
329       APP_ABORT("Error: Unknown branching type in SerialBranching. \n");
330   }
331 
332   // bcast walker information and calculate new walker counts locally
333   comm.broadcast_n(buffer.data(), buffer.size());
334 
335   int target = wset.get_TG_target_population();
336   wlk_counts.resize(comm.size());
337   for (int i = 0, p = 0; i < comm.size(); i++)
338   {
339     int cnt = 0;
340     for (int k = 0; k < target; k++, p++)
341       cnt += buffer[p].second;
342     wlk_counts[i] = cnt;
343   }
344   if (wset.get_global_target_population() != std::accumulate(wlk_counts.begin(), wlk_counts.end(), 0))
345   {
346     app_error() << " Error: targetN != nwold: " << target << " "
347                 << std::accumulate(wlk_counts.begin(), wlk_counts.end(), 0) << std::endl;
348     APP_ABORT(" Error: targetN != nwold.");
349   }
350 
351   // reserve space for extra walkers
352   if (wlk_counts[comm.rank()] > target)
353     Wexcess.reextent(
354         {std::max(0, wlk_counts[comm.rank()] - target), wset.single_walker_size() + wset.single_walker_bp_size()});
355 
356   // perform local branching
357   // walkers beyond target go in Wexcess
358   wset.branch(buffer.begin() + target * comm.rank(), buffer.begin() + target * (comm.rank() + 1), Wexcess);
359 }
360 
361 /**
362  * Implements the distributed comb branching algorithm.
363  */
364 template<class WalkerSet,
365          class Mat,
366          class Random,
367          typename = typename std::enable_if<(WalkerSet::contiguous_walker)>::type,
368          typename = typename std::enable_if<(WalkerSet::fixed_population)>::type>
CombBranching(WalkerSet & wset,BRANCHING_ALGORITHM type,std::vector<int> & wlk_counts,Mat & Wexcess,Random & rng,communicator & comm)369 inline void CombBranching(WalkerSet& wset,
370                           BRANCHING_ALGORITHM type,
371                           std::vector<int>& wlk_counts,
372                           Mat& Wexcess,
373                           Random& rng,
374                           communicator& comm)
375 {
376   APP_ABORT("Error: comb not implemented yet. \n");
377 }
378 
379 } // namespace afqmc
380 
381 } // namespace qmcplusplus
382 
383 #endif
384