1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2021 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #include "TWFdispatcher.h"
14 #include <cassert>
15 #include "TrialWaveFunction.h"
16 
17 namespace qmcplusplus
18 {
TWFdispatcher(bool use_batch)19 TWFdispatcher::TWFdispatcher(bool use_batch) : use_batch_(use_batch) {}
20 
flex_evaluateLog(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list) const21 void TWFdispatcher::flex_evaluateLog(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
22                                      const RefVectorWithLeader<ParticleSet>& p_list) const
23 {
24   assert(wf_list.size() == p_list.size());
25   if (use_batch_)
26     TrialWaveFunction::mw_evaluateLog(wf_list, p_list);
27   else
28     for (size_t iw = 0; iw < wf_list.size(); iw++)
29       wf_list[iw].evaluateLog(p_list[iw]);
30 }
31 
flex_recompute(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,const std::vector<bool> & recompute) const32 void TWFdispatcher::flex_recompute(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
33                                    const RefVectorWithLeader<ParticleSet>& p_list,
34                                    const std::vector<bool>& recompute) const
35 {
36   assert(wf_list.size() == p_list.size());
37   if (use_batch_)
38     TrialWaveFunction::mw_recompute(wf_list, p_list, recompute);
39   else
40     for (size_t iw = 0; iw < wf_list.size(); iw++)
41       if (recompute[iw])
42         wf_list[iw].recompute(p_list[iw]);
43 }
44 
flex_calcRatio(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,int iat,std::vector<PsiValueType> & ratios,ComputeType ct) const45 void TWFdispatcher::flex_calcRatio(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
46                                    const RefVectorWithLeader<ParticleSet>& p_list,
47                                    int iat,
48                                    std::vector<PsiValueType>& ratios,
49                                    ComputeType ct) const
50 {
51   assert(wf_list.size() == p_list.size());
52   if (use_batch_)
53     TrialWaveFunction::mw_calcRatio(wf_list, p_list, iat, ratios, ct);
54   else
55   {
56     const int num_wf = wf_list.size();
57     ratios.resize(num_wf);
58     for (size_t iw = 0; iw < num_wf; iw++)
59       ratios[iw] = wf_list[iw].calcRatio(p_list[iw], iat, ct);
60   }
61 }
62 
flex_prepareGroup(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,int ig) const63 void TWFdispatcher::flex_prepareGroup(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
64                                       const RefVectorWithLeader<ParticleSet>& p_list,
65                                       int ig) const
66 {
67   assert(wf_list.size() == p_list.size());
68   if (use_batch_)
69     TrialWaveFunction::mw_prepareGroup(wf_list, p_list, ig);
70   else
71     for (size_t iw = 0; iw < wf_list.size(); iw++)
72       wf_list[iw].prepareGroup(p_list[iw], ig);
73 }
74 
flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,int iat,std::vector<GradType> & grad_now) const75 void TWFdispatcher::flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
76                                   const RefVectorWithLeader<ParticleSet>& p_list,
77                                   int iat,
78                                   std::vector<GradType>& grad_now) const
79 {
80   assert(wf_list.size() == p_list.size());
81   if (use_batch_)
82     TrialWaveFunction::mw_evalGrad(wf_list, p_list, iat, grad_now);
83   else
84   {
85     const int num_wf = wf_list.size();
86     grad_now.resize(num_wf);
87     for (size_t iw = 0; iw < num_wf; iw++)
88       grad_now[iw] = wf_list[iw].evalGrad(p_list[iw], iat);
89   }
90 }
91 
flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,int iat,std::vector<PsiValueType> & ratios,std::vector<GradType> & grad_new) const92 void TWFdispatcher::flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
93                                        const RefVectorWithLeader<ParticleSet>& p_list,
94                                        int iat,
95                                        std::vector<PsiValueType>& ratios,
96                                        std::vector<GradType>& grad_new) const
97 {
98   assert(wf_list.size() == p_list.size());
99   if (use_batch_)
100     TrialWaveFunction::mw_calcRatioGrad(wf_list, p_list, iat, ratios, grad_new);
101   else
102   {
103     const int num_wf = wf_list.size();
104     ratios.resize(num_wf);
105     grad_new.resize(num_wf);
106     for (size_t iw = 0; iw < num_wf; iw++)
107       ratios[iw] = wf_list[iw].calcRatioGrad(p_list[iw], iat, grad_new[iw]);
108   }
109 }
110 
flex_accept_rejectMove(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,int iat,const std::vector<bool> & isAccepted,bool safe_to_delay) const111 void TWFdispatcher::flex_accept_rejectMove(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
112                                            const RefVectorWithLeader<ParticleSet>& p_list,
113                                            int iat,
114                                            const std::vector<bool>& isAccepted,
115                                            bool safe_to_delay) const
116 {
117   assert(wf_list.size() == p_list.size());
118   if (use_batch_)
119     TrialWaveFunction::mw_accept_rejectMove(wf_list, p_list, iat, isAccepted, safe_to_delay);
120   else
121     for (size_t iw = 0; iw < wf_list.size(); iw++)
122       if (isAccepted[iw])
123         wf_list[iw].acceptMove(p_list[iw], iat, safe_to_delay);
124       else
125         wf_list[iw].rejectMove(iat);
126 }
127 
flex_completeUpdates(const RefVectorWithLeader<TrialWaveFunction> & wf_list) const128 void TWFdispatcher::flex_completeUpdates(const RefVectorWithLeader<TrialWaveFunction>& wf_list) const
129 {
130   if (use_batch_)
131     TrialWaveFunction::mw_completeUpdates(wf_list);
132   else
133     for (TrialWaveFunction& wf : wf_list)
134       wf.completeUpdates();
135 }
136 
flex_evaluateGL(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVectorWithLeader<ParticleSet> & p_list,bool fromscratch) const137 void TWFdispatcher::flex_evaluateGL(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
138                                     const RefVectorWithLeader<ParticleSet>& p_list,
139                                     bool fromscratch) const
140 {
141   assert(wf_list.size() == p_list.size());
142   if (use_batch_)
143     TrialWaveFunction::mw_evaluateGL(wf_list, p_list, fromscratch);
144   else
145     for (size_t iw = 0; iw < wf_list.size(); iw++)
146       wf_list[iw].evaluateGL(p_list[iw], fromscratch);
147 }
148 
flex_evaluateRatios(const RefVectorWithLeader<TrialWaveFunction> & wf_list,const RefVector<const VirtualParticleSet> & vp_list,const RefVector<std::vector<ValueType>> & ratios_list,ComputeType ct) const149 void TWFdispatcher::flex_evaluateRatios(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
150                                         const RefVector<const VirtualParticleSet>& vp_list,
151                                         const RefVector<std::vector<ValueType>>& ratios_list,
152                                         ComputeType ct) const
153 {
154   assert(wf_list.size() == vp_list.size());
155   assert(wf_list.size() == ratios_list.size());
156   if (use_batch_)
157     TrialWaveFunction::mw_evaluateRatios(wf_list, vp_list, ratios_list, ct);
158   else
159     for (size_t iw = 0; iw < wf_list.size(); iw++)
160       wf_list[iw].evaluateRatios(vp_list[iw], ratios_list[iw], ct);
161 }
162 
163 } // namespace qmcplusplus
164