1 // **************************************************************************
2 //                                   gauss.cu
3 //                             -------------------
4 //                           Trung Dac Nguyen (ORNL)
5 //
6 //  Device code for acceleration of the gauss pair style
7 //
8 // __________________________________________________________________________
9 //    This file is part of the LAMMPS Accelerator Library (LAMMPS_AL)
10 // __________________________________________________________________________
11 //
12 //    begin                :
13 //    email                : nguyentd@ornl.gov
14 // ***************************************************************************/
15 
16 #ifdef NV_KERNEL
17 #include "lal_aux_fun1.h"
18 #ifndef _DOUBLE_DOUBLE
19 texture<float4> pos_tex;
20 #else
21 texture<int4,1> pos_tex;
22 #endif
23 #else
24 #define pos_tex x_
25 #endif
26 
k_gauss(const __global numtyp4 * restrict x_,const __global numtyp4 * restrict gauss1,const int lj_types,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom)27 __kernel void k_gauss(const __global numtyp4 *restrict x_,
28                       const __global numtyp4 *restrict gauss1,
29                       const int lj_types,
30                       const __global numtyp *restrict sp_lj_in,
31                       const __global int *dev_nbor,
32                       const __global int *dev_packed,
33                       __global acctyp4 *restrict ans,
34                       __global acctyp *restrict engv,
35                       const int eflag, const int vflag, const int inum,
36                       const int nbor_pitch, const int t_per_atom) {
37   int tid, ii, offset;
38   atom_info(t_per_atom,ii,tid,offset);
39 
40   __local numtyp sp_lj[4];
41   sp_lj[0]=sp_lj_in[0];
42   sp_lj[1]=sp_lj_in[1];
43   sp_lj[2]=sp_lj_in[2];
44   sp_lj[3]=sp_lj_in[3];
45 
46   acctyp energy=(acctyp)0;
47   acctyp4 f;
48   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
49   acctyp virial[6];
50   for (int i=0; i<6; i++)
51     virial[i]=(acctyp)0;
52 
53   if (ii<inum) {
54     const __global int *nbor, *list_end;
55     int i, numj;
56     __local int n_stride;
57     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
58               n_stride,list_end,nbor);
59 
60     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
61     int itype=ix.w;
62 
63     numtyp factor_lj;
64     for ( ; nbor<list_end; nbor+=n_stride) {
65 
66       int j=*nbor;
67       factor_lj = sp_lj[sbmask(j)];
68       j &= NEIGHMASK;
69 
70       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
71       int jtype=jx.w;
72 
73       // Compute r12
74       numtyp delx = ix.x-jx.x;
75       numtyp dely = ix.y-jx.y;
76       numtyp delz = ix.z-jx.z;
77       numtyp rsq = delx*delx+dely*dely+delz*delz;
78 
79       int mtype=itype*lj_types+jtype;
80       if (rsq<gauss1[mtype].z) {
81         numtyp r2inv = ucl_recip(rsq);
82         numtyp r = ucl_sqrt(rsq);
83         numtyp force = (numtyp)-2.0*gauss1[mtype].x*gauss1[mtype].y*rsq*
84         ucl_exp(-gauss1[mtype].y*rsq)*r2inv*factor_lj;
85 
86         f.x+=delx*force;
87         f.y+=dely*force;
88         f.z+=delz*force;
89 
90         if (eflag>0) {
91           numtyp e=-(gauss1[mtype].x*ucl_exp(-gauss1[mtype].y*rsq) -
92             gauss1[mtype].w);
93           energy+=factor_lj*e;
94         }
95         if (vflag>0) {
96           virial[0] += delx*delx*force;
97           virial[1] += dely*dely*force;
98           virial[2] += delz*delz*force;
99           virial[3] += delx*dely*force;
100           virial[4] += delx*delz*force;
101           virial[5] += dely*delz*force;
102         }
103       }
104 
105     } // for nbor
106     store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
107                   ans,engv);
108   } // if ii
109 }
110 
k_gauss_fast(const __global numtyp4 * restrict x_,const __global numtyp4 * restrict gauss1_in,const __global numtyp * restrict sp_lj_in,const __global int * dev_nbor,const __global int * dev_packed,__global acctyp4 * restrict ans,__global acctyp * restrict engv,const int eflag,const int vflag,const int inum,const int nbor_pitch,const int t_per_atom)111 __kernel void k_gauss_fast(const __global numtyp4 *restrict x_,
112                            const __global numtyp4 *restrict gauss1_in,
113                            const __global numtyp *restrict sp_lj_in,
114                            const __global int *dev_nbor,
115                            const __global int *dev_packed,
116                            __global acctyp4 *restrict ans,
117                            __global acctyp *restrict engv,
118                            const int eflag, const int vflag, const int inum,
119                            const int nbor_pitch, const int t_per_atom) {
120   int tid, ii, offset;
121   atom_info(t_per_atom,ii,tid,offset);
122 
123   __local numtyp4 gauss1[MAX_SHARED_TYPES*MAX_SHARED_TYPES];
124   __local numtyp sp_lj[4];
125   if (tid<4)
126     sp_lj[tid]=sp_lj_in[tid];
127   if (tid<MAX_SHARED_TYPES*MAX_SHARED_TYPES) {
128     gauss1[tid]=gauss1_in[tid];
129   }
130 
131   acctyp energy=(acctyp)0;
132   acctyp4 f;
133   f.x=(acctyp)0; f.y=(acctyp)0; f.z=(acctyp)0;
134   acctyp virial[6];
135   for (int i=0; i<6; i++)
136     virial[i]=(acctyp)0;
137 
138   __syncthreads();
139 
140   if (ii<inum) {
141     const __global int *nbor, *list_end;
142     int i, numj;
143     __local int n_stride;
144     nbor_info(dev_nbor,dev_packed,nbor_pitch,t_per_atom,ii,offset,i,numj,
145               n_stride,list_end,nbor);
146 
147     numtyp4 ix; fetch4(ix,i,pos_tex); //x_[i];
148     int iw=ix.w;
149     int itype=fast_mul((int)MAX_SHARED_TYPES,iw);
150 
151     numtyp factor_lj;
152     for ( ; nbor<list_end; nbor+=n_stride) {
153 
154       int j=*nbor;
155       factor_lj = sp_lj[sbmask(j)];
156       j &= NEIGHMASK;
157 
158       numtyp4 jx; fetch4(jx,j,pos_tex); //x_[j];
159       int mtype=itype+jx.w;
160 
161       // Compute r12
162       numtyp delx = ix.x-jx.x;
163       numtyp dely = ix.y-jx.y;
164       numtyp delz = ix.z-jx.z;
165       numtyp rsq = delx*delx+dely*dely+delz*delz;
166 
167       if (rsq<gauss1[mtype].z) {
168         numtyp r2inv = ucl_recip(rsq);
169         numtyp r = ucl_sqrt(rsq);
170         numtyp force = (numtyp)-2.0*gauss1[mtype].x*gauss1[mtype].y*rsq*
171         ucl_exp(-gauss1[mtype].y*rsq)*r2inv*factor_lj;
172 
173         f.x+=delx*force;
174         f.y+=dely*force;
175         f.z+=delz*force;
176 
177         if (eflag>0) {
178           numtyp e=-(gauss1[mtype].x*ucl_exp(-gauss1[mtype].y*rsq) -
179             gauss1[mtype].w);
180           energy+=factor_lj*e;
181         }
182         if (vflag>0) {
183           virial[0] += delx*delx*force;
184           virial[1] += dely*dely*force;
185           virial[2] += delz*delz*force;
186           virial[3] += delx*dely*force;
187           virial[4] += delx*delz*force;
188           virial[5] += dely*delz*force;
189         }
190       }
191 
192     } // for nbor
193     store_answers(f,energy,virial,ii,inum,tid,t_per_atom,offset,eflag,vflag,
194                   ans,engv);
195   } // if ii
196 }
197 
198