1 // clang-format off
2 /* ----------------------------------------------------------------------
3    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
4    https://www.lammps.org/, Sandia National Laboratories
5    Steve Plimpton, sjplimp@sandia.gov
6 
7    Copyright (2003) Sandia Corporation.  Under the terms of Contract
8    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
9    certain rights in this software.  This software is distributed under
10    the GNU General Public License.
11 
12    See the README file in the top-level LAMMPS directory.
13 ------------------------------------------------------------------------- */
14 
15 #include "fix_setforce_kokkos.h"
16 
17 #include "atom_kokkos.h"
18 #include "update.h"
19 #include "modify.h"
20 #include "domain.h"
21 #include "region.h"
22 #include "input.h"
23 #include "variable.h"
24 #include "memory_kokkos.h"
25 #include "error.h"
26 #include "atom_masks.h"
27 #include "kokkos_base.h"
28 
29 #include <cstring>
30 
31 using namespace LAMMPS_NS;
32 using namespace FixConst;
33 
34 enum{NONE,CONSTANT,EQUAL,ATOM};
35 
36 /* ---------------------------------------------------------------------- */
37 
38 template<class DeviceType>
FixSetForceKokkos(LAMMPS * lmp,int narg,char ** arg)39 FixSetForceKokkos<DeviceType>::FixSetForceKokkos(LAMMPS *lmp, int narg, char **arg) :
40   FixSetForce(lmp, narg, arg)
41 {
42   kokkosable = 1;
43   atomKK = (AtomKokkos *) atom;
44   execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
45   datamask_read = EMPTY_MASK;
46   datamask_modify = EMPTY_MASK;
47 
48   memory->destroy(sforce);
49   memoryKK->create_kokkos(k_sforce,sforce,maxatom,3,"setforce:sforce");
50   d_sforce = k_sforce.view<DeviceType>();
51 }
52 
53 /* ---------------------------------------------------------------------- */
54 
55 template<class DeviceType>
~FixSetForceKokkos()56 FixSetForceKokkos<DeviceType>::~FixSetForceKokkos()
57 {
58   if (copymode) return;
59 
60   memoryKK->destroy_kokkos(k_sforce,sforce);
61   sforce = nullptr;
62 }
63 
64 /* ---------------------------------------------------------------------- */
65 
66 template<class DeviceType>
init()67 void FixSetForceKokkos<DeviceType>::init()
68 {
69   FixSetForce::init();
70 
71   if (utils::strmatch(update->integrate_style,"^respa"))
72     error->all(FLERR,"Cannot (yet) use respa with Kokkos");
73 }
74 
75 /* ---------------------------------------------------------------------- */
76 
77 template<class DeviceType>
post_force(int)78 void FixSetForceKokkos<DeviceType>::post_force(int /*vflag*/)
79 {
80   atomKK->sync(execution_space, X_MASK | F_MASK | MASK_MASK);
81 
82   x = atomKK->k_x.view<DeviceType>();
83   f = atomKK->k_f.view<DeviceType>();
84   mask = atomKK->k_mask.view<DeviceType>();
85 
86   int nlocal = atom->nlocal;
87 
88   // update region if necessary
89 
90   region = nullptr;
91   if (iregion >= 0) {
92     region = domain->regions[iregion];
93     region->prematch();
94     DAT::tdual_int_1d k_match = DAT::tdual_int_1d("setforce:k_match",nlocal);
95     KokkosBase* regionKKBase = dynamic_cast<KokkosBase*>(region);
96     regionKKBase->match_all_kokkos(groupbit,k_match);
97     k_match.template sync<DeviceType>();
98     d_match = k_match.template view<DeviceType>();
99   }
100 
101   // reallocate sforce array if necessary
102 
103   if (varflag == ATOM && atom->nmax > maxatom) {
104     maxatom = atom->nmax;
105     memoryKK->destroy_kokkos(k_sforce,sforce);
106     memoryKK->create_kokkos(k_sforce,sforce,maxatom,3,"setforce:sforce");
107     d_sforce = k_sforce.view<DeviceType>();
108   }
109 
110   foriginal[0] = foriginal[1] = foriginal[2] = 0.0;
111   double_3 foriginal_kk;
112   force_flag = 0;
113 
114   if (varflag == CONSTANT) {
115     copymode = 1;
116     Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixSetForceConstant>(0,nlocal),*this,foriginal_kk);
117     copymode = 0;
118 
119   // variable force, wrap with clear/add
120 
121   } else {
122 
123     atomKK->sync(Host,ALL_MASK); // this can be removed when variable class is ported to Kokkos
124 
125     modify->clearstep_compute();
126 
127     if (xstyle == EQUAL) xvalue = input->variable->compute_equal(xvar);
128     else if (xstyle == ATOM)
129       input->variable->compute_atom(xvar,igroup,&sforce[0][0],3,0);
130     if (ystyle == EQUAL) yvalue = input->variable->compute_equal(yvar);
131     else if (ystyle == ATOM)
132       input->variable->compute_atom(yvar,igroup,&sforce[0][1],3,0);
133     if (zstyle == EQUAL) zvalue = input->variable->compute_equal(zvar);
134     else if (zstyle == ATOM)
135       input->variable->compute_atom(zvar,igroup,&sforce[0][2],3,0);
136 
137     modify->addstep_compute(update->ntimestep + 1);
138 
139     if (varflag == ATOM) {  // this can be removed when variable class is ported to Kokkos
140       k_sforce.modify<LMPHostType>();
141       k_sforce.sync<DeviceType>();
142     }
143 
144     copymode = 1;
145     Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixSetForceNonConstant>(0,nlocal),*this,foriginal_kk);
146     copymode = 0;
147   }
148 
149   atomKK->modified(execution_space, F_MASK);
150 
151   foriginal[0] = foriginal_kk.d0;
152   foriginal[1] = foriginal_kk.d1;
153   foriginal[2] = foriginal_kk.d2;
154 }
155 
156 template<class DeviceType>
157 KOKKOS_INLINE_FUNCTION
operator ()(TagFixSetForceConstant,const int & i,double_3 & foriginal_kk) const158 void FixSetForceKokkos<DeviceType>::operator()(TagFixSetForceConstant, const int &i, double_3& foriginal_kk) const {
159   if (mask[i] & groupbit) {
160     if (region && !d_match[i]) return;
161     foriginal_kk.d0 += f(i,0);
162     foriginal_kk.d1 += f(i,1);
163     foriginal_kk.d2 += f(i,2);
164     if (xstyle) f(i,0) = xvalue;
165     if (ystyle) f(i,1) = yvalue;
166     if (zstyle) f(i,2) = zvalue;
167   }
168 }
169 
170 template<class DeviceType>
171 KOKKOS_INLINE_FUNCTION
operator ()(TagFixSetForceNonConstant,const int & i,double_3 & foriginal_kk) const172 void FixSetForceKokkos<DeviceType>::operator()(TagFixSetForceNonConstant, const int &i, double_3& foriginal_kk) const {
173   if (mask[i] & groupbit) {
174     if (region && !d_match[i]) return;
175     foriginal_kk.d0 += f(i,0);
176     foriginal_kk.d1 += f(i,1);
177     foriginal_kk.d2 += f(i,2);
178     if (xstyle == ATOM) f(i,0) = d_sforce(i,0);
179     else if (xstyle) f(i,0) = xvalue;
180     if (ystyle == ATOM) f(i,1) = d_sforce(i,1);
181     else if (ystyle) f(i,1) = yvalue;
182     if (zstyle == ATOM) f(i,2) = d_sforce(i,2);
183     else if (zstyle) f(i,2) = zvalue;
184   }
185 }
186 
187 namespace LAMMPS_NS {
188 template class FixSetForceKokkos<LMPDeviceType>;
189 #ifdef LMP_KOKKOS_GPU
190 template class FixSetForceKokkos<LMPHostType>;
191 #endif
192 }
193 
194