1 // **************************************************************************
2 //                                   lal_table.cu
3 //                             -------------------
4 //                           Trung Dac Nguyen (ORNL)
5 //
6 //  Device code for acceleration of the table pair style
7 //
8 // __________________________________________________________________________
9 //    This file is part of the LAMMPS Accelerator Library (LAMMPS_AL)
10 // __________________________________________________________________________
11 //
12 //    begin                :
13 //    email                : nguyentd@ornl.gov
14 // ***************************************************************************
15 
16 #if defined(NV_KERNEL) || defined(USE_HIP)
17 #include "lal_aux_fun1.h"
18 #ifndef _DOUBLE_DOUBLE
19 _texture( pos_tex,float4);
20 #else
21 _texture_2d( pos_tex,int4);
22 #endif
23 #else
24 #define pos_tex x_
25 #endif
26 
27 #define LOOKUP 0
28 #define LINEAR 1
29 #define SPLINE 2
30 #define BITMAP 3
31 
32 #ifndef __UNION_INT_FLOAT
33 #define __UNION_INT_FLOAT
34 typedef union {
35   int i;
36   float f;
37 } union_int_float;
38 #endif
39 
40 /// ---------------- LOOKUP -------------------------------------------------
41 
k_table(const __global numtyp4 * restrict x_,const __global int * restrict tabindex,const __global numtyp4 * restrict coeff2,const __global numtyp4 * restrict coeff3,const __global numtyp4 * restrict coeff4,const int lj_types,const __global numtyp * restrict cutsq,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)42 __kernel void k_table(const __global numtyp4 *restrict x_,
43                       const __global int *restrict tabindex,
44                       const __global numtyp4 *restrict coeff2,
45                       const __global numtyp4 *restrict coeff3,
46                       const __global numtyp4 *restrict coeff4,
47                       const int lj_types,
48                       const __global numtyp *restrict cutsq,
49                       const __global numtyp *restrict sp_lj_in,
50                       const __global int *dev_nbor,
51                       const __global int *dev_packed,
52                       __global acctyp4 *restrict ans,
53                       __global acctyp *restrict engv,
54                       const int eflag, const int vflag, const int inum,
55                       const int nbor_pitch, const int t_per_atom,
56                       int tablength) {
57   int tid, ii, offset;
58   atom_info(t_per_atom,ii,tid,offset);
59 
60   __local numtyp sp_lj[4];
61   int n_stride;
62   local_allocate_store_pair();
63 
64   sp_lj[0]=sp_lj_in[0];
65   sp_lj[1]=sp_lj_in[1];
66   sp_lj[2]=sp_lj_in[2];
67   sp_lj[3]=sp_lj_in[3];
68 
69   acctyp4 f;
70   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
71   acctyp energy, virial[6];
72   if (EVFLAG) {
73     energy=(acctyp)0;
74     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
75   }
76 
77   int tlm1 = tablength - 1;
78 
79   if (ii<inum) {
80     int nbor, nbor_end;
81     int i, numj;
82     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
83               n_stride,nbor_end,nbor);
84 
85     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
86     int itype=ix.w;
87 
88     numtyp factor_lj;
89     for ( ; nbor<nbor_end; nbor+=n_stride) {
90 
91       int j=dev_packed[nbor];
92       factor_lj = sp_lj[sbmask(j)];
93       j &= NEIGHMASK;
94 
95       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
96       int mtype=itype*lj_types+jx.w;
97       int tbindex = tabindex[mtype];
98 
99       // Compute r12
100       numtyp delx = ix.x-jx.x;
101       numtyp dely = ix.y-jx.y;
102       numtyp delz = ix.z-jx.z;
103       numtyp rsq = delx*delx+dely*dely+delz*delz;
104 
105       if (rsq<cutsq[mtype]) {
106         int itable=0,idx;
107         numtyp force = (numtyp)0;
108         itable = (rsq - coeff2[mtype].x) * coeff2[mtype].y;
109         if (itable < tlm1) {
110           idx = itable + tbindex*tablength;
111           force = factor_lj * coeff3[idx].z;
112         } else force = (numtyp)0.0;
113 
114         f.x+=delx*force;
115         f.y+=dely*force;
116         f.z+=delz*force;
117 
118         if (EVFLAG && eflag) {
119           numtyp e = (numtyp)0.0;
120           if (itable < tlm1)
121             e = coeff3[idx].y;
122           energy+=factor_lj*e;
123         }
124         if (EVFLAG && vflag) {
125           virial[0] += delx*delx*force;
126           virial[1] += dely*dely*force;
127           virial[2] += delz*delz*force;
128           virial[3] += delx*dely*force;
129           virial[4] += delx*delz*force;
130           virial[5] += dely*delz*force;
131         }
132       }
133 
134     } // for nbor
135   } // if ii
136   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
137                 ans,engv);
138 }
139 
k_table_fast(const __global numtyp4 * restrict x_,const __global int * restrict tabindex,const __global numtyp4 * restrict coeff2,const __global numtyp4 * restrict coeff3,const __global numtyp4 * restrict coeff4,const __global numtyp * restrict cutsq_in,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)140 __kernel void k_table_fast(const __global numtyp4 *restrict x_,
141                            const __global int *restrict tabindex,
142                            const __global numtyp4 *restrict coeff2,
143                            const __global numtyp4 *restrict coeff3,
144                            const __global numtyp4 *restrict coeff4,
145                            const __global numtyp *restrict cutsq_in,
146                            const __global numtyp *restrict sp_lj_in,
147                            const __global int *dev_nbor,
148                            const __global int *dev_packed,
149                            __global acctyp4 *restrict ans,
150                            __global acctyp *restrict engv,
151                            const int eflag, const int vflag, const int inum,
152                            const int nbor_pitch, const int t_per_atom,
153                            int tablength) {
154   int tid, ii, offset;
155   atom_info(t_per_atom,ii,tid,offset);
156 
157   __local numtyp cutsq[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
158   __local numtyp sp_lj[4];
159   int n_stride;
160   local_allocate_store_pair();
161 
162   if (tid<4)
163     sp_lj[tid]=sp_lj_in[tid];
164   if (tid<MAX_SHARED_TYPES*MAX_SHARED_TYPES) {
165     cutsq[tid]=cutsq_in[tid];
166   }
167 
168   acctyp4 f;
169   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
170   acctyp energy, virial[6];
171   if (EVFLAG) {
172     energy=(acctyp)0;
173     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
174   }
175 
176   __syncthreads();
177 
178   int tlm1 = tablength - 1;
179 
180   if (ii<inum) {
181     int nbor, nbor_end;
182     int i, numj;
183     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
184               n_stride,nbor_end,nbor);
185 
186     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
187     int iw=ix.w;
188     int itype=fast_mul((int)MAX_SHARED_TYPES,iw);
189 
190     numtyp factor_lj;
191     for ( ; nbor<nbor_end; nbor+=n_stride) {
192 
193       int j=dev_packed[nbor];
194       factor_lj = sp_lj[sbmask(j)];
195       j &= NEIGHMASK;
196 
197       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
198       int mtype=itype+jx.w;
199       int tbindex = tabindex[mtype];
200 
201       // Compute r12
202       numtyp delx = ix.x-jx.x;
203       numtyp dely = ix.y-jx.y;
204       numtyp delz = ix.z-jx.z;
205       numtyp rsq = delx*delx+dely*dely+delz*delz;
206 
207       if (rsq<cutsq[mtype]) {
208         int itable=0,idx;
209         numtyp force = (numtyp)0;
210         itable = (rsq - coeff2[mtype].x) * coeff2[mtype].y;
211         if (itable < tlm1) {
212           idx = itable + tbindex*tablength;
213           force = factor_lj * coeff3[idx].z;
214         } else force = (numtyp)0.0;
215 
216         f.x+=delx*force;
217         f.y+=dely*force;
218         f.z+=delz*force;
219 
220         if (EVFLAG && eflag) {
221           numtyp e = (numtyp)0.0;
222           if (itable < tlm1)
223             e = coeff3[idx].y;
224           energy+=factor_lj*e;
225         }
226         if (EVFLAG && vflag) {
227           virial[0] += delx*delx*force;
228           virial[1] += dely*dely*force;
229           virial[2] += delz*delz*force;
230           virial[3] += delx*dely*force;
231           virial[4] += delx*delz*force;
232           virial[5] += dely*delz*force;
233         }
234       }
235 
236     } // for nbor
237   } // if ii
238   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
239                 ans,engv);
240 }
241 
242 /// ---------------- LINEAR -------------------------------------------------
243 
k_table_linear(const __global numtyp4 * restrict x_,const __global int * restrict tabindex,const __global numtyp4 * restrict coeff2,const __global numtyp4 * restrict coeff3,const __global numtyp4 * restrict coeff4,const int lj_types,const __global numtyp * restrict cutsq,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)244 __kernel void k_table_linear(const __global numtyp4 *restrict x_,
245                              const __global int *restrict tabindex,
246                              const __global numtyp4 *restrict coeff2,
247                              const __global numtyp4 *restrict coeff3,
248                              const __global numtyp4 *restrict coeff4,
249                              const int lj_types,
250                              const __global numtyp *restrict cutsq,
251                              const __global numtyp *restrict sp_lj_in,
252                              const __global int *dev_nbor,
253                              const __global int *dev_packed,
254                              __global acctyp4 *restrict ans,
255                              __global acctyp *restrict engv,
256                              const int eflag, const int vflag, const int inum,
257                              const int nbor_pitch, const int t_per_atom,
258                              int tablength) {
259   int tid, ii, offset;
260   atom_info(t_per_atom,ii,tid,offset);
261 
262   __local numtyp sp_lj[4];
263   int n_stride;
264   local_allocate_store_pair();
265 
266   sp_lj[0]=sp_lj_in[0];
267   sp_lj[1]=sp_lj_in[1];
268   sp_lj[2]=sp_lj_in[2];
269   sp_lj[3]=sp_lj_in[3];
270 
271   acctyp4 f;
272   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
273   acctyp energy, virial[6];
274   if (EVFLAG) {
275     energy=(acctyp)0;
276     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
277   }
278 
279   int tlm1 = tablength - 1;
280 
281   if (ii<inum) {
282     int nbor, nbor_end;
283     int i, numj;
284     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
285               n_stride,nbor_end,nbor);
286 
287     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
288     int itype=ix.w;
289 
290     numtyp factor_lj;
291     for ( ; nbor<nbor_end; nbor+=n_stride) {
292 
293       int j=dev_packed[nbor];
294       factor_lj = sp_lj[sbmask(j)];
295       j &= NEIGHMASK;
296 
297       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
298       int mtype=itype*lj_types+jx.w;
299       int tbindex = tabindex[mtype];
300 
301       // Compute r12
302       numtyp delx = ix.x-jx.x;
303       numtyp dely = ix.y-jx.y;
304       numtyp delz = ix.z-jx.z;
305       numtyp rsq = delx*delx+dely*dely+delz*delz;
306 
307       if (rsq<cutsq[mtype]) {
308         int itable=0,idx;
309         numtyp fraction=(numtyp)0;
310         numtyp value = (numtyp)0;
311         numtyp force = (numtyp)0;
312         itable = (rsq - coeff2[mtype].x) * coeff2[mtype].y;
313         if (itable < tlm1) {
314           idx = itable + tbindex*tablength;
315           fraction = (rsq - coeff3[idx].x) * coeff2[mtype].y;
316           value = coeff3[idx].z + fraction*coeff4[idx].z;
317           force = factor_lj * value;
318         } else force = (numtyp)0.0;
319 
320         f.x+=delx*force;
321         f.y+=dely*force;
322         f.z+=delz*force;
323 
324         if (EVFLAG && eflag) {
325           numtyp e = (numtyp)0.0;
326           if (itable < tlm1)
327             e = coeff3[idx].y + fraction*coeff4[idx].y;
328           energy+=factor_lj*e;
329         }
330         if (EVFLAG && vflag) {
331           virial[0] += delx*delx*force;
332           virial[1] += dely*dely*force;
333           virial[2] += delz*delz*force;
334           virial[3] += delx*dely*force;
335           virial[4] += delx*delz*force;
336           virial[5] += dely*delz*force;
337         }
338       }
339 
340     } // for nbor
341   } // if ii
342   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
343                 ans,engv);
344 }
345 
k_table_linear_fast(const __global numtyp4 * restrict x_,const __global int * restrict tabindex,const __global numtyp4 * restrict coeff2,const __global numtyp4 * restrict coeff3,const __global numtyp4 * restrict coeff4,const __global numtyp * restrict cutsq_in,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)346 __kernel void k_table_linear_fast(const __global numtyp4 *restrict x_,
347                                   const __global int *restrict tabindex,
348                                   const __global numtyp4 *restrict coeff2,
349                                   const __global numtyp4 *restrict coeff3,
350                                   const __global numtyp4 *restrict coeff4,
351                                   const __global numtyp *restrict cutsq_in,
352                                   const __global numtyp *restrict sp_lj_in,
353                                   const __global int *dev_nbor,
354                                   const __global int *dev_packed,
355                                   __global acctyp4 *restrict ans,
356                                   __global acctyp *restrict engv,
357                                   const int eflag, const int vflag,
358                                   const int inum, const int nbor_pitch,
359                                   const int t_per_atom, int tablength) {
360   int tid, ii, offset;
361   atom_info(t_per_atom,ii,tid,offset);
362 
363   __local numtyp cutsq[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
364   __local numtyp sp_lj[4];
365   int n_stride;
366   local_allocate_store_pair();
367 
368   if (tid<4)
369     sp_lj[tid]=sp_lj_in[tid];
370   if (tid<MAX_SHARED_TYPES*MAX_SHARED_TYPES) {
371     cutsq[tid]=cutsq_in[tid];
372   }
373 
374   acctyp4 f;
375   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
376   acctyp energy, virial[6];
377   if (EVFLAG) {
378     energy=(acctyp)0;
379     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
380   }
381 
382   __syncthreads();
383 
384   int tlm1 = tablength - 1;
385 
386   if (ii<inum) {
387     int nbor, nbor_end;
388     int i, numj;
389     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
390               n_stride,nbor_end,nbor);
391 
392     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
393     int iw=ix.w;
394     int itype=fast_mul((int)MAX_SHARED_TYPES,iw);
395 
396     numtyp factor_lj;
397     for ( ; nbor<nbor_end; nbor+=n_stride) {
398 
399       int j=dev_packed[nbor];
400       factor_lj = sp_lj[sbmask(j)];
401       j &= NEIGHMASK;
402 
403       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
404       int mtype=itype+jx.w;
405       int tbindex = tabindex[mtype];
406 
407       // Compute r12
408       numtyp delx = ix.x-jx.x;
409       numtyp dely = ix.y-jx.y;
410       numtyp delz = ix.z-jx.z;
411       numtyp rsq = delx*delx+dely*dely+delz*delz;
412 
413       if (rsq<cutsq[mtype]) {
414         int itable=0,idx;
415         numtyp fraction=(numtyp)0;
416         numtyp value = (numtyp)0;
417         numtyp force = (numtyp)0;
418         itable = (rsq - coeff2[mtype].x) * coeff2[mtype].y;
419         if (itable < tlm1) {
420           idx = itable + tbindex*tablength;
421           fraction = (rsq - coeff3[idx].x) * coeff2[mtype].y;
422           value = coeff3[idx].z + fraction*coeff4[idx].z;
423           force = factor_lj * value;
424         } else force = (numtyp)0.0;
425 
426         f.x+=delx*force;
427         f.y+=dely*force;
428         f.z+=delz*force;
429 
430         if (EVFLAG && eflag) {
431           numtyp e = (numtyp)0.0;
432           if (itable < tlm1)
433             e = coeff3[idx].y + fraction*coeff4[idx].y;
434           energy+=factor_lj*e;
435         }
436         if (EVFLAG && vflag) {
437           virial[0] += delx*delx*force;
438           virial[1] += dely*dely*force;
439           virial[2] += delz*delz*force;
440           virial[3] += delx*dely*force;
441           virial[4] += delx*delz*force;
442           virial[5] += dely*delz*force;
443         }
444       }
445 
446     } // for nbor
447   } // if ii
448   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
449                 ans,engv);
450 }
451 
452 /// ---------------- SPLINE -------------------------------------------------
453 
k_table_spline(const __global numtyp4 * restrict x_,const __global int * restrict tabindex,const __global numtyp4 * restrict coeff2,const __global numtyp4 * restrict coeff3,const __global numtyp4 * restrict coeff4,const int lj_types,const __global numtyp * restrict cutsq,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)454 __kernel void k_table_spline(const __global numtyp4 *restrict x_,
455                              const __global int *restrict tabindex,
456                              const __global numtyp4 *restrict coeff2,
457                              const __global numtyp4 *restrict coeff3,
458                              const __global numtyp4 *restrict coeff4,
459                              const int lj_types,
460                              const __global numtyp *restrict cutsq,
461                              const __global numtyp *restrict sp_lj_in,
462                              const __global int *dev_nbor,
463                              const __global int *dev_packed,
464                              __global acctyp4 *restrict ans,
465                              __global acctyp *restrict engv,
466                              const int eflag, const int vflag, const int inum,
467                              const int nbor_pitch, const int t_per_atom,
468                              int tablength) {
469   int tid, ii, offset;
470   atom_info(t_per_atom,ii,tid,offset);
471 
472   __local numtyp sp_lj[4];
473   int n_stride;
474   local_allocate_store_pair();
475 
476   sp_lj[0]=sp_lj_in[0];
477   sp_lj[1]=sp_lj_in[1];
478   sp_lj[2]=sp_lj_in[2];
479   sp_lj[3]=sp_lj_in[3];
480 
481   acctyp4 f;
482   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
483   acctyp energy, virial[6];
484   if (EVFLAG) {
485     energy=(acctyp)0;
486     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
487   }
488 
489   int tlm1 = tablength - 1;
490 
491   if (ii<inum) {
492     int nbor, nbor_end;
493     int i, numj;
494     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
495               n_stride,nbor_end,nbor);
496 
497     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
498     int itype=ix.w;
499 
500     numtyp factor_lj;
501     for ( ; nbor<nbor_end; nbor+=n_stride) {
502 
503       int j=dev_packed[nbor];
504       factor_lj = sp_lj[sbmask(j)];
505       j &= NEIGHMASK;
506 
507       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
508       int mtype=itype*lj_types+jx.w;
509       int tbindex = tabindex[mtype];
510 
511       // Compute r12
512       numtyp delx = ix.x-jx.x;
513       numtyp dely = ix.y-jx.y;
514       numtyp delz = ix.z-jx.z;
515       numtyp rsq = delx*delx+dely*dely+delz*delz;
516 
517       if (rsq<cutsq[mtype]) {
518         int itable=0,idx;
519         numtyp a = (numtyp)0;
520         numtyp b = (numtyp)0;
521         numtyp value = (numtyp)0;
522         numtyp force = (numtyp)0;
523         itable = (rsq - coeff2[mtype].x) * coeff2[mtype].y;
524         if (itable < tlm1) {
525           idx = itable + tbindex*tablength;
526           b = (rsq - coeff3[idx].x) * coeff2[mtype].y;
527           a = (numtyp)1.0 - b;
528           value = a * coeff3[idx].z + b * coeff3[idx+1].z +
529             ((a*a*a-a)*coeff4[idx].z + (b*b*b-b)*coeff4[idx+1].z) *
530                   coeff2[mtype].z;
531           force = factor_lj * value;
532         } else force = (numtyp)0.0;
533 
534         f.x+=delx*force;
535         f.y+=dely*force;
536         f.z+=delz*force;
537 
538         if (EVFLAG && eflag) {
539           numtyp e = (numtyp)0.0;
540           if (itable < tlm1) {
541             e = a * coeff3[idx].y + b * coeff3[idx+1].y +
542                 ((a*a*a-a)*coeff4[idx].y + (b*b*b-b)*coeff4[idx+1].y) *
543                   coeff2[mtype].z;
544           }
545           energy+=factor_lj*e;
546         }
547         if (EVFLAG && vflag) {
548           virial[0] += delx*delx*force;
549           virial[1] += dely*dely*force;
550           virial[2] += delz*delz*force;
551           virial[3] += delx*dely*force;
552           virial[4] += delx*delz*force;
553           virial[5] += dely*delz*force;
554         }
555       }
556 
557     } // for nbor
558   } // if ii
559   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
560                 ans,engv);
561 }
562 
k_table_spline_fast(const __global numtyp4 * x_,const __global int * tabindex,const __global numtyp4 * coeff2,const __global numtyp4 * coeff3,const __global numtyp4 * coeff4,const __global numtyp * cutsq_in,const __global numtyp * sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * ans,__global acctyp * engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)563 __kernel void k_table_spline_fast(const __global numtyp4 *x_,
564                                   const __global int *tabindex,
565                                   const __global numtyp4* coeff2,
566                                   const __global numtyp4 *coeff3,
567                                   const __global numtyp4 *coeff4,
568                                   const __global numtyp *cutsq_in,
569                                   const __global numtyp* sp_lj_in,
570                                   const __global int *dev_nbor,
571                                   const __global int *dev_packed,
572                                   __global acctyp4 *ans,
573                                   __global acctyp *engv,
574                                   const int eflag, const int vflag,
575                                   const int inum, const int nbor_pitch,
576                                   const int t_per_atom, int tablength) {
577   int tid, ii, offset;
578   atom_info(t_per_atom,ii,tid,offset);
579 
580   __local numtyp cutsq[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
581   __local numtyp sp_lj[4];
582   int n_stride;
583   local_allocate_store_pair();
584 
585   if (tid<4)
586     sp_lj[tid]=sp_lj_in[tid];
587   if (tid<MAX_SHARED_TYPES*MAX_SHARED_TYPES) {
588     cutsq[tid]=cutsq_in[tid];
589   }
590 
591   acctyp4 f;
592   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
593   acctyp energy, virial[6];
594   if (EVFLAG) {
595     energy=(acctyp)0;
596     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
597   }
598   __syncthreads();
599 
600   int tlm1 = tablength - 1;
601 
602   if (ii<inum) {
603     int nbor, nbor_end;
604     int i, numj;
605     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
606               n_stride,nbor_end,nbor);
607 
608     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
609     int iw=ix.w;
610     int itype=fast_mul((int)MAX_SHARED_TYPES,iw);
611 
612     numtyp factor_lj;
613     for ( ; nbor<nbor_end; nbor+=n_stride) {
614 
615       int j=dev_packed[nbor];
616       factor_lj = sp_lj[sbmask(j)];
617       j &= NEIGHMASK;
618 
619       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
620       int mtype=itype+jx.w;
621       int tbindex = tabindex[mtype];
622 
623       // Compute r12
624       numtyp delx = ix.x-jx.x;
625       numtyp dely = ix.y-jx.y;
626       numtyp delz = ix.z-jx.z;
627       numtyp rsq = delx*delx+dely*dely+delz*delz;
628 
629       if (rsq<cutsq[mtype]) {
630         int itable=0,idx;
631         numtyp a = (numtyp)0;
632         numtyp b = (numtyp)0;
633         numtyp value = (numtyp)0;
634         numtyp force = (numtyp)0;
635         itable = (rsq - coeff2[mtype].x) * coeff2[mtype].y;
636         if (itable < tlm1) {
637           idx = itable + tbindex*tablength;
638           b = (rsq - coeff3[idx].x) * coeff2[mtype].y;
639           a = (numtyp)1.0 - b;
640           value = a * coeff3[idx].z + b * coeff3[idx+1].z +
641             ((a*a*a-a)*coeff4[idx].z + (b*b*b-b)*coeff4[idx+1].z) *
642                   coeff2[mtype].z;
643           force = factor_lj * value;
644         } else force = (numtyp)0.0;
645 
646         f.x+=delx*force;
647         f.y+=dely*force;
648         f.z+=delz*force;
649 
650         if (EVFLAG && eflag) {
651           numtyp e = (numtyp)0.0;
652           if (itable < tlm1) {
653             e = a * coeff3[idx].y + b * coeff3[idx+1].y +
654                 ((a*a*a-a)*coeff4[idx].y + (b*b*b-b)*coeff4[idx+1].y) *
655                   coeff2[mtype].z;
656           }
657           energy+=factor_lj*e;
658         }
659         if (EVFLAG && vflag) {
660           virial[0] += delx*delx*force;
661           virial[1] += dely*dely*force;
662           virial[2] += delz*delz*force;
663           virial[3] += delx*dely*force;
664           virial[4] += delx*delz*force;
665           virial[5] += dely*delz*force;
666         }
667       }
668 
669     } // for nbor
670   } // if ii
671   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
672                 ans,engv);
673 }
674 
675 /// ---------------- BITMAP -------------------------------------------------
676 
k_table_bitmap(const __global numtyp4 * x_,const __global int * tabindex,const __global int * nshiftbits,const __global int * nmask,const __global numtyp4 * coeff2,const __global numtyp4 * coeff3,const __global numtyp4 * coeff4,const int lj_types,const __global numtyp * cutsq,const __global numtyp * sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * ans,__global acctyp * engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)677 __kernel void k_table_bitmap(const __global numtyp4 *x_,
678                              const __global int *tabindex,
679                              const __global int *nshiftbits,
680                              const __global int *nmask,
681                              const __global numtyp4* coeff2,
682                              const __global numtyp4 *coeff3,
683                              const __global numtyp4 *coeff4,
684                              const int lj_types,
685                              const __global numtyp *cutsq,
686                              const __global numtyp* sp_lj_in,
687                              const __global int *dev_nbor,
688                              const __global int *dev_packed,
689                              __global acctyp4 *ans,
690                              __global acctyp *engv,
691                              const int eflag, const int vflag, const int inum,
692                              const int nbor_pitch, const int t_per_atom,
693                              int tablength) {
694   int tid, ii, offset;
695   atom_info(t_per_atom,ii,tid,offset);
696 
697   __local numtyp sp_lj[4];
698   int n_stride;
699   local_allocate_store_pair();
700 
701   sp_lj[0]=sp_lj_in[0];
702   sp_lj[1]=sp_lj_in[1];
703   sp_lj[2]=sp_lj_in[2];
704   sp_lj[3]=sp_lj_in[3];
705 
706   acctyp4 f;
707   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
708   acctyp energy, virial[6];
709   if (EVFLAG) {
710     energy=(acctyp)0;
711     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
712   }
713 
714   int tlm1 = tablength - 1;
715 
716   if (ii<inum) {
717     int nbor, nbor_end;
718     int i, numj;
719     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
720               n_stride,nbor_end,nbor);
721 
722     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
723     int itype=ix.w;
724 
725     numtyp factor_lj;
726     for ( ; nbor<nbor_end; nbor+=n_stride) {
727 
728       int j=dev_packed[nbor];
729       factor_lj = sp_lj[sbmask(j)];
730       j &= NEIGHMASK;
731 
732       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
733       int mtype=itype*lj_types+jx.w;
734       int tbindex = tabindex[mtype];
735 
736       // Compute r12
737       numtyp delx = ix.x-jx.x;
738       numtyp dely = ix.y-jx.y;
739       numtyp delz = ix.z-jx.z;
740       numtyp rsq = delx*delx+dely*dely+delz*delz;
741 
742       if (rsq<cutsq[mtype]) {
743         int itable=0,idx;
744         numtyp fraction=(numtyp)0;
745         numtyp value = (numtyp)0;
746         numtyp force = (numtyp)0;
747         union_int_float rsq_lookup;
748         rsq_lookup.f = rsq;
749         itable = rsq_lookup.i & nmask[mtype];
750         itable >>= nshiftbits[mtype];
751         if (itable <= tlm1) {
752           idx = itable + tbindex*tablength;
753           fraction = (rsq_lookup.f - coeff3[idx].x) * coeff4[idx].w;
754           value = coeff3[idx].z + fraction*coeff4[idx].z;
755           force = factor_lj * value;
756         } else force = (numtyp)0.0;
757 
758         f.x+=delx*force;
759         f.y+=dely*force;
760         f.z+=delz*force;
761 
762         if (EVFLAG && eflag) {
763           numtyp e = (numtyp)0.0;
764           if (itable <= tlm1)
765             e = coeff3[idx].y + fraction*coeff4[idx].y;
766           energy+=factor_lj*e;
767         }
768         if (EVFLAG && vflag) {
769           virial[0] += delx*delx*force;
770           virial[1] += dely*dely*force;
771           virial[2] += delz*delz*force;
772           virial[3] += delx*dely*force;
773           virial[4] += delx*delz*force;
774           virial[5] += dely*delz*force;
775         }
776       }
777 
778     } // for nbor
779   } // if ii
780   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
781                 ans,engv);
782 }
783 
k_table_bitmap_fast(const __global numtyp4 * x_,const __global int * tabindex,const __global int * nshiftbits,const __global int * nmask,const __global numtyp4 * coeff2,const __global numtyp4 * coeff3,const __global numtyp4 * coeff4,const __global numtyp * cutsq_in,const __global numtyp * sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * ans,__global acctyp * engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom,int tablength)784 __kernel void k_table_bitmap_fast(const __global numtyp4 *x_,
785                                   const __global int *tabindex,
786                                   const __global int *nshiftbits,
787                                   const __global int *nmask,
788                                   const __global numtyp4* coeff2,
789                                   const __global numtyp4 *coeff3,
790                                   const __global numtyp4 *coeff4,
791                                   const __global numtyp *cutsq_in,
792                                   const __global numtyp* sp_lj_in,
793                                   const __global int *dev_nbor,
794                                   const __global int *dev_packed,
795                                   __global acctyp4 *ans,
796                                   __global acctyp *engv,
797                                   const int eflag, const int vflag,
798                                   const int inum, const int nbor_pitch,
799                                   const int t_per_atom, int tablength) {
800   int tid, ii, offset;
801   atom_info(t_per_atom,ii,tid,offset);
802 
803   __local numtyp cutsq[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
804   __local numtyp sp_lj[4];
805   int n_stride;
806   local_allocate_store_pair();
807 
808   if (tid<4)
809     sp_lj[tid]=sp_lj_in[tid];
810   if (tid<MAX_SHARED_TYPES*MAX_SHARED_TYPES) {
811     cutsq[tid]=cutsq_in[tid];
812   }
813 
814   acctyp4 f;
815   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
816   acctyp energy, virial[6];
817   if (EVFLAG) {
818     energy=(acctyp)0;
819     for (int i=0; i<6; i++) virial[i]=(acctyp)0;
820   }
821 
822   __syncthreads();
823 
824   int tlm1 = tablength - 1;
825 
826   if (ii<inum) {
827     int nbor, nbor_end;
828     int i, numj;
829     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
830               n_stride,nbor_end,nbor);
831 
832     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
833     int iw=ix.w;
834     int itype=fast_mul((int)MAX_SHARED_TYPES,iw);
835 
836     numtyp factor_lj;
837     for ( ; nbor<nbor_end; nbor+=n_stride) {
838 
839       int j=dev_packed[nbor];
840       factor_lj = sp_lj[sbmask(j)];
841       j &= NEIGHMASK;
842 
843       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
844       int mtype=itype+jx.w;
845       int tbindex = tabindex[mtype];
846 
847       // Compute r12
848       numtyp delx = ix.x-jx.x;
849       numtyp dely = ix.y-jx.y;
850       numtyp delz = ix.z-jx.z;
851       numtyp rsq = delx*delx+dely*dely+delz*delz;
852 
853       if (rsq<cutsq[mtype]) {
854         int itable=0,idx;
855         numtyp fraction=(numtyp)0;
856         numtyp value = (numtyp)0;
857         numtyp force = (numtyp)0;
858         union_int_float rsq_lookup;
859         rsq_lookup.f = rsq;
860         itable = rsq_lookup.i & nmask[mtype];
861         itable >>= nshiftbits[mtype];
862         if (itable <= tlm1) {
863           idx = itable + tbindex*tablength;
864           fraction = (rsq_lookup.f - coeff3[idx].x) * coeff4[idx].w;
865           value = coeff3[idx].z + fraction*coeff4[idx].z;
866           force = factor_lj * value;
867         } else force = (numtyp)0.0;
868 
869         f.x+=delx*force;
870         f.y+=dely*force;
871         f.z+=delz*force;
872 
873         if (EVFLAG && eflag) {
874           numtyp e = (numtyp)0.0;
875           if (itable <= tlm1)
876             e = coeff3[idx].y + fraction*coeff4[idx].y;
877           energy+=factor_lj*e;
878         }
879         if (EVFLAG && vflag) {
880           virial[0] += delx*delx*force;
881           virial[1] += dely*dely*force;
882           virial[2] += delz*delz*force;
883           virial[3] += delx*dely*force;
884           virial[4] += delx*delz*force;
885           virial[5] += dely*delz*force;
886         }
887       }
888 
889     } // for nbor
890   } // if ii
891   store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
892                 ans,engv);
893 }
894