1 /* -*- c++ -*- ---------------------------------------------------------- 2 LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator 3 https://www.lammps.org/, Sandia National Laboratories 4 Steve Plimpton, sjplimp@sandia.gov 5 6 Copyright (2003) Sandia Corporation. Under the terms of Contract 7 DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains 8 certain rights in this software. This software is distributed under 9 the GNU General Public License. 10 11 See the README file in the top-level LAMMPS directory. 12 ------------------------------------------------------------------------- */ 13 14 #ifdef FIX_CLASS 15 // clang-format off 16 FixStyle(rx/kk,FixRxKokkos<LMPDeviceType>); 17 FixStyle(rx/kk/device,FixRxKokkos<LMPDeviceType>); 18 FixStyle(rx/kk/host,FixRxKokkos<LMPHostType>); 19 // clang-format on 20 #else 21 22 // clang-format off 23 #ifndef LMP_FIX_RX_KOKKOS_H 24 #define LMP_FIX_RX_KOKKOS_H 25 26 #include "fix_rx.h" 27 #include "pair_dpd_fdt_energy_kokkos.h" 28 #include "kokkos_type.h" 29 #include "neigh_list.h" 30 #include "neigh_list_kokkos.h" 31 32 namespace LAMMPS_NS { 33 34 struct Tag_FixRxKokkos_zeroTemperatureViews {}; 35 struct Tag_FixRxKokkos_zeroCounterViews {}; 36 37 template <int WT_FLAG, bool NEWTON_PAIR, int NEIGHFLAG> 38 struct Tag_FixRxKokkos_firstPairOperator {}; 39 40 template <int WT_FLAG, int LOCAL_TEMP_FLAG> 41 struct Tag_FixRxKokkos_2ndPairOperator {}; 42 43 template <bool ZERO_RATES> 44 struct Tag_FixRxKokkos_solveSystems {}; 45 46 struct s_CounterType 47 { 48 int nSteps, nIters, nFuncs, nFails; 49 50 KOKKOS_INLINE_FUNCTION s_CounterTypes_CounterType51 s_CounterType() : nSteps(0), nIters(0), nFuncs(0), nFails(0) {}; 52 53 KOKKOS_INLINE_FUNCTION 54 s_CounterType& operator+=(const s_CounterType &rhs) 55 { 56 nSteps += rhs.nSteps; 57 nIters += rhs.nIters; 58 nFuncs += rhs.nFuncs; 59 nFails += rhs.nFails; 60 return *this; 61 } 62 63 KOKKOS_INLINE_FUNCTION 64 volatile s_CounterType& operator+=(const volatile s_CounterType &rhs) volatile 65 { 66 nSteps += rhs.nSteps; 67 nIters += rhs.nIters; 68 nFuncs += rhs.nFuncs; 69 nFails += rhs.nFails; 70 return *this; 71 } 72 }; 73 typedef struct s_CounterType CounterType; 74 75 template <class DeviceType> 76 class FixRxKokkos : public FixRX { 77 public: 78 typedef ArrayTypes<DeviceType> AT; 79 80 FixRxKokkos(class LAMMPS *, int, char **); 81 virtual ~FixRxKokkos(); 82 virtual void init(); 83 void init_list(int, class NeighList *); 84 void post_constructor(); 85 virtual void setup_pre_force(int); 86 virtual void pre_force(int); 87 88 // Define a value_type here for the reduction operator on CounterType. 89 typedef CounterType value_type; 90 91 KOKKOS_INLINE_FUNCTION 92 void operator()(Tag_FixRxKokkos_zeroCounterViews, const int&) const; 93 94 KOKKOS_INLINE_FUNCTION 95 void operator()(Tag_FixRxKokkos_zeroTemperatureViews, const int&) const; 96 97 template <int WT_FLAG, bool NEWTON_PAIR, int NEIGHFLAG> 98 KOKKOS_INLINE_FUNCTION 99 void operator()(Tag_FixRxKokkos_firstPairOperator<WT_FLAG,NEWTON_PAIR,NEIGHFLAG>, const int&) const; 100 101 template <int WT_FLAG, int LOCAL_TEMP_FLAG> 102 KOKKOS_INLINE_FUNCTION 103 void operator()(Tag_FixRxKokkos_2ndPairOperator<WT_FLAG,LOCAL_TEMP_FLAG>, const int&) const; 104 105 template <bool ZERO_RATES> 106 KOKKOS_INLINE_FUNCTION 107 void operator()(Tag_FixRxKokkos_solveSystems<ZERO_RATES>, const int&, CounterType&) const; 108 109 //protected: 110 PairDPDfdtEnergyKokkos<DeviceType>* pairDPDEKK; 111 double VDPD; 112 113 double boltz; 114 double t_stop; 115 116 template <typename T, int stride = 1> 117 struct StridedArrayType 118 { 119 typedef T value_type; 120 enum { Stride = stride }; 121 122 value_type *m_data; 123 124 KOKKOS_INLINE_FUNCTION StridedArrayTypeStridedArrayType125 StridedArrayType() : m_data(nullptr) {} 126 KOKKOS_INLINE_FUNCTION StridedArrayTypeStridedArrayType127 StridedArrayType(value_type *ptr) : m_data(ptr) {} 128 operatorStridedArrayType129 KOKKOS_INLINE_FUNCTION value_type& operator()(const int idx) { return m_data[Stride*idx]; } operatorStridedArrayType130 KOKKOS_INLINE_FUNCTION const value_type& operator()(const int idx) const { return m_data[Stride*idx]; } 131 KOKKOS_INLINE_FUNCTION value_type& operator[](const int idx) { return m_data[Stride*idx]; } 132 KOKKOS_INLINE_FUNCTION const value_type& operator[](const int idx) const { return m_data[Stride*idx]; } 133 }; 134 135 template <int stride = 1> 136 struct UserRHSDataKokkos 137 { 138 StridedArrayType<double,1> kFor; 139 StridedArrayType<double,1> rxnRateLaw; 140 }; 141 142 void solve_reactions(const int vflag, const bool isPreForce); 143 144 int rhs (double, const double *, double *, void *) const; 145 int rhs_dense (double, const double *, double *, void *) const; 146 int rhs_sparse(double, const double *, double *, void *) const; 147 148 template <typename VectorType, typename UserDataType> 149 KOKKOS_INLINE_FUNCTION 150 int k_rhs (double, const VectorType&, VectorType&, UserDataType& ) const; 151 152 template <typename VectorType, typename UserDataType> 153 KOKKOS_INLINE_FUNCTION 154 int k_rhs_dense (double, const VectorType&, VectorType&, UserDataType& ) const; 155 156 template <typename VectorType, typename UserDataType> 157 KOKKOS_INLINE_FUNCTION 158 int k_rhs_sparse(double, const VectorType&, VectorType&, UserDataType& ) const; 159 160 //!< Classic Runge-Kutta 4th-order stepper. 161 void rk4(const double t_stop, double *y, double *rwork, void *v_params) const; 162 163 //!< Runge-Kutta-Fehlberg ODE Solver. 164 void rkf45(const int neq, const double t_stop, double *y, double *rwork, void *v_params, CounterType& counter) const; 165 166 //!< Runge-Kutta-Fehlberg ODE stepper function. 167 void rkf45_step (const int neq, const double h, double y[], double y_out[], 168 double rwk[], void *) const; 169 170 //!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver. 171 int rkf45_h0 (const int neq, const double t, const double t_stop, 172 const double hmin, const double hmax, 173 double& h0, double y[], double rwk[], void *v_params) const; 174 175 //!< Classic Runge-Kutta 4th-order stepper. 176 template <typename VectorType, typename UserDataType> 177 KOKKOS_INLINE_FUNCTION 178 void k_rk4(const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData) const; 179 180 //!< Runge-Kutta-Fehlberg ODE Solver. 181 template <typename VectorType, typename UserDataType> 182 KOKKOS_INLINE_FUNCTION 183 void k_rkf45(const int neq, const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData, CounterType& counter) const; 184 185 //!< Runge-Kutta-Fehlberg ODE stepper function. 186 template <typename VectorType, typename UserDataType> 187 KOKKOS_INLINE_FUNCTION 188 void k_rkf45_step (const int neq, const double h, VectorType& y, VectorType& y_out, 189 VectorType& rwk, UserDataType& userData) const; 190 191 //!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver. 192 template <typename VectorType, typename UserDataType> 193 KOKKOS_INLINE_FUNCTION 194 int k_rkf45_h0 (const int neq, const double t, const double t_stop, 195 const double hmin, const double hmax, 196 double& h0, VectorType& y, VectorType& rwk, UserDataType& userData) const; 197 198 //!< ODE Solver diagnostics. 199 void odeDiagnostics(); 200 201 //!< Special counters per-ode. 202 int *diagnosticCounterPerODEnSteps; 203 int *diagnosticCounterPerODEnFuncs; 204 DAT::tdual_int_1d k_diagnosticCounterPerODEnSteps; 205 DAT::tdual_int_1d k_diagnosticCounterPerODEnFuncs; 206 //typename ArrayTypes<DeviceType>::t_int_1d d_diagnosticCounterPerODEnSteps; 207 //typename ArrayTypes<DeviceType>::t_int_1d d_diagnosticCounterPerODEnFuncs; 208 typename AT::t_int_1d d_diagnosticCounterPerODEnSteps; 209 typename AT::t_int_1d d_diagnosticCounterPerODEnFuncs; 210 HAT::t_int_1d h_diagnosticCounterPerODEnSteps; 211 HAT::t_int_1d h_diagnosticCounterPerODEnFuncs; 212 213 template <typename KokkosDeviceType> 214 struct KineticsType 215 { 216 // Arrhenius rate coefficients. 217 typename ArrayTypes<KokkosDeviceType>::t_float_1d Arr, nArr, Ea; 218 219 // Dense versions. 220 typename ArrayTypes<KokkosDeviceType>::t_float_2d stoich, stoichReactants, stoichProducts; 221 222 // Sparse versions. 223 typename ArrayTypes<KokkosDeviceType>::t_int_2d nuk, inu; 224 typename ArrayTypes<KokkosDeviceType>::t_float_2d nu; 225 typename ArrayTypes<KokkosDeviceType>::t_int_1d isIntegral; 226 }; 227 228 //!< Kokkos versions of the kinetics data. 229 KineticsType<LMPHostType> h_kineticsData; 230 KineticsType<DeviceType> d_kineticsData; 231 232 bool update_kinetics_data; 233 234 void create_kinetics_data(); 235 236 // Need a dual-view and device-view for dpdThetaLocal and sumWeights since they're used in several callbacks. 237 DAT::tdual_efloat_1d k_dpdThetaLocal, k_sumWeights; 238 //typename ArrayTypes<DeviceType>::t_efloat_1d d_dpdThetaLocal, d_sumWeights; 239 typename AT::t_efloat_1d d_dpdThetaLocal, d_sumWeights; 240 HAT::t_efloat_1d h_dpdThetaLocal, h_sumWeights; 241 242 typename ArrayTypes<DeviceType>::t_x_array_randomread d_x ; 243 typename ArrayTypes<DeviceType>::t_int_1d_randomread d_type ; 244 typename ArrayTypes<DeviceType>::t_efloat_1d d_dpdTheta; 245 246 typename ArrayTypes<DeviceType>::tdual_ffloat_2d k_cutsq; 247 typename ArrayTypes<DeviceType>::t_ffloat_2d d_cutsq; 248 //double **h_cutsq; 249 250 typename ArrayTypes<DeviceType>::t_neighbors_2d d_neighbors; 251 typename ArrayTypes<DeviceType>::t_int_1d d_ilist ; 252 typename ArrayTypes<DeviceType>::t_int_1d d_numneigh ; 253 254 typename ArrayTypes<DeviceType>::t_float_2d d_dvector; 255 typename ArrayTypes<DeviceType>::t_int_1d d_mask ; 256 257 typename ArrayTypes<DeviceType>::t_double_1d d_scratchSpace; 258 size_t scratchSpaceSize; 259 260 // Error flag for any failures. 261 DAT::tdual_int_scalar k_error_flag; 262 263 template <int WT_FLAG, int LOCAL_TEMP_FLAG, bool NEWTON_PAIR, int NEIGHFLAG> 264 void computeLocalTemperature(); 265 266 int pack_reverse_comm(int, int, double *); 267 void unpack_reverse_comm(int, int *, double *); 268 int pack_forward_comm(int , int *, double *, int, int *); 269 void unpack_forward_comm(int , int , double *); 270 271 //private: // replicate a few from FixRX 272 int my_restartFlag; 273 int nlocal; 274 }; 275 276 } 277 278 #endif 279 #endif 280 281 /* ERROR/WARNING messages: 282 283 */ 284