1
2// =================================================================================================
3// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5// width of 100 characters per line.
6//
7// Author(s):
8//   Cedric Nugteren <www.cedricnugteren.nl>
9//
10// This file contains the Xgemv kernel (fast versions) for matrix-vector multiplication.
11//
12// =================================================================================================
13
14// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
15// literal). Comment-out this line for syntax-highlighting when developing.
16R"(
17
18// =================================================================================================
19
20// Parameters set by the tuner or by the database. Here they are given a basic default value in case
21// this kernel file is used outside of the CLBlast library.
22
23// 1: For the full version, see 'xgemv.opencl'
24
25// 2: For the fast version
26#ifndef WGS2
27  #define WGS2 64     // The local work-group size
28#endif
29#ifndef WPT2
30  #define WPT2 1      // The amount of work-per-thread
31#endif
32#ifndef VW2
33  #define VW2 1       // Vector width of matrix A loads
34#endif
35
36// 3: For the fast rotated version
37#ifndef WGS3
38  #define WGS3 64     // The local work-group size
39#endif
40#ifndef WPT3
41  #define WPT3 1      // The tile-size
42#endif
43#ifndef VW3
44  #define VW3 1       // Vector width of matrix A loads
45#endif
46
47// =================================================================================================
48
49// Data-widths for the 'fast' kernel
50#if VW2 == 1
51  typedef real realVF;
52#elif VW2 == 2
53  typedef real2 realVF;
54#elif VW2 == 4
55  typedef real4 realVF;
56#elif VW2 == 8
57  typedef real8 realVF;
58#elif VW2 == 16
59  typedef real16 realVF;
60#endif
61
62// Data-widths for the 'fast' kernel with rotated matrix
63#if VW3 == 1
64  typedef real realVFR;
65#elif VW3 == 2
66  typedef real2 realVFR;
67#elif VW3 == 4
68  typedef real4 realVFR;
69#elif VW3 == 8
70  typedef real8 realVFR;
71#elif VW3 == 16
72  typedef real16 realVFR;
73#endif
74
75// =================================================================================================
76
77// Loads a vector input value
78INLINE_FUNC realVF LoadMatrixAVF(const __global realVF* restrict agm, const int x, const int y,
79                                 const int a_ld) {
80  return agm[a_ld*y + x];
81}
82
83// =================================================================================================
84
85// Faster version of the kernel, assuming that:
86// --> 'm' and 'n' are multiples of WGS2
87// --> 'a_offset' is 0
88// --> 'a_ld' is a multiple of VW2
89// --> 'a_rotated' is 0
90// --> 'do_conjugate' is 0
91__kernel __attribute__((reqd_work_group_size(WGS2, 1, 1)))
92void XgemvFast(const int m, const int n,
93               const real_arg arg_alpha,
94               const real_arg arg_beta,
95               const int a_rotated,
96               const __global realVF* restrict agm, const int a_offset, const int a_ld,
97               const __global real* restrict xgm, const int x_offset, const int x_inc,
98               __global real* ygm, const int y_offset, const int y_inc,
99               const int do_conjugate, const int parameter,
100               const int kl_unused, const int ku_unused) {
101  const real alpha = GetRealArg(arg_alpha);
102  const real beta = GetRealArg(arg_beta);
103
104  // Local memory for the vector X
105  __local real xlm[WGS2];
106
107  // Initializes the accumulation registers
108  real acc[WPT2];
109  #pragma unroll
110  for (int w=0; w<WPT2; ++w) {
111    SetToZero(acc[w]);
112  }
113
114  // Loops over work-group sized portions of the work
115  for (int kwg=0; kwg<n; kwg+=WGS2) {
116
117    // Loads the vector X into local memory
118    const int lid = get_local_id(0);
119    xlm[lid] = xgm[(kwg + lid)*x_inc + x_offset];
120
121    // Synchronizes all threads in a workgroup
122    barrier(CLK_LOCAL_MEM_FENCE);
123
124    // The multiply-add function (not rotated)
125    #pragma unroll
126    for (int kl=0; kl<WGS2; ++kl) {
127      const int k = kwg + kl;
128      #pragma unroll
129      for (int w=0; w<WPT2/VW2; ++w) {
130        const int gid = (WPT2/VW2)*get_global_id(0) + w;
131        realVF avec = agm[(a_ld/VW2)*k + gid];
132        #if VW2 == 1
133          MultiplyAdd(acc[VW2*w+0], xlm[kl], avec);
134        #elif VW2 == 2
135          MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x);
136          MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y);
137        #elif VW2 == 4
138          MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x);
139          MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y);
140          MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.z);
141          MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.w);
142        #elif VW2 == 8
143          MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0);
144          MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1);
145          MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2);
146          MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3);
147          MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4);
148          MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5);
149          MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6);
150          MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7);
151        #elif VW2 == 16
152          MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0);
153          MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1);
154          MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2);
155          MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3);
156          MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4);
157          MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5);
158          MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6);
159          MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7);
160          MultiplyAdd(acc[VW2*w+8], xlm[kl], avec.s8);
161          MultiplyAdd(acc[VW2*w+9], xlm[kl], avec.s9);
162          MultiplyAdd(acc[VW2*w+10], xlm[kl], avec.sA);
163          MultiplyAdd(acc[VW2*w+11], xlm[kl], avec.sB);
164          MultiplyAdd(acc[VW2*w+12], xlm[kl], avec.sC);
165          MultiplyAdd(acc[VW2*w+13], xlm[kl], avec.sD);
166          MultiplyAdd(acc[VW2*w+14], xlm[kl], avec.sE);
167          MultiplyAdd(acc[VW2*w+15], xlm[kl], avec.sF);
168        #endif
169      }
170    }
171
172    // Synchronizes all threads in a workgroup
173    barrier(CLK_LOCAL_MEM_FENCE);
174  }
175
176  // Stores the final result
177  #pragma unroll
178  for (int w=0; w<WPT2; ++w) {
179    const int gid = WPT2*get_global_id(0) + w;
180    real yval = ygm[gid*y_inc + y_offset];
181    AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[w], beta, yval);
182  }
183}
184
185// =================================================================================================
186
187// Faster version of the kernel, assuming that:
188// --> 'm' and 'n' are multiples of WGS3
189// --> 'a_offset' is 0
190// --> 'a_ld' is a multiple of VW3
191// --> 'a_rotated' is 1
192// --> 'do_conjugate' is 0
193__kernel __attribute__((reqd_work_group_size(WGS3, 1, 1)))
194void XgemvFastRot(const int m, const int n,
195                  const real_arg arg_alpha,
196                  const real_arg arg_beta,
197                  const int a_rotated,
198                  const __global realVFR* restrict agm, const int a_offset, const int a_ld,
199                  const __global real* restrict xgm, const int x_offset, const int x_inc,
200                  __global real* ygm, const int y_offset, const int y_inc,
201                  const int do_conjugate, const int parameter,
202                  const int kl_unused, const int ku_unused) {
203  const real alpha = GetRealArg(arg_alpha);
204  const real beta = GetRealArg(arg_beta);
205
206  // Local memory to store a tile of the matrix (for coalescing)
207  __local real tile[WPT3][WGS3];
208  const int lid = get_local_id(0);
209  const int lid_mod = lid % (WPT3/VW3);
210  const int lid_div = lid / (WPT3/VW3);
211
212  // Local memory for the vector X
213  __local real xlm[WPT3];
214
215  // Initializes the accumulation register
216  real acc;
217  SetToZero(acc);
218
219  // Loops over tile-sized portions of the work
220  for (int kwg=0; kwg<n; kwg+=WPT3) {
221
222    // Loads the vector X into local memory
223    if (lid < WPT3) {
224      xlm[lid] = xgm[(kwg + lid) * x_inc + x_offset];
225    }
226
227    // Loads the matrix A into local memory
228    #pragma unroll
229    for (int kl=0; kl<WPT3/VW3; ++kl) {
230      const int x = (kwg/VW3) + lid_mod;
231      const int y = get_group_id(0) * WGS3 + lid_div * (WPT3/VW3) + kl;
232      realVFR avec = agm[(a_ld/VW3) * y + x];
233      #if VW3 == 1
234        tile[kl*VW3 + 0][lid] = avec;
235      #elif VW3 == 2
236        tile[kl*VW3 + 0][lid] = avec.x;
237        tile[kl*VW3 + 1][lid] = avec.y;
238      #elif VW3 == 4
239        tile[kl*VW3 + 0][lid] = avec.x;
240        tile[kl*VW3 + 1][lid] = avec.y;
241        tile[kl*VW3 + 2][lid] = avec.z;
242        tile[kl*VW3 + 3][lid] = avec.w;
243      #elif VW3 == 8
244        tile[kl*VW3 + 0][lid] = avec.s0;
245        tile[kl*VW3 + 1][lid] = avec.s1;
246        tile[kl*VW3 + 2][lid] = avec.s2;
247        tile[kl*VW3 + 3][lid] = avec.s3;
248        tile[kl*VW3 + 4][lid] = avec.s4;
249        tile[kl*VW3 + 5][lid] = avec.s5;
250        tile[kl*VW3 + 6][lid] = avec.s6;
251        tile[kl*VW3 + 7][lid] = avec.s7;
252      #elif VW3 == 16
253        tile[kl*VW3 + 0][lid] = avec.s0;
254        tile[kl*VW3 + 1][lid] = avec.s1;
255        tile[kl*VW3 + 2][lid] = avec.s2;
256        tile[kl*VW3 + 3][lid] = avec.s3;
257        tile[kl*VW3 + 4][lid] = avec.s4;
258        tile[kl*VW3 + 5][lid] = avec.s5;
259        tile[kl*VW3 + 6][lid] = avec.s6;
260        tile[kl*VW3 + 7][lid] = avec.s7;
261        tile[kl*VW3 + 8][lid] = avec.s8;
262        tile[kl*VW3 + 9][lid] = avec.s9;
263        tile[kl*VW3 + 10][lid] = avec.sA;
264        tile[kl*VW3 + 11][lid] = avec.sB;
265        tile[kl*VW3 + 12][lid] = avec.sC;
266        tile[kl*VW3 + 13][lid] = avec.sD;
267        tile[kl*VW3 + 14][lid] = avec.sE;
268        tile[kl*VW3 + 15][lid] = avec.sF;
269      #endif
270    }
271
272    // Synchronizes all threads in a workgroup
273    barrier(CLK_LOCAL_MEM_FENCE);
274
275    // The multiply-add function (rotated)
276    #pragma unroll
277    for (int kl=0; kl<WPT3/VW3; ++kl) {
278      #pragma unroll
279      for (int v=0; v<VW3; ++v) {
280        real aval = tile[lid_mod*VW3 + v][lid_div * (WPT3/VW3) + kl];
281        real xval = xlm[kl*VW3 + v];
282        MultiplyAdd(acc, xval, aval);
283      }
284    }
285
286    // Synchronizes all threads in a workgroup
287    barrier(CLK_LOCAL_MEM_FENCE);
288  }
289
290  // Stores the final result
291  const int gid = get_global_id(0);
292  real yval = ygm[gid * y_inc + y_offset];
293  AXPBY(ygm[gid * y_inc + y_offset], alpha, acc, beta, yval);
294}
295
296// =================================================================================================
297
298// End of the C++11 raw string literal
299)"
300
301// =================================================================================================
302