1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2019 QMCPACK developers.
6 //
7 // File developed by: Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
8 //
9 // File created by: Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 #include "catch.hpp"
13 
14 #include "Configuration.h"
15 #include "Message/Communicate.h"
16 #include "QMCDrivers/Crowd.h"
17 #include "type_traits/template_types.hpp"
18 #include "Estimators/EstimatorManagerNew.h"
19 #include "QMCWaveFunctions/tests/MinimalWaveFunctionPool.h"
20 #include "Particle/tests/MinimalParticlePool.h"
21 #include "QMCHamiltonians/tests/MinimalHamiltonianPool.h"
22 
23 #include "QMCDrivers/tests/SetupPools.h"
24 
25 namespace qmcplusplus
26 {
27 namespace testing
28 {
29 class CrowdWithWalkers
30 {
31 public:
32   using MCPWalker = Walker<QMCTraits, PtclOnLatticeTraits>;
33 
34   EstimatorManagerNew em;
35   UPtr<Crowd> crowd_ptr;
get_crowd()36   Crowd& get_crowd() { return *crowd_ptr; }
37   UPtrVector<MCPWalker> walkers;
38   UPtrVector<ParticleSet> psets;
39   UPtrVector<TrialWaveFunction> twfs;
40   UPtrVector<QMCHamiltonian> hams;
41   std::vector<TinyVector<double, 3>> tpos;
42   DriverWalkerResourceCollection driverwalker_resource_collection_;
43   const MultiWalkerDispatchers dispatchers_;
44 
45 public:
CrowdWithWalkers(SetupPools & pools)46   CrowdWithWalkers(SetupPools& pools) : em(pools.comm), dispatchers_(true)
47   {
48     crowd_ptr    = std::make_unique<Crowd>(em, driverwalker_resource_collection_, dispatchers_);
49     Crowd& crowd = *crowd_ptr;
50     // To match the minimal particle set
51     int num_particles = 2;
52     // for testing we update the first position in the walker
53     auto makePointWalker = [this, &pools, &crowd, num_particles](TinyVector<double, 3> pos) {
54       walkers.emplace_back(std::make_unique<MCPWalker>(num_particles));
55       walkers.back()->R[0] = pos;
56       psets.emplace_back(std::make_unique<ParticleSet>(*(pools.particle_pool->getParticleSet("e"))));
57       twfs.emplace_back(pools.wavefunction_pool->getPrimary()->makeClone(*psets.back()));
58       hams.emplace_back(pools.hamiltonian_pool->getPrimary()->makeClone(*psets.back(), *twfs.back()));
59       crowd.addWalker(*walkers.back(), *psets.back(), *twfs.back(), *hams.back());
60     };
61 
62     tpos.push_back(TinyVector<double, 3>(1.0, 0.0, 0.0));
63     makePointWalker(tpos.back());
64     tpos.push_back(TinyVector<double, 3>(1.0, 2.0, 0.0));
65     makePointWalker(tpos.back());
66   }
67 
makeAnotherPointWalker()68   void makeAnotherPointWalker()
69   {
70     walkers.emplace_back(std::make_unique<MCPWalker>(*walkers.back()));
71     psets.emplace_back(std::make_unique<ParticleSet>(*psets.back()));
72     twfs.emplace_back(twfs.back()->makeClone(*psets.back()));
73     hams.emplace_back(hams.back()->makeClone(*psets.back(), *twfs.back()));
74   }
75 };
76 } // namespace testing
77 
78 TEST_CASE("Crowd integration", "[drivers]")
79 {
80   Communicate* comm = OHMMS::Controller;
81 
82   EstimatorManagerNew em(comm);
83 
84   const MultiWalkerDispatchers dispatchers(true);
85   DriverWalkerResourceCollection driverwalker_resource_collection_;
86   Crowd crowd(em, driverwalker_resource_collection_, dispatchers);
87 }
88 
89 TEST_CASE("Crowd redistribute walkers")
90 {
91   using namespace testing;
92   SetupPools pools;
93 
94   CrowdWithWalkers crowd_with_walkers(pools);
95   Crowd& crowd = crowd_with_walkers.get_crowd();
96 
97   crowd_with_walkers.makeAnotherPointWalker();
98   crowd.clearWalkers();
99   for (int iw = 0; iw < crowd_with_walkers.walkers.size(); ++iw)
100     crowd.addWalker(*crowd_with_walkers.walkers[iw], *crowd_with_walkers.psets[iw], *crowd_with_walkers.twfs[iw],
101                     *crowd_with_walkers.hams[iw]);
102   REQUIRE(crowd.size() == 3);
103 }
104 
105 } // namespace qmcplusplus
106