1 /* -*- c++ -*- ----------------------------------------------------------
2    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
3    https://www.lammps.org/, Sandia National Laboratories
4    Steve Plimpton, sjplimp@sandia.gov
5 
6    Copyright (2003) Sandia Corporation.  Under the terms of Contract
7    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
8    certain rights in this software.  This software is distributed under
9    the GNU General Public License.
10 
11    See the README file in the top-level LAMMPS directory.
12 ------------------------------------------------------------------------- */
13 
14 #ifdef FIX_CLASS
15 // clang-format off
16 FixStyle(rx/kk,FixRxKokkos<LMPDeviceType>);
17 FixStyle(rx/kk/device,FixRxKokkos<LMPDeviceType>);
18 FixStyle(rx/kk/host,FixRxKokkos<LMPHostType>);
19 // clang-format on
20 #else
21 
22 // clang-format off
23 #ifndef LMP_FIX_RX_KOKKOS_H
24 #define LMP_FIX_RX_KOKKOS_H
25 
26 #include "fix_rx.h"
27 #include "pair_dpd_fdt_energy_kokkos.h"
28 #include "kokkos_type.h"
29 #include "neigh_list.h"
30 #include "neigh_list_kokkos.h"
31 
32 namespace LAMMPS_NS {
33 
34 struct Tag_FixRxKokkos_zeroTemperatureViews {};
35 struct Tag_FixRxKokkos_zeroCounterViews {};
36 
37 template <int WT_FLAG, bool NEWTON_PAIR, int NEIGHFLAG>
38 struct Tag_FixRxKokkos_firstPairOperator {};
39 
40 template <int WT_FLAG, int LOCAL_TEMP_FLAG>
41 struct Tag_FixRxKokkos_2ndPairOperator {};
42 
43 template <bool ZERO_RATES>
44 struct Tag_FixRxKokkos_solveSystems {};
45 
46 struct s_CounterType
47 {
48   int nSteps, nIters, nFuncs, nFails;
49 
50   KOKKOS_INLINE_FUNCTION
s_CounterTypes_CounterType51   s_CounterType() : nSteps(0), nIters(0), nFuncs(0), nFails(0) {};
52 
53   KOKKOS_INLINE_FUNCTION
54   s_CounterType& operator+=(const s_CounterType &rhs)
55   {
56     nSteps += rhs.nSteps;
57     nIters += rhs.nIters;
58     nFuncs += rhs.nFuncs;
59     nFails += rhs.nFails;
60     return *this;
61   }
62 
63   KOKKOS_INLINE_FUNCTION
64   volatile s_CounterType& operator+=(const volatile s_CounterType &rhs) volatile
65   {
66     nSteps += rhs.nSteps;
67     nIters += rhs.nIters;
68     nFuncs += rhs.nFuncs;
69     nFails += rhs.nFails;
70     return *this;
71   }
72 };
73 typedef struct s_CounterType CounterType;
74 
75 template <class DeviceType>
76 class FixRxKokkos : public FixRX {
77  public:
78   typedef ArrayTypes<DeviceType> AT;
79 
80   FixRxKokkos(class LAMMPS *, int, char **);
81   virtual ~FixRxKokkos();
82   virtual void init();
83   void init_list(int, class NeighList *);
84   void post_constructor();
85   virtual void setup_pre_force(int);
86   virtual void pre_force(int);
87 
88   // Define a value_type here for the reduction operator on CounterType.
89   typedef CounterType value_type;
90 
91   KOKKOS_INLINE_FUNCTION
92   void operator()(Tag_FixRxKokkos_zeroCounterViews, const int&) const;
93 
94   KOKKOS_INLINE_FUNCTION
95   void operator()(Tag_FixRxKokkos_zeroTemperatureViews, const int&) const;
96 
97   template <int WT_FLAG, bool NEWTON_PAIR, int NEIGHFLAG>
98   KOKKOS_INLINE_FUNCTION
99   void operator()(Tag_FixRxKokkos_firstPairOperator<WT_FLAG,NEWTON_PAIR,NEIGHFLAG>, const int&) const;
100 
101   template <int WT_FLAG, int LOCAL_TEMP_FLAG>
102   KOKKOS_INLINE_FUNCTION
103   void operator()(Tag_FixRxKokkos_2ndPairOperator<WT_FLAG,LOCAL_TEMP_FLAG>, const int&) const;
104 
105   template <bool ZERO_RATES>
106   KOKKOS_INLINE_FUNCTION
107   void operator()(Tag_FixRxKokkos_solveSystems<ZERO_RATES>, const int&, CounterType&) const;
108 
109  //protected:
110   PairDPDfdtEnergyKokkos<DeviceType>* pairDPDEKK;
111   double VDPD;
112 
113   double boltz;
114   double t_stop;
115 
116   template <typename T, int stride = 1>
117   struct StridedArrayType
118   {
119     typedef T value_type;
120     enum { Stride = stride };
121 
122     value_type *m_data;
123 
124     KOKKOS_INLINE_FUNCTION
StridedArrayTypeStridedArrayType125     StridedArrayType() : m_data(nullptr) {}
126     KOKKOS_INLINE_FUNCTION
StridedArrayTypeStridedArrayType127     StridedArrayType(value_type *ptr) : m_data(ptr) {}
128 
operatorStridedArrayType129     KOKKOS_INLINE_FUNCTION       value_type& operator()(const int idx)       { return m_data[Stride*idx]; }
operatorStridedArrayType130     KOKKOS_INLINE_FUNCTION const value_type& operator()(const int idx) const { return m_data[Stride*idx]; }
131     KOKKOS_INLINE_FUNCTION       value_type& operator[](const int idx)       { return m_data[Stride*idx]; }
132     KOKKOS_INLINE_FUNCTION const value_type& operator[](const int idx) const { return m_data[Stride*idx]; }
133   };
134 
135   template <int stride = 1>
136   struct UserRHSDataKokkos
137   {
138     StridedArrayType<double,1> kFor;
139     StridedArrayType<double,1> rxnRateLaw;
140   };
141 
142   void solve_reactions(const int vflag, const bool isPreForce);
143 
144   int rhs       (double, const double *, double *, void *) const;
145   int rhs_dense (double, const double *, double *, void *) const;
146   int rhs_sparse(double, const double *, double *, void *) const;
147 
148   template <typename VectorType, typename UserDataType>
149     KOKKOS_INLINE_FUNCTION
150   int k_rhs       (double, const VectorType&, VectorType&, UserDataType& ) const;
151 
152   template <typename VectorType, typename UserDataType>
153     KOKKOS_INLINE_FUNCTION
154   int k_rhs_dense (double, const VectorType&, VectorType&, UserDataType& ) const;
155 
156   template <typename VectorType, typename UserDataType>
157     KOKKOS_INLINE_FUNCTION
158   int k_rhs_sparse(double, const VectorType&, VectorType&, UserDataType& ) const;
159 
160   //!< Classic Runge-Kutta 4th-order stepper.
161   void rk4(const double t_stop, double *y, double *rwork, void *v_params) const;
162 
163   //!< Runge-Kutta-Fehlberg ODE Solver.
164   void rkf45(const int neq, const double t_stop, double *y, double *rwork, void *v_params, CounterType& counter) const;
165 
166   //!< Runge-Kutta-Fehlberg ODE stepper function.
167   void rkf45_step (const int neq, const double h, double y[], double y_out[],
168                    double rwk[], void *) const;
169 
170   //!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver.
171   int rkf45_h0 (const int neq, const double t, const double t_stop,
172                      const double hmin, const double hmax,
173                      double& h0, double y[], double rwk[], void *v_params) const;
174 
175   //!< Classic Runge-Kutta 4th-order stepper.
176   template <typename VectorType, typename UserDataType>
177     KOKKOS_INLINE_FUNCTION
178   void k_rk4(const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData) const;
179 
180   //!< Runge-Kutta-Fehlberg ODE Solver.
181   template <typename VectorType, typename UserDataType>
182     KOKKOS_INLINE_FUNCTION
183   void k_rkf45(const int neq, const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData, CounterType& counter) const;
184 
185   //!< Runge-Kutta-Fehlberg ODE stepper function.
186   template <typename VectorType, typename UserDataType>
187     KOKKOS_INLINE_FUNCTION
188   void k_rkf45_step (const int neq, const double h, VectorType& y, VectorType& y_out,
189                      VectorType& rwk, UserDataType& userData) const;
190 
191   //!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver.
192   template <typename VectorType, typename UserDataType>
193     KOKKOS_INLINE_FUNCTION
194   int k_rkf45_h0 (const int neq, const double t, const double t_stop,
195                   const double hmin, const double hmax,
196                   double& h0, VectorType& y, VectorType& rwk, UserDataType& userData) const;
197 
198   //!< ODE Solver diagnostics.
199   void odeDiagnostics();
200 
201   //!< Special counters per-ode.
202   int *diagnosticCounterPerODEnSteps;
203   int *diagnosticCounterPerODEnFuncs;
204   DAT::tdual_int_1d k_diagnosticCounterPerODEnSteps;
205   DAT::tdual_int_1d k_diagnosticCounterPerODEnFuncs;
206   //typename ArrayTypes<DeviceType>::t_int_1d d_diagnosticCounterPerODEnSteps;
207   //typename ArrayTypes<DeviceType>::t_int_1d d_diagnosticCounterPerODEnFuncs;
208   typename AT::t_int_1d d_diagnosticCounterPerODEnSteps;
209   typename AT::t_int_1d d_diagnosticCounterPerODEnFuncs;
210   HAT::t_int_1d h_diagnosticCounterPerODEnSteps;
211   HAT::t_int_1d h_diagnosticCounterPerODEnFuncs;
212 
213   template <typename KokkosDeviceType>
214   struct KineticsType
215   {
216     // Arrhenius rate coefficients.
217     typename ArrayTypes<KokkosDeviceType>::t_float_1d Arr, nArr, Ea;
218 
219     // Dense versions.
220     typename ArrayTypes<KokkosDeviceType>::t_float_2d stoich, stoichReactants, stoichProducts;
221 
222     // Sparse versions.
223     typename ArrayTypes<KokkosDeviceType>::t_int_2d   nuk, inu;
224     typename ArrayTypes<KokkosDeviceType>::t_float_2d nu;
225     typename ArrayTypes<KokkosDeviceType>::t_int_1d   isIntegral;
226   };
227 
228   //!< Kokkos versions of the kinetics data.
229   KineticsType<LMPHostType> h_kineticsData;
230   KineticsType<DeviceType>  d_kineticsData;
231 
232   bool update_kinetics_data;
233 
234   void create_kinetics_data();
235 
236   // Need a dual-view and device-view for dpdThetaLocal and sumWeights since they're used in several callbacks.
237   DAT::tdual_efloat_1d k_dpdThetaLocal, k_sumWeights;
238   //typename ArrayTypes<DeviceType>::t_efloat_1d d_dpdThetaLocal, d_sumWeights;
239   typename AT::t_efloat_1d d_dpdThetaLocal, d_sumWeights;
240   HAT::t_efloat_1d h_dpdThetaLocal, h_sumWeights;
241 
242   typename ArrayTypes<DeviceType>::t_x_array_randomread d_x       ;
243   typename ArrayTypes<DeviceType>::t_int_1d_randomread  d_type    ;
244   typename ArrayTypes<DeviceType>::t_efloat_1d          d_dpdTheta;
245 
246   typename ArrayTypes<DeviceType>::tdual_ffloat_2d k_cutsq;
247   typename ArrayTypes<DeviceType>::t_ffloat_2d     d_cutsq;
248   //double **h_cutsq;
249 
250   typename ArrayTypes<DeviceType>::t_neighbors_2d d_neighbors;
251   typename ArrayTypes<DeviceType>::t_int_1d       d_ilist    ;
252   typename ArrayTypes<DeviceType>::t_int_1d       d_numneigh ;
253 
254   typename ArrayTypes<DeviceType>::t_float_2d  d_dvector;
255   typename ArrayTypes<DeviceType>::t_int_1d    d_mask   ;
256 
257   typename ArrayTypes<DeviceType>::t_double_1d d_scratchSpace;
258   size_t scratchSpaceSize;
259 
260   // Error flag for any failures.
261   DAT::tdual_int_scalar k_error_flag;
262 
263   template <int WT_FLAG, int LOCAL_TEMP_FLAG, bool NEWTON_PAIR, int NEIGHFLAG>
264   void computeLocalTemperature();
265 
266   int pack_reverse_comm(int, int, double *);
267   void unpack_reverse_comm(int, int *, double *);
268   int pack_forward_comm(int , int *, double *, int, int *);
269   void unpack_forward_comm(int , int , double *);
270 
271  //private: // replicate a few from FixRX
272   int my_restartFlag;
273   int nlocal;
274 };
275 
276 }
277 
278 #endif
279 #endif
280 
281 /* ERROR/WARNING messages:
282 
283 */
284