1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by: Miguel Morales, moralessilva2@llnl.gov, Lawrence Livermore National Laboratory
8 //
9 // File created by: Miguel Morales, moralessilva2@llnl.gov, Lawrence Livermore National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11
12 #ifndef QMCPLUSPLUS_AFQMC_WALKERCONTROL_HPP
13 #define QMCPLUSPLUS_AFQMC_WALKERCONTROL_HPP
14
15
16 #include <tuple>
17 #include <cassert>
18 #include <memory>
19 #include <stack>
20 #include <mpi.h>
21 #include "AFQMC/config.h"
22 #include "Utilities/FairDivide.h"
23
24 #include "AFQMC/Walkers/WalkerConfig.hpp"
25 #include "AFQMC/Walkers/WalkerUtilities.hpp"
26
27 #include "mpi3/communicator.hpp"
28 #include "mpi3/request.hpp"
29
30 namespace qmcplusplus
31 {
32 namespace afqmc
33 {
34 /** swap Walkers with Recv/Send
35 *
36 * The algorithm ensures that the load per node can differ only by one walker.
37 * The communication is one-dimensional.
38 * Wexcess is an object with multi::array concept which contains walkers beyond the expected
39 * pupolation target.
40 */
41 template<class WlkBucket, class Mat, class IVec = std::vector<int>>
swapWalkersSimple(WlkBucket & wset,Mat && Wexcess,IVec & CurrNumPerNode,IVec & NewNumPerNode,communicator & comm)42 inline int swapWalkersSimple(WlkBucket& wset,
43 Mat&& Wexcess,
44 IVec& CurrNumPerNode,
45 IVec& NewNumPerNode,
46 communicator& comm)
47 {
48 int wlk_size = wset.single_walker_size() + wset.single_walker_bp_size();
49 int NumContexts, MyContext;
50 NumContexts = comm.size();
51 MyContext = comm.rank();
52 static_assert(std::decay<Mat>::type::dimensionality == 2, "Wrong dimensionality");
53 if (wlk_size != Wexcess.size(1))
54 throw std::runtime_error("Array dimension error in swapWalkersSimple().");
55 if (1 != Wexcess.stride(1))
56 throw std::runtime_error("Array shape error in swapWalkersSimple().");
57 if (CurrNumPerNode.size() < NumContexts || NewNumPerNode.size() < NumContexts)
58 throw std::runtime_error("Array dimension error in swapWalkersSimple().");
59 if (wset.capacity() < NewNumPerNode[MyContext])
60 throw std::runtime_error("Insufficient capacity in swapWalkersSimple().");
61 std::vector<int> minus, plus;
62 int deltaN = 0;
63 for (int ip = 0; ip < NumContexts; ip++)
64 {
65 int dn = CurrNumPerNode[ip] - NewNumPerNode[ip];
66 if (ip == MyContext)
67 deltaN = dn;
68 if (dn > 0)
69 {
70 plus.insert(plus.end(), dn, ip);
71 }
72 else if (dn < 0)
73 {
74 minus.insert(minus.end(), -dn, ip);
75 }
76 }
77 int nswap = std::min(plus.size(), minus.size());
78 int nsend = 0;
79 if (deltaN <= 0 && wset.size() != CurrNumPerNode[MyContext])
80 throw std::runtime_error("error in swapWalkersSimple().");
81 if (deltaN > 0 && (wset.size() != NewNumPerNode[MyContext] || int(Wexcess.size(0)) != deltaN))
82 throw std::runtime_error("error in swapWalkersSimple().");
83 std::vector<ComplexType> buff;
84 if (deltaN < 0)
85 buff.resize(wlk_size);
86 for (int ic = 0; ic < nswap; ic++)
87 {
88 if (plus[ic] == MyContext)
89 {
90 comm.send_n(Wexcess[nsend].origin(), Wexcess[nsend].size(), minus[ic], plus[ic] + 999);
91 ++nsend;
92 }
93 if (minus[ic] == MyContext)
94 {
95 comm.receive_n(buff.data(), buff.size(), plus[ic], plus[ic] + 999);
96 wset.push_walkers(boost::multi::array_ref<ComplexType, 2>(buff.data(), {1, wlk_size}));
97 }
98 }
99 return nswap;
100 }
101
102 /** swap Walkers with Irecv/Send
103 *
104 * The algorithm ensures that the load per node can differ only by one walker.
105 * The communication is one-dimensional.
106 */
107 template<class WlkBucket, class Mat, class IVec = std::vector<int>>
108 // eventually generalize MPI_Comm to a MPI wrapper
swapWalkersAsync(WlkBucket & wset,Mat && Wexcess,IVec & CurrNumPerNode,IVec & NewNumPerNode,communicator & comm)109 inline int swapWalkersAsync(WlkBucket& wset,
110 Mat&& Wexcess,
111 IVec& CurrNumPerNode,
112 IVec& NewNumPerNode,
113 communicator& comm)
114 {
115 int wlk_size = wset.single_walker_size() + wset.single_walker_bp_size();
116 int NumContexts, MyContext;
117 NumContexts = comm.size();
118 MyContext = comm.rank();
119 static_assert(std::decay<Mat>::type::dimensionality == 2, "Wrong dimensionality");
120 if (wlk_size != Wexcess.size(1))
121 throw std::runtime_error("Array dimension error in swapWalkersAsync().");
122 if (1 != Wexcess.stride(1) || (Wexcess.size(0) > 0 && Wexcess.size(1) != Wexcess.stride(0)))
123 throw std::runtime_error("Array shape error in swapWalkersAsync().");
124 if (CurrNumPerNode.size() < NumContexts || NewNumPerNode.size() < NumContexts)
125 throw std::runtime_error("Array dimension error in swapWalkersAsync().");
126 if (wset.capacity() < NewNumPerNode[MyContext])
127 throw std::runtime_error("Insufficient capacity in swapWalkersAsync().");
128 std::vector<int> minus, plus;
129 int deltaN = 0;
130 for (int ip = 0; ip < NumContexts; ip++)
131 {
132 int dn = CurrNumPerNode[ip] - NewNumPerNode[ip];
133 if (ip == MyContext)
134 deltaN = dn;
135 if (dn > 0)
136 {
137 plus.insert(plus.end(), dn, ip);
138 }
139 else if (dn < 0)
140 {
141 minus.insert(minus.end(), -dn, ip);
142 }
143 }
144 int nswap = std::min(plus.size(), minus.size());
145 int nsend = 0;
146 int countSend = 1;
147 if (deltaN <= 0 && wset.size() != CurrNumPerNode[MyContext])
148 throw std::runtime_error("error(1) in swapWalkersAsync().");
149 if (deltaN > 0 && (wset.size() != NewNumPerNode[MyContext] || int(Wexcess.size(0)) != deltaN))
150 throw std::runtime_error("error(2) in swapWalkersAsync().");
151 std::vector<ComplexType*> buffers;
152 std::vector<boost::mpi3::request> requests;
153 std::vector<int> recvCounts;
154 for (int ic = 0; ic < nswap; ic++)
155 {
156 if (plus[ic] == MyContext)
157 {
158 if ((ic < nswap - 1) && (plus[ic] == plus[ic + 1]) && (minus[ic] == minus[ic + 1]))
159 {
160 countSend++;
161 }
162 else
163 {
164 requests.emplace_back(comm.isend(Wexcess[nsend].origin(), Wexcess[nsend].origin() + countSend * Wexcess.size(1),
165 minus[ic], plus[ic] + 1999));
166 nsend += countSend;
167 countSend = 1;
168 }
169 }
170 if (minus[ic] == MyContext)
171 {
172 if ((ic < nswap - 1) && (plus[ic] == plus[ic + 1]) && (minus[ic] == minus[ic + 1]))
173 {
174 countSend++;
175 }
176 else
177 {
178 ComplexType* bf = new ComplexType[countSend * wlk_size];
179 buffers.push_back(bf);
180 recvCounts.push_back(countSend);
181 requests.emplace_back(comm.ireceive_n(bf, countSend * wlk_size, plus[ic], plus[ic] + 1999));
182 countSend = 1;
183 }
184 }
185 }
186 if (deltaN < 0)
187 {
188 // receiving nodes
189 for (int ip = 0; ip < requests.size(); ++ip)
190 {
191 requests[ip].wait();
192 wset.push_walkers(boost::multi::array_ref<ComplexType, 2>(buffers[ip], {recvCounts[ip], wlk_size}));
193 delete[] buffers[ip];
194 }
195 }
196 else
197 {
198 // sending nodes
199 for (int ip = 0; ip < requests.size(); ++ip)
200 requests[ip].wait();
201 }
202 return nswap;
203 }
204
205
206 /**
207 * Implements Cafarrel's minimum branching algorithm.
208 * - buff: array of walker info (weight,num).
209 */
210 template<class Random>
min_branch(std::vector<std::pair<double,int>> & buff,Random & rng,double max_c,double min_c)211 inline void min_branch(std::vector<std::pair<double, int>>& buff, Random& rng, double max_c, double min_c)
212 {
213 APP_ABORT(" Error: min_branch not implemented yet. \n\n\n");
214 }
215
216 /**
217 * Implements Cafarrel's minimum branching algorithm.
218 * - buff: array of walker info (weight,num).
219 */
220 template<class Random>
serial_comb(std::vector<std::pair<double,int>> & buff,Random & rng)221 inline void serial_comb(std::vector<std::pair<double, int>>& buff, Random& rng)
222 {
223 APP_ABORT(" Error: serial_comb not implemented yet. \n\n\n");
224 }
225
226 /**
227 * Implements the paired branching algorithm on a popultion of walkers,
228 * given a list of walker weights. For each walker in the list, returns the weight
229 * and number of times the walker should appear in the new list.
230 * - buff: array of walker info (weight,num).
231 */
232 template<class Random>
pair_branch(std::vector<std::pair<double,int>> & buff,Random & rng,double max_c,double min_c)233 inline void pair_branch(std::vector<std::pair<double, int>>& buff, Random& rng, double max_c, double min_c)
234 {
235 typedef std::tuple<double, int, int> tp;
236 typedef std::vector<tp>::iterator tp_it;
237 // slow for now, not efficient!!!
238 int nw = buff.size();
239 std::vector<tp> wlks(nw);
240 for (int i = 0; i < nw; i++)
241 wlks[i] = tp{buff[i].first, 1, i};
242
243 std::sort(wlks.begin(), wlks.end(), [](const tp& a, const tp& b) { return std::get<0>(a) < std::get<0>(b); });
244
245 tp_it it_s = wlks.begin();
246 tp_it it_l = wlks.end() - 1;
247
248 while (it_s < it_l)
249 {
250 if (std::abs(std::get<0>(*it_s)) < min_c || std::abs(std::get<0>(*it_l)) > max_c)
251 {
252 double w12 = std::get<0>(*it_s) + std::get<0>(*it_l);
253 if (rng() < std::get<0>(*it_l) / w12)
254 {
255 std::get<0>(*it_l) = 0.5 * w12;
256 std::get<0>(*it_s) = 0.0;
257 std::get<1>(*it_l) = 2;
258 std::get<1>(*it_s) = 0;
259 }
260 else
261 {
262 std::get<0>(*it_s) = 0.5 * w12;
263 std::get<0>(*it_l) = 0.0;
264 std::get<1>(*it_s) = 2;
265 std::get<1>(*it_l) = 0;
266 }
267 it_s++;
268 it_l--;
269 }
270 else
271 break;
272 }
273
274 int nnew = 0;
275 int nzero = 0;
276 for (auto& w : wlks)
277 {
278 buff[std::get<2>(w)] = {std::get<0>(w), std::get<1>(w)};
279 nnew += std::get<1>(w);
280 if (std::get<1>(w) > 0 && std::abs(std::get<0>(w)) < 1e-7)
281 nzero++;
282 }
283 if (nzero > 0)
284 {
285 app_error() << " Error in pair_branch: nzero>0: " << nzero << std::endl;
286 app_error() << " Found walkers with zero weight after branch.\n"
287 << " Try reducing subSteps or reducing the time step." << std::endl;
288 APP_ABORT("Error in pair_branch.");
289 }
290 if (nw != nnew)
291 APP_ABORT("Error: Problems with pair_branching.\n");
292 }
293
294 /**
295 * Implements the serial branching algorithm on the set of walkers.
296 * Serial branch involves gathering the list of weights on the root node
297 * and making the decisions locally. The new list of walker weights is then bcasted.
298 * This implementation requires contiguous walkers and fixed population walker sets.
299 */
300 template<class WalkerSet,
301 class Mat,
302 class Random,
303 typename = typename std::enable_if<(WalkerSet::contiguous_walker)>::type,
304 typename = typename std::enable_if<(WalkerSet::fixed_population)>::type>
SerialBranching(WalkerSet & wset,BRANCHING_ALGORITHM type,double min_,double max_,std::vector<int> & wlk_counts,Mat & Wexcess,Random & rng,communicator & comm)305 inline void SerialBranching(WalkerSet& wset,
306 BRANCHING_ALGORITHM type,
307 double min_,
308 double max_,
309 std::vector<int>& wlk_counts,
310 Mat& Wexcess,
311 Random& rng,
312 communicator& comm)
313 {
314 std::vector<std::pair<double, int>> buffer(wset.get_global_target_population());
315
316 // assemble list of weights
317 getGlobalListOfWalkerWeights(wset, buffer, comm);
318
319 // using global weight list, use pair branching algorithm
320 if (comm.root())
321 {
322 if (type == PAIR)
323 pair_branch(buffer, rng, max_, min_);
324 else if (type == MIN_BRANCH)
325 min_branch(buffer, rng, max_, min_);
326 else if (type == SERIAL_COMB)
327 serial_comb(buffer, rng);
328 else
329 APP_ABORT("Error: Unknown branching type in SerialBranching. \n");
330 }
331
332 // bcast walker information and calculate new walker counts locally
333 comm.broadcast_n(buffer.data(), buffer.size());
334
335 int target = wset.get_TG_target_population();
336 wlk_counts.resize(comm.size());
337 for (int i = 0, p = 0; i < comm.size(); i++)
338 {
339 int cnt = 0;
340 for (int k = 0; k < target; k++, p++)
341 cnt += buffer[p].second;
342 wlk_counts[i] = cnt;
343 }
344 if (wset.get_global_target_population() != std::accumulate(wlk_counts.begin(), wlk_counts.end(), 0))
345 {
346 app_error() << " Error: targetN != nwold: " << target << " "
347 << std::accumulate(wlk_counts.begin(), wlk_counts.end(), 0) << std::endl;
348 APP_ABORT(" Error: targetN != nwold.");
349 }
350
351 // reserve space for extra walkers
352 if (wlk_counts[comm.rank()] > target)
353 Wexcess.reextent(
354 {std::max(0, wlk_counts[comm.rank()] - target), wset.single_walker_size() + wset.single_walker_bp_size()});
355
356 // perform local branching
357 // walkers beyond target go in Wexcess
358 wset.branch(buffer.begin() + target * comm.rank(), buffer.begin() + target * (comm.rank() + 1), Wexcess);
359 }
360
361 /**
362 * Implements the distributed comb branching algorithm.
363 */
364 template<class WalkerSet,
365 class Mat,
366 class Random,
367 typename = typename std::enable_if<(WalkerSet::contiguous_walker)>::type,
368 typename = typename std::enable_if<(WalkerSet::fixed_population)>::type>
CombBranching(WalkerSet & wset,BRANCHING_ALGORITHM type,std::vector<int> & wlk_counts,Mat & Wexcess,Random & rng,communicator & comm)369 inline void CombBranching(WalkerSet& wset,
370 BRANCHING_ALGORITHM type,
371 std::vector<int>& wlk_counts,
372 Mat& Wexcess,
373 Random& rng,
374 communicator& comm)
375 {
376 APP_ABORT("Error: comb not implemented yet. \n");
377 }
378
379 } // namespace afqmc
380
381 } // namespace qmcplusplus
382
383 #endif
384