1 ////////////////////////////////////////////////////////////////////////////////////// 2 // This file is distributed under the University of Illinois/NCSA Open Source License. 3 // See LICENSE file in top directory for details. 4 // 5 // Copyright (c) 2021 QMCPACK developers. 6 // 7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory 8 // 9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory 10 ////////////////////////////////////////////////////////////////////////////////////// 11 12 13 #ifndef QMCPLUSPLUS_TWFDISPATCH_H 14 #define QMCPLUSPLUS_TWFDISPATCH_H 15 16 #include "TrialWaveFunction.h" 17 18 namespace qmcplusplus 19 { 20 21 /** Wrappers for dispatching to TrialWaveFunction single walker APIs or mw_ APIs. 22 * This should be only used by QMC drivers. 23 * member function names must match mw_ APIs in TrialWaveFunction 24 */ 25 class TWFdispatcher 26 { 27 public: 28 using PsiValueType = TrialWaveFunction::PsiValueType; 29 using ComputeType = TrialWaveFunction::ComputeType; 30 using ValueType = TrialWaveFunction::ValueType; 31 using GradType = TrialWaveFunction::GradType; 32 33 TWFdispatcher(bool use_batch); 34 35 void flex_evaluateLog(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 36 const RefVectorWithLeader<ParticleSet>& p_list) const; 37 38 void flex_recompute(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 39 const RefVectorWithLeader<ParticleSet>& p_list, 40 const std::vector<bool>& recompute) const; 41 42 void flex_calcRatio(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 43 const RefVectorWithLeader<ParticleSet>& p_list, 44 int iat, 45 std::vector<PsiValueType>& ratios, 46 ComputeType ct = ComputeType::ALL) const; 47 48 void flex_prepareGroup(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 49 const RefVectorWithLeader<ParticleSet>& p_list, 50 int ig) const; 51 52 void flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 53 const RefVectorWithLeader<ParticleSet>& p_list, 54 int iat, 55 std::vector<GradType>& grad_now) const; 56 57 void flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 58 const RefVectorWithLeader<ParticleSet>& p_list, 59 int iat, 60 std::vector<PsiValueType>& ratios, 61 std::vector<GradType>& grad_new) const; 62 63 void flex_accept_rejectMove(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 64 const RefVectorWithLeader<ParticleSet>& p_list, 65 int iat, 66 const std::vector<bool>& isAccepted, 67 bool safe_to_delay) const; 68 69 void flex_completeUpdates(const RefVectorWithLeader<TrialWaveFunction>& wf_list) const; 70 71 void flex_evaluateGL(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 72 const RefVectorWithLeader<ParticleSet>& p_list, 73 bool fromscratch) const; 74 75 void flex_evaluateRatios(const RefVectorWithLeader<TrialWaveFunction>& wf_list, 76 const RefVector<const VirtualParticleSet>& vp_list, 77 const RefVector<std::vector<ValueType>>& ratios_list, 78 ComputeType ct) const; 79 80 private: 81 bool use_batch_; 82 }; 83 } // namespace qmcplusplus 84 85 #endif 86