1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2021 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #ifndef QMCPLUSPLUS_TWFDISPATCH_H
14 #define QMCPLUSPLUS_TWFDISPATCH_H
15 
16 #include "TrialWaveFunction.h"
17 
18 namespace qmcplusplus
19 {
20 
21 /** Wrappers for dispatching to TrialWaveFunction single walker APIs or mw_ APIs.
22  * This should be only used by QMC drivers.
23  * member function names must match mw_ APIs in TrialWaveFunction
24  */
25 class TWFdispatcher
26 {
27 public:
28   using PsiValueType = TrialWaveFunction::PsiValueType;
29   using ComputeType  = TrialWaveFunction::ComputeType;
30   using ValueType    = TrialWaveFunction::ValueType;
31   using GradType     = TrialWaveFunction::GradType;
32 
33   TWFdispatcher(bool use_batch);
34 
35   void flex_evaluateLog(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
36                         const RefVectorWithLeader<ParticleSet>& p_list) const;
37 
38   void flex_recompute(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
39                       const RefVectorWithLeader<ParticleSet>& p_list,
40                       const std::vector<bool>& recompute) const;
41 
42   void flex_calcRatio(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
43                       const RefVectorWithLeader<ParticleSet>& p_list,
44                       int iat,
45                       std::vector<PsiValueType>& ratios,
46                       ComputeType ct = ComputeType::ALL) const;
47 
48   void flex_prepareGroup(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
49                          const RefVectorWithLeader<ParticleSet>& p_list,
50                          int ig) const;
51 
52   void flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
53                      const RefVectorWithLeader<ParticleSet>& p_list,
54                      int iat,
55                      std::vector<GradType>& grad_now) const;
56 
57   void flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
58                           const RefVectorWithLeader<ParticleSet>& p_list,
59                           int iat,
60                           std::vector<PsiValueType>& ratios,
61                           std::vector<GradType>& grad_new) const;
62 
63   void flex_accept_rejectMove(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
64                               const RefVectorWithLeader<ParticleSet>& p_list,
65                               int iat,
66                               const std::vector<bool>& isAccepted,
67                               bool safe_to_delay) const;
68 
69   void flex_completeUpdates(const RefVectorWithLeader<TrialWaveFunction>& wf_list) const;
70 
71   void flex_evaluateGL(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
72                        const RefVectorWithLeader<ParticleSet>& p_list,
73                        bool fromscratch) const;
74 
75   void flex_evaluateRatios(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
76                            const RefVector<const VirtualParticleSet>& vp_list,
77                            const RefVector<std::vector<ValueType>>& ratios_list,
78                            ComputeType ct) const;
79 
80 private:
81   bool use_batch_;
82 };
83 } // namespace qmcplusplus
84 
85 #endif
86