1 // clang-format off
2 /* ----------------------------------------------------------------------
3    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
4    https://www.lammps.org/, Sandia National Laboratories
5    Steve Plimpton, sjplimp@sandia.gov
6 
7    Copyright (2003) Sandia Corporation.  Under the terms of Contract
8    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
9    certain rights in this software.  This software is distributed under
10    the GNU General Public License.
11 
12    See the README file in the top-level LAMMPS directory.
13 ------------------------------------------------------------------------- */
14 
15 #include <mpi.h>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <cstdio>
20 #include "fix_shake_kokkos.h"
21 #include "fix_rattle.h"
22 #include "atom_kokkos.h"
23 #include "atom_vec.h"
24 #include "molecule.h"
25 #include "update.h"
26 #include "respa.h"
27 #include "modify.h"
28 #include "domain.h"
29 #include "force.h"
30 #include "bond.h"
31 #include "angle.h"
32 #include "comm.h"
33 #include "group.h"
34 #include "fix_respa.h"
35 #include "math_const.h"
36 #include "memory_kokkos.h"
37 #include "error.h"
38 #include "kokkos.h"
39 #include "atom_masks.h"
40 
41 using namespace LAMMPS_NS;
42 using namespace FixConst;
43 using namespace MathConst;
44 
45 #define RVOUS 1   // 0 for irregular, 1 for all2all
46 
47 #define BIG 1.0e20
48 #define MASSDELTA 0.1
49 
50 /* ---------------------------------------------------------------------- */
51 
52 template<class DeviceType>
FixShakeKokkos(LAMMPS * lmp,int narg,char ** arg)53 FixShakeKokkos<DeviceType>::FixShakeKokkos(LAMMPS *lmp, int narg, char **arg) :
54   FixShake(lmp, narg, arg)
55 {
56   kokkosable = 1;
57   forward_comm_device = 1;
58   atomKK = (AtomKokkos *)atom;
59   execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
60 
61   datamask_read = EMPTY_MASK;
62   datamask_modify = EMPTY_MASK;
63 
64   shake_flag_tmp = shake_flag;
65   shake_atom_tmp = shake_atom;
66   shake_type_tmp = shake_type;
67 
68   shake_flag = nullptr;
69   shake_atom = nullptr;
70   shake_type = nullptr;
71 
72   int nmax = atom->nmax;
73 
74   grow_arrays(nmax);
75 
76   for (int i = 0; i < nmax; i++) {
77     k_shake_flag.h_view[i] = shake_flag_tmp[i];
78     k_shake_atom.h_view(i,0) = shake_atom_tmp[i][0];
79     k_shake_atom.h_view(i,1) = shake_atom_tmp[i][1];
80     k_shake_atom.h_view(i,2) = shake_atom_tmp[i][2];
81     k_shake_atom.h_view(i,3) = shake_atom_tmp[i][3];
82     k_shake_type.h_view(i,0) = shake_type_tmp[i][0];
83     k_shake_type.h_view(i,1) = shake_type_tmp[i][1];
84     k_shake_type.h_view(i,2) = shake_type_tmp[i][2];
85   }
86 
87   k_shake_flag.modify_host();
88   k_shake_atom.modify_host();
89   k_shake_type.modify_host();
90 
91   k_bond_distance = DAT::tdual_float_1d("fix_shake:bond_distance",atom->nbondtypes+1);
92   k_angle_distance = DAT::tdual_float_1d("fix_shake:angle_distance",atom->nangletypes+1);
93 
94   d_bond_distance = k_bond_distance.view<DeviceType>();
95   d_angle_distance = k_angle_distance.view<DeviceType>();
96 
97   // use 1D view for scalars to reduce GPU memory operations
98 
99   d_scalars = typename AT::t_int_1d("neighbor:scalars",2);
100   h_scalars = HAT::t_int_1d("neighbor:scalars_mirror",2);
101 
102   d_error_flag = Kokkos::subview(d_scalars,0);
103   d_nlist = Kokkos::subview(d_scalars,1);
104 
105   h_error_flag = Kokkos::subview(h_scalars,0);
106   h_nlist = Kokkos::subview(h_scalars,1);
107 
108   memory->destroy(shake_flag_tmp);
109   memory->destroy(shake_atom_tmp);
110   memory->destroy(shake_type_tmp);
111 }
112 
113 /* ---------------------------------------------------------------------- */
114 
115 template<class DeviceType>
~FixShakeKokkos()116 FixShakeKokkos<DeviceType>::~FixShakeKokkos()
117 {
118   if (copymode) return;
119 
120   k_shake_flag.sync_host();
121   k_shake_atom.sync_host();
122 
123   for (int i = 0; i < nlocal; i++) {
124     if (shake_flag[i] == 0) continue;
125     else if (shake_flag[i] == 1) {
126       bondtype_findset(i,shake_atom[i][0],shake_atom[i][1],1);
127       bondtype_findset(i,shake_atom[i][0],shake_atom[i][2],1);
128       angletype_findset(i,shake_atom[i][1],shake_atom[i][2],1);
129     } else if (shake_flag[i] == 2) {
130       bondtype_findset(i,shake_atom[i][0],shake_atom[i][1],1);
131     } else if (shake_flag[i] == 3) {
132       bondtype_findset(i,shake_atom[i][0],shake_atom[i][1],1);
133       bondtype_findset(i,shake_atom[i][0],shake_atom[i][2],1);
134     } else if (shake_flag[i] == 4) {
135       bondtype_findset(i,shake_atom[i][0],shake_atom[i][1],1);
136       bondtype_findset(i,shake_atom[i][0],shake_atom[i][2],1);
137       bondtype_findset(i,shake_atom[i][0],shake_atom[i][3],1);
138     }
139   }
140 
141   memoryKK->destroy_kokkos(k_shake_flag,shake_flag);
142   memoryKK->destroy_kokkos(k_shake_atom,shake_atom);
143   memoryKK->destroy_kokkos(k_shake_type,shake_type);
144   memoryKK->destroy_kokkos(k_xshake,xshake);
145   memoryKK->destroy_kokkos(k_list,list);
146 
147   memoryKK->destroy_kokkos(k_vatom,vatom);
148 }
149 
150 /* ----------------------------------------------------------------------
151    set bond and angle distances
152    this init must happen after force->bond and force->angle inits
153 ------------------------------------------------------------------------- */
154 
155 template<class DeviceType>
init()156 void FixShakeKokkos<DeviceType>::init()
157 {
158   FixShake::init();
159 
160   if (utils::strmatch(update->integrate_style,"^respa"))
161     error->all(FLERR,"Cannot yet use respa with Kokkos");
162 
163   if (rattle)
164     error->all(FLERR,"Cannot yet use KOKKOS package with fix rattle");
165 
166   // set equilibrium bond distances
167 
168   for (int i = 1; i <= atom->nbondtypes; i++)
169     k_bond_distance.h_view[i] = bond_distance[i];
170 
171   // set equilibrium angle distances
172 
173   for (int i = 1; i <= atom->nangletypes; i++)
174     k_angle_distance.h_view[i] = angle_distance[i];
175 
176   k_bond_distance.modify_host();
177   k_angle_distance.modify_host();
178 
179   k_bond_distance.sync<DeviceType>();
180   k_angle_distance.sync<DeviceType>();
181 }
182 
183 
184 /* ----------------------------------------------------------------------
185    build list of SHAKE clusters to constrain
186    if one or more atoms in cluster are on this proc,
187      this proc lists the cluster exactly once
188 ------------------------------------------------------------------------- */
189 
190 template<class DeviceType>
pre_neighbor()191 void FixShakeKokkos<DeviceType>::pre_neighbor()
192 {
193   // local copies of atom quantities
194   // used by SHAKE until next re-neighboring
195 
196   x = atom->x;
197   v = atom->v;
198   f = atom->f;
199   mass = atom->mass;
200   rmass = atom->rmass;
201   type = atom->type;
202   nlocal = atom->nlocal;
203 
204   map_style = atom->map_style;
205   if (map_style == Atom::MAP_ARRAY) {
206     k_map_array = atomKK->k_map_array;
207     k_map_array.template sync<DeviceType>();
208   } else if (map_style == Atom::MAP_HASH) {
209     k_map_hash = atomKK->k_map_hash;
210   }
211 
212   k_shake_flag.sync<DeviceType>();
213   k_shake_atom.sync<DeviceType>();
214 
215   // extend size of SHAKE list if necessary
216 
217   if (nlocal > maxlist) {
218     maxlist = nlocal;
219     memoryKK->destroy_kokkos(k_list,list);
220     memoryKK->create_kokkos(k_list,list,maxlist,"shake:list");
221     d_list = k_list.view<DeviceType>();
222   }
223 
224   // Atom Map
225 
226   map_style = atom->map_style;
227 
228   if (map_style == Atom::MAP_ARRAY) {
229     k_map_array = atomKK->k_map_array;
230     k_map_array.template sync<DeviceType>();
231   } else if (map_style == Atom::MAP_HASH) {
232     k_map_hash = atomKK->k_map_hash;
233   }
234 
235   // build list of SHAKE clusters I compute
236 
237   Kokkos::deep_copy(d_scalars,0);
238 
239   {
240     // local variables for lambda capture
241 
242     auto d_shake_flag = this->d_shake_flag;
243     auto d_shake_atom = this->d_shake_atom;
244     auto d_list = this->d_list;
245     auto d_error_flag = this->d_error_flag;
246     auto d_nlist = this->d_nlist;
247     auto map_style = atom->map_style;
248     auto k_map_array = this->k_map_array;
249     auto k_map_hash = this->k_map_hash;
250 
251     Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal),
252      LAMMPS_LAMBDA(const int& i) {
253       if (d_shake_flag[i]) {
254         if (d_shake_flag[i] == 2) {
255           const int atom1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,0),map_style,k_map_array,k_map_hash);
256           const int atom2 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,1),map_style,k_map_array,k_map_hash);
257           if (atom1 == -1 || atom2 == -1) {
258             d_error_flag() = 1;
259           }
260           if (i <= atom1 && i <= atom2) {
261             const int nlist = Kokkos::atomic_fetch_add(&d_nlist(),1);
262             d_list[nlist] = i;
263           }
264         } else if (d_shake_flag[i] % 2 == 1) {
265           const int atom1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,0),map_style,k_map_array,k_map_hash);
266           const int atom2 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,1),map_style,k_map_array,k_map_hash);
267           const int atom3 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,2),map_style,k_map_array,k_map_hash);
268           if (atom1 == -1 || atom2 == -1 || atom3 == -1)
269             d_error_flag() = 1;
270           if (i <= atom1 && i <= atom2 && i <= atom3) {
271             const int nlist = Kokkos::atomic_fetch_add(&d_nlist(),1);
272             d_list[nlist] = i;
273           }
274         } else {
275           const int atom1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,0),map_style,k_map_array,k_map_hash);
276           const int atom2 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,1),map_style,k_map_array,k_map_hash);
277           const int atom3 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,2),map_style,k_map_array,k_map_hash);
278           const int atom4 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(i,3),map_style,k_map_array,k_map_hash);
279           if (atom1 == -1 || atom2 == -1 || atom3 == -1 || atom4 == -1)
280             d_error_flag() = 1;
281           if (i <= atom1 && i <= atom2 && i <= atom3 && i <= atom4) {
282             const int nlist = Kokkos::atomic_fetch_add(&d_nlist(),1);
283             d_list[nlist] = i;
284           }
285         }
286       }
287     });
288   }
289 
290   Kokkos::deep_copy(h_scalars,d_scalars);
291   nlist = h_nlist();
292 
293   if (h_error_flag() == 1) {
294     error->one(FLERR,"Shake atoms missing on proc "
295                                  "{} at step {}",me,update->ntimestep);
296   }
297 }
298 
299 /* ----------------------------------------------------------------------
300    compute the force adjustment for SHAKE constraint
301 ------------------------------------------------------------------------- */
302 
303 template<class DeviceType>
post_force(int vflag)304 void FixShakeKokkos<DeviceType>::post_force(int vflag)
305 {
306   copymode = 1;
307 
308   d_x = atomKK->k_x.view<DeviceType>();
309   d_f = atomKK->k_f.view<DeviceType>();
310   d_type = atomKK->k_type.view<DeviceType>();
311   d_rmass = atomKK->k_rmass.view<DeviceType>();
312   d_mass = atomKK->k_mass.view<DeviceType>();
313   nlocal = atomKK->nlocal;
314 
315   map_style = atom->map_style;
316   if (map_style == Atom::MAP_ARRAY) {
317     k_map_array = atomKK->k_map_array;
318     k_map_array.template sync<DeviceType>();
319   } else if (map_style == Atom::MAP_HASH) {
320     k_map_hash = atomKK->k_map_hash;
321   }
322 
323   if (d_rmass.data())
324     atomKK->sync(execution_space,X_MASK|F_MASK|RMASS_MASK);
325   else
326     atomKK->sync(execution_space,X_MASK|F_MASK|TYPE_MASK);
327 
328   k_shake_flag.sync<DeviceType>();
329   k_shake_atom.sync<DeviceType>();
330   k_shake_type.sync<DeviceType>();
331 
332   if (update->ntimestep == next_output) {
333     atomKK->sync(Host,X_MASK);
334     k_shake_flag.sync_host();
335     k_shake_atom.sync_host();
336     k_shake_type.sync_host();
337     stats();
338   }
339 
340   // xshake = unconstrained move with current v,f
341   // communicate results if necessary
342 
343   unconstrained_update();
344   if (nprocs > 1) comm->forward_comm_fix(this);
345   k_xshake.sync<DeviceType>();
346 
347   // virial setup
348 
349   v_init(vflag);
350 
351   // reallocate per-atom arrays if necessary
352 
353   if (vflag_atom) {
354     memoryKK->destroy_kokkos(k_vatom,vatom);
355     memoryKK->create_kokkos(k_vatom,vatom,maxvatom,"improper:vatom");
356     d_vatom = k_vatom.template view<KKDeviceType>();
357   }
358 
359 
360   neighflag = lmp->kokkos->neighflag;
361 
362   // FULL neighlist still needs atomics in fix shake
363 
364   if (neighflag == FULL) {
365     if (lmp->kokkos->nthreads > 1 || lmp->kokkos->ngpus > 0)
366       neighflag = HALFTHREAD;
367     else
368       neighflag = HALF;
369   }
370 
371   need_dup = 0;
372   if (neighflag != HALF)
373     need_dup = std::is_same<typename NeedDup<HALFTHREAD,DeviceType>::value,Kokkos::Experimental::ScatterDuplicated>::value;
374 
375   // allocate duplicated memory
376 
377   if (need_dup) {
378     dup_f            = Kokkos::Experimental::create_scatter_view<Kokkos::Experimental::ScatterSum, Kokkos::Experimental::ScatterDuplicated>(d_f);
379     dup_vatom        = Kokkos::Experimental::create_scatter_view<Kokkos::Experimental::ScatterSum, Kokkos::Experimental::ScatterDuplicated>(d_vatom);
380   } else {
381     ndup_f            = Kokkos::Experimental::create_scatter_view<Kokkos::Experimental::ScatterSum, Kokkos::Experimental::ScatterNonDuplicated>(d_f);
382     ndup_vatom        = Kokkos::Experimental::create_scatter_view<Kokkos::Experimental::ScatterSum, Kokkos::Experimental::ScatterNonDuplicated>(d_vatom);
383   }
384 
385   Kokkos::deep_copy(d_error_flag,0);
386 
387   update_domain_variables();
388 
389   EV_FLOAT ev;
390 
391   // loop over clusters to add constraint forces
392 
393   if (neighflag == HALF) {
394    if (evflag)
395       Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixShakePostForce<HALF,1> >(0,nlist),*this,ev);
396     else
397       Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagFixShakePostForce<HALF,0> >(0,nlist),*this);
398   } else {
399     if (evflag)
400       Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixShakePostForce<HALFTHREAD,1> >(0,nlist),*this,ev);
401     else
402       Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagFixShakePostForce<HALFTHREAD,0> >(0,nlist),*this);
403   }
404 
405   copymode = 0;
406 
407   Kokkos::deep_copy(h_error_flag,d_error_flag);
408 
409   if (h_error_flag() == 2)
410     error->warning(FLERR,"Shake determinant < 0.0");
411   else if (h_error_flag() == 3)
412     error->one(FLERR,"Shake determinant = 0.0");
413 
414   // store vflag for coordinate_constraints_end_of_step()
415 
416   vflag_post_force = vflag;
417 
418   // reduction over duplicated memory
419 
420   if (need_dup)
421     Kokkos::Experimental::contribute(d_f,dup_f);
422 
423   atomKK->modified(execution_space,F_MASK);
424 
425   if (vflag_global) {
426     virial[0] += ev.v[0];
427     virial[1] += ev.v[1];
428     virial[2] += ev.v[2];
429     virial[3] += ev.v[3];
430     virial[4] += ev.v[4];
431     virial[5] += ev.v[5];
432   }
433 
434   if (vflag_atom) {
435     k_vatom.template modify<DeviceType>();
436     k_vatom.template sync<LMPHostType>();
437   }
438 
439   // free duplicated memory
440 
441   if (need_dup) {
442     dup_f = decltype(dup_f)();
443     dup_vatom = decltype(dup_vatom)();
444   }
445 }
446 
447 /* ---------------------------------------------------------------------- */
448 
449 template<class DeviceType>
450 template<int NEIGHFLAG, int EVFLAG>
451 KOKKOS_INLINE_FUNCTION
operator ()(TagFixShakePostForce<NEIGHFLAG,EVFLAG>,const int & i,EV_FLOAT & ev) const452 void FixShakeKokkos<DeviceType>::operator()(TagFixShakePostForce<NEIGHFLAG,EVFLAG>, const int &i, EV_FLOAT& ev) const {
453   const int m = d_list[i];
454   if (d_shake_flag[m] == 2) shake<NEIGHFLAG,EVFLAG>(m,ev);
455   else if (d_shake_flag[m] == 3) shake3<NEIGHFLAG,EVFLAG>(m,ev);
456   else if (d_shake_flag[m] == 4) shake4<NEIGHFLAG,EVFLAG>(m,ev);
457   else shake3angle<NEIGHFLAG,EVFLAG>(m,ev);
458 }
459 
460 template<class DeviceType>
461 template<int NEIGHFLAG, int EVFLAG>
462 KOKKOS_INLINE_FUNCTION
operator ()(TagFixShakePostForce<NEIGHFLAG,EVFLAG>,const int & i) const463 void FixShakeKokkos<DeviceType>::operator()(TagFixShakePostForce<NEIGHFLAG,EVFLAG>, const int &i) const {
464   EV_FLOAT ev;
465   this->template operator()<NEIGHFLAG,EVFLAG>(TagFixShakePostForce<NEIGHFLAG,EVFLAG>(), i, ev);
466 }
467 
468 /* ----------------------------------------------------------------------
469    count # of degrees-of-freedom removed by SHAKE for atoms in igroup
470 ------------------------------------------------------------------------- */
471 
472 template<class DeviceType>
dof(int igroup)473 int FixShakeKokkos<DeviceType>::dof(int igroup)
474 {
475 
476   d_mask = atomKK->k_mask.view<DeviceType>();
477   d_tag = atomKK->k_tag.view<DeviceType>();
478   nlocal = atom->nlocal;
479 
480   atomKK->sync(execution_space,MASK_MASK|TAG_MASK);
481   k_shake_flag.sync<DeviceType>();
482   k_shake_atom.sync<DeviceType>();
483 
484   // count dof in a cluster if and only if
485   // the central atom is in group and atom i is the central atom
486 
487   int n = 0;
488   {
489     // local variables for lambda capture
490 
491     auto d_shake_flag = this->d_shake_flag;
492     auto d_shake_atom = this->d_shake_atom;
493     auto tag = this->d_tag;
494     auto mask = this->d_mask;
495     auto groupbit = group->bitmask[igroup];
496 
497     Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType>(0,nlocal),
498      LAMMPS_LAMBDA(const int& i, int& n) {
499       if (!(mask[i] & groupbit)) return;
500       if (d_shake_flag[i] == 0) return;
501       if (d_shake_atom(i,0) != tag[i]) return;
502       if (d_shake_flag[i] == 1) n += 3;
503       else if (d_shake_flag[i] == 2) n += 1;
504       else if (d_shake_flag[i] == 3) n += 2;
505       else if (d_shake_flag[i] == 4) n += 3;
506     },n);
507   }
508 
509   int nall;
510   MPI_Allreduce(&n,&nall,1,MPI_INT,MPI_SUM,world);
511   return nall;
512 }
513 
514 
515 /* ----------------------------------------------------------------------
516    assumes NVE update, seems to be accurate enough for NVT,NPT,NPH as well
517 ------------------------------------------------------------------------- */
518 
519 template<class DeviceType>
unconstrained_update()520 void FixShakeKokkos<DeviceType>::unconstrained_update()
521 {
522   d_x = atomKK->k_x.view<DeviceType>();
523   d_v = atomKK->k_v.view<DeviceType>();
524   d_f = atomKK->k_f.view<DeviceType>();
525   d_type = atomKK->k_type.view<DeviceType>();
526   d_rmass = atomKK->k_rmass.view<DeviceType>();
527   d_mass = atomKK->k_mass.view<DeviceType>();
528   nlocal = atom->nlocal;
529 
530   if (d_rmass.data())
531     atomKK->sync(execution_space,X_MASK|V_MASK|F_MASK|RMASS_MASK);
532   else
533     atomKK->sync(execution_space,X_MASK|V_MASK|F_MASK|TYPE_MASK);
534 
535 
536   k_shake_flag.sync<DeviceType>();
537   k_xshake.sync<DeviceType>();
538 
539   {
540     // local variables for lambda capture
541 
542     auto d_shake_flag = this->d_shake_flag;
543     auto d_xshake = this->d_xshake;
544     auto x = this->d_x;
545     auto v = this->d_v;
546     auto f = this->d_f;
547     auto dtfsq = this->dtfsq;
548     auto dtv = this->dtv;
549 
550     if (d_rmass.data()) {
551 
552       auto rmass = this->d_rmass;
553 
554       Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal),
555        LAMMPS_LAMBDA(const int& i) {
556         if (d_shake_flag[i]) {
557           const double dtfmsq = dtfsq / rmass[i];
558           d_xshake(i,0) = x(i,0) + dtv*v(i,0) + dtfmsq*f(i,0);
559           d_xshake(i,1) = x(i,1) + dtv*v(i,1) + dtfmsq*f(i,1);
560           d_xshake(i,2) = x(i,2) + dtv*v(i,2) + dtfmsq*f(i,2);
561         } else d_xshake(i,2) = d_xshake(i,1) = d_xshake(i,0) = 0.0;
562       });
563     } else {
564 
565       auto mass = this->d_mass;
566       auto type = this->d_type;
567 
568       Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal),
569        LAMMPS_LAMBDA(const int& i) {
570         if (d_shake_flag[i]) {
571           const double dtfmsq = dtfsq / mass[type[i]];
572           d_xshake(i,0) = x(i,0) + dtv*v(i,0) + dtfmsq*f(i,0);
573           d_xshake(i,1) = x(i,1) + dtv*v(i,1) + dtfmsq*f(i,1);
574           d_xshake(i,2) = x(i,2) + dtv*v(i,2) + dtfmsq*f(i,2);
575         } else d_xshake(i,2) = d_xshake(i,1) = d_xshake(i,0) = 0.0;
576       });
577     }
578   }
579 
580   k_xshake.modify<DeviceType>();
581 }
582 
583 /* ---------------------------------------------------------------------- */
584 
585 template<class DeviceType>
586 template<int NEIGHFLAG, int EVFLAG>
587 KOKKOS_INLINE_FUNCTION
shake(int m,EV_FLOAT & ev) const588 void FixShakeKokkos<DeviceType>::shake(int m, EV_FLOAT& ev) const
589 {
590 
591   // The f array is duplicated for OpenMP, atomic for CUDA, and neither for Serial
592 
593   auto v_f = ScatterViewHelper<typename NeedDup<NEIGHFLAG,DeviceType>::value,decltype(dup_f),decltype(ndup_f)>::get(dup_f,ndup_f);
594   auto a_f = v_f.template access<typename AtomicDup<NEIGHFLAG,DeviceType>::value>();
595 
596   int nlist,list[2];
597   double v[6];
598   double invmass0,invmass1;
599 
600   // local atom IDs and constraint distances
601 
602   int i0 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,0),map_style,k_map_array,k_map_hash);
603   int i1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,1),map_style,k_map_array,k_map_hash);
604   double bond1 = d_bond_distance[d_shake_type(m,0)];
605 
606   // r01 = distance vec between atoms, with PBC
607 
608   double r01[3];
609   r01[0] = d_x(i0,0) - d_x(i1,0);
610   r01[1] = d_x(i0,1) - d_x(i1,1);
611   r01[2] = d_x(i0,2) - d_x(i1,2);
612   minimum_image(r01);
613 
614   // s01 = distance vec after unconstrained update, with PBC
615   // use Domain::minimum_image_once(), not minimum_image()
616   // b/c xshake values might be huge, due to e.g. fix gcmc
617 
618   double s01[3];
619   s01[0] = d_xshake(i0,0) - d_xshake(i1,0);
620   s01[1] = d_xshake(i0,1) - d_xshake(i1,1);
621   s01[2] = d_xshake(i0,2) - d_xshake(i1,2);
622   minimum_image_once(s01);
623 
624   // scalar distances between atoms
625 
626   double r01sq = r01[0]*r01[0] + r01[1]*r01[1] + r01[2]*r01[2];
627   double s01sq = s01[0]*s01[0] + s01[1]*s01[1] + s01[2]*s01[2];
628 
629   // a,b,c = coeffs in quadratic equation for lamda
630 
631   if (d_rmass.data()) {
632     invmass0 = 1.0/d_rmass[i0];
633     invmass1 = 1.0/d_rmass[i1];
634   } else {
635     invmass0 = 1.0/d_mass[d_type[i0]];
636     invmass1 = 1.0/d_mass[d_type[i1]];
637   }
638 
639   double a = (invmass0+invmass1)*(invmass0+invmass1) * r01sq;
640   double b = 2.0 * (invmass0+invmass1) *
641     (s01[0]*r01[0] + s01[1]*r01[1] + s01[2]*r01[2]);
642   double c = s01sq - bond1*bond1;
643 
644   // error check
645 
646   double determ = b*b - 4.0*a*c;
647   if (determ < 0.0) {
648     //error->warning(FLERR,"Shake determinant < 0.0",0);
649     d_error_flag() = 2;
650     determ = 0.0;
651   }
652 
653   // exact quadratic solution for lamda
654 
655   double lamda,lamda1,lamda2;
656   lamda1 = (-b+sqrt(determ)) / (2.0*a);
657   lamda2 = (-b-sqrt(determ)) / (2.0*a);
658 
659   if (fabs(lamda1) <= fabs(lamda2)) lamda = lamda1;
660   else lamda = lamda2;
661 
662   // update forces if atom is owned by this processor
663 
664   lamda /= dtfsq;
665 
666   if (i0 < nlocal) {
667     a_f(i0,0) += lamda*r01[0];
668     a_f(i0,1) += lamda*r01[1];
669     a_f(i0,2) += lamda*r01[2];
670   }
671 
672   if (i1 < nlocal) {
673     a_f(i1,0) -= lamda*r01[0];
674     a_f(i1,1) -= lamda*r01[1];
675     a_f(i1,2) -= lamda*r01[2];
676   }
677 
678   if (EVFLAG) {
679     nlist = 0;
680     if (i0 < nlocal) list[nlist++] = i0;
681     if (i1 < nlocal) list[nlist++] = i1;
682 
683     v[0] = lamda*r01[0]*r01[0];
684     v[1] = lamda*r01[1]*r01[1];
685     v[2] = lamda*r01[2]*r01[2];
686     v[3] = lamda*r01[0]*r01[1];
687     v[4] = lamda*r01[0]*r01[2];
688     v[5] = lamda*r01[1]*r01[2];
689 
690     v_tally<NEIGHFLAG>(ev,nlist,list,2.0,v);
691   }
692 }
693 
694 /* ---------------------------------------------------------------------- */
695 
696 template<class DeviceType>
697 template<int NEIGHFLAG, int EVFLAG>
698 KOKKOS_INLINE_FUNCTION
shake3(int m,EV_FLOAT & ev) const699 void FixShakeKokkos<DeviceType>::shake3(int m, EV_FLOAT& ev) const
700 {
701 
702   // The f array is duplicated for OpenMP, atomic for CUDA, and neither for Serial
703 
704   auto v_f = ScatterViewHelper<typename NeedDup<NEIGHFLAG,DeviceType>::value,decltype(dup_f),decltype(ndup_f)>::get(dup_f,ndup_f);
705   auto a_f = v_f.template access<typename AtomicDup<NEIGHFLAG,DeviceType>::value>();
706 
707   int nlist,list[3];
708   double v[6];
709   double invmass0,invmass1,invmass2;
710 
711   // local atom IDs and constraint distances
712 
713   int i0 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,0),map_style,k_map_array,k_map_hash);
714   int i1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,1),map_style,k_map_array,k_map_hash);
715   int i2 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,2),map_style,k_map_array,k_map_hash);
716   double bond1 = d_bond_distance[d_shake_type(m,0)];
717   double bond2 = d_bond_distance[d_shake_type(m,1)];
718 
719   // r01,r02 = distance vec between atoms, with PBC
720 
721   double r01[3];
722   r01[0] = d_x(i0,0) - d_x(i1,0);
723   r01[1] = d_x(i0,1) - d_x(i1,1);
724   r01[2] = d_x(i0,2) - d_x(i1,2);
725   minimum_image(r01);
726 
727   double r02[3];
728   r02[0] = d_x(i0,0) - d_x(i2,0);
729   r02[1] = d_x(i0,1) - d_x(i2,1);
730   r02[2] = d_x(i0,2) - d_x(i2,2);
731   minimum_image(r02);
732 
733   // s01,s02 = distance vec after unconstrained update, with PBC
734   // use Domain::minimum_image_once(), not minimum_image()
735   // b/c xshake values might be huge, due to e.g. fix gcmc
736 
737   double s01[3];
738   s01[0] = d_xshake(i0,0) - d_xshake(i1,0);
739   s01[1] = d_xshake(i0,1) - d_xshake(i1,1);
740   s01[2] = d_xshake(i0,2) - d_xshake(i1,2);
741   minimum_image_once(s01);
742 
743   double s02[3];
744   s02[0] = d_xshake(i0,0) - d_xshake(i2,0);
745   s02[1] = d_xshake(i0,1) - d_xshake(i2,1);
746   s02[2] = d_xshake(i0,2) - d_xshake(i2,2);
747   minimum_image_once(s02);
748 
749   // scalar distances between atoms
750 
751   double r01sq = r01[0]*r01[0] + r01[1]*r01[1] + r01[2]*r01[2];
752   double r02sq = r02[0]*r02[0] + r02[1]*r02[1] + r02[2]*r02[2];
753   double s01sq = s01[0]*s01[0] + s01[1]*s01[1] + s01[2]*s01[2];
754   double s02sq = s02[0]*s02[0] + s02[1]*s02[1] + s02[2]*s02[2];
755 
756   // matrix coeffs and rhs for lamda equations
757 
758   if (d_rmass.data()) {
759     invmass0 = 1.0/d_rmass[i0];
760     invmass1 = 1.0/d_rmass[i1];
761     invmass2 = 1.0/d_rmass[i2];
762   } else {
763     invmass0 = 1.0/d_mass[d_type[i0]];
764     invmass1 = 1.0/d_mass[d_type[i1]];
765     invmass2 = 1.0/d_mass[d_type[i2]];
766   }
767 
768   double a11 = 2.0 * (invmass0+invmass1) *
769     (s01[0]*r01[0] + s01[1]*r01[1] + s01[2]*r01[2]);
770   double a12 = 2.0 * invmass0 *
771     (s01[0]*r02[0] + s01[1]*r02[1] + s01[2]*r02[2]);
772   double a21 = 2.0 * invmass0 *
773     (s02[0]*r01[0] + s02[1]*r01[1] + s02[2]*r01[2]);
774   double a22 = 2.0 * (invmass0+invmass2) *
775     (s02[0]*r02[0] + s02[1]*r02[1] + s02[2]*r02[2]);
776 
777   // inverse of matrix
778 
779   double determ = a11*a22 - a12*a21;
780   if (determ == 0.0) d_error_flag() = 3;
781   //error->one(FLERR,"Shake determinant = 0.0");
782   double determinv = 1.0/determ;
783 
784   double a11inv = a22*determinv;
785   double a12inv = -a12*determinv;
786   double a21inv = -a21*determinv;
787   double a22inv = a11*determinv;
788 
789   // quadratic correction coeffs
790 
791   double r0102 = (r01[0]*r02[0] + r01[1]*r02[1] + r01[2]*r02[2]);
792 
793   double quad1_0101 = (invmass0+invmass1)*(invmass0+invmass1) * r01sq;
794   double quad1_0202 = invmass0*invmass0 * r02sq;
795   double quad1_0102 = 2.0 * (invmass0+invmass1)*invmass0 * r0102;
796 
797   double quad2_0202 = (invmass0+invmass2)*(invmass0+invmass2) * r02sq;
798   double quad2_0101 = invmass0*invmass0 * r01sq;
799   double quad2_0102 = 2.0 * (invmass0+invmass2)*invmass0 * r0102;
800 
801   // iterate until converged
802 
803   double lamda01 = 0.0;
804   double lamda02 = 0.0;
805   int niter = 0;
806   int done = 0;
807 
808   double quad1,quad2,b1,b2,lamda01_new,lamda02_new;
809 
810   while (!done && niter < max_iter) {
811     quad1 = quad1_0101 * lamda01*lamda01 + quad1_0202 * lamda02*lamda02 +
812       quad1_0102 * lamda01*lamda02;
813     quad2 = quad2_0101 * lamda01*lamda01 + quad2_0202 * lamda02*lamda02 +
814       quad2_0102 * lamda01*lamda02;
815 
816     b1 = bond1*bond1 - s01sq - quad1;
817     b2 = bond2*bond2 - s02sq - quad2;
818 
819     lamda01_new = a11inv*b1 + a12inv*b2;
820     lamda02_new = a21inv*b1 + a22inv*b2;
821 
822     done = 1;
823     if (fabs(lamda01_new-lamda01) > tolerance) done = 0;
824     if (fabs(lamda02_new-lamda02) > tolerance) done = 0;
825 
826     lamda01 = lamda01_new;
827     lamda02 = lamda02_new;
828 
829     // stop iterations before we have a floating point overflow
830     // max double is < 1.0e308, so 1e150 is a reasonable cutoff
831 
832     if (fabs(lamda01) > 1e150 || fabs(lamda02) > 1e150) done = 1;
833 
834     niter++;
835   }
836 
837   // update forces if atom is owned by this processor
838 
839   lamda01 = lamda01/dtfsq;
840   lamda02 = lamda02/dtfsq;
841 
842   if (i0 < nlocal) {
843     a_f(i0,0) += lamda01*r01[0] + lamda02*r02[0];
844     a_f(i0,1) += lamda01*r01[1] + lamda02*r02[1];
845     a_f(i0,2) += lamda01*r01[2] + lamda02*r02[2];
846   }
847 
848   if (i1 < nlocal) {
849     a_f(i1,0) -= lamda01*r01[0];
850     a_f(i1,1) -= lamda01*r01[1];
851     a_f(i1,2) -= lamda01*r01[2];
852   }
853 
854   if (i2 < nlocal) {
855     a_f(i2,0) -= lamda02*r02[0];
856     a_f(i2,1) -= lamda02*r02[1];
857     a_f(i2,2) -= lamda02*r02[2];
858   }
859 
860   if (EVFLAG) {
861     nlist = 0;
862     if (i0 < nlocal) list[nlist++] = i0;
863     if (i1 < nlocal) list[nlist++] = i1;
864     if (i2 < nlocal) list[nlist++] = i2;
865 
866     v[0] = lamda01*r01[0]*r01[0] + lamda02*r02[0]*r02[0];
867     v[1] = lamda01*r01[1]*r01[1] + lamda02*r02[1]*r02[1];
868     v[2] = lamda01*r01[2]*r01[2] + lamda02*r02[2]*r02[2];
869     v[3] = lamda01*r01[0]*r01[1] + lamda02*r02[0]*r02[1];
870     v[4] = lamda01*r01[0]*r01[2] + lamda02*r02[0]*r02[2];
871     v[5] = lamda01*r01[1]*r01[2] + lamda02*r02[1]*r02[2];
872 
873     v_tally<NEIGHFLAG>(ev,nlist,list,3.0,v);
874   }
875 }
876 
877 /* ---------------------------------------------------------------------- */
878 
879 template<class DeviceType>
880 template<int NEIGHFLAG, int EVFLAG>
881 KOKKOS_INLINE_FUNCTION
shake4(int m,EV_FLOAT & ev) const882 void FixShakeKokkos<DeviceType>::shake4(int m, EV_FLOAT& ev) const
883 {
884 
885   // The f array is duplicated for OpenMP, atomic for CUDA, and neither for Serial
886 
887   auto v_f = ScatterViewHelper<typename NeedDup<NEIGHFLAG,DeviceType>::value,decltype(dup_f),decltype(ndup_f)>::get(dup_f,ndup_f);
888   auto a_f = v_f.template access<typename AtomicDup<NEIGHFLAG,DeviceType>::value>();
889 
890  int nlist,list[4];
891   double v[6];
892   double invmass0,invmass1,invmass2,invmass3;
893 
894   // local atom IDs and constraint distances
895 
896   int i0 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,0),map_style,k_map_array,k_map_hash);
897   int i1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,1),map_style,k_map_array,k_map_hash);
898   int i2 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,2),map_style,k_map_array,k_map_hash);
899   int i3 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,3),map_style,k_map_array,k_map_hash);
900   double bond1 = d_bond_distance[d_shake_type(m,0)];
901   double bond2 = d_bond_distance[d_shake_type(m,1)];
902   double bond3 = d_bond_distance[d_shake_type(m,2)];
903 
904   // r01,r02,r03 = distance vec between atoms, with PBC
905 
906   double r01[3];
907   r01[0] = d_x(i0,0) - d_x(i1,0);
908   r01[1] = d_x(i0,1) - d_x(i1,1);
909   r01[2] = d_x(i0,2) - d_x(i1,2);
910   minimum_image(r01);
911 
912   double r02[3];
913   r02[0] = d_x(i0,0) - d_x(i2,0);
914   r02[1] = d_x(i0,1) - d_x(i2,1);
915   r02[2] = d_x(i0,2) - d_x(i2,2);
916   minimum_image(r02);
917 
918   double r03[3];
919   r03[0] = d_x(i0,0) - d_x(i3,0);
920   r03[1] = d_x(i0,1) - d_x(i3,1);
921   r03[2] = d_x(i0,2) - d_x(i3,2);
922   minimum_image(r03);
923 
924   // s01,s02,s03 = distance vec after unconstrained update, with PBC
925   // use Domain::minimum_image_once(), not minimum_image()
926   // b/c xshake values might be huge, due to e.g. fix gcmc
927 
928   double s01[3];
929   s01[0] = d_xshake(i0,0) - d_xshake(i1,0);
930   s01[1] = d_xshake(i0,1) - d_xshake(i1,1);
931   s01[2] = d_xshake(i0,2) - d_xshake(i1,2);
932   minimum_image_once(s01);
933 
934   double s02[3];
935   s02[0] = d_xshake(i0,0) - d_xshake(i2,0);
936   s02[1] = d_xshake(i0,1) - d_xshake(i2,1);
937   s02[2] = d_xshake(i0,2) - d_xshake(i2,2);
938   minimum_image_once(s02);
939 
940   double s03[3];
941   s03[0] = d_xshake(i0,0) - d_xshake(i3,0);
942   s03[1] = d_xshake(i0,1) - d_xshake(i3,1);
943   s03[2] = d_xshake(i0,2) - d_xshake(i3,2);
944   minimum_image_once(s03);
945 
946   // scalar distances between atoms
947 
948   double r01sq = r01[0]*r01[0] + r01[1]*r01[1] + r01[2]*r01[2];
949   double r02sq = r02[0]*r02[0] + r02[1]*r02[1] + r02[2]*r02[2];
950   double r03sq = r03[0]*r03[0] + r03[1]*r03[1] + r03[2]*r03[2];
951   double s01sq = s01[0]*s01[0] + s01[1]*s01[1] + s01[2]*s01[2];
952   double s02sq = s02[0]*s02[0] + s02[1]*s02[1] + s02[2]*s02[2];
953   double s03sq = s03[0]*s03[0] + s03[1]*s03[1] + s03[2]*s03[2];
954 
955   // matrix coeffs and rhs for lamda equations
956 
957   if (d_rmass.data()) {
958     invmass0 = 1.0/d_rmass[i0];
959     invmass1 = 1.0/d_rmass[i1];
960     invmass2 = 1.0/d_rmass[i2];
961     invmass3 = 1.0/d_rmass[i3];
962   } else {
963     invmass0 = 1.0/d_mass[d_type[i0]];
964     invmass1 = 1.0/d_mass[d_type[i1]];
965     invmass2 = 1.0/d_mass[d_type[i2]];
966     invmass3 = 1.0/d_mass[d_type[i3]];
967   }
968 
969   double a11 = 2.0 * (invmass0+invmass1) *
970     (s01[0]*r01[0] + s01[1]*r01[1] + s01[2]*r01[2]);
971   double a12 = 2.0 * invmass0 *
972     (s01[0]*r02[0] + s01[1]*r02[1] + s01[2]*r02[2]);
973   double a13 = 2.0 * invmass0 *
974     (s01[0]*r03[0] + s01[1]*r03[1] + s01[2]*r03[2]);
975   double a21 = 2.0 * invmass0 *
976     (s02[0]*r01[0] + s02[1]*r01[1] + s02[2]*r01[2]);
977   double a22 = 2.0 * (invmass0+invmass2) *
978     (s02[0]*r02[0] + s02[1]*r02[1] + s02[2]*r02[2]);
979   double a23 = 2.0 * invmass0 *
980     (s02[0]*r03[0] + s02[1]*r03[1] + s02[2]*r03[2]);
981   double a31 = 2.0 * invmass0 *
982     (s03[0]*r01[0] + s03[1]*r01[1] + s03[2]*r01[2]);
983   double a32 = 2.0 * invmass0 *
984     (s03[0]*r02[0] + s03[1]*r02[1] + s03[2]*r02[2]);
985   double a33 = 2.0 * (invmass0+invmass3) *
986     (s03[0]*r03[0] + s03[1]*r03[1] + s03[2]*r03[2]);
987 
988   // inverse of matrix;
989 
990   double determ = a11*a22*a33 + a12*a23*a31 + a13*a21*a32 -
991     a11*a23*a32 - a12*a21*a33 - a13*a22*a31;
992   if (determ == 0.0) d_error_flag() = 3;
993   //error->one(FLERR,"Shake determinant = 0.0");
994   double determinv = 1.0/determ;
995 
996   double a11inv = determinv * (a22*a33 - a23*a32);
997   double a12inv = -determinv * (a12*a33 - a13*a32);
998   double a13inv = determinv * (a12*a23 - a13*a22);
999   double a21inv = -determinv * (a21*a33 - a23*a31);
1000   double a22inv = determinv * (a11*a33 - a13*a31);
1001   double a23inv = -determinv * (a11*a23 - a13*a21);
1002   double a31inv = determinv * (a21*a32 - a22*a31);
1003   double a32inv = -determinv * (a11*a32 - a12*a31);
1004   double a33inv = determinv * (a11*a22 - a12*a21);
1005 
1006   // quadratic correction coeffs
1007 
1008   double r0102 = (r01[0]*r02[0] + r01[1]*r02[1] + r01[2]*r02[2]);
1009   double r0103 = (r01[0]*r03[0] + r01[1]*r03[1] + r01[2]*r03[2]);
1010   double r0203 = (r02[0]*r03[0] + r02[1]*r03[1] + r02[2]*r03[2]);
1011 
1012   double quad1_0101 = (invmass0+invmass1)*(invmass0+invmass1) * r01sq;
1013   double quad1_0202 = invmass0*invmass0 * r02sq;
1014   double quad1_0303 = invmass0*invmass0 * r03sq;
1015   double quad1_0102 = 2.0 * (invmass0+invmass1)*invmass0 * r0102;
1016   double quad1_0103 = 2.0 * (invmass0+invmass1)*invmass0 * r0103;
1017   double quad1_0203 = 2.0 * invmass0*invmass0 * r0203;
1018 
1019   double quad2_0101 = invmass0*invmass0 * r01sq;
1020   double quad2_0202 = (invmass0+invmass2)*(invmass0+invmass2) * r02sq;
1021   double quad2_0303 = invmass0*invmass0 * r03sq;
1022   double quad2_0102 = 2.0 * (invmass0+invmass2)*invmass0 * r0102;
1023   double quad2_0103 = 2.0 * invmass0*invmass0 * r0103;
1024   double quad2_0203 = 2.0 * (invmass0+invmass2)*invmass0 * r0203;
1025 
1026   double quad3_0101 = invmass0*invmass0 * r01sq;
1027   double quad3_0202 = invmass0*invmass0 * r02sq;
1028   double quad3_0303 = (invmass0+invmass3)*(invmass0+invmass3) * r03sq;
1029   double quad3_0102 = 2.0 * invmass0*invmass0 * r0102;
1030   double quad3_0103 = 2.0 * (invmass0+invmass3)*invmass0 * r0103;
1031   double quad3_0203 = 2.0 * (invmass0+invmass3)*invmass0 * r0203;
1032 
1033   // iterate until converged
1034 
1035   double lamda01 = 0.0;
1036   double lamda02 = 0.0;
1037   double lamda03 = 0.0;
1038   int niter = 0;
1039   int done = 0;
1040 
1041   double quad1,quad2,quad3,b1,b2,b3,lamda01_new,lamda02_new,lamda03_new;
1042 
1043   while (!done && niter < max_iter) {
1044     quad1 = quad1_0101 * lamda01*lamda01 +
1045       quad1_0202 * lamda02*lamda02 +
1046       quad1_0303 * lamda03*lamda03 +
1047       quad1_0102 * lamda01*lamda02 +
1048       quad1_0103 * lamda01*lamda03 +
1049       quad1_0203 * lamda02*lamda03;
1050 
1051     quad2 = quad2_0101 * lamda01*lamda01 +
1052       quad2_0202 * lamda02*lamda02 +
1053       quad2_0303 * lamda03*lamda03 +
1054       quad2_0102 * lamda01*lamda02 +
1055       quad2_0103 * lamda01*lamda03 +
1056       quad2_0203 * lamda02*lamda03;
1057 
1058     quad3 = quad3_0101 * lamda01*lamda01 +
1059       quad3_0202 * lamda02*lamda02 +
1060       quad3_0303 * lamda03*lamda03 +
1061       quad3_0102 * lamda01*lamda02 +
1062       quad3_0103 * lamda01*lamda03 +
1063       quad3_0203 * lamda02*lamda03;
1064 
1065     b1 = bond1*bond1 - s01sq - quad1;
1066     b2 = bond2*bond2 - s02sq - quad2;
1067     b3 = bond3*bond3 - s03sq - quad3;
1068 
1069     lamda01_new = a11inv*b1 + a12inv*b2 + a13inv*b3;
1070     lamda02_new = a21inv*b1 + a22inv*b2 + a23inv*b3;
1071     lamda03_new = a31inv*b1 + a32inv*b2 + a33inv*b3;
1072 
1073     done = 1;
1074     if (fabs(lamda01_new-lamda01) > tolerance) done = 0;
1075     if (fabs(lamda02_new-lamda02) > tolerance) done = 0;
1076     if (fabs(lamda03_new-lamda03) > tolerance) done = 0;
1077 
1078     lamda01 = lamda01_new;
1079     lamda02 = lamda02_new;
1080     lamda03 = lamda03_new;
1081 
1082     // stop iterations before we have a floating point overflow
1083     // max double is < 1.0e308, so 1e150 is a reasonable cutoff
1084 
1085     if (fabs(lamda01) > 1e150 || fabs(lamda02) > 1e150
1086         || fabs(lamda03) > 1e150) done = 1;
1087 
1088     niter++;
1089   }
1090 
1091   // update forces if atom is owned by this processor
1092 
1093   lamda01 = lamda01/dtfsq;
1094   lamda02 = lamda02/dtfsq;
1095   lamda03 = lamda03/dtfsq;
1096 
1097   if (i0 < nlocal) {
1098     a_f(i0,0) += lamda01*r01[0] + lamda02*r02[0] + lamda03*r03[0];
1099     a_f(i0,1) += lamda01*r01[1] + lamda02*r02[1] + lamda03*r03[1];
1100     a_f(i0,2) += lamda01*r01[2] + lamda02*r02[2] + lamda03*r03[2];
1101   }
1102 
1103   if (i1 < nlocal) {
1104     a_f(i1,0) -= lamda01*r01[0];
1105     a_f(i1,1) -= lamda01*r01[1];
1106     a_f(i1,2) -= lamda01*r01[2];
1107   }
1108 
1109   if (i2 < nlocal) {
1110     a_f(i2,0) -= lamda02*r02[0];
1111     a_f(i2,1) -= lamda02*r02[1];
1112     a_f(i2,2) -= lamda02*r02[2];
1113   }
1114 
1115   if (i3 < nlocal) {
1116     a_f(i3,0) -= lamda03*r03[0];
1117     a_f(i3,1) -= lamda03*r03[1];
1118     a_f(i3,2) -= lamda03*r03[2];
1119   }
1120 
1121   if (EVFLAG) {
1122     nlist = 0;
1123     if (i0 < nlocal) list[nlist++] = i0;
1124     if (i1 < nlocal) list[nlist++] = i1;
1125     if (i2 < nlocal) list[nlist++] = i2;
1126     if (i3 < nlocal) list[nlist++] = i3;
1127 
1128     v[0] = lamda01*r01[0]*r01[0]+lamda02*r02[0]*r02[0]+lamda03*r03[0]*r03[0];
1129     v[1] = lamda01*r01[1]*r01[1]+lamda02*r02[1]*r02[1]+lamda03*r03[1]*r03[1];
1130     v[2] = lamda01*r01[2]*r01[2]+lamda02*r02[2]*r02[2]+lamda03*r03[2]*r03[2];
1131     v[3] = lamda01*r01[0]*r01[1]+lamda02*r02[0]*r02[1]+lamda03*r03[0]*r03[1];
1132     v[4] = lamda01*r01[0]*r01[2]+lamda02*r02[0]*r02[2]+lamda03*r03[0]*r03[2];
1133     v[5] = lamda01*r01[1]*r01[2]+lamda02*r02[1]*r02[2]+lamda03*r03[1]*r03[2];
1134 
1135     v_tally<NEIGHFLAG>(ev,nlist,list,4.0,v);
1136   }
1137 }
1138 
1139 /* ---------------------------------------------------------------------- */
1140 
1141 template<class DeviceType>
1142 template<int NEIGHFLAG, int EVFLAG>
1143 KOKKOS_INLINE_FUNCTION
shake3angle(int m,EV_FLOAT & ev) const1144 void FixShakeKokkos<DeviceType>::shake3angle(int m, EV_FLOAT& ev) const
1145 {
1146 
1147   // The f array is duplicated for OpenMP, atomic for CUDA, and neither for Serial
1148 
1149   auto v_f = ScatterViewHelper<typename NeedDup<NEIGHFLAG,DeviceType>::value,decltype(dup_f),decltype(ndup_f)>::get(dup_f,ndup_f);
1150   auto a_f = v_f.template access<typename AtomicDup<NEIGHFLAG,DeviceType>::value>();
1151 
1152   int nlist,list[3];
1153   double v[6];
1154   double invmass0,invmass1,invmass2;
1155 
1156   // local atom IDs and constraint distances
1157 
1158   int i0 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,0),map_style,k_map_array,k_map_hash);
1159   int i1 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,1),map_style,k_map_array,k_map_hash);
1160   int i2 = AtomKokkos::map_kokkos<DeviceType>(d_shake_atom(m,2),map_style,k_map_array,k_map_hash);
1161   double bond1 = d_bond_distance[d_shake_type(m,0)];
1162   double bond2 = d_bond_distance[d_shake_type(m,1)];
1163   double bond12 = d_angle_distance[d_shake_type(m,2)];
1164 
1165   // r01,r02,r12 = distance vec between atoms, with PBC
1166 
1167   double r01[3];
1168   r01[0] = d_x(i0,0) - d_x(i1,0);
1169   r01[1] = d_x(i0,1) - d_x(i1,1);
1170   r01[2] = d_x(i0,2) - d_x(i1,2);
1171   minimum_image(r01);
1172 
1173   double r02[3];
1174   r02[0] = d_x(i0,0) - d_x(i2,0);
1175   r02[1] = d_x(i0,1) - d_x(i2,1);
1176   r02[2] = d_x(i0,2) - d_x(i2,2);
1177   minimum_image(r02);
1178 
1179   double r12[3];
1180   r12[0] = d_x(i1,0) - d_x(i2,0);
1181   r12[1] = d_x(i1,1) - d_x(i2,1);
1182   r12[2] = d_x(i1,2) - d_x(i2,2);
1183   minimum_image(r12);
1184 
1185   // s01,s02,s12 = distance vec after unconstrained update, with PBC
1186   // use Domain::minimum_image_once(), not minimum_image()
1187   // b/c xshake values might be huge, due to e.g. fix gcmc
1188 
1189   double s01[3];
1190   s01[0] = d_xshake(i0,0) - d_xshake(i1,0);
1191   s01[1] = d_xshake(i0,1) - d_xshake(i1,1);
1192   s01[2] = d_xshake(i0,2) - d_xshake(i1,2);
1193   minimum_image_once(s01);
1194 
1195   double s02[3];
1196   s02[0] = d_xshake(i0,0) - d_xshake(i2,0);
1197   s02[1] = d_xshake(i0,1) - d_xshake(i2,1);
1198   s02[2] = d_xshake(i0,2) - d_xshake(i2,2);
1199   minimum_image_once(s02);
1200 
1201   double s12[3];
1202   s12[0] = d_xshake(i1,0) - d_xshake(i2,0);
1203   s12[1] = d_xshake(i1,1) - d_xshake(i2,1);
1204   s12[2] = d_xshake(i1,2) - d_xshake(i2,2);
1205   minimum_image_once(s12);
1206 
1207   // scalar distances between atoms
1208 
1209   double r01sq = r01[0]*r01[0] + r01[1]*r01[1] + r01[2]*r01[2];
1210   double r02sq = r02[0]*r02[0] + r02[1]*r02[1] + r02[2]*r02[2];
1211   double r12sq = r12[0]*r12[0] + r12[1]*r12[1] + r12[2]*r12[2];
1212   double s01sq = s01[0]*s01[0] + s01[1]*s01[1] + s01[2]*s01[2];
1213   double s02sq = s02[0]*s02[0] + s02[1]*s02[1] + s02[2]*s02[2];
1214   double s12sq = s12[0]*s12[0] + s12[1]*s12[1] + s12[2]*s12[2];
1215 
1216   // matrix coeffs and rhs for lamda equations
1217 
1218   if (d_rmass.data()) {
1219     invmass0 = 1.0/d_rmass[i0];
1220     invmass1 = 1.0/d_rmass[i1];
1221     invmass2 = 1.0/d_rmass[i2];
1222   } else {
1223     invmass0 = 1.0/d_mass[d_type[i0]];
1224     invmass1 = 1.0/d_mass[d_type[i1]];
1225     invmass2 = 1.0/d_mass[d_type[i2]];
1226   }
1227 
1228   double a11 = 2.0 * (invmass0+invmass1) *
1229     (s01[0]*r01[0] + s01[1]*r01[1] + s01[2]*r01[2]);
1230   double a12 = 2.0 * invmass0 *
1231     (s01[0]*r02[0] + s01[1]*r02[1] + s01[2]*r02[2]);
1232   double a13 = - 2.0 * invmass1 *
1233     (s01[0]*r12[0] + s01[1]*r12[1] + s01[2]*r12[2]);
1234   double a21 = 2.0 * invmass0 *
1235     (s02[0]*r01[0] + s02[1]*r01[1] + s02[2]*r01[2]);
1236   double a22 = 2.0 * (invmass0+invmass2) *
1237     (s02[0]*r02[0] + s02[1]*r02[1] + s02[2]*r02[2]);
1238   double a23 = 2.0 * invmass2 *
1239     (s02[0]*r12[0] + s02[1]*r12[1] + s02[2]*r12[2]);
1240   double a31 = - 2.0 * invmass1 *
1241     (s12[0]*r01[0] + s12[1]*r01[1] + s12[2]*r01[2]);
1242   double a32 = 2.0 * invmass2 *
1243     (s12[0]*r02[0] + s12[1]*r02[1] + s12[2]*r02[2]);
1244   double a33 = 2.0 * (invmass1+invmass2) *
1245     (s12[0]*r12[0] + s12[1]*r12[1] + s12[2]*r12[2]);
1246 
1247   // inverse of matrix
1248 
1249   double determ = a11*a22*a33 + a12*a23*a31 + a13*a21*a32 -
1250     a11*a23*a32 - a12*a21*a33 - a13*a22*a31;
1251   if (determ == 0.0) d_error_flag() = 3;
1252   //error->one(FLERR,"Shake determinant = 0.0");
1253   double determinv = 1.0/determ;
1254 
1255   double a11inv = determinv * (a22*a33 - a23*a32);
1256   double a12inv = -determinv * (a12*a33 - a13*a32);
1257   double a13inv = determinv * (a12*a23 - a13*a22);
1258   double a21inv = -determinv * (a21*a33 - a23*a31);
1259   double a22inv = determinv * (a11*a33 - a13*a31);
1260   double a23inv = -determinv * (a11*a23 - a13*a21);
1261   double a31inv = determinv * (a21*a32 - a22*a31);
1262   double a32inv = -determinv * (a11*a32 - a12*a31);
1263   double a33inv = determinv * (a11*a22 - a12*a21);
1264 
1265   // quadratic correction coeffs
1266 
1267   double r0102 = (r01[0]*r02[0] + r01[1]*r02[1] + r01[2]*r02[2]);
1268   double r0112 = (r01[0]*r12[0] + r01[1]*r12[1] + r01[2]*r12[2]);
1269   double r0212 = (r02[0]*r12[0] + r02[1]*r12[1] + r02[2]*r12[2]);
1270 
1271   double quad1_0101 = (invmass0+invmass1)*(invmass0+invmass1) * r01sq;
1272   double quad1_0202 = invmass0*invmass0 * r02sq;
1273   double quad1_1212 = invmass1*invmass1 * r12sq;
1274   double quad1_0102 = 2.0 * (invmass0+invmass1)*invmass0 * r0102;
1275   double quad1_0112 = - 2.0 * (invmass0+invmass1)*invmass1 * r0112;
1276   double quad1_0212 = - 2.0 * invmass0*invmass1 * r0212;
1277 
1278   double quad2_0101 = invmass0*invmass0 * r01sq;
1279   double quad2_0202 = (invmass0+invmass2)*(invmass0+invmass2) * r02sq;
1280   double quad2_1212 = invmass2*invmass2 * r12sq;
1281   double quad2_0102 = 2.0 * (invmass0+invmass2)*invmass0 * r0102;
1282   double quad2_0112 = 2.0 * invmass0*invmass2 * r0112;
1283   double quad2_0212 = 2.0 * (invmass0+invmass2)*invmass2 * r0212;
1284 
1285   double quad3_0101 = invmass1*invmass1 * r01sq;
1286   double quad3_0202 = invmass2*invmass2 * r02sq;
1287   double quad3_1212 = (invmass1+invmass2)*(invmass1+invmass2) * r12sq;
1288   double quad3_0102 = - 2.0 * invmass1*invmass2 * r0102;
1289   double quad3_0112 = - 2.0 * (invmass1+invmass2)*invmass1 * r0112;
1290   double quad3_0212 = 2.0 * (invmass1+invmass2)*invmass2 * r0212;
1291 
1292   // iterate until converged
1293 
1294   double lamda01 = 0.0;
1295   double lamda02 = 0.0;
1296   double lamda12 = 0.0;
1297   int niter = 0;
1298   int done = 0;
1299 
1300   double quad1,quad2,quad3,b1,b2,b3,lamda01_new,lamda02_new,lamda12_new;
1301 
1302   while (!done && niter < max_iter) {
1303 
1304     quad1 = quad1_0101 * lamda01*lamda01 +
1305       quad1_0202 * lamda02*lamda02 +
1306       quad1_1212 * lamda12*lamda12 +
1307       quad1_0102 * lamda01*lamda02 +
1308       quad1_0112 * lamda01*lamda12 +
1309       quad1_0212 * lamda02*lamda12;
1310 
1311     quad2 = quad2_0101 * lamda01*lamda01 +
1312       quad2_0202 * lamda02*lamda02 +
1313       quad2_1212 * lamda12*lamda12 +
1314       quad2_0102 * lamda01*lamda02 +
1315       quad2_0112 * lamda01*lamda12 +
1316       quad2_0212 * lamda02*lamda12;
1317 
1318     quad3 = quad3_0101 * lamda01*lamda01 +
1319       quad3_0202 * lamda02*lamda02 +
1320       quad3_1212 * lamda12*lamda12 +
1321       quad3_0102 * lamda01*lamda02 +
1322       quad3_0112 * lamda01*lamda12 +
1323       quad3_0212 * lamda02*lamda12;
1324 
1325     b1 = bond1*bond1 - s01sq - quad1;
1326     b2 = bond2*bond2 - s02sq - quad2;
1327     b3 = bond12*bond12 - s12sq - quad3;
1328 
1329     lamda01_new = a11inv*b1 + a12inv*b2 + a13inv*b3;
1330     lamda02_new = a21inv*b1 + a22inv*b2 + a23inv*b3;
1331     lamda12_new = a31inv*b1 + a32inv*b2 + a33inv*b3;
1332 
1333     done = 1;
1334     if (fabs(lamda01_new-lamda01) > tolerance) done = 0;
1335     if (fabs(lamda02_new-lamda02) > tolerance) done = 0;
1336     if (fabs(lamda12_new-lamda12) > tolerance) done = 0;
1337 
1338     lamda01 = lamda01_new;
1339     lamda02 = lamda02_new;
1340     lamda12 = lamda12_new;
1341 
1342     // stop iterations before we have a floating point overflow
1343     // max double is < 1.0e308, so 1e150 is a reasonable cutoff
1344 
1345     if (fabs(lamda01) > 1e150 || fabs(lamda02) > 1e150
1346         || fabs(lamda12) > 1e150) done = 1;
1347 
1348     niter++;
1349   }
1350 
1351   // update forces if atom is owned by this processor
1352 
1353   lamda01 = lamda01/dtfsq;
1354   lamda02 = lamda02/dtfsq;
1355   lamda12 = lamda12/dtfsq;
1356 
1357   if (i0 < nlocal) {
1358     a_f(i0,0) += lamda01*r01[0] + lamda02*r02[0];
1359     a_f(i0,1) += lamda01*r01[1] + lamda02*r02[1];
1360     a_f(i0,2) += lamda01*r01[2] + lamda02*r02[2];
1361   }
1362 
1363   if (i1 < nlocal) {
1364     a_f(i1,0) -= lamda01*r01[0] - lamda12*r12[0];
1365     a_f(i1,1) -= lamda01*r01[1] - lamda12*r12[1];
1366     a_f(i1,2) -= lamda01*r01[2] - lamda12*r12[2];
1367   }
1368 
1369   if (i2 < nlocal) {
1370     a_f(i2,0) -= lamda02*r02[0] + lamda12*r12[0];
1371     a_f(i2,1) -= lamda02*r02[1] + lamda12*r12[1];
1372     a_f(i2,2) -= lamda02*r02[2] + lamda12*r12[2];
1373   }
1374 
1375   if (EVFLAG) {
1376     nlist = 0;
1377     if (i0 < nlocal) list[nlist++] = i0;
1378     if (i1 < nlocal) list[nlist++] = i1;
1379     if (i2 < nlocal) list[nlist++] = i2;
1380 
1381     v[0] = lamda01*r01[0]*r01[0]+lamda02*r02[0]*r02[0]+lamda12*r12[0]*r12[0];
1382     v[1] = lamda01*r01[1]*r01[1]+lamda02*r02[1]*r02[1]+lamda12*r12[1]*r12[1];
1383     v[2] = lamda01*r01[2]*r01[2]+lamda02*r02[2]*r02[2]+lamda12*r12[2]*r12[2];
1384     v[3] = lamda01*r01[0]*r01[1]+lamda02*r02[0]*r02[1]+lamda12*r12[0]*r12[1];
1385     v[4] = lamda01*r01[0]*r01[2]+lamda02*r02[0]*r02[2]+lamda12*r12[0]*r12[2];
1386     v[5] = lamda01*r01[1]*r01[2]+lamda02*r02[1]*r02[2]+lamda12*r12[1]*r12[2];
1387 
1388     v_tally<NEIGHFLAG>(ev,nlist,list,3.0,v);
1389   }
1390 }
1391 
1392 /* ----------------------------------------------------------------------
1393    allocate local atom-based arrays
1394 ------------------------------------------------------------------------- */
1395 
1396 template<class DeviceType>
grow_arrays(int nmax)1397 void FixShakeKokkos<DeviceType>::grow_arrays(int nmax)
1398 {
1399   memoryKK->grow_kokkos(k_shake_flag,shake_flag,nmax,"shake:shake_flag");
1400   memoryKK->grow_kokkos(k_shake_atom,shake_atom,nmax,4,"shake:shake_atom");
1401   memoryKK->grow_kokkos(k_shake_type,shake_type,nmax,3,"shake:shake_type");
1402   memoryKK->destroy_kokkos(k_xshake,xshake);
1403   memoryKK->create_kokkos(k_xshake,xshake,nmax,"shake:xshake");
1404 
1405   d_shake_flag = k_shake_flag.view<DeviceType>();
1406   d_shake_atom = k_shake_atom.view<DeviceType>();
1407   d_shake_type = k_shake_type.view<DeviceType>();
1408   d_xshake = k_xshake.view<DeviceType>();
1409 
1410   memory->destroy(ftmp);
1411   memory->create(ftmp,nmax,3,"shake:ftmp");
1412   memory->destroy(vtmp);
1413   memory->create(vtmp,nmax,3,"shake:vtmp");
1414 }
1415 
1416 /* ----------------------------------------------------------------------
1417    copy values within local atom-based arrays
1418 ------------------------------------------------------------------------- */
1419 
1420 template<class DeviceType>
copy_arrays(int i,int j,int delflag)1421 void FixShakeKokkos<DeviceType>::copy_arrays(int i, int j, int delflag)
1422 {
1423   k_shake_flag.sync_host();
1424   k_shake_atom.sync_host();
1425   k_shake_type.sync_host();
1426 
1427   FixShake::copy_arrays(i,j,delflag);
1428 
1429   k_shake_flag.modify_host();
1430   k_shake_atom.modify_host();
1431   k_shake_type.modify_host();
1432 }
1433 
1434 /* ----------------------------------------------------------------------
1435    initialize one atom's array values, called when atom is created
1436 ------------------------------------------------------------------------- */
1437 
1438 template<class DeviceType>
set_arrays(int i)1439 void FixShakeKokkos<DeviceType>::set_arrays(int i)
1440 {
1441   k_shake_flag.sync_host();
1442 
1443   shake_flag[i] = 0;
1444 
1445   k_shake_flag.modify_host();
1446 }
1447 
1448 /* ----------------------------------------------------------------------
1449    update one atom's array values
1450    called when molecule is created from fix gcmc
1451 ------------------------------------------------------------------------- */
1452 
1453 template<class DeviceType>
update_arrays(int i,int atom_offset)1454 void FixShakeKokkos<DeviceType>::update_arrays(int i, int atom_offset)
1455 {
1456   k_shake_flag.sync_host();
1457   k_shake_atom.sync_host();
1458 
1459   FixShake::update_arrays(i,atom_offset);
1460 
1461   k_shake_flag.modify_host();
1462   k_shake_atom.modify_host();
1463 }
1464 
1465 /* ----------------------------------------------------------------------
1466    initialize a molecule inserted by another fix, e.g. deposit or pour
1467    called when molecule is created
1468    nlocalprev = # of atoms on this proc before molecule inserted
1469    tagprev = atom ID previous to new atoms in the molecule
1470    xgeom,vcm,quat ignored
1471 ------------------------------------------------------------------------- */
1472 
1473 template<class DeviceType>
set_molecule(int nlocalprev,tagint tagprev,int imol,double * xgeom,double * vcm,double * quat)1474 void FixShakeKokkos<DeviceType>::set_molecule(int nlocalprev, tagint tagprev, int imol,
1475                             double * xgeom, double * vcm, double * quat)
1476 {
1477   atomKK->sync(Host,TAG_MASK);
1478   k_shake_flag.sync_host();
1479   k_shake_atom.sync_host();
1480   k_shake_type.sync_host();
1481 
1482   FixShake::set_molecule(nlocalprev,tagprev,imol,xgeom,vcm,quat);
1483 
1484   k_shake_atom.modify_host();
1485   k_shake_type.modify_host();
1486 }
1487 
1488 /* ----------------------------------------------------------------------
1489    pack values in local atom-based arrays for exchange with another proc
1490 ------------------------------------------------------------------------- */
1491 
1492 template<class DeviceType>
pack_exchange(int i,double * buf)1493 int FixShakeKokkos<DeviceType>::pack_exchange(int i, double *buf)
1494 {
1495   k_shake_flag.sync_host();
1496   k_shake_atom.sync_host();
1497   k_shake_type.sync_host();
1498 
1499   int m = FixShake::pack_exchange(i,buf);
1500 
1501   k_shake_flag.modify_host();
1502   k_shake_atom.modify_host();
1503   k_shake_type.modify_host();
1504 
1505   return m;
1506 }
1507 
1508 /* ----------------------------------------------------------------------
1509    unpack values in local atom-based arrays from exchange with another proc
1510 ------------------------------------------------------------------------- */
1511 
1512 template<class DeviceType>
unpack_exchange(int nlocal,double * buf)1513 int FixShakeKokkos<DeviceType>::unpack_exchange(int nlocal, double *buf)
1514 {
1515   k_shake_flag.sync_host();
1516   k_shake_atom.sync_host();
1517   k_shake_type.sync_host();
1518 
1519   int m = FixShake::unpack_exchange(nlocal,buf);
1520 
1521   k_shake_flag.modify_host();
1522   k_shake_atom.modify_host();
1523   k_shake_type.modify_host();
1524 
1525   return m;
1526 }
1527 
1528 /* ---------------------------------------------------------------------- */
1529 
1530 template<class DeviceType>
pack_forward_comm_fix_kokkos(int n,DAT::tdual_int_2d k_sendlist,int iswap_in,DAT::tdual_xfloat_1d & k_buf,int pbc_flag,int * pbc)1531 int FixShakeKokkos<DeviceType>::pack_forward_comm_fix_kokkos(int n, DAT::tdual_int_2d k_sendlist,
1532                                                         int iswap_in, DAT::tdual_xfloat_1d &k_buf,
1533                                                         int pbc_flag, int* pbc)
1534 {
1535   d_sendlist = k_sendlist.view<DeviceType>();
1536   iswap = iswap_in;
1537   d_buf = k_buf.view<DeviceType>();
1538 
1539   if (domain->triclinic == 0) {
1540     dx = pbc[0]*domain->xprd;
1541     dy = pbc[1]*domain->yprd;
1542     dz = pbc[2]*domain->zprd;
1543   } else {
1544     dx = pbc[0]*domain->xprd + pbc[5]*domain->xy + pbc[4]*domain->xz;
1545     dy = pbc[1]*domain->yprd + pbc[3]*domain->yz;
1546     dz = pbc[2]*domain->zprd;
1547   }
1548 
1549   if (pbc_flag)
1550     Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagFixShakePackForwardComm<1> >(0,n),*this);
1551   else
1552     Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagFixShakePackForwardComm<0> >(0,n),*this);
1553   return n*3;
1554 }
1555 
1556 template<class DeviceType>
1557 template<int PBC_FLAG>
1558 KOKKOS_INLINE_FUNCTION
operator ()(TagFixShakePackForwardComm<PBC_FLAG>,const int & i) const1559 void FixShakeKokkos<DeviceType>::operator()(TagFixShakePackForwardComm<PBC_FLAG>, const int &i) const {
1560   const int j = d_sendlist(iswap, i);
1561 
1562   if (PBC_FLAG == 0) {
1563     d_buf[3*i] = d_xshake(j,0);
1564     d_buf[3*i+1] = d_xshake(j,1);
1565     d_buf[3*i+2] = d_xshake(j,2);
1566   } else {
1567     d_buf[3*i] = d_xshake(j,0) + dx;
1568     d_buf[3*i+1] = d_xshake(j,1) + dy;
1569     d_buf[3*i+2] = d_xshake(j,2) + dz;
1570   }
1571 }
1572 
1573 /* ---------------------------------------------------------------------- */
1574 
1575 template<class DeviceType>
pack_forward_comm(int n,int * list,double * buf,int pbc_flag,int * pbc)1576 int FixShakeKokkos<DeviceType>::pack_forward_comm(int n, int *list, double *buf,
1577                                 int pbc_flag, int *pbc)
1578 {
1579   k_xshake.sync_host();
1580 
1581   int m = FixShake::pack_forward_comm(n,list,buf,pbc_flag,pbc);
1582 
1583   k_xshake.modify_host();
1584 
1585   return m;
1586 }
1587 
1588 /* ---------------------------------------------------------------------- */
1589 
1590 template<class DeviceType>
unpack_forward_comm_fix_kokkos(int n,int first_in,DAT::tdual_xfloat_1d & buf)1591 void FixShakeKokkos<DeviceType>::unpack_forward_comm_fix_kokkos(int n, int first_in, DAT::tdual_xfloat_1d &buf)
1592 {
1593   first = first_in;
1594   d_buf = buf.view<DeviceType>();
1595   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagFixShakeUnpackForwardComm>(0,n),*this);
1596 }
1597 
1598 template<class DeviceType>
1599 KOKKOS_INLINE_FUNCTION
operator ()(TagFixShakeUnpackForwardComm,const int & i) const1600 void FixShakeKokkos<DeviceType>::operator()(TagFixShakeUnpackForwardComm, const int &i) const {
1601   d_xshake(i + first,0) = d_buf[3*i];
1602   d_xshake(i + first,1) = d_buf[3*i+1];
1603   d_xshake(i + first,2) = d_buf[3*i+2];
1604 }
1605 
1606 /* ---------------------------------------------------------------------- */
1607 
1608 template<class DeviceType>
unpack_forward_comm(int n,int first,double * buf)1609 void FixShakeKokkos<DeviceType>::unpack_forward_comm(int n, int first, double *buf)
1610 {
1611   k_xshake.sync_host();
1612 
1613   FixShake::unpack_forward_comm(n,first,buf);
1614 
1615   k_xshake.modify_host();
1616 }
1617 
1618 /* ----------------------------------------------------------------------
1619    add coordinate constraining forces
1620    this method is called at the end of a timestep
1621 ------------------------------------------------------------------------- */
1622 
1623 template<class DeviceType>
shake_end_of_step(int vflag)1624 void FixShakeKokkos<DeviceType>::shake_end_of_step(int vflag) {
1625   dtv     = update->dt;
1626   dtfsq   = 0.5 * update->dt * update->dt * force->ftm2v;
1627   FixShakeKokkos<DeviceType>::post_force(vflag);
1628   if (!rattle) dtfsq = update->dt * update->dt * force->ftm2v;
1629 }
1630 
1631 /* ----------------------------------------------------------------------
1632    calculate constraining forces based on the current configuration
1633    change coordinates
1634 ------------------------------------------------------------------------- */
1635 
1636 template<class DeviceType>
correct_coordinates(int vflag)1637 void FixShakeKokkos<DeviceType>::correct_coordinates(int vflag) {
1638   atomKK->sync(Host,X_MASK|V_MASK|F_MASK);
1639 
1640   // save current forces and velocities so that you
1641   // initialize them to zero such that FixShake::unconstrained_coordinate_update has no effect
1642 
1643   for (int j=0; j<nlocal; j++) {
1644     for (int k=0; k<3; k++) {
1645 
1646       // store current value of forces and velocities
1647       ftmp[j][k] = f[j][k];
1648       vtmp[j][k] = v[j][k];
1649 
1650       // set f and v to zero for SHAKE
1651 
1652       v[j][k] = 0;
1653       f[j][k] = 0;
1654     }
1655   }
1656 
1657   atomKK->modified(Host,V_MASK|F_MASK);
1658 
1659   // call SHAKE to correct the coordinates which were updated without constraints
1660   // IMPORTANT: use 1 as argument and thereby enforce velocity Verlet
1661 
1662   dtfsq   = 0.5 * update->dt * update->dt * force->ftm2v;
1663   FixShakeKokkos<DeviceType>::post_force(vflag);
1664 
1665   atomKK->sync(Host,X_MASK|F_MASK);
1666 
1667   // integrate coordinates: x' = xnp1 + dt^2/2m_i * f, where f is the constraining force
1668   // NOTE: After this command, the coordinates geometry of the molecules will be correct!
1669 
1670   double dtfmsq;
1671   if (rmass) {
1672     for (int i = 0; i < nlocal; i++) {
1673       dtfmsq = dtfsq/ rmass[i];
1674       x[i][0] = x[i][0] + dtfmsq*f[i][0];
1675       x[i][1] = x[i][1] + dtfmsq*f[i][1];
1676       x[i][2] = x[i][2] + dtfmsq*f[i][2];
1677     }
1678   }
1679   else {
1680     for (int i = 0; i < nlocal; i++) {
1681       dtfmsq = dtfsq / mass[type[i]];
1682       x[i][0] = x[i][0] + dtfmsq*f[i][0];
1683       x[i][1] = x[i][1] + dtfmsq*f[i][1];
1684       x[i][2] = x[i][2] + dtfmsq*f[i][2];
1685     }
1686   }
1687 
1688   // copy forces and velocities back
1689 
1690   for (int j=0; j<nlocal; j++) {
1691     for (int k=0; k<3; k++) {
1692       f[j][k] = ftmp[j][k];
1693       v[j][k] = vtmp[j][k];
1694     }
1695   }
1696 
1697   if (!rattle) dtfsq = update->dt * update->dt * force->ftm2v;
1698 
1699   // communicate changes
1700   // NOTE: for compatibility xshake is temporarily set to x, such that pack/unpack_forward
1701   //       can be used for communicating the coordinates.
1702 
1703   double **xtmp = xshake;
1704   xshake = x;
1705   if (nprocs > 1) {
1706     forward_comm_device = 0;
1707     comm->forward_comm_fix(this);
1708     forward_comm_device = 1;
1709   }
1710   xshake = xtmp;
1711 
1712   atomKK->modified(Host,X_MASK|V_MASK|F_MASK);
1713 }
1714 
1715 /* ----------------------------------------------------------------------
1716    tally virial into global and per-atom accumulators
1717    n = # of local owned atoms involved, with local indices in list
1718    v = total virial for the interaction involving total atoms
1719    increment global virial by n/total fraction
1720    increment per-atom virial of each atom in list by 1/total fraction
1721    this method can be used when fix computes forces in post_force()
1722      e.g. fix shake, fix rigid: compute virial only on owned atoms
1723        whether newton_bond is on or off
1724      other procs will tally left-over fractions for atoms they own
1725 ------------------------------------------------------------------------- */
1726 template<class DeviceType>
1727 template<int NEIGHFLAG>
1728 KOKKOS_INLINE_FUNCTION
v_tally(EV_FLOAT & ev,int n,int * list,double total,double * v) const1729 void FixShakeKokkos<DeviceType>::v_tally(EV_FLOAT &ev, int n, int *list, double total,
1730      double *v) const
1731 {
1732   int m;
1733 
1734   if (vflag_global) {
1735     double fraction = n/total;
1736     ev.v[0] += fraction*v[0];
1737     ev.v[1] += fraction*v[1];
1738     ev.v[2] += fraction*v[2];
1739     ev.v[3] += fraction*v[3];
1740     ev.v[4] += fraction*v[4];
1741     ev.v[5] += fraction*v[5];
1742   }
1743 
1744   if (vflag_atom) {
1745     double fraction = 1.0/total;
1746     for (int i = 0; i < n; i++) {
1747       auto v_vatom = ScatterViewHelper<typename NeedDup<NEIGHFLAG,DeviceType>::value,decltype(dup_vatom),decltype(ndup_vatom)>::get(dup_vatom,ndup_vatom);
1748       auto a_vatom = v_vatom.template access<typename AtomicDup<NEIGHFLAG,DeviceType>::value>();
1749       m = list[i];
1750       a_vatom(m,0) += fraction*v[0];
1751       a_vatom(m,1) += fraction*v[1];
1752       a_vatom(m,2) += fraction*v[2];
1753       a_vatom(m,3) += fraction*v[3];
1754       a_vatom(m,4) += fraction*v[4];
1755       a_vatom(m,5) += fraction*v[5];
1756     }
1757   }
1758 }
1759 
1760 /* ---------------------------------------------------------------------- */
1761 
1762 template<class DeviceType>
update_domain_variables()1763 void FixShakeKokkos<DeviceType>::update_domain_variables()
1764 {
1765   triclinic = domain->triclinic;
1766   xperiodic = domain->xperiodic;
1767   xprd_half = domain->xprd_half;
1768   xprd = domain->xprd;
1769   yperiodic = domain->yperiodic;
1770   yprd_half = domain->yprd_half;
1771   yprd = domain->yprd;
1772   zperiodic = domain->zperiodic;
1773   zprd_half = domain->zprd_half;
1774   zprd = domain->zprd;
1775   xy = domain->xy;
1776   xz = domain->xz;
1777   yz = domain->yz;
1778 }
1779 
1780 /* ----------------------------------------------------------------------
1781    minimum image convention in periodic dimensions
1782    use 1/2 of box size as test
1783    for triclinic, also add/subtract tilt factors in other dims as needed
1784    changed "if" to "while" to enable distance to
1785      far-away ghost atom returned by atom->map() to be wrapped back into box
1786      could be problem for looking up atom IDs when cutoff > boxsize
1787    this should not be used if atom has moved infinitely far outside box
1788      b/c while could iterate forever
1789      e.g. fix shake prediction of new position with highly overlapped atoms
1790      use minimum_image_once() instead
1791    copied from domain.cpp
1792 ------------------------------------------------------------------------- */
1793 
1794 template<class DeviceType>
1795 KOKKOS_INLINE_FUNCTION
minimum_image(double * delta) const1796 void FixShakeKokkos<DeviceType>::minimum_image(double *delta) const
1797 {
1798   if (triclinic == 0) {
1799     if (xperiodic) {
1800       while (fabs(delta[0]) > xprd_half) {
1801         if (delta[0] < 0.0) delta[0] += xprd;
1802         else delta[0] -= xprd;
1803       }
1804     }
1805     if (yperiodic) {
1806       while (fabs(delta[1]) > yprd_half) {
1807         if (delta[1] < 0.0) delta[1] += yprd;
1808         else delta[1] -= yprd;
1809       }
1810     }
1811     if (zperiodic) {
1812       while (fabs(delta[2]) > zprd_half) {
1813         if (delta[2] < 0.0) delta[2] += zprd;
1814         else delta[2] -= zprd;
1815       }
1816     }
1817 
1818   } else {
1819     if (zperiodic) {
1820       while (fabs(delta[2]) > zprd_half) {
1821         if (delta[2] < 0.0) {
1822           delta[2] += zprd;
1823           delta[1] += yz;
1824           delta[0] += xz;
1825         } else {
1826           delta[2] -= zprd;
1827           delta[1] -= yz;
1828           delta[0] -= xz;
1829         }
1830       }
1831     }
1832     if (yperiodic) {
1833       while (fabs(delta[1]) > yprd_half) {
1834         if (delta[1] < 0.0) {
1835           delta[1] += yprd;
1836           delta[0] += xy;
1837         } else {
1838           delta[1] -= yprd;
1839           delta[0] -= xy;
1840         }
1841       }
1842     }
1843     if (xperiodic) {
1844       while (fabs(delta[0]) > xprd_half) {
1845         if (delta[0] < 0.0) delta[0] += xprd;
1846         else delta[0] -= xprd;
1847       }
1848     }
1849   }
1850 }
1851 
1852 /* ----------------------------------------------------------------------
1853    minimum image convention in periodic dimensions
1854    use 1/2 of box size as test
1855    for triclinic, also add/subtract tilt factors in other dims as needed
1856    only shift by one box length in each direction
1857    this should not be used if multiple box shifts are required
1858    copied from domain.cpp
1859 ------------------------------------------------------------------------- */
1860 
1861 template<class DeviceType>
1862 KOKKOS_INLINE_FUNCTION
minimum_image_once(double * delta) const1863 void FixShakeKokkos<DeviceType>::minimum_image_once(double *delta) const
1864 {
1865   if (triclinic == 0) {
1866     if (xperiodic) {
1867       if (fabs(delta[0]) > xprd_half) {
1868         if (delta[0] < 0.0) delta[0] += xprd;
1869         else delta[0] -= xprd;
1870       }
1871     }
1872     if (yperiodic) {
1873       if (fabs(delta[1]) > yprd_half) {
1874         if (delta[1] < 0.0) delta[1] += yprd;
1875         else delta[1] -= yprd;
1876       }
1877     }
1878     if (zperiodic) {
1879       if (fabs(delta[2]) > zprd_half) {
1880         if (delta[2] < 0.0) delta[2] += zprd;
1881         else delta[2] -= zprd;
1882       }
1883     }
1884 
1885   } else {
1886     if (zperiodic) {
1887       if (fabs(delta[2]) > zprd_half) {
1888         if (delta[2] < 0.0) {
1889           delta[2] += zprd;
1890           delta[1] += yz;
1891           delta[0] += xz;
1892         } else {
1893           delta[2] -= zprd;
1894           delta[1] -= yz;
1895           delta[0] -= xz;
1896         }
1897       }
1898     }
1899     if (yperiodic) {
1900       if (fabs(delta[1]) > yprd_half) {
1901         if (delta[1] < 0.0) {
1902           delta[1] += yprd;
1903           delta[0] += xy;
1904         } else {
1905           delta[1] -= yprd;
1906           delta[0] -= xy;
1907         }
1908       }
1909     }
1910     if (xperiodic) {
1911       if (fabs(delta[0]) > xprd_half) {
1912         if (delta[0] < 0.0) delta[0] += xprd;
1913         else delta[0] -= xprd;
1914       }
1915     }
1916   }
1917 }
1918 
1919 /* ---------------------------------------------------------------------- */
1920 
1921 namespace LAMMPS_NS {
1922 template class FixShakeKokkos<LMPDeviceType>;
1923 #ifdef LMP_KOKKOS_GPU
1924 template class FixShakeKokkos<LMPHostType>;
1925 #endif
1926 }
1927 
1928