1 /* -*- c++ -*- ----------------------------------------------------------
2    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
3    https://www.lammps.org/, Sandia National Laboratories
4    Steve Plimpton, sjplimp@sandia.gov
5 
6    Copyright (2003) Sandia Corporation.  Under the terms of Contract
7    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
8    certain rights in this software.  This software is distributed under
9    the GNU General Public License.
10 
11    See the README file in the top-level LAMMPS directory.
12 ------------------------------------------------------------------------- */
13 
14 #ifdef NPAIR_CLASS
15 // clang-format off
16 typedef NPairSSAKokkos<LMPHostType> NPairSSAKokkosHost;
17 NPairStyle(half/bin/newton/ssa/kk/host,
18            NPairSSAKokkosHost,
19            NP_HALF | NP_BIN | NP_NEWTON | NP_ORTHO | NP_SSA | NP_GHOST | NP_KOKKOS_HOST);
20 
21 typedef NPairSSAKokkos<LMPDeviceType> NPairSSAKokkosDevice;
22 NPairStyle(half/bin/newton/ssa/kk/device,
23            NPairSSAKokkosDevice,
24            NP_HALF | NP_BIN | NP_NEWTON | NP_ORTHO | NP_SSA | NP_GHOST | NP_KOKKOS_DEVICE);
25 // clang-format on
26 #else
27 
28 // clang-format off
29 #ifndef LMP_NPAIR_SSA_KOKKOS_H
30 #define LMP_NPAIR_SSA_KOKKOS_H
31 
32 #include "npair.h"
33 #include "neigh_list_kokkos.h"
34 
35 namespace LAMMPS_NS {
36 
37 template<class DeviceType>
38 class NPairSSAKokkos : public NPair {
39  public:
40   typedef ArrayTypes<DeviceType> AT;
41 
42   // SSA Work plan data structures
43   int ssa_phaseCt;
44   DAT::tdual_int_1d k_ssa_phaseLen;
45   DAT::tdual_int_1d_3 k_ssa_phaseOff;
46   DAT::tdual_int_2d k_ssa_itemLoc;
47   DAT::tdual_int_2d k_ssa_itemLen;
48   typename AT::t_int_1d ssa_phaseLen;
49   typename AT::t_int_1d_3 ssa_phaseOff;
50   typename AT::t_int_2d ssa_itemLoc;
51   typename AT::t_int_2d ssa_itemLen;
52 
53   const int ssa_gphaseCt;
54   DAT::tdual_int_1d k_ssa_gphaseLen;
55   DAT::tdual_int_2d k_ssa_gitemLoc;
56   DAT::tdual_int_2d k_ssa_gitemLen;
57   typename AT::t_int_1d ssa_gphaseLen;
58   typename AT::t_int_2d ssa_gitemLoc;
59   typename AT::t_int_2d ssa_gitemLen;
60 
61   NPairSSAKokkos(class LAMMPS *);
~NPairSSAKokkos()62   ~NPairSSAKokkos() {}
63   void copy_neighbor_info();
64   void copy_bin_info();
65   void copy_stencil_info();
66   void build(class NeighList *);
67  private:
68   // data from Neighbor class
69 
70   DAT::tdual_xfloat_2d k_cutneighsq;
71 
72   // exclusion data from Neighbor class
73 
74   DAT::tdual_int_1d k_ex1_type,k_ex2_type;
75   DAT::tdual_int_2d k_ex_type;
76   DAT::tdual_int_1d k_ex1_group,k_ex2_group;
77   DAT::tdual_int_1d k_ex1_bit,k_ex2_bit;
78   DAT::tdual_int_1d k_ex_mol_group;
79   DAT::tdual_int_1d k_ex_mol_bit;
80   DAT::tdual_int_1d k_ex_mol_intra;
81 
82   // data from NBinSSA class
83 
84   int atoms_per_bin;
85   DAT::tdual_int_1d k_bincount;
86   DAT::tdual_int_2d k_bins;
87   int ghosts_per_gbin;
88   DAT::tdual_int_1d k_gbincount;
89   DAT::tdual_int_2d k_gbins;
90   int lbinxlo, lbinxhi, lbinylo, lbinyhi, lbinzlo, lbinzhi;
91 
92   // data from NStencilSSA class
93 
94   int nstencil;
95   DAT::tdual_int_1d k_stencil;  // # of J neighs for each I
96   DAT::tdual_int_1d_3 k_stencilxyz;
97   DAT::tdual_int_1d k_nstencil_ssa;
98   int sx1, sy1, sz1;
99 };
100 
101 template<class DeviceType>
102 class NPairSSAKokkosExecute
103 {
104   typedef ArrayTypes<DeviceType> AT;
105 
106  public:
107   NeighListKokkos<DeviceType> neigh_list;
108 
109   // data from Neighbor class
110 
111   const typename AT::t_xfloat_2d_randomread cutneighsq;
112 
113   // exclusion data from Neighbor class
114 
115   const int exclude;
116 
117   const int nex_type;
118   const typename AT::t_int_1d_const ex1_type,ex2_type;
119   const typename AT::t_int_2d_const ex_type;
120 
121   const int nex_group;
122   const typename AT::t_int_1d_const ex1_group,ex2_group;
123   const typename AT::t_int_1d_const ex1_bit,ex2_bit;
124 
125   const int nex_mol;
126   const typename AT::t_int_1d_const ex_mol_group;
127   const typename AT::t_int_1d_const ex_mol_bit;
128   const typename AT::t_int_1d_const ex_mol_intra;
129 
130   // data from NBinSSA class
131 
132   const typename AT::t_int_1d bincount;
133   const typename AT::t_int_1d_const c_bincount;
134   typename AT::t_int_2d bins;
135   typename AT::t_int_2d_const c_bins;
136   const typename AT::t_int_1d gbincount;
137   const typename AT::t_int_1d_const c_gbincount;
138   typename AT::t_int_2d gbins;
139   typename AT::t_int_2d_const c_gbins;
140   const int lbinxlo, lbinxhi, lbinylo, lbinyhi, lbinzlo, lbinzhi;
141 
142 
143   // data from NStencil class
144 
145   const int nstencil;
146   const int sx1, sy1, sz1;
147   typename AT::t_int_1d d_stencil;  // # of J neighs for each I
148   typename AT::t_int_1d_3 d_stencilxyz;
149   typename AT::t_int_1d d_nstencil_ssa;
150 
151   // data from Atom class
152 
153   const typename AT::t_x_array_randomread x;
154   const typename AT::t_int_1d_const type,mask;
155   const typename AT::t_tagint_1d_const molecule;
156   const typename AT::t_tagint_1d_const tag;
157   const typename AT::t_tagint_2d_const special;
158   const typename AT::t_int_2d_const nspecial;
159   const int molecular;
160   int moltemplate;
161 
162   int special_flag[4];
163 
164   const int nbinx,nbiny,nbinz;
165   const int mbinx,mbiny,mbinz;
166   const int mbinxlo,mbinylo,mbinzlo;
167   const X_FLOAT bininvx,bininvy,bininvz;
168   X_FLOAT bboxhi[3],bboxlo[3];
169 
170   const int nlocal;
171 
172   typename AT::t_int_scalar resize;
173   typename AT::t_int_scalar new_maxneighs;
174   typename ArrayTypes<LMPHostType>::t_int_scalar h_resize;
175   typename ArrayTypes<LMPHostType>::t_int_scalar h_new_maxneighs;
176 
177   const int xperiodic, yperiodic, zperiodic;
178   const int xprd_half, yprd_half, zprd_half;
179 
180   // SSA Work plan data structures
181   int ssa_phaseCt;
182   typename AT::t_int_1d d_ssa_phaseLen;
183   typename AT::t_int_1d_3_const d_ssa_phaseOff;
184   typename AT::t_int_2d d_ssa_itemLoc;
185   typename AT::t_int_2d d_ssa_itemLen;
186   int ssa_gphaseCt;
187   typename AT::t_int_1d d_ssa_gphaseLen;
188   typename AT::t_int_2d d_ssa_gitemLoc;
189   typename AT::t_int_2d d_ssa_gitemLen;
190 
NPairSSAKokkosExecute(const NeighListKokkos<DeviceType> & _neigh_list,const typename AT::t_xfloat_2d_randomread & _cutneighsq,const typename AT::t_int_1d & _bincount,const typename AT::t_int_2d & _bins,const typename AT::t_int_1d & _gbincount,const typename AT::t_int_2d & _gbins,const int _lbinxlo,const int _lbinxhi,const int _lbinylo,const int _lbinyhi,const int _lbinzlo,const int _lbinzhi,const int _nstencil,const int _sx1,const int _sy1,const int _sz1,const typename AT::t_int_1d & _d_stencil,const typename AT::t_int_1d_3 & _d_stencilxyz,const typename AT::t_int_1d & _d_nstencil_ssa,const int _ssa_phaseCt,const typename AT::t_int_1d & _d_ssa_phaseLen,const typename AT::t_int_1d_3 & _d_ssa_phaseOff,const typename AT::t_int_2d & _d_ssa_itemLoc,const typename AT::t_int_2d & _d_ssa_itemLen,const int _ssa_gphaseCt,const typename AT::t_int_1d & _d_ssa_gphaseLen,const typename AT::t_int_2d & _d_ssa_gitemLoc,const typename AT::t_int_2d & _d_ssa_gitemLen,const int _nlocal,const typename AT::t_x_array_randomread & _x,const typename AT::t_int_1d_const & _type,const typename AT::t_int_1d_const & _mask,const typename AT::t_tagint_1d_const & _molecule,const typename AT::t_tagint_1d_const & _tag,const typename AT::t_tagint_2d_const & _special,const typename AT::t_int_2d_const & _nspecial,const int & _molecular,const int & _nbinx,const int & _nbiny,const int & _nbinz,const int & _mbinx,const int & _mbiny,const int & _mbinz,const int & _mbinxlo,const int & _mbinylo,const int & _mbinzlo,const X_FLOAT & _bininvx,const X_FLOAT & _bininvy,const X_FLOAT & _bininvz,const int & _exclude,const int & _nex_type,const typename AT::t_int_1d_const & _ex1_type,const typename AT::t_int_1d_const & _ex2_type,const typename AT::t_int_2d_const & _ex_type,const int & _nex_group,const typename AT::t_int_1d_const & _ex1_group,const typename AT::t_int_1d_const & _ex2_group,const typename AT::t_int_1d_const & _ex1_bit,const typename AT::t_int_1d_const & _ex2_bit,const int & _nex_mol,const typename AT::t_int_1d_const & _ex_mol_group,const typename AT::t_int_1d_const & _ex_mol_bit,const typename AT::t_int_1d_const & _ex_mol_intra,const X_FLOAT * _bboxhi,const X_FLOAT * _bboxlo,const int & _xperiodic,const int & _yperiodic,const int & _zperiodic,const int & _xprd_half,const int & _yprd_half,const int & _zprd_half)191   NPairSSAKokkosExecute(
192         const NeighListKokkos<DeviceType> &_neigh_list,
193         const typename AT::t_xfloat_2d_randomread &_cutneighsq,
194         const typename AT::t_int_1d &_bincount,
195         const typename AT::t_int_2d &_bins,
196         const typename AT::t_int_1d &_gbincount,
197         const typename AT::t_int_2d &_gbins,
198         const int _lbinxlo, const int _lbinxhi,
199         const int _lbinylo, const int _lbinyhi,
200         const int _lbinzlo, const int _lbinzhi,
201         const int _nstencil, const int _sx1, const int _sy1, const int _sz1,
202         const typename AT::t_int_1d &_d_stencil,
203         const typename AT::t_int_1d_3 &_d_stencilxyz,
204         const typename AT::t_int_1d &_d_nstencil_ssa,
205         const int _ssa_phaseCt,
206         const typename AT::t_int_1d &_d_ssa_phaseLen,
207         const typename AT::t_int_1d_3 &_d_ssa_phaseOff,
208         const typename AT::t_int_2d &_d_ssa_itemLoc,
209         const typename AT::t_int_2d &_d_ssa_itemLen,
210         const int _ssa_gphaseCt,
211         const typename AT::t_int_1d &_d_ssa_gphaseLen,
212         const typename AT::t_int_2d &_d_ssa_gitemLoc,
213         const typename AT::t_int_2d &_d_ssa_gitemLen,
214         const int _nlocal,
215         const typename AT::t_x_array_randomread &_x,
216         const typename AT::t_int_1d_const &_type,
217         const typename AT::t_int_1d_const &_mask,
218         const typename AT::t_tagint_1d_const &_molecule,
219         const typename AT::t_tagint_1d_const &_tag,
220         const typename AT::t_tagint_2d_const &_special,
221         const typename AT::t_int_2d_const &_nspecial,
222         const int &_molecular,
223         const int & _nbinx,const int & _nbiny,const int & _nbinz,
224         const int & _mbinx,const int & _mbiny,const int & _mbinz,
225         const int & _mbinxlo,const int & _mbinylo,const int & _mbinzlo,
226         const X_FLOAT &_bininvx,const X_FLOAT &_bininvy,const X_FLOAT &_bininvz,
227         const int & _exclude,const int & _nex_type,
228         const typename AT::t_int_1d_const & _ex1_type,
229         const typename AT::t_int_1d_const & _ex2_type,
230         const typename AT::t_int_2d_const & _ex_type,
231         const int & _nex_group,
232         const typename AT::t_int_1d_const & _ex1_group,
233         const typename AT::t_int_1d_const & _ex2_group,
234         const typename AT::t_int_1d_const & _ex1_bit,
235         const typename AT::t_int_1d_const & _ex2_bit,
236         const int & _nex_mol,
237         const typename AT::t_int_1d_const & _ex_mol_group,
238         const typename AT::t_int_1d_const & _ex_mol_bit,
239         const typename AT::t_int_1d_const & _ex_mol_intra,
240         const X_FLOAT *_bboxhi, const X_FLOAT* _bboxlo,
241         const int & _xperiodic, const int & _yperiodic, const int & _zperiodic,
242         const int & _xprd_half, const int & _yprd_half, const int & _zprd_half):
243     neigh_list(_neigh_list), cutneighsq(_cutneighsq),
244     exclude(_exclude),nex_type(_nex_type),
245     ex1_type(_ex1_type),ex2_type(_ex2_type),ex_type(_ex_type),
246     nex_group(_nex_group),
247     ex1_group(_ex1_group),ex2_group(_ex2_group),
248     ex1_bit(_ex1_bit),ex2_bit(_ex2_bit),nex_mol(_nex_mol),
249     ex_mol_group(_ex_mol_group),ex_mol_bit(_ex_mol_bit),
250     ex_mol_intra(_ex_mol_intra),
251     bincount(_bincount),c_bincount(_bincount),bins(_bins),c_bins(_bins),
252     gbincount(_gbincount),c_gbincount(_gbincount),gbins(_gbins),c_gbins(_gbins),
253     lbinxlo(_lbinxlo),lbinxhi(_lbinxhi),
254     lbinylo(_lbinylo),lbinyhi(_lbinyhi),
255     lbinzlo(_lbinzlo),lbinzhi(_lbinzhi),
256     nstencil(_nstencil),sx1(_sx1),sy1(_sy1),sz1(_sz1),
257     d_stencil(_d_stencil),d_stencilxyz(_d_stencilxyz),d_nstencil_ssa(_d_nstencil_ssa),
258     x(_x),type(_type),mask(_mask),molecule(_molecule),
259     tag(_tag),special(_special),nspecial(_nspecial),molecular(_molecular),
260     nbinx(_nbinx),nbiny(_nbiny),nbinz(_nbinz),
261     mbinx(_mbinx),mbiny(_mbiny),mbinz(_mbinz),
262     mbinxlo(_mbinxlo),mbinylo(_mbinylo),mbinzlo(_mbinzlo),
263     bininvx(_bininvx),bininvy(_bininvy),bininvz(_bininvz),
264     nlocal(_nlocal),
265     xperiodic(_xperiodic),yperiodic(_yperiodic),zperiodic(_zperiodic),
266     xprd_half(_xprd_half),yprd_half(_yprd_half),zprd_half(_zprd_half),
267     ssa_phaseCt(_ssa_phaseCt),
268     d_ssa_phaseLen(_d_ssa_phaseLen),
269     d_ssa_phaseOff(_d_ssa_phaseOff),
270     d_ssa_itemLoc(_d_ssa_itemLoc),
271     d_ssa_itemLen(_d_ssa_itemLen),
272     ssa_gphaseCt(_ssa_gphaseCt),
273     d_ssa_gphaseLen(_d_ssa_gphaseLen),
274     d_ssa_gitemLoc(_d_ssa_gitemLoc),
275     d_ssa_gitemLen(_d_ssa_gitemLen)
276     {
277 
278     if (molecular == 2) moltemplate = 1;
279     else moltemplate = 0;
280 
281     bboxlo[0] = _bboxlo[0]; bboxlo[1] = _bboxlo[1]; bboxlo[2] = _bboxlo[2];
282     bboxhi[0] = _bboxhi[0]; bboxhi[1] = _bboxhi[1]; bboxhi[2] = _bboxhi[2];
283 
284     resize = typename AT::t_int_scalar("NPairSSAKokkosExecute::resize");
285     h_resize = Kokkos::create_mirror_view(resize);
286     h_resize() = 1;
287     new_maxneighs = typename AT::
288       t_int_scalar("NPairSSAKokkosExecute::new_maxneighs");
289     h_new_maxneighs = Kokkos::create_mirror_view(new_maxneighs);
290     h_new_maxneighs() = neigh_list.maxneighs;
291   };
292 
~NPairSSAKokkosExecute()293   ~NPairSSAKokkosExecute() {neigh_list.copymode = 1;};
294 
295   KOKKOS_FUNCTION
296   void build_locals_onePhase(const bool firstTry, int me, int workPhase) const;
297 
298   KOKKOS_FUNCTION
299   void build_ghosts_onePhase(int workPhase) const;
300 
301   KOKKOS_INLINE_FUNCTION
coord2bin(const X_FLOAT & x,const X_FLOAT & y,const X_FLOAT & z,int * i)302   int coord2bin(const X_FLOAT & x,const X_FLOAT & y,const X_FLOAT & z, int* i) const
303   {
304     int ix,iy,iz;
305 
306     if (x >= bboxhi[0])
307       ix = static_cast<int> ((x-bboxhi[0])*bininvx) + nbinx;
308     else if (x >= bboxlo[0]) {
309       ix = static_cast<int> ((x-bboxlo[0])*bininvx);
310       ix = MIN(ix,nbinx-1);
311     } else
312       ix = static_cast<int> ((x-bboxlo[0])*bininvx) - 1;
313 
314     if (y >= bboxhi[1])
315       iy = static_cast<int> ((y-bboxhi[1])*bininvy) + nbiny;
316     else if (y >= bboxlo[1]) {
317       iy = static_cast<int> ((y-bboxlo[1])*bininvy);
318       iy = MIN(iy,nbiny-1);
319     } else
320       iy = static_cast<int> ((y-bboxlo[1])*bininvy) - 1;
321 
322     if (z >= bboxhi[2])
323       iz = static_cast<int> ((z-bboxhi[2])*bininvz) + nbinz;
324     else if (z >= bboxlo[2]) {
325       iz = static_cast<int> ((z-bboxlo[2])*bininvz);
326       iz = MIN(iz,nbinz-1);
327     } else
328       iz = static_cast<int> ((z-bboxlo[2])*bininvz) - 1;
329 
330     i[0] = ix - mbinxlo;
331     i[1] = iy - mbinylo;
332     i[2] = iz - mbinzlo;
333 
334     return (iz-mbinzlo)*mbiny*mbinx + (iy-mbinylo)*mbinx + (ix-mbinxlo);
335   }
336 
337   KOKKOS_INLINE_FUNCTION
338   int exclusion(const int &i,const int &j, const int &itype,const int &jtype) const;
339 
340   KOKKOS_INLINE_FUNCTION
341   int find_special(const int &i, const int &j) const;
342 
343   KOKKOS_INLINE_FUNCTION
minimum_image_check(double dx,double dy,double dz)344   int minimum_image_check(double dx, double dy, double dz) const {
345     if (xperiodic && fabs(dx) > xprd_half) return 1;
346     if (yperiodic && fabs(dy) > yprd_half) return 1;
347     if (zperiodic && fabs(dz) > zprd_half) return 1;
348     return 0;
349   }
350 
351 };
352 
353 }
354 
355 #endif
356 #endif
357 
358 /* ERROR/WARNING messages:
359 
360 */
361