1 // clang-format off
2 /* -*- c++ -*- ----------------------------------------------------------
3    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
4    https://www.lammps.org/, Sandia National Laboratories
5    Steve Plimpton, sjplimp@sandia.gov
6 
7    Copyright (2003) Sandia Corporation.  Under the terms of Contract
8    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
9    certain rights in this software.  This software is distributed under
10    the GNU General Public License.
11 
12    See the README file in the top-level LAMMPS directory.
13 ------------------------------------------------------------------------- */
14 
15 /* ----------------------------------------------------------------------
16    Contributing authors: Christian Trott (SNL), Stan Moore (SNL)
17 ------------------------------------------------------------------------- */
18 
19 #include "sna_kokkos.h"
20 #include <cmath>
21 #include <cstring>
22 #include <cstdlib>
23 #include <type_traits>
24 
25 namespace LAMMPS_NS {
26 
27 static const double MY_PI  = 3.14159265358979323846; // pi
28 
29 template<class DeviceType, typename real_type, int vector_length>
30 inline
SNAKokkos(real_type rfac0_in,int twojmax_in,real_type rmin0_in,int switch_flag_in,int bzero_flag_in,int chem_flag_in,int bnorm_flag_in,int wselfall_flag_in,int nelements_in)31 SNAKokkos<DeviceType, real_type, vector_length>::SNAKokkos(real_type rfac0_in,
32          int twojmax_in, real_type rmin0_in, int switch_flag_in, int bzero_flag_in,
33          int chem_flag_in, int bnorm_flag_in, int wselfall_flag_in, int nelements_in)
34 {
35   LAMMPS_NS::ExecutionSpace execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
36   host_flag = (execution_space == LAMMPS_NS::Host);
37 
38   wself = static_cast<real_type>(1.0);
39 
40   rfac0 = rfac0_in;
41   rmin0 = rmin0_in;
42   switch_flag = switch_flag_in;
43   bzero_flag = bzero_flag_in;
44 
45   chem_flag = chem_flag_in;
46   if (chem_flag)
47     nelements = nelements_in;
48   else
49     nelements = 1;
50   bnorm_flag = bnorm_flag_in;
51   wselfall_flag = wselfall_flag_in;
52 
53   twojmax = twojmax_in;
54 
55   ncoeff = compute_ncoeff();
56 
57   nmax = 0;
58   natom = 0;
59 
60   build_indexlist();
61 
62   int jdimpq = twojmax + 2;
63   rootpqarray = t_sna_2d("SNAKokkos::rootpqarray",jdimpq,jdimpq);
64 
65   cglist = t_sna_1d("SNAKokkos::cglist",idxcg_max);
66 
67   if (bzero_flag) {
68     bzero = Kokkos::View<real_type*, Kokkos::LayoutRight, DeviceType>("sna:bzero",twojmax+1);
69     auto h_bzero = Kokkos::create_mirror_view(bzero);
70 
71     double www = wself*wself*wself;
72     for (int j = 0; j <= twojmax; j++)
73       if (bnorm_flag)
74         h_bzero[j] = www;
75       else
76         h_bzero[j] = www*(j+1);
77     Kokkos::deep_copy(bzero,h_bzero);
78   }
79 }
80 
81 /* ---------------------------------------------------------------------- */
82 
83 template<class DeviceType, typename real_type, int vector_length>
84 KOKKOS_INLINE_FUNCTION
~SNAKokkos()85 SNAKokkos<DeviceType, real_type, vector_length>::~SNAKokkos()
86 {
87 }
88 
89 template<class DeviceType, typename real_type, int vector_length>
90 inline
build_indexlist()91 void SNAKokkos<DeviceType, real_type, vector_length>::build_indexlist()
92 {
93   // index list for cglist
94 
95   int jdim = twojmax + 1;
96   idxcg_block = Kokkos::View<int***, DeviceType>(Kokkos::NoInit("SNAKokkos::idxcg_block"),jdim,jdim,jdim);
97   auto h_idxcg_block = Kokkos::create_mirror_view(idxcg_block);
98 
99   int idxcg_count = 0;
100   for (int j1 = 0; j1 <= twojmax; j1++)
101     for (int j2 = 0; j2 <= j1; j2++)
102       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2) {
103         h_idxcg_block(j1,j2,j) = idxcg_count;
104         for (int m1 = 0; m1 <= j1; m1++)
105           for (int m2 = 0; m2 <= j2; m2++)
106             idxcg_count++;
107       }
108   idxcg_max = idxcg_count;
109   Kokkos::deep_copy(idxcg_block,h_idxcg_block);
110 
111   // index list for uarray
112   // need to include both halves
113 
114   idxu_block = Kokkos::View<int*, DeviceType>(Kokkos::NoInit("SNAKokkos::idxu_block"),jdim);
115   auto h_idxu_block = Kokkos::create_mirror_view(idxu_block);
116 
117   int idxu_count = 0;
118 
119   for (int j = 0; j <= twojmax; j++) {
120     h_idxu_block[j] = idxu_count;
121     for (int mb = 0; mb <= j; mb++)
122       for (int ma = 0; ma <= j; ma++)
123         idxu_count++;
124   }
125   idxu_max = idxu_count;
126   Kokkos::deep_copy(idxu_block,h_idxu_block);
127 
128   // index list for half uarray
129   idxu_half_block = Kokkos::View<int*, DeviceType>(Kokkos::NoInit("SNAKokkos::idxu_half_block"),jdim);
130   auto h_idxu_half_block = Kokkos::create_mirror_view(idxu_half_block);
131 
132   int idxu_half_count = 0;
133   for (int j = 0; j <= twojmax; j++) {
134     h_idxu_half_block[j] = idxu_half_count;
135     for (int mb = 0; 2*mb <= j; mb++)
136       for (int ma = 0; ma <= j; ma++)
137         idxu_half_count++;
138   }
139   idxu_half_max = idxu_half_count;
140   Kokkos::deep_copy(idxu_half_block, h_idxu_half_block);
141 
142   // mapping between full and half indexing, encoding flipping
143   idxu_full_half = Kokkos::View<FullHalfMapper*, DeviceType>(Kokkos::NoInit("SNAKokkos::idxu_full_half"),idxu_max);
144   auto h_idxu_full_half = Kokkos::create_mirror_view(idxu_full_half);
145 
146   idxu_count = 0;
147   for (int j = 0; j <= twojmax; j++) {
148     int jju_half = h_idxu_half_block[j];
149     for (int mb = 0; mb <= j; mb++) {
150       for (int ma = 0; ma <= j; ma++) {
151         FullHalfMapper mapper;
152         if (2*mb <= j) {
153           mapper.idxu_half = jju_half + mb * (j + 1) + ma;
154           mapper.flip_sign = 0;
155         } else {
156           mapper.idxu_half = jju_half + (j + 1 - mb) * (j + 1) - (ma + 1);
157           mapper.flip_sign = (((ma+mb)%2==0)?1:-1);
158         }
159         h_idxu_full_half[idxu_count] = mapper;
160         idxu_count++;
161       }
162     }
163   }
164 
165   Kokkos::deep_copy(idxu_full_half, h_idxu_full_half);
166 
167   // index list for "cache" uarray
168   // this is the GPU scratch memory requirements
169   // applied the CPU structures
170   idxu_cache_block = Kokkos::View<int*, DeviceType>(Kokkos::NoInit("SNAKokkos::idxu_cache_block"),jdim);
171   auto h_idxu_cache_block = Kokkos::create_mirror_view(idxu_cache_block);
172 
173   int idxu_cache_count = 0;
174   for (int j = 0; j <= twojmax; j++) {
175     h_idxu_cache_block[j] = idxu_cache_count;
176     for (int mb = 0; mb < ((j+3)/2); mb++)
177       for (int ma = 0; ma <= j; ma++)
178         idxu_cache_count++;
179   }
180   idxu_cache_max = idxu_cache_count;
181   Kokkos::deep_copy(idxu_cache_block, h_idxu_cache_block);
182 
183   // index list for beta and B
184 
185   int idxb_count = 0;
186   for (int j1 = 0; j1 <= twojmax; j1++)
187     for (int j2 = 0; j2 <= j1; j2++)
188       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2)
189         if (j >= j1) idxb_count++;
190 
191   idxb_max = idxb_count;
192   idxb = Kokkos::View<int*[3], DeviceType>(Kokkos::NoInit("SNAKokkos::idxb"),idxb_max);
193   auto h_idxb = Kokkos::create_mirror_view(idxb);
194 
195   idxb_count = 0;
196   for (int j1 = 0; j1 <= twojmax; j1++)
197     for (int j2 = 0; j2 <= j1; j2++)
198       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2)
199         if (j >= j1) {
200           h_idxb(idxb_count,0) = j1;
201           h_idxb(idxb_count,1) = j2;
202           h_idxb(idxb_count,2) = j;
203           idxb_count++;
204         }
205   Kokkos::deep_copy(idxb,h_idxb);
206 
207   // reverse index list for beta and b
208 
209   idxb_block = Kokkos::View<int***, DeviceType>(Kokkos::NoInit("SNAKokkos::idxb_block"),jdim,jdim,jdim);
210   auto h_idxb_block = Kokkos::create_mirror_view(idxb_block);
211 
212   idxb_count = 0;
213   for (int j1 = 0; j1 <= twojmax; j1++)
214     for (int j2 = 0; j2 <= j1; j2++)
215       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2) {
216         if (j >= j1) {
217           h_idxb_block(j1,j2,j) = idxb_count;
218           idxb_count++;
219         }
220       }
221   Kokkos::deep_copy(idxb_block,h_idxb_block);
222 
223   // index list for zlist
224 
225   int idxz_count = 0;
226 
227   for (int j1 = 0; j1 <= twojmax; j1++)
228     for (int j2 = 0; j2 <= j1; j2++)
229       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2)
230         for (int mb = 0; 2*mb <= j; mb++)
231           for (int ma = 0; ma <= j; ma++)
232             idxz_count++;
233 
234   idxz_max = idxz_count;
235   idxz = Kokkos::View<int*[10], DeviceType>(Kokkos::NoInit("SNAKokkos::idxz"),idxz_max);
236   auto h_idxz = Kokkos::create_mirror_view(idxz);
237 
238   idxz_block = Kokkos::View<int***, DeviceType>(Kokkos::NoInit("SNAKokkos::idxz_block"), jdim,jdim,jdim);
239   auto h_idxz_block = Kokkos::create_mirror_view(idxz_block);
240 
241   idxz_count = 0;
242   for (int j1 = 0; j1 <= twojmax; j1++)
243     for (int j2 = 0; j2 <= j1; j2++)
244       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2) {
245         h_idxz_block(j1,j2,j) = idxz_count;
246 
247         // find right beta(ii,jjb) entry
248         // multiply and divide by j+1 factors
249         // account for multiplicity of 1, 2, or 3
250 
251         for (int mb = 0; 2*mb <= j; mb++)
252           for (int ma = 0; ma <= j; ma++) {
253             h_idxz(idxz_count,0) = j1;
254             h_idxz(idxz_count,1) = j2;
255             h_idxz(idxz_count,2) = j;
256             h_idxz(idxz_count,3) = MAX(0, (2 * ma - j - j2 + j1) / 2);
257             h_idxz(idxz_count,4) = (2 * ma - j - (2 * h_idxz(idxz_count,3) - j1) + j2) / 2;
258             h_idxz(idxz_count,5) = MAX(0, (2 * mb - j - j2 + j1) / 2);
259             h_idxz(idxz_count,6) = (2 * mb - j - (2 * h_idxz(idxz_count,5) - j1) + j2) / 2;
260             h_idxz(idxz_count,7) = MIN(j1, (2 * ma - j + j2 + j1) / 2) - h_idxz(idxz_count,3) + 1;
261             h_idxz(idxz_count,8) = MIN(j1, (2 * mb - j + j2 + j1) / 2) - h_idxz(idxz_count,5) + 1;
262 
263             // apply to z(j1,j2,j,ma,mb) to unique element of y(j)
264             // ylist is "compressed" via symmetry in its
265             // contraction with dulist
266             const int jju_half = h_idxu_half_block[j] + (j+1)*mb + ma;
267             h_idxz(idxz_count,9) = jju_half;
268 
269             idxz_count++;
270           }
271       }
272   Kokkos::deep_copy(idxz,h_idxz);
273   Kokkos::deep_copy(idxz_block,h_idxz_block);
274 
275 }
276 
277 /* ---------------------------------------------------------------------- */
278 
279 template<class DeviceType, typename real_type, int vector_length>
280 inline
init()281 void SNAKokkos<DeviceType, real_type, vector_length>::init()
282 {
283   init_clebsch_gordan();
284   init_rootpqarray();
285 }
286 
287 template<class DeviceType, typename real_type, int vector_length>
288 inline
grow_rij(int newnatom,int newnmax)289 void SNAKokkos<DeviceType, real_type, vector_length>::grow_rij(int newnatom, int newnmax)
290 {
291   if (newnatom <= natom && newnmax <= nmax) return;
292   natom = newnatom;
293   nmax = newnmax;
294 
295   rij = t_sna_3d(Kokkos::NoInit("sna:rij"),natom,nmax,3);
296   wj = t_sna_2d(Kokkos::NoInit("sna:wj"),natom,nmax);
297   rcutij = t_sna_2d(Kokkos::NoInit("sna:rcutij"),natom,nmax);
298   inside = t_sna_2i(Kokkos::NoInit("sna:inside"),natom,nmax);
299   element = t_sna_2i(Kokkos::NoInit("sna:element"),natom,nmax);
300   dedr = t_sna_3d(Kokkos::NoInit("sna:dedr"),natom,nmax,3);
301 
302 #ifdef LMP_KOKKOS_GPU
303   if (!host_flag) {
304     const int natom_div = (natom + vector_length - 1) / vector_length;
305 
306     a_pack = t_sna_3c_ll(Kokkos::NoInit("sna:a_pack"),vector_length,nmax,natom_div);
307     b_pack = t_sna_3c_ll(Kokkos::NoInit("sna:b_pack"),vector_length,nmax,natom_div);
308     da_pack = t_sna_4c_ll(Kokkos::NoInit("sna:da_pack"),vector_length,nmax,natom_div,3);
309     db_pack = t_sna_4c_ll(Kokkos::NoInit("sna:db_pack"),vector_length,nmax,natom_div,3);
310     sfac_pack = t_sna_4d_ll(Kokkos::NoInit("sna:sfac_pack"),vector_length,nmax,natom_div,4);
311     ulisttot = t_sna_3c_ll(Kokkos::NoInit("sna:ulisttot"),1,1,1); // dummy allocation
312     ulisttot_full = t_sna_3c_ll(Kokkos::NoInit("sna:ulisttot"),1,1,1);
313     ulisttot_re_pack = t_sna_4d_ll(Kokkos::NoInit("sna:ulisttot_re_pack"),vector_length,idxu_half_max,nelements,natom_div);
314     ulisttot_im_pack = t_sna_4d_ll(Kokkos::NoInit("sna:ulisttot_im_pack"),vector_length,idxu_half_max,nelements,natom_div);
315     ulisttot_pack = t_sna_4c_ll(Kokkos::NoInit("sna:ulisttot_pack"),vector_length,idxu_max,nelements,natom_div);
316     ulist = t_sna_3c_ll(Kokkos::NoInit("sna:ulist"),1,1,1);
317     zlist = t_sna_3c_ll(Kokkos::NoInit("sna:zlist"),1,1,1);
318     zlist_pack = t_sna_4c_ll(Kokkos::NoInit("sna:zlist_pack"),vector_length,idxz_max,ndoubles,natom_div);
319     blist = t_sna_3d(Kokkos::NoInit("sna:blist"),natom,ntriples,idxb_max);
320     blist_pack = t_sna_4d_ll(Kokkos::NoInit("sna:blist_pack"),vector_length,idxb_max,ntriples,natom_div);
321     ylist = t_sna_3c_ll(Kokkos::NoInit("sna:ylist"),1,1,1);
322     ylist_pack_re = t_sna_4d_ll(Kokkos::NoInit("sna:ylist_pack_re"),vector_length,idxu_half_max,nelements,natom_div);
323     ylist_pack_im = t_sna_4d_ll(Kokkos::NoInit("sna:ylist_pack_im"),vector_length,idxu_half_max,nelements,natom_div);
324     dulist = t_sna_4c3_ll(Kokkos::NoInit("sna:dulist"),1,1,1);
325   } else {
326 #endif
327     a_pack = t_sna_3c_ll(Kokkos::NoInit("sna:a_pack"),1,1,1);
328     b_pack = t_sna_3c_ll(Kokkos::NoInit("sna:b_pack"),1,1,1);
329     da_pack = t_sna_4c_ll(Kokkos::NoInit("sna:da_pack"),1,1,1,1);
330     db_pack = t_sna_4c_ll(Kokkos::NoInit("sna:db_pack"),1,1,1,1);
331     sfac_pack = t_sna_4d_ll(Kokkos::NoInit("sna:sfac_pack"),1,1,1,1);
332     ulisttot = t_sna_3c_ll(Kokkos::NoInit("sna:ulisttot"),idxu_half_max,nelements,natom);
333     ulisttot_full = t_sna_3c_ll(Kokkos::NoInit("sna:ulisttot_full"),idxu_max,nelements,natom);
334     ulisttot_re_pack = t_sna_4d_ll(Kokkos::NoInit("sna:ulisttot_re"),1,1,1,1);
335     ulisttot_im_pack = t_sna_4d_ll(Kokkos::NoInit("sna:ulisttot_im"),1,1,1,1);
336     ulisttot_pack = t_sna_4c_ll(Kokkos::NoInit("sna:ulisttot_pack"),1,1,1,1);
337     ulist = t_sna_3c_ll(Kokkos::NoInit("sna:ulist"),idxu_cache_max,natom,nmax);
338     zlist = t_sna_3c_ll(Kokkos::NoInit("sna:zlist"),idxz_max,ndoubles,natom);
339     zlist_pack = t_sna_4c_ll(Kokkos::NoInit("sna:zlist_pack"),1,1,1,1);
340     blist = t_sna_3d(Kokkos::NoInit("sna:blist"),natom,ntriples,idxb_max);
341     blist_pack = t_sna_4d_ll(Kokkos::NoInit("sna:blist_pack"),1,1,1,1);
342     ylist = t_sna_3c_ll(Kokkos::NoInit("sna:ylist"),idxu_half_max,nelements,natom);
343     ylist_pack_re = t_sna_4d_ll(Kokkos::NoInit("sna:ylist_pack_re"),1,1,1,1);
344     ylist_pack_im = t_sna_4d_ll(Kokkos::NoInit("sna:ylist_pack_im"),1,1,1,1);
345     dulist = t_sna_4c3_ll(Kokkos::NoInit("sna:dulist"),idxu_cache_max,natom,nmax);
346 
347 #ifdef LMP_KOKKOS_GPU
348   }
349 #endif
350 }
351 
352 /* ----------------------------------------------------------------------
353  * GPU routines
354  * ----------------------------------------------------------------------*/
355 
356 
357 /* ----------------------------------------------------------------------
358    Precompute the Cayley-Klein parameters and the derivatives thereof.
359    This routine better exploits parallelism than the GPU ComputeUi and
360    ComputeFusedDeidrj, which are one warp per atom-neighbor pair.
361 ------------------------------------------------------------------------- */
362 
363 template<class DeviceType, typename real_type, int vector_length>
364 KOKKOS_INLINE_FUNCTION
compute_cayley_klein(const int & iatom_mod,const int & jnbor,const int & iatom_div)365 void SNAKokkos<DeviceType, real_type, vector_length>::compute_cayley_klein(const int& iatom_mod, const int& jnbor, const int& iatom_div)
366 {
367   const int iatom = iatom_mod + vector_length * iatom_div;
368   const real_type x = rij(iatom,jnbor,0);
369   const real_type y = rij(iatom,jnbor,1);
370   const real_type z = rij(iatom,jnbor,2);
371   const real_type rsq = x * x + y * y + z * z;
372   const real_type r = sqrt(rsq);
373   const real_type rcut = rcutij(iatom, jnbor);
374   const real_type rscale0 = rfac0 * static_cast<real_type>(MY_PI) / (rcut - rmin0);
375   const real_type theta0 = (r - rmin0) * rscale0;
376   const real_type sn = sin(theta0);
377   const real_type cs = cos(theta0);
378   const real_type z0 = r * cs / sn;
379   const real_type dz0dr = z0 / r - (r*rscale0) * (rsq + z0 * z0) / rsq;
380 
381   const real_type wj_local = wj(iatom, jnbor);
382   real_type sfac, dsfac;
383   compute_s_dsfac(r, rcut, sfac, dsfac);
384   sfac *= wj_local;
385   dsfac *= wj_local;
386 
387   const real_type rinv = static_cast<real_type>(1.0) / r;
388   const real_type ux = x * rinv;
389   const real_type uy = y * rinv;
390   const real_type uz = z * rinv;
391 
392   const real_type r0inv = static_cast<real_type>(1.0) / sqrt(r * r + z0 * z0);
393 
394   const complex a = { z0 * r0inv, -z * r0inv };
395   const complex b = { r0inv * y, -r0inv * x };
396 
397   const real_type dr0invdr = -r0inv * r0inv * r0inv * (r + z0 * dz0dr);
398 
399   const real_type dr0invx = dr0invdr * ux;
400   const real_type dr0invy = dr0invdr * uy;
401   const real_type dr0invz = dr0invdr * uz;
402 
403   const real_type dz0x = dz0dr * ux;
404   const real_type dz0y = dz0dr * uy;
405   const real_type dz0z = dz0dr * uz;
406 
407   const complex dax = { dz0x * r0inv + z0 * dr0invx, -z * dr0invx };
408   const complex day = { dz0y * r0inv + z0 * dr0invy, -z * dr0invy };
409   const complex daz = { dz0z * r0inv + z0 * dr0invz, -z * dr0invz - r0inv };
410 
411   const complex dbx = { y * dr0invx, -x * dr0invx - r0inv };
412   const complex dby = { y * dr0invy + r0inv, -x * dr0invy };
413   const complex dbz = { y * dr0invz, -x * dr0invz };
414 
415   const real_type dsfacux = dsfac * ux;
416   const real_type dsfacuy = dsfac * uy;
417   const real_type dsfacuz = dsfac * uz;
418 
419   a_pack(iatom_mod,jnbor,iatom_div) = a;
420   b_pack(iatom_mod,jnbor,iatom_div) = b;
421 
422   da_pack(iatom_mod,jnbor,iatom_div,0) = dax;
423   db_pack(iatom_mod,jnbor,iatom_div,0) = dbx;
424 
425   da_pack(iatom_mod,jnbor,iatom_div,1) = day;
426   db_pack(iatom_mod,jnbor,iatom_div,1) = dby;
427 
428   da_pack(iatom_mod,jnbor,iatom_div,2) = daz;
429   db_pack(iatom_mod,jnbor,iatom_div,2) = dbz;
430 
431   sfac_pack(iatom_mod,jnbor,iatom_div,0) = sfac;
432   sfac_pack(iatom_mod,jnbor,iatom_div,1) = dsfacux;
433   sfac_pack(iatom_mod,jnbor,iatom_div,2) = dsfacuy;
434   sfac_pack(iatom_mod,jnbor,iatom_div,3) = dsfacuz;
435 
436   // we need to explicitly zero `dedr` somewhere before hitting
437   // ComputeFusedDeidrj --- this is just a convenient place to do it.
438   dedr(iatom_mod + vector_length * iatom_div, jnbor, 0) = static_cast<real_type>(0.);
439   dedr(iatom_mod + vector_length * iatom_div, jnbor, 1) = static_cast<real_type>(0.);
440   dedr(iatom_mod + vector_length * iatom_div, jnbor, 2) = static_cast<real_type>(0.);
441 
442 }
443 
444 /* ----------------------------------------------------------------------
445    Initialize ulisttot with self-energy terms.
446    Ulisttot uses a "half" data layout which takes
447    advantage of the symmetry of the Wigner U matrices.
448 ------------------------------------------------------------------------- */
449 
450 template<class DeviceType, typename real_type, int vector_length>
451 KOKKOS_INLINE_FUNCTION
pre_ui(const int & iatom_mod,const int & j,const int & ielem,const int & iatom_div)452 void SNAKokkos<DeviceType, real_type, vector_length>::pre_ui(const int& iatom_mod, const int& j, const int& ielem, const int& iatom_div)
453 {
454 
455   for (int jelem = 0; jelem < nelements; jelem++) {
456     int jju_half = idxu_half_block(j);
457 
458     // Only diagonal elements get initialized
459     // Top half only: gets symmetrized by TransformUi
460 
461     for (int mb = 0; 2*mb <= j; mb++) {
462       for (int ma = 0; ma <= j; ma++) {
463 
464         real_type re_part = static_cast<real_type>(0.);
465         if (ma == mb && (!chem_flag || ielem == jelem || wselfall_flag)) { re_part = wself; }
466 
467         ulisttot_re_pack(iatom_mod, jju_half, jelem, iatom_div) = re_part;
468         ulisttot_im_pack(iatom_mod, jju_half, jelem, iatom_div) = static_cast<real_type>(0.);
469 
470         jju_half++;
471       }
472     }
473   }
474 
475 }
476 
477 /* ----------------------------------------------------------------------
478    compute Ui by computing Wigner U-functions for one neighbor and
479    accumulating to the total. GPU only.
480 ------------------------------------------------------------------------- */
481 
482 // Version of the code that exposes additional parallelism by threading over `j_bend` values
483 
484 template<class DeviceType, typename real_type, int vector_length>
485 KOKKOS_INLINE_FUNCTION
compute_ui_small(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,const int iatom_mod,const int j_bend,const int jnbor,const int iatom_div)486 void SNAKokkos<DeviceType, real_type, vector_length>::compute_ui_small(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, const int iatom_mod, const int j_bend, const int jnbor, const int iatom_div)
487 {
488 
489   // get shared memory offset
490   // scratch size: 32 atoms * (twojmax+1) cached values, no double buffer
491   const int tile_size = vector_length * (twojmax + 1);
492 
493   const int team_rank = team.team_rank();
494   const int scratch_shift = team_rank * tile_size;
495 
496   // extract and wrap
497   const WignerWrapper<real_type, vector_length> ulist_wrapper((complex*)team.team_shmem().get_shmem(team.team_size() * tile_size * sizeof(complex), 0) + scratch_shift, iatom_mod);
498 
499   // load parameters
500   const complex a = a_pack(iatom_mod, jnbor, iatom_div);
501   const complex b = b_pack(iatom_mod, jnbor, iatom_div);
502   const real_type sfac = sfac_pack(iatom_mod, jnbor, iatom_div, 0);
503 
504   const int jelem = element(iatom_mod + vector_length * iatom_div, jnbor);
505 
506   // we need to "choose" when to bend
507   // this for loop is here for context --- we expose additional
508   // parallelism over this loop instead
509   //for (int j_bend = 0; j_bend <= twojmax; j_bend++) {
510   evaluate_ui_jbend(ulist_wrapper, a, b, sfac, jelem, iatom_mod, j_bend, iatom_div);
511 }
512 
513 // Version of the code that loops over all `j_bend` values which reduces integer arithmetic
514 // and some amount of load imbalance, at the expense of reducing parallelism
515 template<class DeviceType, typename real_type, int vector_length>
516 KOKKOS_INLINE_FUNCTION
compute_ui_large(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,const int iatom_mod,const int jnbor,const int iatom_div)517 void SNAKokkos<DeviceType, real_type, vector_length>::compute_ui_large(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, const int iatom_mod, const int jnbor, const int iatom_div)
518 {
519   // get shared memory offset
520   // scratch size: 32 atoms * (twojmax+1) cached values, no double buffer
521   const int tile_size = vector_length * (twojmax + 1);
522 
523   const int team_rank = team.team_rank();
524   const int scratch_shift = team_rank * tile_size;
525 
526   // extract and wrap
527   const WignerWrapper<real_type, vector_length> ulist_wrapper((complex*)team.team_shmem().get_shmem(team.team_size() * tile_size * sizeof(complex), 0) + scratch_shift, iatom_mod);
528 
529   // load parameters
530   const complex a = a_pack(iatom_mod, jnbor, iatom_div);
531   const complex b = b_pack(iatom_mod, jnbor, iatom_div);
532   const real_type sfac = sfac_pack(iatom_mod, jnbor, iatom_div, 0);
533 
534   const int jelem = element(iatom_mod + vector_length * iatom_div, jnbor);
535 
536   // we need to "choose" when to bend
537   #ifdef LMP_KK_DEVICE_COMPILE
538   #pragma unroll
539   #endif
540   for (int j_bend = 0; j_bend <= twojmax; j_bend++) {
541     evaluate_ui_jbend(ulist_wrapper, a, b, sfac, jelem, iatom_mod, j_bend, iatom_div);
542   }
543 }
544 
545 // Core "evaluation" kernel that gets reused in `compute_ui_small` and `compute_ui_large`
546 template<class DeviceType, typename real_type, int vector_length>
547 KOKKOS_FORCEINLINE_FUNCTION
evaluate_ui_jbend(const WignerWrapper<real_type,vector_length> & ulist_wrapper,const complex & a,const complex & b,const real_type & sfac,const int & jelem,const int & iatom_mod,const int & j_bend,const int & iatom_div)548 void SNAKokkos<DeviceType, real_type, vector_length>::evaluate_ui_jbend(const WignerWrapper<real_type, vector_length>& ulist_wrapper,
549           const complex& a, const complex& b, const real_type& sfac, const int& jelem,
550           const int& iatom_mod, const int& j_bend, const int& iatom_div)
551 {
552 
553   // utot(j,ma,mb) = 0 for all j,ma,ma
554   // utot(j,ma,ma) = 1 for all j,ma
555   // for j in neighbors of i:
556   //   compute r0 = (x,y,z,z0)
557   //   utot(j,ma,mb) += u(r0;j,ma,mb) for all j,ma,mb
558 
559   // level 0 is just 1.
560   ulist_wrapper.set(0, complex::one());
561 
562   // j from before the bend, don't store, mb == 0
563   for (int j = 1; j <= j_bend; j++) {
564 
565     constexpr int mb = 0; // intentional for readability, compiler should optimize this out
566 
567     complex ulist_accum = complex::zero();
568 
569     int ma;
570     for (ma = 0; ma < j; ma++) {
571 
572       // grab the cached value
573       const complex ulist_prev = ulist_wrapper.get(ma);
574 
575       // ulist_accum += rootpq * a.conj() * ulist_prev;
576       real_type rootpq = rootpqarray(j - ma, j - mb);
577       ulist_accum.re += rootpq * (a.re * ulist_prev.re + a.im * ulist_prev.im);
578       ulist_accum.im += rootpq * (a.re * ulist_prev.im - a.im * ulist_prev.re);
579 
580       // store ulist_accum, we atomic accumulate values after the bend, so no atomic add here
581       ulist_wrapper.set(ma, ulist_accum);
582 
583       // next value
584       // ulist_accum = -rootpq * b.conj() * ulist_prev;
585       rootpq = rootpqarray(ma + 1, j - mb);
586       ulist_accum.re = -rootpq * (b.re * ulist_prev.re + b.im * ulist_prev.im);
587       ulist_accum.im = -rootpq * (b.re * ulist_prev.im - b.im * ulist_prev.re);
588 
589     }
590 
591     ulist_wrapper.set(ma, ulist_accum);
592   }
593 
594   // now we're after the bend, start storing but only up to the "half way point"
595   const int j_half_way = MIN(2 * j_bend, twojmax);
596 
597   int mb = 1;
598   int j; //= j_bend + 1; // need this value below
599   for (j = j_bend + 1; j <= j_half_way; j++) {
600 
601     const int jjup = idxu_half_block[j-1] + (mb - 1) * j;
602 
603     complex ulist_accum = complex::zero();
604 
605     int ma;
606     for (ma = 0; ma < j; ma++) {
607 
608       // grab the cached value
609       const complex ulist_prev = ulist_wrapper.get(ma);
610 
611       // atomic add the previous level here
612       Kokkos::atomic_add(&(ulisttot_re_pack(iatom_mod, jjup + ma, jelem, iatom_div)), ulist_prev.re * sfac);
613       Kokkos::atomic_add(&(ulisttot_im_pack(iatom_mod, jjup + ma, jelem, iatom_div)), ulist_prev.im * sfac);
614 
615       // ulist_accum += rootpq * b * ulist_prev;
616       real_type rootpq = rootpqarray(j - ma, mb);
617       ulist_accum.re += rootpq * (b.re * ulist_prev.re - b.im * ulist_prev.im);
618       ulist_accum.im += rootpq * (b.re * ulist_prev.im + b.im * ulist_prev.re);
619 
620       // store ulist_accum
621       ulist_wrapper.set(ma, ulist_accum);
622 
623       // next value
624       // ulist_accum = rootpq * a * ulist_prev;
625       rootpq = rootpqarray(ma + 1, mb);
626       ulist_accum.re = rootpq * (a.re * ulist_prev.re - a.im * ulist_prev.im);
627       ulist_accum.im = rootpq * (a.re * ulist_prev.im + a.im * ulist_prev.re);
628     }
629 
630     ulist_wrapper.set(ma, ulist_accum);
631 
632     mb++;
633   }
634 
635   // atomic add the last level
636   const int jjup = idxu_half_block[j-1] + (mb - 1) * j;
637 
638   for (int ma = 0; ma < j; ma++) {
639     const complex ulist_prev = ulist_wrapper.get(ma);
640 
641     // atomic add the previous level here
642     Kokkos::atomic_add(&(ulisttot_re_pack(iatom_mod, jjup + ma, jelem, iatom_div)), ulist_prev.re * sfac);
643     Kokkos::atomic_add(&(ulisttot_im_pack(iatom_mod, jjup + ma, jelem, iatom_div)), ulist_prev.im * sfac);
644   }
645 
646 }
647 
648 /* ----------------------------------------------------------------------
649    compute Zi by summing over products of Ui,
650    AoSoA data layout to take advantage of coalescing, avoiding warp
651    divergence. GPU version
652 ------------------------------------------------------------------------- */
653 
654 template<class DeviceType, typename real_type, int vector_length>
655 KOKKOS_INLINE_FUNCTION
compute_zi(const int & iatom_mod,const int & jjz,const int & iatom_div)656 void SNAKokkos<DeviceType, real_type, vector_length>::compute_zi(const int& iatom_mod, const int& jjz, const int& iatom_div)
657 {
658 
659   const int j1 = idxz(jjz, 0);
660   const int j2 = idxz(jjz, 1);
661   const int j = idxz(jjz, 2);
662   const int ma1min = idxz(jjz, 3);
663   const int ma2max = idxz(jjz, 4);
664   const int mb1min = idxz(jjz, 5);
665   const int mb2max = idxz(jjz, 6);
666   const int na = idxz(jjz, 7);
667   const int nb = idxz(jjz, 8);
668 
669   const real_type* cgblock = cglist.data() + idxcg_block(j1, j2, j);
670 
671   int idouble = 0;
672 
673   for (int elem1 = 0; elem1 < nelements; elem1++) {
674     for (int elem2 = 0; elem2 < nelements; elem2++) {
675 
676       zlist_pack(iatom_mod,jjz,idouble,iatom_div) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom_mod, elem1, elem2, iatom_div, cgblock);
677 
678       idouble++;
679     }
680   }
681 }
682 
683 /* ----------------------------------------------------------------------
684    compute Bi by summing conj(Ui)*Zi
685    AoSoA data layout to take advantage of coalescing, avoiding warp
686    divergence.
687 ------------------------------------------------------------------------- */
688 
689 template<class DeviceType, typename real_type, int vector_length>
690 KOKKOS_INLINE_FUNCTION
compute_bi(const int & iatom_mod,const int & jjb,const int & iatom_div)691 void SNAKokkos<DeviceType, real_type, vector_length>::compute_bi(const int& iatom_mod, const int& jjb, const int& iatom_div)
692 {
693   // for j1 = 0,...,twojmax
694   //   for j2 = 0,twojmax
695   //     for j = |j1-j2|,Min(twojmax,j1+j2),2
696   //        b(j1,j2,j) = 0
697   //        for mb = 0,...,jmid
698   //          for ma = 0,...,j
699   //            b(j1,j2,j) +=
700   //              2*Conj(u(j,ma,mb))*z(j1,j2,j,ma,mb)
701 
702   const int j1 = idxb(jjb,0);
703   const int j2 = idxb(jjb,1);
704   const int j = idxb(jjb,2);
705 
706   const int jjz = idxz_block(j1,j2,j);
707   const int jju = idxu_block[j];
708 
709   int itriple = 0;
710   int idouble = 0;
711   for (int elem1 = 0; elem1 < nelements; elem1++) {
712     for (int elem2 = 0; elem2 < nelements; elem2++) {
713       for (int elem3 = 0; elem3 < nelements; elem3++) {
714 
715         double sumzu = 0.0;
716         double sumzu_temp = 0.0;
717 
718         for (int mb = 0; 2*mb < j; mb++) {
719           for (int ma = 0; ma <= j; ma++) {
720             const int jju_index = jju+mb*(j+1)+ma;
721             const int jjz_index = jjz+mb*(j+1)+ma;
722             if (2*mb == j) return; // I think we can remove this?
723             const complex utot = ulisttot_pack(iatom_mod, jju_index, elem3, iatom_div);
724             const complex zloc = zlist_pack(iatom_mod, jjz_index, idouble, iatom_div);
725             sumzu_temp += utot.re * zloc.re + utot.im * zloc.im;
726           }
727         }
728         sumzu += sumzu_temp;
729 
730         // For j even, special treatment for middle column
731         if (j%2 == 0) {
732           sumzu_temp = 0.;
733 
734           const int mb = j/2;
735           for (int ma = 0; ma < mb; ma++) {
736             const int jju_index = jju+(mb-1)*(j+1)+(j+1)+ma;
737             const int jjz_index = jjz+(mb-1)*(j+1)+(j+1)+ma;
738 
739             const complex utot = ulisttot_pack(iatom_mod, jju_index, elem3, iatom_div);
740             const complex zloc = zlist_pack(iatom_mod, jjz_index, idouble, iatom_div);
741             sumzu_temp += utot.re * zloc.re + utot.im * zloc.im;
742 
743           }
744           sumzu += sumzu_temp;
745 
746           const int ma = mb;
747           const int jju_index = jju+(mb-1)*(j+1)+(j+1)+ma;
748           const int jjz_index = jjz+(mb-1)*(j+1)+(j+1)+ma;
749 
750           const complex utot = ulisttot_pack(iatom_mod, jju_index, elem3, iatom_div);
751           const complex zloc = zlist_pack(iatom_mod, jjz_index, idouble, iatom_div);
752           sumzu += static_cast<real_type>(0.5) * (utot.re * zloc.re + utot.im * zloc.im);
753         } // end if jeven
754 
755         sumzu *= static_cast<real_type>(2.0);
756         if (bzero_flag) {
757           if (!wselfall_flag) {
758             if (elem1 == elem2 && elem1 == elem3) {
759               sumzu -= bzero[j];
760             }
761           } else {
762             sumzu -= bzero[j];
763           }
764         }
765         blist_pack(iatom_mod, jjb, itriple, iatom_div) = sumzu;
766             //} // end loop over j
767           //} // end loop over j1, j2
768         itriple++;
769       } // end loop over elem3
770       idouble++;
771     } // end loop over elem2
772   } // end loop over elem1
773 }
774 
775 
776 /* ----------------------------------------------------------------------
777    compute Yi from Ui without storing Zi, looping over zlist indices.
778    AoSoA data layout to take advantage of coalescing, avoiding warp
779    divergence. GPU version.
780 ------------------------------------------------------------------------- */
781 
782 template<class DeviceType, typename real_type, int vector_length>
783 KOKKOS_INLINE_FUNCTION
compute_yi(int iatom_mod,int jjz,int iatom_div,const Kokkos::View<real_type ***,Kokkos::LayoutLeft,DeviceType> & beta_pack)784 void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(int iatom_mod, int jjz, int iatom_div,
785  const Kokkos::View<real_type***, Kokkos::LayoutLeft, DeviceType> &beta_pack)
786 {
787 
788   const int j1 = idxz(jjz, 0);
789   const int j2 = idxz(jjz, 1);
790   const int j = idxz(jjz, 2);
791   const int ma1min = idxz(jjz, 3);
792   const int ma2max = idxz(jjz, 4);
793   const int mb1min = idxz(jjz, 5);
794   const int mb2max = idxz(jjz, 6);
795   const int na = idxz(jjz, 7);
796   const int nb = idxz(jjz, 8);
797   const int jju_half = idxz(jjz, 9);
798 
799   const real_type *cgblock = cglist.data() + idxcg_block(j1,j2,j);
800   //int mb = (2 * (mb1min+mb2max) - j1 - j2 + j) / 2;
801   //int ma = (2 * (ma1min+ma2max) - j1 - j2 + j) / 2;
802 
803   for (int elem1 = 0; elem1 < nelements; elem1++) {
804     for (int elem2 = 0; elem2 < nelements; elem2++) {
805 
806       const complex ztmp = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom_mod, elem1, elem2, iatom_div, cgblock);
807 
808       // apply to z(j1,j2,j,ma,mb) to unique element of y(j)
809       // find right y_list[jju] and beta(iatom,jjb) entries
810       // multiply and divide by j+1 factors
811       // account for multiplicity of 1, 2, or 3
812 
813       // pick out right beta value
814       for (int elem3 = 0; elem3 < nelements; elem3++) {
815 
816         const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom_mod, elem1, elem2, elem3, iatom_div, beta_pack);
817 
818         Kokkos::atomic_add(&(ylist_pack_re(iatom_mod, jju_half, elem3, iatom_div)), betaj * ztmp.re);
819         Kokkos::atomic_add(&(ylist_pack_im(iatom_mod, jju_half, elem3, iatom_div)), betaj * ztmp.im);
820       } // end loop over elem3
821     } // end loop over elem2
822   } // end loop over elem1
823 }
824 
825 /* ----------------------------------------------------------------------
826    compute Yi from Ui without storing Zi, looping over zlist indices.
827    AoSoA data layout to take advantage of coalescing, avoiding warp
828    divergence. GPU version.
829 ------------------------------------------------------------------------- */
830 
831 template<class DeviceType, typename real_type, int vector_length>
832 KOKKOS_INLINE_FUNCTION
compute_yi_with_zlist(int iatom_mod,int jjz,int iatom_div,const Kokkos::View<real_type ***,Kokkos::LayoutLeft,DeviceType> & beta_pack)833 void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_with_zlist(int iatom_mod, int jjz, int iatom_div,
834  const Kokkos::View<real_type***, Kokkos::LayoutLeft, DeviceType> &beta_pack)
835 {
836   const int j1 = idxz(jjz, 0);
837   const int j2 = idxz(jjz, 1);
838   const int j = idxz(jjz, 2);
839   const int jju_half = idxz(jjz, 9);
840   int idouble = 0;
841   for (int elem1 = 0; elem1 < nelements; elem1++) {
842     for (int elem2 = 0; elem2 < nelements; elem2++) {
843       const complex ztmp = zlist_pack(iatom_mod,jjz,idouble,iatom_div);
844       // apply to z(j1,j2,j,ma,mb) to unique element of y(j)
845       // find right y_list[jju] and beta(iatom,jjb) entries
846       // multiply and divide by j+1 factors
847       // account for multiplicity of 1, 2, or 3
848       // pick out right beta value
849       for (int elem3 = 0; elem3 < nelements; elem3++) {
850 
851         const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom_mod, elem1, elem2, elem3, iatom_div, beta_pack);
852 
853         Kokkos::atomic_add(&(ylist_pack_re(iatom_mod, jju_half, elem3, iatom_div)), betaj * ztmp.re);
854         Kokkos::atomic_add(&(ylist_pack_im(iatom_mod, jju_half, elem3, iatom_div)), betaj * ztmp.im);
855       } // end loop over elem3
856       idouble++;
857     } // end loop over elem2
858   } // end loop over elem1
859 }
860 
861 // Core "evaluation" kernel that computes a single zlist value
862 // which gets used in both `compute_zi` and `compute_yi`
863 template<class DeviceType, typename real_type, int vector_length>
864 KOKKOS_FORCEINLINE_FUNCTION
evaluate_zi(const int & j1,const int & j2,const int & j,const int & ma1min,const int & ma2max,const int & mb1min,const int & mb2max,const int & na,const int & nb,const int & iatom_mod,const int & elem1,const int & elem2,const int & iatom_div,const real_type * cgblock)865 typename SNAKokkos<DeviceType, real_type, vector_length>::complex SNAKokkos<DeviceType, real_type, vector_length>::evaluate_zi(const int& j1, const int& j2, const int& j,
866         const int& ma1min, const int& ma2max, const int& mb1min, const int& mb2max, const int& na, const int& nb,
867         const int& iatom_mod, const int& elem1, const int& elem2, const int& iatom_div, const real_type* cgblock) {
868 
869   complex ztmp = complex::zero();
870 
871   int jju1 = idxu_block[j1] + (j1+1)*mb1min;
872   int jju2 = idxu_block[j2] + (j2+1)*mb2max;
873   int icgb = mb1min*(j2+1) + mb2max;
874 
875   #ifdef LMP_KK_DEVICE_COMPILE
876   #pragma unroll
877   #endif
878   for (int ib = 0; ib < nb; ib++) {
879 
880     int ma1 = ma1min;
881     int ma2 = ma2max;
882     int icga = ma1min*(j2+1) + ma2max;
883 
884     #ifdef LMP_KK_DEVICE_COMPILE
885     #pragma unroll
886     #endif
887     for (int ia = 0; ia < na; ia++) {
888       const complex utot1 = ulisttot_pack(iatom_mod, jju1+ma1, elem1, iatom_div);
889       const complex utot2 = ulisttot_pack(iatom_mod, jju2+ma2, elem2, iatom_div);
890       const real_type cgcoeff_a = cgblock[icga];
891       const real_type cgcoeff_b = cgblock[icgb];
892       ztmp.re += cgcoeff_a * cgcoeff_b * (utot1.re * utot2.re - utot1.im * utot2.im);
893       ztmp.im += cgcoeff_a * cgcoeff_b * (utot1.re * utot2.im + utot1.im * utot2.re);
894       ma1++;
895       ma2--;
896       icga += j2;
897     } // end loop over ia
898 
899     jju1 += j1 + 1;
900     jju2 -= j2 + 1;
901     icgb += j2;
902   } // end loop over ib
903 
904   if (bnorm_flag) {
905     const real_type scale = static_cast<real_type>(1) / static_cast<real_type>(j + 1);
906     ztmp.re *= scale;
907     ztmp.im *= scale;
908   }
909 
910   return ztmp;
911 }
912 
913 // Core "evaluation" kernel that extracts and rescales the appropriate `beta` value,
914 // which gets used in both `compute_yi` and `compute_yi_from_zlist
915 template<class DeviceType, typename real_type, int vector_length>
916 KOKKOS_FORCEINLINE_FUNCTION
evaluate_beta_scaled(const int & j1,const int & j2,const int & j,const int & iatom_mod,const int & elem1,const int & elem2,const int & elem3,const int & iatom_div,const Kokkos::View<real_type ***,Kokkos::LayoutLeft,DeviceType> & beta_pack)917 typename SNAKokkos<DeviceType, real_type, vector_length>::real_type SNAKokkos<DeviceType, real_type, vector_length>::evaluate_beta_scaled(const int& j1, const int& j2, const int& j,
918           const int& iatom_mod, const int& elem1, const int& elem2, const int& elem3, const int& iatom_div,
919           const Kokkos::View<real_type***, Kokkos::LayoutLeft, DeviceType> &beta_pack) {
920 
921   real_type betaj = 0;
922 
923   if (j >= j1) {
924     const int jjb = idxb_block(j1, j2, j);
925     const int itriple = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + jjb;
926     if (j1 == j) {
927       if (j2 == j) betaj = static_cast<real_type>(3) * beta_pack(iatom_mod, itriple, iatom_div);
928       else betaj = static_cast<real_type>(2) * beta_pack(iatom_mod, itriple, iatom_div);
929     } else betaj = beta_pack(iatom_mod, itriple, iatom_div);
930   } else if (j >= j2) {
931     const int jjb = idxb_block(j, j2, j1);
932     const int itriple = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + jjb;
933     if (j2 == j) betaj = static_cast<real_type>(2) * beta_pack(iatom_mod, itriple, iatom_div);
934     else betaj = beta_pack(iatom_mod, itriple, iatom_div);
935   } else {
936     const int jjb = idxb_block(j2, j, j1);
937     const int itriple = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + jjb;
938     betaj = beta_pack(iatom_mod, itriple, iatom_div);
939   }
940 
941   if (!bnorm_flag && j1 > j) {
942     const real_type scale = static_cast<real_type>(j1 + 1) / static_cast<real_type>(j + 1);
943     betaj *= scale;
944   }
945 
946   return betaj;
947 
948 }
949 
950 /* ----------------------------------------------------------------------
951    Fused calculation of the derivative of Ui w.r.t. atom j
952    and accumulation into dEidRj. GPU only.
953 ------------------------------------------------------------------------- */
954 
955 // Version of the code that exposes additional parallelism by threading over `j_bend` values
956 template<class DeviceType, typename real_type, int vector_length>
957 template<int dir>
958 KOKKOS_INLINE_FUNCTION
compute_fused_deidrj_small(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,const int iatom_mod,const int j_bend,const int jnbor,const int iatom_div)959 void SNAKokkos<DeviceType, real_type, vector_length>::compute_fused_deidrj_small(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, const int iatom_mod, const int j_bend, const int jnbor, const int iatom_div)
960 {
961   // get shared memory offset
962   // scratch size: 32 atoms * (twojmax+1) cached values, no double buffer
963   const int tile_size = vector_length * (twojmax + 1);
964 
965   const int team_rank = team.team_rank();
966   const int scratch_shift = team_rank * tile_size;
967 
968   // extract, wrap shared memory buffer
969   WignerWrapper<real_type, vector_length> ulist_wrapper((complex*)team.team_shmem().get_shmem(team.team_size() * tile_size * sizeof(complex), 0) + scratch_shift, iatom_mod);
970   WignerWrapper<real_type, vector_length> dulist_wrapper((complex*)team.team_shmem().get_shmem(team.team_size() * tile_size * sizeof(complex), 0) + scratch_shift, iatom_mod);
971 
972   // load parameters
973   const complex a = a_pack(iatom_mod, jnbor, iatom_div);
974   const complex b = b_pack(iatom_mod, jnbor, iatom_div);
975   const complex da = da_pack(iatom_mod, jnbor, iatom_div, dir);
976   const complex db = db_pack(iatom_mod, jnbor, iatom_div, dir);
977   const real_type sfac = sfac_pack(iatom_mod, jnbor, iatom_div, 0);
978   const real_type dsfacu = sfac_pack(iatom_mod, jnbor, iatom_div, dir + 1); // dsfac * u
979 
980   const int jelem = element(iatom_mod + vector_length * iatom_div, jnbor);
981 
982   // compute the contribution to dedr_full_sum for one "bend" location
983   const real_type dedr_full_sum = evaluate_duidrj_jbend(ulist_wrapper, a, b, sfac, dulist_wrapper, da, db, dsfacu,
984                                                        jelem, iatom_mod, j_bend, iatom_div);
985 
986   // dedr gets zeroed out at the start of each iteration in compute_cayley_klein
987   Kokkos::atomic_add(&(dedr(iatom_mod + vector_length * iatom_div, jnbor, dir)), static_cast<real_type>(2.0) * dedr_full_sum);
988 
989 }
990 
991 // Version of the code that loops over all `j_bend` values which reduces integer arithmetic
992 // and some amount of load imbalance, at the expense of reducing parallelism
993 template<class DeviceType, typename real_type, int vector_length>
994 template<int dir>
995 KOKKOS_INLINE_FUNCTION
compute_fused_deidrj_large(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,const int iatom_mod,const int jnbor,const int iatom_div)996 void SNAKokkos<DeviceType, real_type, vector_length>::compute_fused_deidrj_large(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, const int iatom_mod, const int jnbor, const int iatom_div)
997 {
998   // get shared memory offset
999   // scratch size: 32 atoms * (twojmax+1) cached values, no double buffer
1000   const int tile_size = vector_length * (twojmax + 1);
1001 
1002   const int team_rank = team.team_rank();
1003   const int scratch_shift = team_rank * tile_size;
1004 
1005   // extract, wrap shared memory buffer
1006   WignerWrapper<real_type, vector_length> ulist_wrapper((complex*)team.team_shmem().get_shmem(team.team_size() * tile_size * sizeof(complex), 0) + scratch_shift, iatom_mod);
1007   WignerWrapper<real_type, vector_length> dulist_wrapper((complex*)team.team_shmem().get_shmem(team.team_size() * tile_size * sizeof(complex), 0) + scratch_shift, iatom_mod);
1008 
1009   // load parameters
1010   const complex a = a_pack(iatom_mod, jnbor, iatom_div);
1011   const complex b = b_pack(iatom_mod, jnbor, iatom_div);
1012   const complex da = da_pack(iatom_mod, jnbor, iatom_div, dir);
1013   const complex db = db_pack(iatom_mod, jnbor, iatom_div, dir);
1014   const real_type sfac = sfac_pack(iatom_mod, jnbor, iatom_div, 0);
1015   const real_type dsfacu = sfac_pack(iatom_mod, jnbor, iatom_div, dir + 1); // dsfac * u
1016 
1017   const int jelem = element(iatom_mod + vector_length * iatom_div, jnbor);
1018 
1019   // compute the contributions to dedr_full_sum for all "bend" locations
1020   real_type dedr_full_sum = static_cast<real_type>(0);
1021   #ifdef LMP_KK_DEVICE_COMPILE
1022   #pragma unroll
1023   #endif
1024   for (int j_bend = 0; j_bend <= twojmax; j_bend++) {
1025     dedr_full_sum += evaluate_duidrj_jbend(ulist_wrapper, a, b, sfac, dulist_wrapper, da, db, dsfacu,
1026                                           jelem, iatom_mod, j_bend, iatom_div);
1027   }
1028 
1029   // there's one thread per atom, neighbor pair, so no need to make this atomic
1030   dedr(iatom_mod + vector_length * iatom_div, jnbor, dir) = static_cast<real_type>(2.0) * dedr_full_sum;
1031 
1032 }
1033 
1034 // Core "evaluation" kernel that gets reused in `compute_fused_deidrj_small` and
1035 // `compute_fused_deidrj_large`
1036 template<class DeviceType, typename real_type, int vector_length>
1037 KOKKOS_FORCEINLINE_FUNCTION
evaluate_duidrj_jbend(const WignerWrapper<real_type,vector_length> & ulist_wrapper,const complex & a,const complex & b,const real_type & sfac,const WignerWrapper<real_type,vector_length> & dulist_wrapper,const complex & da,const complex & db,const real_type & dsfacu,const int & jelem,const int & iatom_mod,const int & j_bend,const int & iatom_div)1038 typename SNAKokkos<DeviceType, real_type, vector_length>::real_type SNAKokkos<DeviceType, real_type, vector_length>::evaluate_duidrj_jbend(const WignerWrapper<real_type, vector_length>& ulist_wrapper, const complex& a, const complex& b, const real_type& sfac,
1039                       const WignerWrapper<real_type, vector_length>& dulist_wrapper, const complex& da, const complex& db, const real_type& dsfacu,
1040                       const int& jelem, const int& iatom_mod, const int& j_bend, const int& iatom_div) {
1041 
1042   real_type dedr_full_sum = static_cast<real_type>(0);
1043 
1044   // level 0 is just 1, 0
1045   ulist_wrapper.set(0, complex::one());
1046   dulist_wrapper.set(0, complex::zero());
1047 
1048   // j from before the bend, don't store, mb == 0
1049   // this is "creeping up the side"
1050   for (int j = 1; j <= j_bend; j++) {
1051 
1052     constexpr int mb = 0; // intentional for readability, compiler should optimize this out
1053 
1054     complex ulist_accum = complex::zero();
1055     complex dulist_accum = complex::zero();
1056 
1057     int ma;
1058     for (ma = 0; ma < j; ma++) {
1059 
1060       // grab the cached value
1061       const complex ulist_prev = ulist_wrapper.get(ma);
1062       const complex dulist_prev = dulist_wrapper.get(ma);
1063 
1064       // ulist_accum += rootpq * a.conj() * ulist_prev;
1065       real_type rootpq = rootpqarray(j - ma, j - mb);
1066       ulist_accum.re += rootpq * (a.re * ulist_prev.re + a.im * ulist_prev.im);
1067       ulist_accum.im += rootpq * (a.re * ulist_prev.im - a.im * ulist_prev.re);
1068 
1069       // product rule of above
1070       dulist_accum.re += rootpq * (da.re * ulist_prev.re + da.im * ulist_prev.im + a.re * dulist_prev.re + a.im * dulist_prev.im);
1071       dulist_accum.im += rootpq * (da.re * ulist_prev.im - da.im * ulist_prev.re + a.re * dulist_prev.im - a.im * dulist_prev.re);
1072 
1073       // store ulist_accum, we atomic accumulate values after the bend, so no atomic add here
1074       ulist_wrapper.set(ma, ulist_accum);
1075       dulist_wrapper.set(ma, dulist_accum);
1076 
1077       // next value
1078       // ulist_accum = -rootpq * b.conj() * ulist_prev;
1079       rootpq = rootpqarray(ma + 1, j - mb);
1080       ulist_accum.re = -rootpq * (b.re * ulist_prev.re + b.im * ulist_prev.im);
1081       ulist_accum.im = -rootpq * (b.re * ulist_prev.im - b.im * ulist_prev.re);
1082 
1083       // product rule of above
1084       dulist_accum.re = -rootpq * (db.re * ulist_prev.re + db.im * ulist_prev.im + b.re * dulist_prev.re + b.im * dulist_prev.im);
1085       dulist_accum.im = -rootpq * (db.re * ulist_prev.im - db.im * ulist_prev.re + b.re * dulist_prev.im - b.im * dulist_prev.re);
1086 
1087     }
1088 
1089     ulist_wrapper.set(ma, ulist_accum);
1090     dulist_wrapper.set(ma, dulist_accum);
1091   }
1092 
1093   // now we're after the bend, start storing but only up to the "half way point"
1094   const int j_half_way = MIN(2 * j_bend, twojmax);
1095 
1096   int mb = 1;
1097   int j; //= j_bend + 1; // need this value below
1098   for (j = j_bend + 1; j <= j_half_way; j++) {
1099 
1100     const int jjup = idxu_half_block[j-1] + (mb - 1) * j;
1101 
1102     complex ulist_accum = complex::zero();
1103     complex dulist_accum = complex::zero();
1104 
1105     int ma;
1106     for (ma = 0; ma < j; ma++) {
1107 
1108       // grab y_local early
1109       // this will never be the last element of a row, no need to rescale.
1110       complex y_local = complex(ylist_pack_re(iatom_mod, jjup + ma, jelem, iatom_div), ylist_pack_im(iatom_mod, jjup+ma, jelem, iatom_div));
1111 
1112       // grab the cached value
1113       const complex ulist_prev = ulist_wrapper.get(ma);
1114       const complex dulist_prev = dulist_wrapper.get(ma);
1115 
1116       // ulist_accum += rootpq * b * ulist_prev;
1117       real_type rootpq = rootpqarray(j - ma, mb);
1118       ulist_accum.re += rootpq * (b.re * ulist_prev.re - b.im * ulist_prev.im);
1119       ulist_accum.im += rootpq * (b.re * ulist_prev.im + b.im * ulist_prev.re);
1120 
1121       // product rule of above
1122       dulist_accum.re += rootpq * (db.re * ulist_prev.re - db.im * ulist_prev.im + b.re * dulist_prev.re - b.im * dulist_prev.im);
1123       dulist_accum.im += rootpq * (db.re * ulist_prev.im + db.im * ulist_prev.re + b.re * dulist_prev.im + b.im * dulist_prev.re);
1124 
1125       // store ulist_accum
1126       ulist_wrapper.set(ma, ulist_accum);
1127       dulist_wrapper.set(ma, dulist_accum);
1128 
1129       // Directly accumulate deidrj into sum_tmp
1130       const complex du_prod = (dsfacu * ulist_prev) + (sfac * dulist_prev);
1131       dedr_full_sum += du_prod.re * y_local.re + du_prod.im * y_local.im;
1132 
1133       // next value
1134       // ulist_accum = rootpq * a * ulist_prev;
1135       rootpq = rootpqarray(ma + 1, mb);
1136       ulist_accum.re = rootpq * (a.re * ulist_prev.re - a.im * ulist_prev.im);
1137       ulist_accum.im = rootpq * (a.re * ulist_prev.im + a.im * ulist_prev.re);
1138 
1139       // product rule of above
1140       dulist_accum.re = rootpq * (da.re * ulist_prev.re - da.im * ulist_prev.im + a.re * dulist_prev.re - a.im * dulist_prev.im);
1141       dulist_accum.im = rootpq * (da.re * ulist_prev.im + da.im * ulist_prev.re + a.re * dulist_prev.im + a.im * dulist_prev.re);
1142 
1143     }
1144 
1145     ulist_wrapper.set(ma, ulist_accum);
1146     dulist_wrapper.set(ma, dulist_accum);
1147 
1148     mb++;
1149   }
1150 
1151   // accumulate the last level
1152   const int jjup = idxu_half_block[j-1] + (mb - 1) * j;
1153 
1154   for (int ma = 0; ma < j; ma++) {
1155     // grab y_local early
1156     complex y_local = complex(ylist_pack_re(iatom_mod, jjup + ma, jelem, iatom_div), ylist_pack_im(iatom_mod, jjup+ma, jelem, iatom_div));
1157     if (j % 2 == 1 && 2*(mb-1) == j-1) { // double check me...
1158       if (ma == (mb-1)) { y_local = static_cast<real_type>(0.5)*y_local; }
1159       else if (ma > (mb-1)) { y_local.re = static_cast<real_type>(0.); y_local.im = static_cast<real_type>(0.); } // can probably avoid this outright
1160       // else the ma < mb gets "double counted", cancelling the 0.5.
1161     }
1162 
1163     const complex ulist_prev = ulist_wrapper.get(ma);
1164     const complex dulist_prev = dulist_wrapper.get(ma);
1165 
1166     // Directly accumulate deidrj into sum_tmp
1167     const complex du_prod = (dsfacu * ulist_prev) + (sfac * dulist_prev);
1168     dedr_full_sum += du_prod.re * y_local.re + du_prod.im * y_local.im;
1169 
1170   }
1171 
1172   return dedr_full_sum;
1173 }
1174 
1175 /* ----------------------------------------------------------------------
1176  * CPU routines
1177  * ----------------------------------------------------------------------*/
1178 
1179 /* ----------------------------------------------------------------------
1180    Ulisttot uses a "half" data layout which takes
1181    advantage of the symmetry of the Wigner U matrices.
1182  * ------------------------------------------------------------------------- */
1183 
1184 template<class DeviceType, typename real_type, int vector_length>
1185 KOKKOS_INLINE_FUNCTION
pre_ui_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,const int & iatom,const int & ielem)1186 void SNAKokkos<DeviceType, real_type, vector_length>::pre_ui_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, const int& iatom, const int& ielem)
1187 {
1188   for (int jelem = 0; jelem < nelements; jelem++) {
1189     for (int j = 0; j <= twojmax; j++) {
1190       int jju = idxu_half_block(j); // removed "const" to work around GCC 7 bug
1191 
1192       // Only diagonal elements get initialized
1193       // for (int m = 0; m < (j+1)*(j/2+1); m++)
1194       Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, (j+1)*(j/2+1)),
1195         [&] (const int m) {
1196 
1197         const int jjup = jju + m;
1198 
1199         // if m is on the "diagonal", initialize it with the self energy.
1200         // Otherwise zero it out
1201         complex init(static_cast<real_type>(0.),static_cast<real_type>(0.));
1202         if (m % (j+2) == 0 && (!chem_flag || ielem == jelem || wselfall_flag)) { init.re = wself; } //need to map iatom to element
1203 
1204         ulisttot(jjup, jelem, iatom) = init;
1205       });
1206     }
1207   }
1208 
1209 }
1210 
1211 
1212 /* ----------------------------------------------------------------------
1213    compute Ui by summing over bispectrum components. CPU only.
1214    See comments above compute_uarray_cpu and add_uarraytot for
1215    data layout comments.
1216 ------------------------------------------------------------------------- */
1217 
1218 template<class DeviceType, typename real_type, int vector_length>
1219 KOKKOS_INLINE_FUNCTION
compute_ui_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom,int jnbor)1220 void SNAKokkos<DeviceType, real_type, vector_length>::compute_ui_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom, int jnbor)
1221 {
1222   real_type rsq, r, x, y, z, z0, theta0;
1223 
1224   // utot(j,ma,mb) = 0 for all j,ma,ma
1225   // utot(j,ma,ma) = 1 for all j,ma
1226   // for j in neighbors of i:
1227   //   compute r0 = (x,y,z,z0)
1228   //   utot(j,ma,mb) += u(r0;j,ma,mb) for all j,ma,mb
1229 
1230   x = rij(iatom,jnbor,0);
1231   y = rij(iatom,jnbor,1);
1232   z = rij(iatom,jnbor,2);
1233   rsq = x * x + y * y + z * z;
1234   r = sqrt(rsq);
1235 
1236   theta0 = (r - rmin0) * rfac0 * MY_PI / (rcutij(iatom,jnbor) - rmin0);
1237   //    theta0 = (r - rmin0) * rscale0;
1238   z0 = r / tan(theta0);
1239 
1240   compute_uarray_cpu(team, iatom, jnbor, x, y, z, z0, r);
1241   add_uarraytot(team, iatom, jnbor, r, wj(iatom,jnbor), rcutij(iatom,jnbor), element(iatom, jnbor));
1242 
1243 }
1244 /* ----------------------------------------------------------------------
1245    compute Zi by summing over products of Ui, CPU version
1246 ------------------------------------------------------------------------- */
1247 
1248 template<class DeviceType, typename real_type, int vector_length>
1249 KOKKOS_INLINE_FUNCTION
compute_zi_cpu(const int & iter)1250 void SNAKokkos<DeviceType, real_type, vector_length>::compute_zi_cpu(const int& iter)
1251 {
1252   const int iatom = iter / idxz_max;
1253   const int jjz = iter % idxz_max;
1254 
1255   const int j1 = idxz(jjz, 0);
1256   const int j2 = idxz(jjz, 1);
1257   const int j = idxz(jjz, 2);
1258   const int ma1min = idxz(jjz, 3);
1259   const int ma2max = idxz(jjz, 4);
1260   const int mb1min = idxz(jjz, 5);
1261   const int mb2max = idxz(jjz, 6);
1262   const int na = idxz(jjz, 7);
1263   const int nb = idxz(jjz, 8);
1264 
1265   const real_type *cgblock = cglist.data() + idxcg_block(j1,j2,j);
1266 
1267   int idouble = 0;
1268 
1269   for (int elem1 = 0; elem1 < nelements; elem1++) {
1270     for (int elem2 = 0; elem2 < nelements; elem2++) {
1271       zlist(jjz, idouble, iatom).re = static_cast<real_type>(0.0);
1272       zlist(jjz, idouble, iatom).im = static_cast<real_type>(0.0);
1273 
1274       int jju1 = idxu_block[j1] + (j1+1)*mb1min;
1275       int jju2 = idxu_block[j2] + (j2+1)*mb2max;
1276       int icgb = mb1min*(j2+1) + mb2max;
1277       for (int ib = 0; ib < nb; ib++) {
1278 
1279         real_type suma1_r = static_cast<real_type>(0.0);
1280         real_type suma1_i = static_cast<real_type>(0.0);
1281 
1282         int ma1 = ma1min;
1283         int ma2 = ma2max;
1284         int icga = ma1min * (j2 + 1) + ma2max;
1285         for (int ia = 0; ia < na; ia++) {
1286           suma1_r += cgblock[icga] * (ulisttot_full(jju1+ma1, elem1, iatom).re * ulisttot_full(jju2+ma2, elem2, iatom).re -
1287                                       ulisttot_full(jju1+ma1, elem1, iatom).im * ulisttot_full(jju2+ma2, elem2, iatom).im);
1288           suma1_i += cgblock[icga] * (ulisttot_full(jju1+ma1, elem1, iatom).re * ulisttot_full(jju2+ma2, elem2, iatom).im +
1289                                       ulisttot_full(jju1+ma1, elem1, iatom).im * ulisttot_full(jju2+ma2, elem2, iatom).re);
1290           ma1++;
1291           ma2--;
1292           icga += j2;
1293         } // end loop over ia
1294 
1295         zlist(jjz, idouble, iatom).re += cgblock[icgb] * suma1_r;
1296         zlist(jjz, idouble, iatom).im += cgblock[icgb] * suma1_i;
1297 
1298         jju1 += j1 + 1;
1299         jju2 -= j2 + 1;
1300         icgb += j2;
1301       } // end loop over ib
1302 
1303       if (bnorm_flag) {
1304         const real_type scale = static_cast<real_type>(1) / static_cast<real_type>(j + 1);
1305         zlist(jjz, idouble, iatom).re *= scale;
1306         zlist(jjz, idouble, iatom).im *= scale;
1307       }
1308       idouble++;
1309     } // end loop over elem2
1310   } // end loop over elem1
1311 }
1312 
1313 
1314 /* ----------------------------------------------------------------------
1315    compute Bi by summing conj(Ui)*Zi, CPU version
1316 ------------------------------------------------------------------------- */
1317 
1318 template<class DeviceType, typename real_type, int vector_length>
1319 KOKKOS_INLINE_FUNCTION
compute_bi_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom)1320 void SNAKokkos<DeviceType, real_type, vector_length>::compute_bi_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom)
1321 {
1322   // for j1 = 0,...,twojmax
1323   //   for j2 = 0,twojmax
1324   //     for j = |j1-j2|,Min(twojmax,j1+j2),2
1325   //        b(j1,j2,j) = 0
1326   //        for mb = 0,...,jmid
1327   //          for ma = 0,...,j
1328   //            b(j1,j2,j) +=
1329   //              2*Conj(u(j,ma,mb))*z(j1,j2,j,ma,mb)
1330 
1331   int itriple = 0;
1332   int idouble = 0;
1333   for (int elem1 = 0; elem1 < nelements; elem1++) {
1334     for (int elem2 = 0; elem2 < nelements; elem2++) {
1335       int jalloy = idouble; // must be non-const to work around gcc compiler bug
1336       for (int elem3 = 0; elem3 < nelements; elem3++) {
1337         Kokkos::parallel_for(Kokkos::TeamThreadRange(team,idxb_max),
1338           [&] (const int& jjb) {
1339           const int j1 = idxb(jjb, 0);
1340           const int j2 = idxb(jjb, 1);
1341           int j = idxb(jjb, 2); // removed "const" to work around GCC 7 bug
1342 
1343           int jjz = idxz_block(j1, j2, j);
1344           int jju = idxu_block[j];
1345           real_type sumzu = static_cast<real_type>(0.0);
1346           real_type sumzu_temp = static_cast<real_type>(0.0);
1347           const int bound = (j+2)/2;
1348           Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team,(j+1)*bound),
1349               [&] (const int mbma, real_type& sum) {
1350               const int ma = mbma % (j + 1);
1351               const int mb = mbma / (j + 1);
1352               const int jju_index = jju + mb * (j + 1) + ma;
1353               const int jjz_index = jjz + mb * (j + 1) + ma;
1354               if (2*mb == j) return;
1355               sum +=
1356                 ulisttot_full(jju_index, elem3, iatom).re * zlist(jjz_index, jalloy, iatom).re +
1357                 ulisttot_full(jju_index, elem3, iatom).im * zlist(jjz_index, jalloy, iatom).im;
1358             },sumzu_temp); // end loop over ma, mb
1359             sumzu += sumzu_temp;
1360 
1361           // For j even, special treatment for middle column
1362 
1363           if (j%2 == 0) {
1364             int mb = j/2; // removed "const" to work around GCC 7 bug
1365             Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, mb),
1366                 [&] (const int ma, real_type& sum) {
1367               const int jju_index = jju+(mb-1)*(j+1)+(j+1)+ma;
1368               const int jjz_index = jjz+(mb-1)*(j+1)+(j+1)+ma;
1369               sum +=
1370                 ulisttot_full(jju_index, elem3, iatom).re * zlist(jjz_index, jalloy, iatom).re +
1371                 ulisttot_full(jju_index, elem3, iatom).im * zlist(jjz_index, jalloy, iatom).im;
1372             },sumzu_temp); // end loop over ma
1373             sumzu += sumzu_temp;
1374 
1375             const int ma = mb;
1376             const int jju_index = jju+(mb-1)*(j+1)+(j+1)+ma;
1377             const int jjz_index = jjz+(mb-1)*(j+1)+(j+1)+ma;
1378             sumzu += static_cast<real_type>(0.5)*
1379               (ulisttot_full(jju_index, elem3, iatom).re * zlist(jjz_index, jalloy, iatom).re +
1380                ulisttot_full(jju_index, elem3, iatom).im * zlist(jjz_index, jalloy, iatom).im);
1381           } // end if jeven
1382 
1383           Kokkos::single(Kokkos::PerThread(team), [&] () {
1384             sumzu *= static_cast<real_type>(2.0);
1385 
1386             // apply bzero shift
1387 
1388             if (bzero_flag) {
1389               if (!wselfall_flag) {
1390                 if (elem1 == elem2 && elem1 == elem3) {
1391                   sumzu -= bzero[j];
1392                 }
1393               } else {
1394                 sumzu -= bzero[j];
1395               }
1396             }
1397 
1398             blist(iatom, itriple, jjb) = sumzu;
1399           });
1400         });
1401           //} // end loop over j
1402         //} // end loop over j1, j2
1403         itriple++;
1404       }
1405       idouble++;
1406     } // end loop over elem2
1407   } // end loop over elem1
1408 
1409 }
1410 
1411 /* ----------------------------------------------------------------------
1412    compute Yi from Ui without storing Zi, looping over zlist indices,
1413    CPU version
1414 ------------------------------------------------------------------------- */
1415 
1416 template<class DeviceType, typename real_type, int vector_length>
1417 KOKKOS_INLINE_FUNCTION
compute_yi_cpu(int iter,const Kokkos::View<real_type **,DeviceType> & beta)1418 void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_cpu(int iter,
1419  const Kokkos::View<real_type**, DeviceType> &beta)
1420 {
1421   real_type betaj;
1422   const int iatom = iter / idxz_max;
1423   const int jjz = iter % idxz_max;
1424 
1425   const int j1 = idxz(jjz, 0);
1426   const int j2 = idxz(jjz, 1);
1427   const int j = idxz(jjz, 2);
1428   const int ma1min = idxz(jjz, 3);
1429   const int ma2max = idxz(jjz, 4);
1430   const int mb1min = idxz(jjz, 5);
1431   const int mb2max = idxz(jjz, 6);
1432   const int na = idxz(jjz, 7);
1433   const int nb = idxz(jjz, 8);
1434   const int jju_half = idxz(jjz, 9);
1435 
1436   const real_type *cgblock = cglist.data() + idxcg_block(j1,j2,j);
1437   //int mb = (2 * (mb1min+mb2max) - j1 - j2 + j) / 2;
1438   //int ma = (2 * (ma1min+ma2max) - j1 - j2 + j) / 2;
1439 
1440   for (int elem1 = 0; elem1 < nelements; elem1++) {
1441     for (int elem2 = 0; elem2 < nelements; elem2++) {
1442 
1443       real_type ztmp_r = 0.0;
1444       real_type ztmp_i = 0.0;
1445 
1446       int jju1 = idxu_block[j1] + (j1 + 1) * mb1min;
1447       int jju2 = idxu_block[j2] + (j2 + 1) * mb2max;
1448       int icgb = mb1min * (j2 +1) + mb2max;
1449 
1450       for (int ib = 0; ib < nb; ib++) {
1451 
1452         real_type suma1_r = 0.0;
1453         real_type suma1_i = 0.0;
1454 
1455         int ma1 = ma1min;
1456         int ma2 = ma2max;
1457         int icga = ma1min*(j2+1) + ma2max;
1458 
1459         for (int ia = 0; ia < na; ia++) {
1460           suma1_r += cgblock[icga] * (ulisttot_full(jju1+ma1, elem1, iatom).re * ulisttot_full(jju2+ma2, elem2, iatom).re -
1461                                       ulisttot_full(jju1+ma1, elem1, iatom).im * ulisttot_full(jju2+ma2, elem2, iatom).im);
1462           suma1_i += cgblock[icga] * (ulisttot_full(jju1+ma1, elem1, iatom).re * ulisttot_full(jju2+ma2, elem2, iatom).im +
1463                                       ulisttot_full(jju1+ma1, elem1, iatom).im * ulisttot_full(jju2+ma2, elem2, iatom).re);
1464           ma1++;
1465           ma2--;
1466           icga += j2;
1467         } // end loop over ia
1468 
1469         ztmp_r += cgblock[icgb] * suma1_r;
1470         ztmp_i += cgblock[icgb] * suma1_i;
1471         jju1 += j1 + 1;
1472         jju2 -= j2 + 1;
1473         icgb += j2;
1474       } // end loop over ib
1475 
1476       if (bnorm_flag) {
1477         const real_type scale = static_cast<real_type>(1) / static_cast<real_type>(j + 1);
1478         ztmp_i *= scale;
1479         ztmp_r *= scale;
1480       }
1481 
1482       // apply to z(j1,j2,j,ma,mb) to unique element of y(j)
1483       // find right y_list[jju] and beta(iatom,jjb) entries
1484       // multiply and divide by j+1 factors
1485       // account for multiplicity of 1, 2, or 3
1486 
1487       // pick out right beta value
1488       for (int elem3 = 0; elem3 < nelements; elem3++) {
1489 
1490         if (j >= j1) {
1491           const int jjb = idxb_block(j1, j2, j);
1492           const int itriple = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + jjb;
1493           if (j1 == j) {
1494             if (j2 == j) betaj = 3 * beta(itriple, iatom);
1495             else betaj = 2 * beta(itriple, iatom);
1496           } else betaj = beta(itriple, iatom);
1497         } else if (j >= j2) {
1498           const int jjb = idxb_block(j, j2, j1);
1499           const int itriple = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + jjb;
1500           if (j2 == j) betaj = 2 * beta(itriple, iatom);
1501           else betaj = beta(itriple, iatom);
1502         } else {
1503           const int jjb = idxb_block(j2, j, j1);
1504           const int itriple = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + jjb;
1505           betaj = beta(itriple, iatom);
1506         }
1507 
1508         if (!bnorm_flag && j1 > j)
1509           betaj *= static_cast<real_type>(j1 + 1) / static_cast<real_type>(j + 1);
1510 
1511         Kokkos::atomic_add(&(ylist(jju_half, elem3, iatom).re), betaj*ztmp_r);
1512         Kokkos::atomic_add(&(ylist(jju_half, elem3, iatom).im), betaj*ztmp_i);
1513       } // end loop over elem3
1514     } // end loop over elem2
1515   } // end loop over elem1
1516 }
1517 
1518 
1519 /* ----------------------------------------------------------------------
1520    calculate derivative of Ui w.r.t. atom j
1521    see comments above compute_duarray_cpu for comments on the
1522    data layout
1523 ------------------------------------------------------------------------- */
1524 
1525 template<class DeviceType, typename real_type, int vector_length>
1526 KOKKOS_INLINE_FUNCTION
compute_duidrj_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom,int jnbor)1527 void SNAKokkos<DeviceType, real_type, vector_length>::compute_duidrj_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom, int jnbor)
1528 {
1529   real_type rsq, r, x, y, z, z0, theta0, cs, sn;
1530   real_type dz0dr;
1531 
1532   x = rij(iatom,jnbor,0);
1533   y = rij(iatom,jnbor,1);
1534   z = rij(iatom,jnbor,2);
1535   rsq = x * x + y * y + z * z;
1536   r = sqrt(rsq);
1537   real_type rscale0 = rfac0 * static_cast<real_type>(MY_PI) / (rcutij(iatom,jnbor) - rmin0);
1538   theta0 = (r - rmin0) * rscale0;
1539   sn = sin(theta0);
1540   cs = cos(theta0);
1541   z0 = r * cs / sn;
1542   dz0dr = z0 / r - (r*rscale0) * (rsq + z0 * z0) / rsq;
1543 
1544   compute_duarray_cpu(team, iatom, jnbor, x, y, z, z0, r, dz0dr, wj(iatom,jnbor), rcutij(iatom,jnbor));
1545 }
1546 
1547 
1548 /* ----------------------------------------------------------------------
1549    compute dEidRj, CPU path only.
1550    dulist takes advantage of a `cached` data layout, similar to the
1551    shared memory layout for the GPU routines, which is efficient for
1552    compressing the calculation in compute_duarray_cpu. That said,
1553    dulist only uses the "half" data layout part of that structure.
1554 ------------------------------------------------------------------------- */
1555 
1556 
1557 template<class DeviceType, typename real_type, int vector_length>
1558 KOKKOS_INLINE_FUNCTION
compute_deidrj_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom,int jnbor)1559 void SNAKokkos<DeviceType, real_type, vector_length>::compute_deidrj_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom, int jnbor)
1560 {
1561   t_scalar3<real_type> final_sum;
1562   const int jelem = element(iatom, jnbor);
1563 
1564   Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team,twojmax+1),
1565       [&] (const int& j, t_scalar3<real_type>& sum_tmp) {
1566     int jju_half = idxu_half_block[j];
1567     int jju_cache = idxu_cache_block[j];
1568 
1569     for (int mb = 0; 2*mb < j; mb++)
1570       for (int ma = 0; ma <= j; ma++) {
1571         sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * ylist(jju_half,jelem,iatom).re +
1572                      dulist(jju_cache,iatom,jnbor,0).im * ylist(jju_half,jelem,iatom).im;
1573         sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * ylist(jju_half,jelem,iatom).re +
1574                      dulist(jju_cache,iatom,jnbor,1).im * ylist(jju_half,jelem,iatom).im;
1575         sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * ylist(jju_half,jelem,iatom).re +
1576                      dulist(jju_cache,iatom,jnbor,2).im * ylist(jju_half,jelem,iatom).im;
1577         jju_half++; jju_cache++;
1578       } //end loop over ma mb
1579 
1580     // For j even, handle middle column
1581 
1582     if (j%2 == 0) {
1583 
1584       int mb = j/2;
1585       for (int ma = 0; ma < mb; ma++) {
1586         sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * ylist(jju_half,jelem,iatom).re +
1587                      dulist(jju_cache,iatom,jnbor,0).im * ylist(jju_half,jelem,iatom).im;
1588         sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * ylist(jju_half,jelem,iatom).re +
1589                      dulist(jju_cache,iatom,jnbor,1).im * ylist(jju_half,jelem,iatom).im;
1590         sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * ylist(jju_half,jelem,iatom).re +
1591                      dulist(jju_cache,iatom,jnbor,2).im * ylist(jju_half,jelem,iatom).im;
1592         jju_half++; jju_cache++;
1593       }
1594 
1595       //int ma = mb;
1596       sum_tmp.x += (dulist(jju_cache,iatom,jnbor,0).re * ylist(jju_half,jelem,iatom).re +
1597                     dulist(jju_cache,iatom,jnbor,0).im * ylist(jju_half,jelem,iatom).im)*0.5;
1598       sum_tmp.y += (dulist(jju_cache,iatom,jnbor,1).re * ylist(jju_half,jelem,iatom).re +
1599                     dulist(jju_cache,iatom,jnbor,1).im * ylist(jju_half,jelem,iatom).im)*0.5;
1600       sum_tmp.z += (dulist(jju_cache,iatom,jnbor,2).re * ylist(jju_half,jelem,iatom).re +
1601                     dulist(jju_cache,iatom,jnbor,2).im * ylist(jju_half,jelem,iatom).im)*0.5;
1602     } // end if jeven
1603 
1604   },final_sum); // end loop over j
1605 
1606   Kokkos::single(Kokkos::PerThread(team), [&] () {
1607     dedr(iatom,jnbor,0) = final_sum.x*2.0;
1608     dedr(iatom,jnbor,1) = final_sum.y*2.0;
1609     dedr(iatom,jnbor,2) = final_sum.z*2.0;
1610   });
1611 
1612 }
1613 
1614 
1615 /* ----------------------------------------------------------------------
1616    add Wigner U-functions for one neighbor to the total
1617    ulist is in a "cached" data layout, which is a compressed layout
1618    which still keeps the recursive calculation simple. On the other hand
1619    `ulisttot` uses a "half" data layout, which fully takes advantage
1620    of the symmetry of the Wigner U matrices.
1621 ------------------------------------------------------------------------- */
1622 
1623 template<class DeviceType, typename real_type, int vector_length>
1624 KOKKOS_INLINE_FUNCTION
add_uarraytot(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom,int jnbor,const real_type & r,const real_type & wj,const real_type & rcut,int jelem)1625 void SNAKokkos<DeviceType, real_type, vector_length>::add_uarraytot(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom, int jnbor,
1626                                           const real_type& r, const real_type& wj, const real_type& rcut, int jelem)
1627 {
1628   const real_type sfac = compute_sfac(r, rcut) * wj;
1629 
1630   Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,twojmax+1),
1631       [&] (const int& j) {
1632 
1633     int jju_half = idxu_half_block[j]; // index into ulisttot
1634     int jju_cache = idxu_cache_block[j]; // index into ulist
1635 
1636     int count = 0;
1637     for (int mb = 0; 2*mb <= j; mb++) {
1638       for (int ma = 0; ma <= j; ma++) {
1639         Kokkos::atomic_add(&(ulisttot(jju_half+count, jelem, iatom).re), sfac * ulist(jju_cache+count, iatom, jnbor).re);
1640         Kokkos::atomic_add(&(ulisttot(jju_half+count, jelem, iatom).im), sfac * ulist(jju_cache+count, iatom, jnbor).im);
1641         count++;
1642       }
1643     }
1644   });
1645 }
1646 
1647 /* ----------------------------------------------------------------------
1648    compute Wigner U-functions for one neighbor.
1649    `ulisttot` uses a "cached" data layout, matching the amount of
1650    information stored between layers via scratch memory on the GPU path
1651 ------------------------------------------------------------------------- */
1652 
1653 template<class DeviceType, typename real_type, int vector_length>
1654 KOKKOS_INLINE_FUNCTION
compute_uarray_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom,int jnbor,const real_type & x,const real_type & y,const real_type & z,const real_type & z0,const real_type & r)1655 void SNAKokkos<DeviceType, real_type, vector_length>::compute_uarray_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom, int jnbor,
1656                          const real_type& x, const real_type& y, const real_type& z, const real_type& z0, const real_type& r)
1657 {
1658   real_type r0inv;
1659   real_type a_r, b_r, a_i, b_i;
1660   real_type rootpq;
1661 
1662   // compute Cayley-Klein parameters for unit quaternion
1663 
1664   r0inv = static_cast<real_type>(1.0) / sqrt(r * r + z0 * z0);
1665   a_r = r0inv * z0;
1666   a_i = -r0inv * z;
1667   b_r = r0inv * y;
1668   b_i = -r0inv * x;
1669 
1670   // VMK Section 4.8.2
1671 
1672   ulist(0,iatom,jnbor).re = 1.0;
1673   ulist(0,iatom,jnbor).im = 0.0;
1674 
1675   for (int j = 1; j <= twojmax; j++) {
1676     int jju = idxu_cache_block[j]; // removed "const" to work around GCC 7 bug
1677     int jjup = idxu_cache_block[j-1]; // removed "const" to work around GCC 7 bug
1678 
1679     // fill in left side of matrix layer from previous layer
1680 
1681     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,(j+2)/2),
1682         [&] (const int& mb) {
1683     //for (int mb = 0; 2*mb <= j; mb++) {
1684       const int jju_index = jju+mb+mb*j;
1685       ulist(jju_index,iatom,jnbor).re = 0.0;
1686       ulist(jju_index,iatom,jnbor).im = 0.0;
1687 
1688       for (int ma = 0; ma < j; ma++) {
1689         const int jju_index = jju+mb+mb*j+ma;
1690         const int jjup_index = jjup+mb*j+ma;
1691         rootpq = rootpqarray(j - ma,j - mb);
1692         ulist(jju_index,iatom,jnbor).re +=
1693           rootpq *
1694           (a_r * ulist(jjup_index,iatom,jnbor).re +
1695            a_i * ulist(jjup_index,iatom,jnbor).im);
1696         ulist(jju_index,iatom,jnbor).im +=
1697           rootpq *
1698           (a_r * ulist(jjup_index,iatom,jnbor).im -
1699            a_i * ulist(jjup_index,iatom,jnbor).re);
1700 
1701         rootpq = rootpqarray(ma + 1,j - mb);
1702         ulist(jju_index+1,iatom,jnbor).re =
1703           -rootpq *
1704           (b_r * ulist(jjup_index,iatom,jnbor).re +
1705            b_i * ulist(jjup_index,iatom,jnbor).im);
1706         ulist(jju_index+1,iatom,jnbor).im =
1707           -rootpq *
1708           (b_r * ulist(jjup_index,iatom,jnbor).im -
1709            b_i * ulist(jjup_index,iatom,jnbor).re);
1710       }
1711 
1712       // copy left side to right side with inversion symmetry VMK 4.4(2)
1713       // u[ma-j,mb-j] = (-1)^(ma-mb)*Conj([u[ma,mb))
1714 
1715       // Only need to add one symmetrized row for convenience
1716       // Symmetry gets "unfolded" in accumulating ulisttot
1717       if (j%2==1 && mb==(j/2)) {
1718         const int mbpar = (mb)%2==0?1:-1;
1719         int mapar = mbpar;
1720         for (int ma = 0; ma <= j; ma++) {
1721           const int jju_index = jju + mb*(j+1) + ma;
1722           const int jjup_index = jju + (j+1-mb)*(j+1)-(ma+1);
1723           if (mapar == 1) {
1724             ulist(jjup_index,iatom,jnbor).re = ulist(jju_index,iatom,jnbor).re;
1725             ulist(jjup_index,iatom,jnbor).im = -ulist(jju_index,iatom,jnbor).im;
1726           } else {
1727             ulist(jjup_index,iatom,jnbor).re = -ulist(jju_index,iatom,jnbor).re;
1728             ulist(jjup_index,iatom,jnbor).im = ulist(jju_index,iatom,jnbor).im;
1729           }
1730           mapar = -mapar;
1731         }
1732       }
1733     });
1734 
1735   }
1736 }
1737 
1738 /* ----------------------------------------------------------------------
1739    compute derivatives of Wigner U-functions for one neighbor
1740    see comments in compute_uarray_cpu()
1741    Uses same cached data layout of ulist
1742 ------------------------------------------------------------------------- */
1743 
1744 template<class DeviceType, typename real_type, int vector_length>
1745 KOKKOS_INLINE_FUNCTION
compute_duarray_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type & team,int iatom,int jnbor,const real_type & x,const real_type & y,const real_type & z,const real_type & z0,const real_type & r,const real_type & dz0dr,const real_type & wj,const real_type & rcut)1746 void SNAKokkos<DeviceType, real_type, vector_length>::compute_duarray_cpu(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom, int jnbor,
1747                           const real_type& x, const real_type& y, const real_type& z,
1748                           const real_type& z0, const real_type& r, const real_type& dz0dr,
1749                           const real_type& wj, const real_type& rcut)
1750 {
1751   real_type r0inv;
1752   real_type a_r, a_i, b_r, b_i;
1753   real_type da_r[3], da_i[3], db_r[3], db_i[3];
1754   real_type dz0[3], dr0inv[3], dr0invdr;
1755   real_type rootpq;
1756 
1757   real_type rinv = 1.0 / r;
1758   real_type ux = x * rinv;
1759   real_type uy = y * rinv;
1760   real_type uz = z * rinv;
1761 
1762   r0inv = 1.0 / sqrt(r * r + z0 * z0);
1763   a_r = z0 * r0inv;
1764   a_i = -z * r0inv;
1765   b_r = y * r0inv;
1766   b_i = -x * r0inv;
1767 
1768   dr0invdr = -r0inv * r0inv * r0inv * (r + z0 * dz0dr);
1769 
1770   dr0inv[0] = dr0invdr * ux;
1771   dr0inv[1] = dr0invdr * uy;
1772   dr0inv[2] = dr0invdr * uz;
1773 
1774   dz0[0] = dz0dr * ux;
1775   dz0[1] = dz0dr * uy;
1776   dz0[2] = dz0dr * uz;
1777 
1778   for (int k = 0; k < 3; k++) {
1779     da_r[k] = dz0[k] * r0inv + z0 * dr0inv[k];
1780     da_i[k] = -z * dr0inv[k];
1781   }
1782 
1783   da_i[2] += -r0inv;
1784 
1785   for (int k = 0; k < 3; k++) {
1786     db_r[k] = y * dr0inv[k];
1787     db_i[k] = -x * dr0inv[k];
1788   }
1789 
1790   db_i[0] += -r0inv;
1791   db_r[1] += r0inv;
1792 
1793   dulist(0,iatom,jnbor,0).re = 0.0;
1794   dulist(0,iatom,jnbor,1).re = 0.0;
1795   dulist(0,iatom,jnbor,2).re = 0.0;
1796   dulist(0,iatom,jnbor,0).im = 0.0;
1797   dulist(0,iatom,jnbor,1).im = 0.0;
1798   dulist(0,iatom,jnbor,2).im = 0.0;
1799 
1800   for (int j = 1; j <= twojmax; j++) {
1801     int jju = idxu_cache_block[j];
1802     int jjup = idxu_cache_block[j-1];
1803     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,(j+2)/2),
1804         [&] (const int& mb) {
1805     //for (int mb = 0; 2*mb <= j; mb++) {
1806       const int jju_index = jju+mb+mb*j;
1807       dulist(jju_index,iatom,jnbor,0).re = 0.0;
1808       dulist(jju_index,iatom,jnbor,1).re = 0.0;
1809       dulist(jju_index,iatom,jnbor,2).re = 0.0;
1810       dulist(jju_index,iatom,jnbor,0).im = 0.0;
1811       dulist(jju_index,iatom,jnbor,1).im = 0.0;
1812       dulist(jju_index,iatom,jnbor,2).im = 0.0;
1813 
1814       for (int ma = 0; ma < j; ma++) {
1815         const int jju_index = jju+mb+mb*j+ma;
1816         const int jjup_index = jjup+mb*j+ma;
1817         rootpq = rootpqarray(j - ma,j - mb);
1818         for (int k = 0; k < 3; k++) {
1819           dulist(jju_index,iatom,jnbor,k).re +=
1820             rootpq * (da_r[k] * ulist(jjup_index,iatom,jnbor).re +
1821                       da_i[k] * ulist(jjup_index,iatom,jnbor).im +
1822                       a_r * dulist(jjup_index,iatom,jnbor,k).re +
1823                       a_i * dulist(jjup_index,iatom,jnbor,k).im);
1824           dulist(jju_index,iatom,jnbor,k).im +=
1825             rootpq * (da_r[k] * ulist(jjup_index,iatom,jnbor).im -
1826                       da_i[k] * ulist(jjup_index,iatom,jnbor).re +
1827                       a_r * dulist(jjup_index,iatom,jnbor,k).im -
1828                       a_i * dulist(jjup_index,iatom,jnbor,k).re);
1829         }
1830 
1831         rootpq = rootpqarray(ma + 1,j - mb);
1832         for (int k = 0; k < 3; k++) {
1833           dulist(jju_index+1,iatom,jnbor,k).re =
1834             -rootpq * (db_r[k] * ulist(jjup_index,iatom,jnbor).re +
1835                        db_i[k] * ulist(jjup_index,iatom,jnbor).im +
1836                        b_r * dulist(jjup_index,iatom,jnbor,k).re +
1837                        b_i * dulist(jjup_index,iatom,jnbor,k).im);
1838           dulist(jju_index+1,iatom,jnbor,k).im =
1839             -rootpq * (db_r[k] * ulist(jjup_index,iatom,jnbor).im -
1840                        db_i[k] * ulist(jjup_index,iatom,jnbor).re +
1841                        b_r * dulist(jjup_index,iatom,jnbor,k).im -
1842                        b_i * dulist(jjup_index,iatom,jnbor,k).re);
1843         }
1844       }
1845 
1846       // Only need to add one symmetrized row for convenience
1847       // Symmetry gets "unfolded" during the dedr accumulation
1848 
1849       // copy left side to right side with inversion symmetry VMK 4.4(2)
1850       // u[ma-j][mb-j] = (-1)^(ma-mb)*Conj([u[ma][mb])
1851 
1852       if (j%2==1 && mb==(j/2)) {
1853         const int mbpar = (mb)%2==0?1:-1;
1854         int mapar = mbpar;
1855         for (int ma = 0; ma <= j; ma++) {
1856           const int jju_index = jju+mb*(j+1)+ma;
1857           const int jjup_index = jju+(mb+2)*(j+1)-(ma+1);
1858           if (mapar == 1) {
1859             for (int k = 0; k < 3; k++) {
1860               dulist(jjup_index,iatom,jnbor,k).re = dulist(jju_index,iatom,jnbor,k).re;
1861               dulist(jjup_index,iatom,jnbor,k).im = -dulist(jju_index,iatom,jnbor,k).im;
1862             }
1863           } else {
1864             for (int k = 0; k < 3; k++) {
1865               dulist(jjup_index,iatom,jnbor,k).re = -dulist(jju_index,iatom,jnbor,k).re;
1866               dulist(jjup_index,iatom,jnbor,k).im = dulist(jju_index,iatom,jnbor,k).im;
1867             }
1868           }
1869           mapar = -mapar;
1870         }
1871       }
1872     });
1873   }
1874 
1875   real_type sfac = compute_sfac(r, rcut);
1876   real_type dsfac = compute_dsfac(r, rcut);
1877 
1878   sfac *= wj;
1879   dsfac *= wj;
1880 
1881   // Even though we fill out a full "cached" data layout above,
1882   // we only need the "half" data for the accumulation into dedr.
1883   // Thus we skip updating any unnecessary data.
1884   for (int j = 0; j <= twojmax; j++) {
1885     int jju = idxu_cache_block[j];
1886     for (int mb = 0; 2*mb <= j; mb++)
1887       for (int ma = 0; ma <= j; ma++) {
1888         dulist(jju,iatom,jnbor,0).re = dsfac * ulist(jju,iatom,jnbor).re * ux +
1889                                   sfac * dulist(jju,iatom,jnbor,0).re;
1890         dulist(jju,iatom,jnbor,0).im = dsfac * ulist(jju,iatom,jnbor).im * ux +
1891                                   sfac * dulist(jju,iatom,jnbor,0).im;
1892         dulist(jju,iatom,jnbor,1).re = dsfac * ulist(jju,iatom,jnbor).re * uy +
1893                                   sfac * dulist(jju,iatom,jnbor,1).re;
1894         dulist(jju,iatom,jnbor,1).im = dsfac * ulist(jju,iatom,jnbor).im * uy +
1895                                   sfac * dulist(jju,iatom,jnbor,1).im;
1896         dulist(jju,iatom,jnbor,2).re = dsfac * ulist(jju,iatom,jnbor).re * uz +
1897                                   sfac * dulist(jju,iatom,jnbor,2).re;
1898         dulist(jju,iatom,jnbor,2).im = dsfac * ulist(jju,iatom,jnbor).im * uz +
1899                                   sfac * dulist(jju,iatom,jnbor,2).im;
1900 
1901         jju++;
1902       }
1903   }
1904 }
1905 
1906 /* ----------------------------------------------------------------------
1907    factorial n, wrapper for precomputed table
1908 ------------------------------------------------------------------------- */
1909 
1910 template<class DeviceType, typename real_type, int vector_length>
1911 inline
factorial(int n)1912 double SNAKokkos<DeviceType, real_type, vector_length>::factorial(int n)
1913 {
1914   //if (n < 0 || n > nmaxfactorial) {
1915   //  char str[128];
1916   //  sprintf(str, "Invalid argument to factorial %d", n);
1917   //  error->all(FLERR, str);
1918   //}
1919 
1920   return nfac_table[n];
1921 }
1922 
1923 /* ----------------------------------------------------------------------
1924    factorial n table, size SNA::nmaxfactorial+1
1925 ------------------------------------------------------------------------- */
1926 
1927 template<class DeviceType, typename real_type, int vector_length>
1928 const double SNAKokkos<DeviceType, real_type, vector_length>::nfac_table[] = {
1929   1,
1930   1,
1931   2,
1932   6,
1933   24,
1934   120,
1935   720,
1936   5040,
1937   40320,
1938   362880,
1939   3628800,
1940   39916800,
1941   479001600,
1942   6227020800,
1943   87178291200,
1944   1307674368000,
1945   20922789888000,
1946   355687428096000,
1947   6.402373705728e+15,
1948   1.21645100408832e+17,
1949   2.43290200817664e+18,
1950   5.10909421717094e+19,
1951   1.12400072777761e+21,
1952   2.5852016738885e+22,
1953   6.20448401733239e+23,
1954   1.5511210043331e+25,
1955   4.03291461126606e+26,
1956   1.08888694504184e+28,
1957   3.04888344611714e+29,
1958   8.8417619937397e+30,
1959   2.65252859812191e+32,
1960   8.22283865417792e+33,
1961   2.63130836933694e+35,
1962   8.68331761881189e+36,
1963   2.95232799039604e+38,
1964   1.03331479663861e+40,
1965   3.71993326789901e+41,
1966   1.37637530912263e+43,
1967   5.23022617466601e+44,
1968   2.03978820811974e+46,
1969   8.15915283247898e+47,
1970   3.34525266131638e+49,
1971   1.40500611775288e+51,
1972   6.04152630633738e+52,
1973   2.65827157478845e+54,
1974   1.1962222086548e+56,
1975   5.50262215981209e+57,
1976   2.58623241511168e+59,
1977   1.24139155925361e+61,
1978   6.08281864034268e+62,
1979   3.04140932017134e+64,
1980   1.55111875328738e+66,
1981   8.06581751709439e+67,
1982   4.27488328406003e+69,
1983   2.30843697339241e+71,
1984   1.26964033536583e+73,
1985   7.10998587804863e+74,
1986   4.05269195048772e+76,
1987   2.35056133128288e+78,
1988   1.3868311854569e+80,
1989   8.32098711274139e+81,
1990   5.07580213877225e+83,
1991   3.14699732603879e+85,
1992   1.98260831540444e+87,
1993   1.26886932185884e+89,
1994   8.24765059208247e+90,
1995   5.44344939077443e+92,
1996   3.64711109181887e+94,
1997   2.48003554243683e+96,
1998   1.71122452428141e+98,
1999   1.19785716699699e+100,
2000   8.50478588567862e+101,
2001   6.12344583768861e+103,
2002   4.47011546151268e+105,
2003   3.30788544151939e+107,
2004   2.48091408113954e+109,
2005   1.88549470166605e+111,
2006   1.45183092028286e+113,
2007   1.13242811782063e+115,
2008   8.94618213078297e+116,
2009   7.15694570462638e+118,
2010   5.79712602074737e+120,
2011   4.75364333701284e+122,
2012   3.94552396972066e+124,
2013   3.31424013456535e+126,
2014   2.81710411438055e+128,
2015   2.42270953836727e+130,
2016   2.10775729837953e+132,
2017   1.85482642257398e+134,
2018   1.65079551609085e+136,
2019   1.48571596448176e+138,
2020   1.3520015276784e+140,
2021   1.24384140546413e+142,
2022   1.15677250708164e+144,
2023   1.08736615665674e+146,
2024   1.03299784882391e+148,
2025   9.91677934870949e+149,
2026   9.61927596824821e+151,
2027   9.42689044888324e+153,
2028   9.33262154439441e+155,
2029   9.33262154439441e+157,
2030   9.42594775983835e+159,
2031   9.61446671503512e+161,
2032   9.90290071648618e+163,
2033   1.02990167451456e+166,
2034   1.08139675824029e+168,
2035   1.14628056373471e+170,
2036   1.22652020319614e+172,
2037   1.32464181945183e+174,
2038   1.44385958320249e+176,
2039   1.58824554152274e+178,
2040   1.76295255109024e+180,
2041   1.97450685722107e+182,
2042   2.23119274865981e+184,
2043   2.54355973347219e+186,
2044   2.92509369349301e+188,
2045   3.3931086844519e+190,
2046   3.96993716080872e+192,
2047   4.68452584975429e+194,
2048   5.5745857612076e+196,
2049   6.68950291344912e+198,
2050   8.09429852527344e+200,
2051   9.8750442008336e+202,
2052   1.21463043670253e+205,
2053   1.50614174151114e+207,
2054   1.88267717688893e+209,
2055   2.37217324288005e+211,
2056   3.01266001845766e+213,
2057   3.8562048236258e+215,
2058   4.97450422247729e+217,
2059   6.46685548922047e+219,
2060   8.47158069087882e+221,
2061   1.118248651196e+224,
2062   1.48727070609069e+226,
2063   1.99294274616152e+228,
2064   2.69047270731805e+230,
2065   3.65904288195255e+232,
2066   5.01288874827499e+234,
2067   6.91778647261949e+236,
2068   9.61572319694109e+238,
2069   1.34620124757175e+241,
2070   1.89814375907617e+243,
2071   2.69536413788816e+245,
2072   3.85437071718007e+247,
2073   5.5502938327393e+249,
2074   8.04792605747199e+251,
2075   1.17499720439091e+254,
2076   1.72724589045464e+256,
2077   2.55632391787286e+258,
2078   3.80892263763057e+260,
2079   5.71338395644585e+262,
2080   8.62720977423323e+264,
2081   1.31133588568345e+267,
2082   2.00634390509568e+269,
2083   3.08976961384735e+271,
2084   4.78914290146339e+273,
2085   7.47106292628289e+275,
2086   1.17295687942641e+278,
2087   1.85327186949373e+280,
2088   2.94670227249504e+282,
2089   4.71472363599206e+284,
2090   7.59070505394721e+286,
2091   1.22969421873945e+289,
2092   2.0044015765453e+291,
2093   3.28721858553429e+293,
2094   5.42391066613159e+295,
2095   9.00369170577843e+297,
2096   1.503616514865e+300, // nmaxfactorial = 167
2097 };
2098 
2099 /* ----------------------------------------------------------------------
2100    the function delta given by VMK Eq. 8.2(1)
2101 ------------------------------------------------------------------------- */
2102 
2103 template<class DeviceType, typename real_type, int vector_length>
2104 inline
deltacg(int j1,int j2,int j)2105 double SNAKokkos<DeviceType, real_type, vector_length>::deltacg(int j1, int j2, int j)
2106 {
2107   double sfaccg = factorial((j1 + j2 + j) / 2 + 1);
2108   return sqrt(factorial((j1 + j2 - j) / 2) *
2109               factorial((j1 - j2 + j) / 2) *
2110               factorial((-j1 + j2 + j) / 2) / sfaccg);
2111 }
2112 
2113 /* ----------------------------------------------------------------------
2114    assign Clebsch-Gordan coefficients using
2115    the quasi-binomial formula VMK 8.2.1(3)
2116 ------------------------------------------------------------------------- */
2117 
2118 template<class DeviceType, typename real_type, int vector_length>
2119 inline
init_clebsch_gordan()2120 void SNAKokkos<DeviceType, real_type, vector_length>::init_clebsch_gordan()
2121 {
2122   auto h_cglist = Kokkos::create_mirror_view(cglist);
2123 
2124   double sum,dcg,sfaccg;
2125   int m, aa2, bb2, cc2;
2126   int ifac;
2127 
2128   int idxcg_count = 0;
2129   for (int j1 = 0; j1 <= twojmax; j1++)
2130     for (int j2 = 0; j2 <= j1; j2++)
2131       for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2) {
2132         for (int m1 = 0; m1 <= j1; m1++) {
2133           aa2 = 2 * m1 - j1;
2134 
2135           for (int m2 = 0; m2 <= j2; m2++) {
2136 
2137             // -c <= cc <= c
2138 
2139             bb2 = 2 * m2 - j2;
2140             m = (aa2 + bb2 + j) / 2;
2141 
2142             if (m < 0 || m > j) {
2143               h_cglist[idxcg_count] = 0.0;
2144               idxcg_count++;
2145               continue;
2146             }
2147 
2148             sum = 0.0;
2149 
2150             for (int z = MAX(0, MAX(-(j - j2 + aa2)
2151                                     / 2, -(j - j1 - bb2) / 2));
2152                  z <= MIN((j1 + j2 - j) / 2,
2153                           MIN((j1 - aa2) / 2, (j2 + bb2) / 2));
2154                  z++) {
2155               ifac = z % 2 ? -1 : 1;
2156               sum += ifac /
2157                 (factorial(z) *
2158                  factorial((j1 + j2 - j) / 2 - z) *
2159                  factorial((j1 - aa2) / 2 - z) *
2160                  factorial((j2 + bb2) / 2 - z) *
2161                  factorial((j - j2 + aa2) / 2 + z) *
2162                  factorial((j - j1 - bb2) / 2 + z));
2163             }
2164 
2165             cc2 = 2 * m - j;
2166             dcg = deltacg(j1, j2, j);
2167             sfaccg = sqrt(factorial((j1 + aa2) / 2) *
2168                           factorial((j1 - aa2) / 2) *
2169                           factorial((j2 + bb2) / 2) *
2170                           factorial((j2 - bb2) / 2) *
2171                           factorial((j  + cc2) / 2) *
2172                           factorial((j  - cc2) / 2) *
2173                           (j + 1));
2174 
2175             h_cglist[idxcg_count] = sum * dcg * sfaccg;
2176             idxcg_count++;
2177           }
2178         }
2179       }
2180   Kokkos::deep_copy(cglist,h_cglist);
2181 }
2182 
2183 /* ----------------------------------------------------------------------
2184    pre-compute table of sqrt[p/m2], p, q = 1,twojmax
2185    the p = 0, q = 0 entries are allocated and skipped for convenience.
2186 ------------------------------------------------------------------------- */
2187 
2188 template<class DeviceType, typename real_type, int vector_length>
2189 inline
init_rootpqarray()2190 void SNAKokkos<DeviceType, real_type, vector_length>::init_rootpqarray()
2191 {
2192   auto h_rootpqarray = Kokkos::create_mirror_view(rootpqarray);
2193   for (int p = 1; p <= twojmax; p++)
2194     for (int q = 1; q <= twojmax; q++)
2195       h_rootpqarray(p,q) = static_cast<real_type>(sqrt(static_cast<double>(p)/q));
2196   Kokkos::deep_copy(rootpqarray,h_rootpqarray);
2197 }
2198 
2199 
2200 /* ---------------------------------------------------------------------- */
2201 
2202 template<class DeviceType, typename real_type, int vector_length>
2203 inline
compute_ncoeff()2204 int SNAKokkos<DeviceType, real_type, vector_length>::compute_ncoeff()
2205 {
2206   int ncount;
2207 
2208   ncount = 0;
2209 
2210   for (int j1 = 0; j1 <= twojmax; j1++)
2211     for (int j2 = 0; j2 <= j1; j2++)
2212       for (int j = j1 - j2;
2213            j <= MIN(twojmax, j1 + j2); j += 2)
2214         if (j >= j1) ncount++;
2215 
2216   ndoubles = nelements*nelements;
2217   ntriples = nelements*nelements*nelements;
2218   if (chem_flag) ncount *= ntriples;
2219 
2220   return ncount;
2221 }
2222 
2223 /* ---------------------------------------------------------------------- */
2224 
2225 template<class DeviceType, typename real_type, int vector_length>
2226 KOKKOS_INLINE_FUNCTION
compute_sfac(real_type r,real_type rcut)2227 real_type SNAKokkos<DeviceType, real_type, vector_length>::compute_sfac(real_type r, real_type rcut)
2228 {
2229   constexpr real_type one = static_cast<real_type>(1.0);
2230   constexpr real_type zero = static_cast<real_type>(0.0);
2231   constexpr real_type onehalf = static_cast<real_type>(0.5);
2232   if (switch_flag == 0) return one;
2233   if (switch_flag == 1) {
2234     if (r <= rmin0) return one;
2235     else if (r > rcut) return zero;
2236     else {
2237       real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
2238       return onehalf * (cos((r - rmin0) * rcutfac) + one);
2239     }
2240   }
2241   return zero;
2242 }
2243 
2244 /* ---------------------------------------------------------------------- */
2245 
2246 template<class DeviceType, typename real_type, int vector_length>
2247 KOKKOS_INLINE_FUNCTION
compute_dsfac(real_type r,real_type rcut)2248 real_type SNAKokkos<DeviceType, real_type, vector_length>::compute_dsfac(real_type r, real_type rcut)
2249 {
2250   constexpr real_type zero = static_cast<real_type>(0.0);
2251   constexpr real_type onehalf = static_cast<real_type>(0.5);
2252   if (switch_flag == 0) return zero;
2253   if (switch_flag == 1) {
2254     if (r <= rmin0) return zero;
2255     else if (r > rcut) return zero;
2256     else {
2257       real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
2258       return -onehalf * sin((r - rmin0) * rcutfac) * rcutfac;
2259     }
2260   }
2261   return zero;
2262 }
2263 
2264 template<class DeviceType, typename real_type, int vector_length>
2265 KOKKOS_INLINE_FUNCTION
compute_s_dsfac(const real_type r,const real_type rcut,real_type & sfac,real_type & dsfac)2266 void SNAKokkos<DeviceType, real_type, vector_length>::compute_s_dsfac(const real_type r, const real_type rcut, real_type& sfac, real_type& dsfac) {
2267   constexpr real_type one = static_cast<real_type>(1.0);
2268   constexpr real_type zero = static_cast<real_type>(0.0);
2269   constexpr real_type onehalf = static_cast<real_type>(0.5);
2270   if (switch_flag == 0) { sfac = zero; dsfac = zero; }
2271   else if (switch_flag == 1) {
2272     if (r <= rmin0) { sfac = one; dsfac = zero; }
2273     else if (r > rcut) { sfac = zero; dsfac = zero; }
2274     else {
2275       const real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
2276       const real_type theta0 = (r - rmin0) * rcutfac;
2277       const real_type sn = sin(theta0);
2278       const real_type cs = cos(theta0);
2279       sfac = onehalf * (cs + one);
2280       dsfac = -onehalf * sn * rcutfac;
2281 
2282     }
2283   } else { sfac = zero; dsfac = zero; }
2284 }
2285 
2286 /* ----------------------------------------------------------------------
2287    memory usage of arrays
2288 ------------------------------------------------------------------------- */
2289 
2290 template<class DeviceType, typename real_type, int vector_length>
memory_usage()2291 double SNAKokkos<DeviceType, real_type, vector_length>::memory_usage()
2292 {
2293   int jdimpq = twojmax + 2;
2294   int jdim = twojmax + 1;
2295   double bytes;
2296 
2297   bytes = 0;
2298 
2299   bytes += jdimpq*jdimpq * sizeof(real_type);               // pqarray
2300   bytes += idxcg_max * sizeof(real_type);                   // cglist
2301 
2302 #ifdef LMP_KOKKOS_GPU
2303   if (!host_flag) {
2304 
2305     auto natom_pad = (natom+vector_length-1)/vector_length;
2306 
2307     bytes += natom_pad * nmax * sizeof(real_type) * 2;     // a_pack
2308     bytes += natom_pad * nmax * sizeof(real_type) * 2;     // b_pack
2309     bytes += natom_pad * nmax * 3 * sizeof(real_type) * 2; // da_pack
2310     bytes += natom_pad * nmax * 3 * sizeof(real_type) * 2; // db_pack
2311     bytes += natom_pad * nmax * 4 * sizeof(real_type);     // sfac_pack
2312 
2313 
2314     bytes += natom_pad * idxu_half_max * nelements * sizeof(real_type);     // ulisttot_re_pack
2315     bytes += natom_pad * idxu_half_max * nelements * sizeof(real_type);     // ulisttot_im_pack
2316     bytes += natom_pad * idxu_max * nelements * sizeof(real_type) * 2;      // ulisttot_pack
2317 
2318     bytes += natom_pad * idxz_max * ndoubles * sizeof(real_type) * 2;   // zlist_pack
2319     bytes += natom_pad * idxb_max * ntriples * sizeof(real_type);       // blist_pack
2320 
2321     bytes += natom_pad * idxu_half_max * nelements * sizeof(real_type); // ylist_pack_re
2322     bytes += natom_pad * idxu_half_max * nelements * sizeof(real_type); // ylist_pack_im
2323   } else {
2324 #endif
2325 
2326     bytes += natom * nmax * idxu_cache_max * sizeof(real_type) * 2;     // ulist
2327     bytes += natom * idxu_half_max * nelements * sizeof(real_type) * 2; // ulisttot
2328     bytes += natom * idxu_max * nelements * sizeof(real_type) * 2;      // ulisttot_full
2329 
2330     bytes += natom * idxz_max * ndoubles * sizeof(real_type) * 2;       // zlist
2331     bytes += natom * idxb_max * ntriples * sizeof(real_type);           // blist
2332 
2333     bytes += natom * idxu_half_max * nelements * sizeof(real_type) * 2; // ylist
2334 
2335     bytes += natom * nmax * idxu_cache_max * 3 * sizeof(real_type) * 2; // dulist
2336 #ifdef LMP_KOKKOS_GPU
2337   }
2338 #endif
2339 
2340   bytes += natom * nmax * 3 * sizeof(real_type);            // dedr
2341 
2342   bytes += jdim * jdim * jdim * sizeof(int);             // idxcg_block
2343   bytes += jdim * sizeof(int);                           // idxu_block
2344   bytes += jdim * sizeof(int);                           // idxu_half_block
2345   bytes += idxu_max * sizeof(FullHalfMapper);            // idxu_full_half
2346   bytes += jdim * sizeof(int);                           // idxu_cache_block
2347   bytes += jdim * jdim * jdim * sizeof(int);             // idxz_block
2348   bytes += jdim * jdim * jdim * sizeof(int);             // idxb_block
2349 
2350   bytes += idxz_max * 10 * sizeof(int);                  // idxz
2351   bytes += idxb_max * 3 * sizeof(int);                   // idxb
2352 
2353   bytes += jdim * sizeof(real_type);                        // bzero
2354 
2355   bytes += natom * nmax * 3 * sizeof(real_type);            // rij
2356   bytes += natom * nmax * sizeof(real_type);                // inside
2357   bytes += natom * nmax * sizeof(real_type);                // wj
2358   bytes += natom * nmax * sizeof(real_type);                // rcutij
2359 
2360   return bytes;
2361 }
2362 
2363 } // namespace LAMMPS_NS
2364