1 /* -*- c++ -*- ---------------------------------------------------------- 2 LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator 3 https://www.lammps.org/, Sandia National Laboratories 4 Steve Plimpton, sjplimp@sandia.gov 5 6 Copyright (2003) Sandia Corporation. Under the terms of Contract 7 DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains 8 certain rights in this software. This software is distributed under 9 the GNU General Public License. 10 11 See the README file in the top-level LAMMPS directory. 12 ------------------------------------------------------------------------- */ 13 14 #ifdef NPAIR_CLASS 15 // clang-format off 16 typedef NPairSSAKokkos<LMPHostType> NPairSSAKokkosHost; 17 NPairStyle(half/bin/newton/ssa/kk/host, 18 NPairSSAKokkosHost, 19 NP_HALF | NP_BIN | NP_NEWTON | NP_ORTHO | NP_SSA | NP_GHOST | NP_KOKKOS_HOST); 20 21 typedef NPairSSAKokkos<LMPDeviceType> NPairSSAKokkosDevice; 22 NPairStyle(half/bin/newton/ssa/kk/device, 23 NPairSSAKokkosDevice, 24 NP_HALF | NP_BIN | NP_NEWTON | NP_ORTHO | NP_SSA | NP_GHOST | NP_KOKKOS_DEVICE); 25 // clang-format on 26 #else 27 28 // clang-format off 29 #ifndef LMP_NPAIR_SSA_KOKKOS_H 30 #define LMP_NPAIR_SSA_KOKKOS_H 31 32 #include "npair.h" 33 #include "neigh_list_kokkos.h" 34 35 namespace LAMMPS_NS { 36 37 template<class DeviceType> 38 class NPairSSAKokkos : public NPair { 39 public: 40 typedef ArrayTypes<DeviceType> AT; 41 42 // SSA Work plan data structures 43 int ssa_phaseCt; 44 DAT::tdual_int_1d k_ssa_phaseLen; 45 DAT::tdual_int_1d_3 k_ssa_phaseOff; 46 DAT::tdual_int_2d k_ssa_itemLoc; 47 DAT::tdual_int_2d k_ssa_itemLen; 48 typename AT::t_int_1d ssa_phaseLen; 49 typename AT::t_int_1d_3 ssa_phaseOff; 50 typename AT::t_int_2d ssa_itemLoc; 51 typename AT::t_int_2d ssa_itemLen; 52 53 const int ssa_gphaseCt; 54 DAT::tdual_int_1d k_ssa_gphaseLen; 55 DAT::tdual_int_2d k_ssa_gitemLoc; 56 DAT::tdual_int_2d k_ssa_gitemLen; 57 typename AT::t_int_1d ssa_gphaseLen; 58 typename AT::t_int_2d ssa_gitemLoc; 59 typename AT::t_int_2d ssa_gitemLen; 60 61 NPairSSAKokkos(class LAMMPS *); ~NPairSSAKokkos()62 ~NPairSSAKokkos() {} 63 void copy_neighbor_info(); 64 void copy_bin_info(); 65 void copy_stencil_info(); 66 void build(class NeighList *); 67 private: 68 // data from Neighbor class 69 70 DAT::tdual_xfloat_2d k_cutneighsq; 71 72 // exclusion data from Neighbor class 73 74 DAT::tdual_int_1d k_ex1_type,k_ex2_type; 75 DAT::tdual_int_2d k_ex_type; 76 DAT::tdual_int_1d k_ex1_group,k_ex2_group; 77 DAT::tdual_int_1d k_ex1_bit,k_ex2_bit; 78 DAT::tdual_int_1d k_ex_mol_group; 79 DAT::tdual_int_1d k_ex_mol_bit; 80 DAT::tdual_int_1d k_ex_mol_intra; 81 82 // data from NBinSSA class 83 84 int atoms_per_bin; 85 DAT::tdual_int_1d k_bincount; 86 DAT::tdual_int_2d k_bins; 87 int ghosts_per_gbin; 88 DAT::tdual_int_1d k_gbincount; 89 DAT::tdual_int_2d k_gbins; 90 int lbinxlo, lbinxhi, lbinylo, lbinyhi, lbinzlo, lbinzhi; 91 92 // data from NStencilSSA class 93 94 int nstencil; 95 DAT::tdual_int_1d k_stencil; // # of J neighs for each I 96 DAT::tdual_int_1d_3 k_stencilxyz; 97 DAT::tdual_int_1d k_nstencil_ssa; 98 int sx1, sy1, sz1; 99 }; 100 101 template<class DeviceType> 102 class NPairSSAKokkosExecute 103 { 104 typedef ArrayTypes<DeviceType> AT; 105 106 public: 107 NeighListKokkos<DeviceType> neigh_list; 108 109 // data from Neighbor class 110 111 const typename AT::t_xfloat_2d_randomread cutneighsq; 112 113 // exclusion data from Neighbor class 114 115 const int exclude; 116 117 const int nex_type; 118 const typename AT::t_int_1d_const ex1_type,ex2_type; 119 const typename AT::t_int_2d_const ex_type; 120 121 const int nex_group; 122 const typename AT::t_int_1d_const ex1_group,ex2_group; 123 const typename AT::t_int_1d_const ex1_bit,ex2_bit; 124 125 const int nex_mol; 126 const typename AT::t_int_1d_const ex_mol_group; 127 const typename AT::t_int_1d_const ex_mol_bit; 128 const typename AT::t_int_1d_const ex_mol_intra; 129 130 // data from NBinSSA class 131 132 const typename AT::t_int_1d bincount; 133 const typename AT::t_int_1d_const c_bincount; 134 typename AT::t_int_2d bins; 135 typename AT::t_int_2d_const c_bins; 136 const typename AT::t_int_1d gbincount; 137 const typename AT::t_int_1d_const c_gbincount; 138 typename AT::t_int_2d gbins; 139 typename AT::t_int_2d_const c_gbins; 140 const int lbinxlo, lbinxhi, lbinylo, lbinyhi, lbinzlo, lbinzhi; 141 142 143 // data from NStencil class 144 145 const int nstencil; 146 const int sx1, sy1, sz1; 147 typename AT::t_int_1d d_stencil; // # of J neighs for each I 148 typename AT::t_int_1d_3 d_stencilxyz; 149 typename AT::t_int_1d d_nstencil_ssa; 150 151 // data from Atom class 152 153 const typename AT::t_x_array_randomread x; 154 const typename AT::t_int_1d_const type,mask; 155 const typename AT::t_tagint_1d_const molecule; 156 const typename AT::t_tagint_1d_const tag; 157 const typename AT::t_tagint_2d_const special; 158 const typename AT::t_int_2d_const nspecial; 159 const int molecular; 160 int moltemplate; 161 162 int special_flag[4]; 163 164 const int nbinx,nbiny,nbinz; 165 const int mbinx,mbiny,mbinz; 166 const int mbinxlo,mbinylo,mbinzlo; 167 const X_FLOAT bininvx,bininvy,bininvz; 168 X_FLOAT bboxhi[3],bboxlo[3]; 169 170 const int nlocal; 171 172 typename AT::t_int_scalar resize; 173 typename AT::t_int_scalar new_maxneighs; 174 typename ArrayTypes<LMPHostType>::t_int_scalar h_resize; 175 typename ArrayTypes<LMPHostType>::t_int_scalar h_new_maxneighs; 176 177 const int xperiodic, yperiodic, zperiodic; 178 const int xprd_half, yprd_half, zprd_half; 179 180 // SSA Work plan data structures 181 int ssa_phaseCt; 182 typename AT::t_int_1d d_ssa_phaseLen; 183 typename AT::t_int_1d_3_const d_ssa_phaseOff; 184 typename AT::t_int_2d d_ssa_itemLoc; 185 typename AT::t_int_2d d_ssa_itemLen; 186 int ssa_gphaseCt; 187 typename AT::t_int_1d d_ssa_gphaseLen; 188 typename AT::t_int_2d d_ssa_gitemLoc; 189 typename AT::t_int_2d d_ssa_gitemLen; 190 NPairSSAKokkosExecute(const NeighListKokkos<DeviceType> & _neigh_list,const typename AT::t_xfloat_2d_randomread & _cutneighsq,const typename AT::t_int_1d & _bincount,const typename AT::t_int_2d & _bins,const typename AT::t_int_1d & _gbincount,const typename AT::t_int_2d & _gbins,const int _lbinxlo,const int _lbinxhi,const int _lbinylo,const int _lbinyhi,const int _lbinzlo,const int _lbinzhi,const int _nstencil,const int _sx1,const int _sy1,const int _sz1,const typename AT::t_int_1d & _d_stencil,const typename AT::t_int_1d_3 & _d_stencilxyz,const typename AT::t_int_1d & _d_nstencil_ssa,const int _ssa_phaseCt,const typename AT::t_int_1d & _d_ssa_phaseLen,const typename AT::t_int_1d_3 & _d_ssa_phaseOff,const typename AT::t_int_2d & _d_ssa_itemLoc,const typename AT::t_int_2d & _d_ssa_itemLen,const int _ssa_gphaseCt,const typename AT::t_int_1d & _d_ssa_gphaseLen,const typename AT::t_int_2d & _d_ssa_gitemLoc,const typename AT::t_int_2d & _d_ssa_gitemLen,const int _nlocal,const typename AT::t_x_array_randomread & _x,const typename AT::t_int_1d_const & _type,const typename AT::t_int_1d_const & _mask,const typename AT::t_tagint_1d_const & _molecule,const typename AT::t_tagint_1d_const & _tag,const typename AT::t_tagint_2d_const & _special,const typename AT::t_int_2d_const & _nspecial,const int & _molecular,const int & _nbinx,const int & _nbiny,const int & _nbinz,const int & _mbinx,const int & _mbiny,const int & _mbinz,const int & _mbinxlo,const int & _mbinylo,const int & _mbinzlo,const X_FLOAT & _bininvx,const X_FLOAT & _bininvy,const X_FLOAT & _bininvz,const int & _exclude,const int & _nex_type,const typename AT::t_int_1d_const & _ex1_type,const typename AT::t_int_1d_const & _ex2_type,const typename AT::t_int_2d_const & _ex_type,const int & _nex_group,const typename AT::t_int_1d_const & _ex1_group,const typename AT::t_int_1d_const & _ex2_group,const typename AT::t_int_1d_const & _ex1_bit,const typename AT::t_int_1d_const & _ex2_bit,const int & _nex_mol,const typename AT::t_int_1d_const & _ex_mol_group,const typename AT::t_int_1d_const & _ex_mol_bit,const typename AT::t_int_1d_const & _ex_mol_intra,const X_FLOAT * _bboxhi,const X_FLOAT * _bboxlo,const int & _xperiodic,const int & _yperiodic,const int & _zperiodic,const int & _xprd_half,const int & _yprd_half,const int & _zprd_half)191 NPairSSAKokkosExecute( 192 const NeighListKokkos<DeviceType> &_neigh_list, 193 const typename AT::t_xfloat_2d_randomread &_cutneighsq, 194 const typename AT::t_int_1d &_bincount, 195 const typename AT::t_int_2d &_bins, 196 const typename AT::t_int_1d &_gbincount, 197 const typename AT::t_int_2d &_gbins, 198 const int _lbinxlo, const int _lbinxhi, 199 const int _lbinylo, const int _lbinyhi, 200 const int _lbinzlo, const int _lbinzhi, 201 const int _nstencil, const int _sx1, const int _sy1, const int _sz1, 202 const typename AT::t_int_1d &_d_stencil, 203 const typename AT::t_int_1d_3 &_d_stencilxyz, 204 const typename AT::t_int_1d &_d_nstencil_ssa, 205 const int _ssa_phaseCt, 206 const typename AT::t_int_1d &_d_ssa_phaseLen, 207 const typename AT::t_int_1d_3 &_d_ssa_phaseOff, 208 const typename AT::t_int_2d &_d_ssa_itemLoc, 209 const typename AT::t_int_2d &_d_ssa_itemLen, 210 const int _ssa_gphaseCt, 211 const typename AT::t_int_1d &_d_ssa_gphaseLen, 212 const typename AT::t_int_2d &_d_ssa_gitemLoc, 213 const typename AT::t_int_2d &_d_ssa_gitemLen, 214 const int _nlocal, 215 const typename AT::t_x_array_randomread &_x, 216 const typename AT::t_int_1d_const &_type, 217 const typename AT::t_int_1d_const &_mask, 218 const typename AT::t_tagint_1d_const &_molecule, 219 const typename AT::t_tagint_1d_const &_tag, 220 const typename AT::t_tagint_2d_const &_special, 221 const typename AT::t_int_2d_const &_nspecial, 222 const int &_molecular, 223 const int & _nbinx,const int & _nbiny,const int & _nbinz, 224 const int & _mbinx,const int & _mbiny,const int & _mbinz, 225 const int & _mbinxlo,const int & _mbinylo,const int & _mbinzlo, 226 const X_FLOAT &_bininvx,const X_FLOAT &_bininvy,const X_FLOAT &_bininvz, 227 const int & _exclude,const int & _nex_type, 228 const typename AT::t_int_1d_const & _ex1_type, 229 const typename AT::t_int_1d_const & _ex2_type, 230 const typename AT::t_int_2d_const & _ex_type, 231 const int & _nex_group, 232 const typename AT::t_int_1d_const & _ex1_group, 233 const typename AT::t_int_1d_const & _ex2_group, 234 const typename AT::t_int_1d_const & _ex1_bit, 235 const typename AT::t_int_1d_const & _ex2_bit, 236 const int & _nex_mol, 237 const typename AT::t_int_1d_const & _ex_mol_group, 238 const typename AT::t_int_1d_const & _ex_mol_bit, 239 const typename AT::t_int_1d_const & _ex_mol_intra, 240 const X_FLOAT *_bboxhi, const X_FLOAT* _bboxlo, 241 const int & _xperiodic, const int & _yperiodic, const int & _zperiodic, 242 const int & _xprd_half, const int & _yprd_half, const int & _zprd_half): 243 neigh_list(_neigh_list), cutneighsq(_cutneighsq), 244 exclude(_exclude),nex_type(_nex_type), 245 ex1_type(_ex1_type),ex2_type(_ex2_type),ex_type(_ex_type), 246 nex_group(_nex_group), 247 ex1_group(_ex1_group),ex2_group(_ex2_group), 248 ex1_bit(_ex1_bit),ex2_bit(_ex2_bit),nex_mol(_nex_mol), 249 ex_mol_group(_ex_mol_group),ex_mol_bit(_ex_mol_bit), 250 ex_mol_intra(_ex_mol_intra), 251 bincount(_bincount),c_bincount(_bincount),bins(_bins),c_bins(_bins), 252 gbincount(_gbincount),c_gbincount(_gbincount),gbins(_gbins),c_gbins(_gbins), 253 lbinxlo(_lbinxlo),lbinxhi(_lbinxhi), 254 lbinylo(_lbinylo),lbinyhi(_lbinyhi), 255 lbinzlo(_lbinzlo),lbinzhi(_lbinzhi), 256 nstencil(_nstencil),sx1(_sx1),sy1(_sy1),sz1(_sz1), 257 d_stencil(_d_stencil),d_stencilxyz(_d_stencilxyz),d_nstencil_ssa(_d_nstencil_ssa), 258 x(_x),type(_type),mask(_mask),molecule(_molecule), 259 tag(_tag),special(_special),nspecial(_nspecial),molecular(_molecular), 260 nbinx(_nbinx),nbiny(_nbiny),nbinz(_nbinz), 261 mbinx(_mbinx),mbiny(_mbiny),mbinz(_mbinz), 262 mbinxlo(_mbinxlo),mbinylo(_mbinylo),mbinzlo(_mbinzlo), 263 bininvx(_bininvx),bininvy(_bininvy),bininvz(_bininvz), 264 nlocal(_nlocal), 265 xperiodic(_xperiodic),yperiodic(_yperiodic),zperiodic(_zperiodic), 266 xprd_half(_xprd_half),yprd_half(_yprd_half),zprd_half(_zprd_half), 267 ssa_phaseCt(_ssa_phaseCt), 268 d_ssa_phaseLen(_d_ssa_phaseLen), 269 d_ssa_phaseOff(_d_ssa_phaseOff), 270 d_ssa_itemLoc(_d_ssa_itemLoc), 271 d_ssa_itemLen(_d_ssa_itemLen), 272 ssa_gphaseCt(_ssa_gphaseCt), 273 d_ssa_gphaseLen(_d_ssa_gphaseLen), 274 d_ssa_gitemLoc(_d_ssa_gitemLoc), 275 d_ssa_gitemLen(_d_ssa_gitemLen) 276 { 277 278 if (molecular == 2) moltemplate = 1; 279 else moltemplate = 0; 280 281 bboxlo[0] = _bboxlo[0]; bboxlo[1] = _bboxlo[1]; bboxlo[2] = _bboxlo[2]; 282 bboxhi[0] = _bboxhi[0]; bboxhi[1] = _bboxhi[1]; bboxhi[2] = _bboxhi[2]; 283 284 resize = typename AT::t_int_scalar("NPairSSAKokkosExecute::resize"); 285 h_resize = Kokkos::create_mirror_view(resize); 286 h_resize() = 1; 287 new_maxneighs = typename AT:: 288 t_int_scalar("NPairSSAKokkosExecute::new_maxneighs"); 289 h_new_maxneighs = Kokkos::create_mirror_view(new_maxneighs); 290 h_new_maxneighs() = neigh_list.maxneighs; 291 }; 292 ~NPairSSAKokkosExecute()293 ~NPairSSAKokkosExecute() {neigh_list.copymode = 1;}; 294 295 KOKKOS_FUNCTION 296 void build_locals_onePhase(const bool firstTry, int me, int workPhase) const; 297 298 KOKKOS_FUNCTION 299 void build_ghosts_onePhase(int workPhase) const; 300 301 KOKKOS_INLINE_FUNCTION coord2bin(const X_FLOAT & x,const X_FLOAT & y,const X_FLOAT & z,int * i)302 int coord2bin(const X_FLOAT & x,const X_FLOAT & y,const X_FLOAT & z, int* i) const 303 { 304 int ix,iy,iz; 305 306 if (x >= bboxhi[0]) 307 ix = static_cast<int> ((x-bboxhi[0])*bininvx) + nbinx; 308 else if (x >= bboxlo[0]) { 309 ix = static_cast<int> ((x-bboxlo[0])*bininvx); 310 ix = MIN(ix,nbinx-1); 311 } else 312 ix = static_cast<int> ((x-bboxlo[0])*bininvx) - 1; 313 314 if (y >= bboxhi[1]) 315 iy = static_cast<int> ((y-bboxhi[1])*bininvy) + nbiny; 316 else if (y >= bboxlo[1]) { 317 iy = static_cast<int> ((y-bboxlo[1])*bininvy); 318 iy = MIN(iy,nbiny-1); 319 } else 320 iy = static_cast<int> ((y-bboxlo[1])*bininvy) - 1; 321 322 if (z >= bboxhi[2]) 323 iz = static_cast<int> ((z-bboxhi[2])*bininvz) + nbinz; 324 else if (z >= bboxlo[2]) { 325 iz = static_cast<int> ((z-bboxlo[2])*bininvz); 326 iz = MIN(iz,nbinz-1); 327 } else 328 iz = static_cast<int> ((z-bboxlo[2])*bininvz) - 1; 329 330 i[0] = ix - mbinxlo; 331 i[1] = iy - mbinylo; 332 i[2] = iz - mbinzlo; 333 334 return (iz-mbinzlo)*mbiny*mbinx + (iy-mbinylo)*mbinx + (ix-mbinxlo); 335 } 336 337 KOKKOS_INLINE_FUNCTION 338 int exclusion(const int &i,const int &j, const int &itype,const int &jtype) const; 339 340 KOKKOS_INLINE_FUNCTION 341 int find_special(const int &i, const int &j) const; 342 343 KOKKOS_INLINE_FUNCTION minimum_image_check(double dx,double dy,double dz)344 int minimum_image_check(double dx, double dy, double dz) const { 345 if (xperiodic && fabs(dx) > xprd_half) return 1; 346 if (yperiodic && fabs(dy) > yprd_half) return 1; 347 if (zperiodic && fabs(dz) > zprd_half) return 1; 348 return 0; 349 } 350 351 }; 352 353 } 354 355 #endif 356 #endif 357 358 /* ERROR/WARNING messages: 359 360 */ 361