1 // **************************************************************************
2 //                                gayberne_lj.cu
3 //                             -------------------
4 //                           W. Michael Brown (ORNL)
5 //
6 //  Device code for Gay-Berne - Lennard-Jones potential acceleration
7 //
8 // __________________________________________________________________________
9 //    This file is part of the LAMMPS Accelerator Library (LAMMPS_AL)
10 // __________________________________________________________________________
11 //
12 //    begin                :
13 //    email                : brownw@ornl.gov
14 // ***************************************************************************
15 
16 #if defined(NV_KERNEL) || defined(USE_HIP)
17 #include "lal_ellipsoid_extra.h"
18 #endif
19 
20 #if (SHUFFLE_AVAIL == 0)
21 #define local_allocate_store_ellipse_lj local_allocate_store_ellipse
22 #else
23 #define local_allocate_store_ellipse_lj()                                   \
24     __local acctyp red_acc[7][BLOCK_ELLIPSE / SIMD_SIZE];
25 #endif
26 
k_gayberne_sphere_ellipsoid(const __global numtyp4 * restrict x_,const __global numtyp4 * restrict q,const __global numtyp4 * restrict shape,const __global numtyp4 * restrict well,const __global numtyp * restrict gum,const __global numtyp2 * restrict sig_eps,const int ntypes,const __global numtyp * restrict lshape,const __global int * dev_nbor,const int stride,__global acctyp4 * restrict ans,__global acctyp * restrict engv,__global int * restrict err_flag,const int eflag,const int vflag,const int start,const int inum,const int t_per_atom)27 __kernel void k_gayberne_sphere_ellipsoid(const __global numtyp4 *restrict x_,
28                                           const __global numtyp4 *restrict q,
29                                           const __global numtyp4 *restrict shape,
30                                           const __global numtyp4 *restrict well,
31                                           const __global numtyp *restrict gum,
32                                           const __global numtyp2 *restrict sig_eps,
33                                           const int ntypes,
34                                           const __global numtyp *restrict lshape,
35                                           const __global int *dev_nbor,
36                                           const int stride,
37                                           __global acctyp4 *restrict ans,
38                                           __global acctyp *restrict engv,
39                                           __global int *restrict err_flag,
40                                           const int eflag, const int vflag,
41                                           const int start, const int inum,
42                                           const int t_per_atom) {
43   int tid, ii, offset;
44   atom_info(t_per_atom,ii,tid,offset);
45   ii+=start;
46 
47   __local numtyp sp_lj[4];
48   int n_stride;
49   local_allocate_store_ellipse_lj();
50 
51   sp_lj[0]=gum[3];
52   sp_lj[1]=gum[4];
53   sp_lj[2]=gum[5];
54   sp_lj[3]=gum[6];
55 
56   acctyp4 f;
57   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
58   acctyp energy, virial[6];
59   if (EVFLAG) {
60     energy=(acctyp)0;
61     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
62   }
63 
64   if (ii<inum) {
65     int nbor, nbor_end;
66     int i, numj;
67     nbor_info_p(dev_nbor,stride,t_per_atom,ii,offset,i,numj,
68                 n_stride,nbor_end,nbor);
69 
70     numtyp4 ix; fetch4(ix,i,pos_tex);
71     int itype=ix.w;
72 
73     numtyp oner=shape[itype].x;
74     numtyp one_well=well[itype].x;
75 
76     numtyp factor_lj;
77     for ( ; nbor<nbor_end; nbor+=n_stride) {
78 
79       int j=dev_nbor[nbor];
80       factor_lj = sp_lj[sbmask(j)];
81       j &= NEIGHMASK;
82 
83       numtyp4 jx; fetch4(jx,j,pos_tex);
84       int jtype=jx.w;
85 
86       // Compute r12
87       numtyp r12[3];
88       r12[0] = jx.x-ix.x;
89       r12[1] = jx.y-ix.y;
90       r12[2] = jx.z-ix.z;
91       numtyp ir = gpu_dot3(r12,r12);
92 
93       ir = ucl_rsqrt(ir);
94       numtyp r = ucl_recip(ir);
95 
96       numtyp r12hat[3];
97       r12hat[0]=r12[0]*ir;
98       r12hat[1]=r12[1]*ir;
99       r12hat[2]=r12[2]*ir;
100 
101       numtyp a2[9];
102       gpu_quat_to_mat_trans(q,j,a2);
103 
104       numtyp u_r, dUr[3], eta;
105       { // Compute U_r, dUr, eta, and teta
106         // Compute g12
107         numtyp g12[9];
108         {
109           {
110             numtyp g2[9];
111             gpu_diag_times3(shape[jtype],a2,g12);
112             gpu_transpose_times3(a2,g12,g2);
113             g12[0]=g2[0]+oner;
114             g12[4]=g2[4]+oner;
115             g12[8]=g2[8]+oner;
116             g12[1]=g2[1];
117             g12[2]=g2[2];
118             g12[3]=g2[3];
119             g12[5]=g2[5];
120             g12[6]=g2[6];
121             g12[7]=g2[7];
122           }
123 
124           { // Compute U_r and dUr
125 
126             // Compute kappa
127             numtyp kappa[3];
128             gpu_mldivide3(g12,r12,kappa,err_flag);
129 
130             // -- kappa is now / r
131             kappa[0]*=ir;
132             kappa[1]*=ir;
133             kappa[2]*=ir;
134 
135             // energy
136 
137             // compute u_r and dUr
138             numtyp uslj_rsq;
139             {
140               // Compute distance of closest approach
141               numtyp h12, sigma12;
142               sigma12 = gpu_dot3(r12hat,kappa);
143               sigma12 = ucl_rsqrt((numtyp)0.5*sigma12);
144               h12 = r-sigma12;
145 
146               // -- kappa is now ok
147               kappa[0]*=r;
148               kappa[1]*=r;
149               kappa[2]*=r;
150 
151               int mtype=fast_mul(ntypes,itype)+jtype;
152               numtyp sigma = sig_eps[mtype].x;
153               numtyp epsilon = sig_eps[mtype].y;
154               numtyp varrho = sigma/(h12+gum[0]*sigma);
155               numtyp varrho6 = varrho*varrho*varrho;
156               varrho6*=varrho6;
157               numtyp varrho12 = varrho6*varrho6;
158               u_r = (numtyp)4.0*epsilon*(varrho12-varrho6);
159 
160               numtyp temp1 = ((numtyp)2.0*varrho12*varrho-varrho6*varrho)/sigma;
161               temp1 = temp1*(numtyp)24.0*epsilon;
162               uslj_rsq = temp1*sigma12*sigma12*sigma12*(numtyp)0.5;
163               numtyp temp2 = gpu_dot3(kappa,r12hat);
164               uslj_rsq = uslj_rsq*ir*ir;
165 
166               dUr[0] = temp1*r12hat[0]+uslj_rsq*(kappa[0]-temp2*r12hat[0]);
167               dUr[1] = temp1*r12hat[1]+uslj_rsq*(kappa[1]-temp2*r12hat[1]);
168               dUr[2] = temp1*r12hat[2]+uslj_rsq*(kappa[2]-temp2*r12hat[2]);
169             }
170           }
171         }
172 
173         // Compute eta
174         {
175           eta = (numtyp)2.0*lshape[itype]*lshape[jtype];
176           numtyp det_g12 = gpu_det3(g12);
177           eta = ucl_powr(eta/det_g12,gum[1]);
178         }
179       }
180 
181       numtyp chi, dchi[3];
182       { // Compute chi and dchi
183 
184         // Compute b12
185         numtyp b12[9];
186         {
187           numtyp b2[9];
188           gpu_diag_times3(well[jtype],a2,b12);
189           gpu_transpose_times3(a2,b12,b2);
190           b12[0]=b2[0]+one_well;
191           b12[4]=b2[4]+one_well;
192           b12[8]=b2[8]+one_well;
193           b12[1]=b2[1];
194           b12[2]=b2[2];
195           b12[3]=b2[3];
196           b12[5]=b2[5];
197           b12[6]=b2[6];
198           b12[7]=b2[7];
199         }
200 
201         // compute chi_12
202         numtyp iota[3];
203         gpu_mldivide3(b12,r12,iota,err_flag);
204         // -- iota is now iota/r
205         iota[0]*=ir;
206         iota[1]*=ir;
207         iota[2]*=ir;
208         chi = gpu_dot3(r12hat,iota);
209         chi = ucl_powr(chi*(numtyp)2.0,gum[2]);
210 
211         // -- iota is now ok
212         iota[0]*=r;
213         iota[1]*=r;
214         iota[2]*=r;
215 
216         numtyp temp1 = gpu_dot3(iota,r12hat);
217         numtyp temp2 = (numtyp)-4.0*ir*ir*gum[2]*ucl_powr(chi,(gum[2]-(numtyp)1.0)/
218                                                      gum[2]);
219         dchi[0] = temp2*(iota[0]-temp1*r12hat[0]);
220         dchi[1] = temp2*(iota[1]-temp1*r12hat[1]);
221         dchi[2] = temp2*(iota[2]-temp1*r12hat[2]);
222       }
223 
224       numtyp temp2 = factor_lj*eta*chi;
225       if (EVFLAG && eflag)
226         energy+=u_r*temp2;
227       numtyp temp1 = -eta*u_r*factor_lj;
228       if (EVFLAG && vflag) {
229         r12[0]*=-1;
230         r12[1]*=-1;
231         r12[2]*=-1;
232         numtyp ft=temp1*dchi[0]-temp2*dUr[0];
233         f.x+=ft;
234         virial[0]+=r12[0]*ft;
235         ft=temp1*dchi[1]-temp2*dUr[1];
236         f.y+=ft;
237         virial[1]+=r12[1]*ft;
238         virial[3]+=r12[0]*ft;
239         ft=temp1*dchi[2]-temp2*dUr[2];
240         f.z+=ft;
241         virial[2]+=r12[2]*ft;
242         virial[4]+=r12[0]*ft;
243         virial[5]+=r12[1]*ft;
244       } else {
245         f.x+=temp1*dchi[0]-temp2*dUr[0];
246         f.y+=temp1*dchi[1]-temp2*dUr[1];
247         f.z+=temp1*dchi[2]-temp2*dUr[2];
248       }
249     } // for nbor
250   } // if ii
251   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
252                 ans,engv);
253 }
254 
k_gayberne_lj(const __global numtyp4 * restrict x_,const __global numtyp4 * restrict lj1,const __global numtyp4 * restrict lj3,const int lj_types,const __global numtyp * restrict gum,const int stride,const __global int * dev_ij,__global acctyp4 * restrict ans,__global acctyp * restrict engv,__global int * restrict err_flag,const int eflag,const int vflag,const int start,const int inum,const int t_per_atom)255 __kernel void k_gayberne_lj(const __global numtyp4 *restrict x_,
256                             const __global numtyp4 *restrict lj1,
257                             const __global numtyp4 *restrict lj3,
258                             const int lj_types,
259                             const __global numtyp *restrict gum,
260                             const int stride,
261                             const __global int *dev_ij,
262                             __global acctyp4 *restrict ans,
263                             __global acctyp *restrict engv,
264                             __global int *restrict err_flag,
265                             const int eflag, const int vflag, const int start,
266                             const int inum, const int t_per_atom) {
267   int tid, ii, offset;
268   atom_info(t_per_atom,ii,tid,offset);
269   ii+=start;
270 
271   __local numtyp sp_lj[4];
272   int n_stride;
273   local_allocate_store_ellipse();
274 
275   sp_lj[0]=gum[3];
276   sp_lj[1]=gum[4];
277   sp_lj[2]=gum[5];
278   sp_lj[3]=gum[6];
279 
280   acctyp4 f;
281   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
282   acctyp energy, virial[6];
283   if (EVFLAG) {
284     energy=(acctyp)0;
285     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
286   }
287 
288   if (ii<inum) {
289     int nbor, nbor_end;
290     int i, numj;
291     nbor_info_e_ss(dev_ij,stride,t_per_atom,ii,offset,i,numj,
292                    n_stride,nbor_end,nbor);
293 
294     numtyp4 ix; fetch4(ix,i,pos_tex);
295     int itype=ix.w;
296 
297     numtyp factor_lj;
298     for ( ; nbor<nbor_end; nbor+=n_stride) {
299 
300       int j=dev_ij[nbor];
301       factor_lj = sp_lj[sbmask(j)];
302       j &= NEIGHMASK;
303 
304       numtyp4 jx; fetch4(jx,j,pos_tex);
305       int jtype=jx.w;
306 
307       // Compute r12
308       numtyp delx = ix.x-jx.x;
309       numtyp dely = ix.y-jx.y;
310       numtyp delz = ix.z-jx.z;
311       numtyp r2inv = delx*delx+dely*dely+delz*delz;
312 
313       int ii=itype*lj_types+jtype;
314       if (r2inv<lj1[ii].z && lj1[ii].w==SPHERE_SPHERE) {
315         r2inv=ucl_recip(r2inv);
316         numtyp r6inv = r2inv*r2inv*r2inv;
317         numtyp force = r2inv*r6inv*(lj1[ii].x*r6inv-lj1[ii].y);
318         force*=factor_lj;
319 
320         f.x+=delx*force;
321         f.y+=dely*force;
322         f.z+=delz*force;
323 
324         if (EVFLAG && eflag) {
325           numtyp e=r6inv*(lj3[ii].x*r6inv-lj3[ii].y);
326           energy+=factor_lj*(e-lj3[ii].z);
327         }
328         if (EVFLAG && vflag) {
329           virial[0] += delx*delx*force;
330           virial[1] += dely*dely*force;
331           virial[2] += delz*delz*force;
332           virial[3] += delx*dely*force;
333           virial[4] += delx*delz*force;
334           virial[5] += dely*delz*force;
335         }
336       }
337 
338     } // for nbor
339   } // if ii
340   acc_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
341               ans,engv);
342 }
343 
k_gayberne_lj_fast(const __global numtyp4 * restrict x_,const __global numtyp4 * restrict lj1_in,const __global numtyp4 * restrict lj3_in,const __global numtyp * restrict gum,const int stride,const __global int * dev_ij,__global acctyp4 * restrict ans,__global acctyp * restrict engv,__global int * restrict err_flag,const int eflag,const int vflag,const int start,const int inum,const int t_per_atom)344 __kernel void k_gayberne_lj_fast(const __global numtyp4 *restrict x_,
345                                  const __global numtyp4 *restrict lj1_in,
346                                  const __global numtyp4 *restrict lj3_in,
347                                  const __global numtyp *restrict gum,
348                                  const int stride,
349                                  const __global int *dev_ij,
350                                  __global acctyp4 *restrict ans,
351                                  __global acctyp *restrict engv,
352                                  __global int *restrict err_flag,
353                                  const int eflag, const int vflag,
354                                  const int start, const int inum,
355                                  const int t_per_atom) {
356   int tid, ii, offset;
357   atom_info(t_per_atom,ii,tid,offset);
358   ii+=start;
359 
360   __local numtyp sp_lj[4];
361   __local numtyp4 lj1[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
362   __local numtyp4 lj3[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
363   int n_stride;
364   local_allocate_store_ellipse();
365 
366   if (tid<4)
367     sp_lj[tid]=gum[tid+3];
368   if (tid<MAX_SHARED_TYPES*MAX_SHARED_TYPES) {
369     lj1[tid]=lj1_in[tid];
370     if (EVFLAG && eflag)
371       lj3[tid]=lj3_in[tid];
372   }
373 
374   acctyp4 f;
375   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
376   acctyp energy, virial[6];
377   if (EVFLAG) {
378     energy=(acctyp)0;
379     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
380   }
381 
382   __syncthreads();
383 
384   if (ii<inum) {
385     int nbor, nbor_end;
386     int i, numj;
387     nbor_info_e_ss(dev_ij,stride,t_per_atom,ii,offset,i,numj,
388                    n_stride,nbor_end,nbor);
389 
390     numtyp4 ix; fetch4(ix,i,pos_tex);
391     int iw=ix.w;
392     int itype=fast_mul((int)MAX_SHARED_TYPES,iw);
393 
394     numtyp factor_lj;
395     for ( ; nbor<nbor_end; nbor+=n_stride) {
396 
397       int j=dev_ij[nbor];
398       factor_lj = sp_lj[sbmask(j)];
399       j &= NEIGHMASK;
400 
401       numtyp4 jx; fetch4(jx,j,pos_tex);
402       int mtype=itype+jx.w;
403 
404       // Compute r12
405       numtyp delx = ix.x-jx.x;
406       numtyp dely = ix.y-jx.y;
407       numtyp delz = ix.z-jx.z;
408       numtyp r2inv = delx*delx+dely*dely+delz*delz;
409 
410       if (r2inv<lj1[mtype].z && lj1[mtype].w==SPHERE_SPHERE) {
411         r2inv=ucl_recip(r2inv);
412         numtyp r6inv = r2inv*r2inv*r2inv;
413         numtyp force = factor_lj*r2inv*r6inv*(lj1[mtype].x*r6inv-lj1[mtype].y);
414 
415         f.x+=delx*force;
416         f.y+=dely*force;
417         f.z+=delz*force;
418 
419         if (EVFLAG && eflag) {
420           numtyp e=r6inv*(lj3[mtype].x*r6inv-lj3[mtype].y);
421           energy+=factor_lj*(e-lj3[mtype].z);
422         }
423         if (EVFLAG && vflag) {
424           virial[0] += delx*delx*force;
425           virial[1] += dely*dely*force;
426           virial[2] += delz*delz*force;
427           virial[3] += delx*dely*force;
428           virial[4] += delx*delz*force;
429           virial[5] += dely*delz*force;
430         }
431       }
432 
433     } // for nbor
434   } // if ii
435   acc_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
436               ans,engv);
437 }
438 
439