1
2// =================================================================================================
3// This file is part of the CLTune project, which loosely follows the Google C++ styleguide and uses
4// a tab-size of two spaces and a max-width of 100 characters per line.
5//
6// Author: cedric.nugteren@surfsara.nl (Cedric Nugteren)
7//
8// This file contains an example OpenCL kernel as part of the gemm.cc example. It is an optimized
9// matrix-multiplication kernel according to the paper by Matsumoto et al. and the tutorial on
10// http://www.cedricnugteren.nl/tutorial.php. It is fully configurable (and tunable!) using more or
11// less the same parameters/naming conventions as in the paper. It supports single and double
12// precision (SGEMM/DGEMM) through a pre-processor define.
13//
14// Note: this kernel requires a compiler compliant to OpenCL 1.1 or higher.
15//
16// -------------------------------------------------------------------------------------------------
17//
18// Copyright 2014 SURFsara
19//
20// Licensed under the Apache License, Version 2.0 (the "License");
21// you may not use this file except in compliance with the License.
22// You may obtain a copy of the License at
23//
24//  http://www.apache.org/licenses/LICENSE-2.0
25//
26// Unless required by applicable law or agreed to in writing, software
27// distributed under the License is distributed on an "AS IS" BASIS,
28// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29// See the License for the specific language governing permissions and
30// limitations under the License.
31//
32// =================================================================================================
33//
34// Matrices are accessed as follows:
35// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m)
36// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n)
37// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m)
38//
39// Or as an image (assuming column-major)
40//       K
41//    o-------o
42//    |       |
43//  N | [B^T] |
44//    |       |
45//    o-------o
46//        K               N
47//    o-------o        o-----o
48//  M |  [A]  |      M | [C] |
49//    |       |        |     |
50//    o-------o        o-----o
51//
52//
53// Parameters determined by the tuner
54// MWG       : Tile-size in dimension M (e.g. 64, 128)
55// NWG       : Tile-size in dimension N (e.g. 64, 128)
56// KWG       : Tile-size in dimension K (e.g. 8, 16)
57// MDIMC     : Threads per workgroup in M-dimension (e.g. 8, 16, 32)
58// NDIMC     : Threads per workgroup in N-dimension (e.g. 8, 16, 32)
59// MDIMA     : Re-shaped tile dimension of matrix A: KDIMA * MDIMA
60// NDIMB     : Re-shaped tile dimension of matrix B: KDIMB * NDIMB
61// KWI       : Unroll factor of the KWG loop (smaller or equal than KWG)
62// VWM       : Vector width of matrices A and C (supported 1, 2, 4, and 8)
63// VWN       : Vector width of matrix B (supported 1, 2, 4, and 8)
64// STRM      : Use strided access within a thread in the M-dimension (1) or not (0)
65// STRN      : Use strided access within a thread in the N-dimension (1) or not (0)
66// SA        : Use local/shared memory to cache matrix A (1) or not (0)
67// SB        : Use local/shared memory to cache matrix B (1) or not (0)
68// PRECISION : Whether to use single (32) or double (64) precision data-types
69//
70// =================================================================================================
71
72// Helper parameters based on the above tuning parameters
73#define MWI (MWG/MDIMC)               // Work per work-item (M-dimension)
74#define NWI (NWG/NDIMC)               // Work per work-item (N-dimension)
75#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
76#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
77#define MWA (MWG/MDIMA)               // Amount of loads-per-thread for matrix A (M-dimension)
78#define KWA (KWG/KDIMA)               // Amount of loads-per-thread for matrix A (K-dimension)
79#define KWB (KWG/KDIMB)               // Amount of loads-per-thread for matrix B (K-dimension)
80#define NWB (NWG/NDIMB)               // Amount of loads-per-thread for matrix B (N-dimension)
81
82// Settings
83#define USE_VECTOR_MAD 1              // Don't unroll the vector MAD computation
84#define USE_CL_MAD 0                  // Uses the non-IEEE754 compliant OpenCL mad() (if above is 0)
85
86// =================================================================================================
87
88// Data-type: single or double precision
89#if PRECISION == 32
90  typedef float real;
91  typedef float2 real2;
92  typedef float4 real4;
93  typedef float8 real8;
94  typedef float16 real16;
95  #define ZERO 0.0f
96#elif PRECISION == 64
97  #if __OPENCL_VERSION__ <= CL_VERSION_1_1 // This the default on OpenCL 1.2 or higher
98     #pragma OPENCL EXTENSION cl_khr_fp64: enable
99  #endif
100  typedef double real;
101  typedef double2 real2;
102  typedef double4 real4;
103  typedef double8 real8;
104  typedef double16 real16;
105  #define ZERO 0.0
106#endif
107
108// =================================================================================================
109
110// Data-widths in dimension M
111#if VWM == 1
112    typedef real realM;
113#elif VWM == 2
114    typedef real2 realM;
115#elif VWM == 4
116    typedef real4 realM;
117#elif VWM == 8
118    typedef real8 realM;
119#elif VWM == 16
120    typedef real16 realM;
121#endif
122
123// Data-widths in dimension N
124#if VWN == 1
125    typedef real realN;
126#elif VWN == 2
127    typedef real2 realN;
128#elif VWN == 4
129    typedef real4 realN;
130#elif VWN == 8
131    typedef real8 realN;
132#elif VWN == 16
133    typedef real16 realN;
134#endif
135
136// =================================================================================================
137
138// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
139// caching the A input matrix.
140#if SA == 1
141inline void GlobalToLocalA(const __global realM* restrict agm, __local realM* alm,
142                           const int kSizeM, const int tid, const int kwg) {
143  const int la0 = tid % MDIMA;
144  const int la1 = tid / MDIMA;
145  #pragma unroll
146  for (int mia=0; mia<MWA/VWM; ++mia) {
147    #pragma unroll
148    for (int kia=0; kia<KWA; ++kia) {
149
150      // Computes the indices based on strided/non-strided access
151      #if STRM == 0
152        int mg = mia + la0*(MWA/VWM);
153      #elif STRM == 1
154        int mg = la0 + mia*MDIMA;
155      #endif
156
157      // Computes the indices for the global memory
158      int kg = kia + la1*KWA;
159      int idm = mg + get_group_id(0)*(MWG/VWM);
160      int idk = kg + kwg;
161
162      // Loads the data from global memory (not transposed) into the local memory
163      alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm];
164    }
165  }
166}
167#endif
168
169// Same as above, but now for the B input matrix
170#if SB == 1
171inline void GlobalToLocalB(const __global realN* restrict bgm, __local realN* blm,
172                           const int kSizeN, const int tid, const int kwg) {
173  const int lb0 = tid % NDIMB;
174  const int lb1 = tid / NDIMB;
175  #pragma unroll
176  for (int kib=0; kib<KWB; ++kib) {
177    #pragma unroll
178    for (int nib=0; nib<NWB/VWN; ++nib) {
179
180      // Computes the indices based on strided/non-strided access
181      #if STRN == 0
182        int ng = nib + lb0*(NWB/VWN);
183      #elif STRN == 1
184        int ng = lb0 + nib*NDIMB;
185      #endif
186
187      // Computes the indices for the global memory
188      int kg = kib + lb1*KWB;
189      int idn = ng + get_group_id(1)*(NWG/VWN);
190      int idk = kg + kwg;
191
192      // Loads the data from global memory (transposed) into the local memory
193      blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn];
194    }
195  }
196}
197#endif
198
199// =================================================================================================
200
201// Caches global off-chip memory directly into per-thread private memory (registers). This function
202// is specific for caching the A input matrix.
203#if SA == 0
204inline void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/VWM],
205                             const int kSizeM, const int idk, const int kwg) {
206  #pragma unroll
207  for (int mi=0; mi<MWI/VWM; ++mi) {
208
209    // Computes the indices based on strided/non-strided access
210    #if STRM == 0
211      int mg = mi + get_local_id(0)*(MWI/VWM);
212    #elif STRM == 1
213      int mg = get_local_id(0) + mi*MDIMC;
214    #endif
215
216    // Computes the indices for the global memory
217    int idm = mg + get_group_id(0)*(MWG/VWM);
218
219    // Loads the data from global memory (not transposed) and stores into registers
220    apm[mi] = agm[idk*(kSizeM/VWM) + idm];
221  }
222}
223#endif
224
225// Same as above, but now for the B input matrix
226#if SB == 0
227inline void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/VWN],
228                             const int kSizeN, const int idk) {
229  #pragma unroll
230  for (int ni=0; ni<NWI/VWN; ++ni) {
231
232    // Computes the indices based on strided/non-strided access
233    #if STRN == 0
234      int ng = ni + get_local_id(1)*(NWI/VWN);
235    #elif STRN == 1
236      int ng = get_local_id(1) + ni*NDIMC;
237    #endif
238
239    // Computes the indices for the global memory
240    int idn = ng + get_group_id(1)*(NWG/VWN);
241
242    // Loads the data from global memory (transposed) and stores into registers
243    bpm[ni] = bgm[idk*(kSizeN/VWN) + idn];
244  }
245}
246#endif
247
248// =================================================================================================
249
250// Caches on-chip local memory into per-thread private memory (registers). This function is specific
251// for caching the A input matrix.
252#if SA == 1
253inline void LocalToPrivateA(__local realM* alm, realM apm[MWI/VWM], const int kg) {
254  #pragma unroll
255  for (int mi=0; mi<MWI/VWM; ++mi) {
256    #if STRM == 0
257      int mg = mi + get_local_id(0)*(MWI/VWM);
258    #elif STRM == 1
259      int mg = get_local_id(0) + mi*MDIMC;
260    #endif
261    apm[mi] = alm[kg*(MWG/VWM) + mg];
262  }
263}
264#endif
265
266// Same as above, but now for the B input matrix
267#if SB == 1
268inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg) {
269  #pragma unroll
270  for (int ni=0; ni<NWI/VWN; ++ni) {
271    #if STRN == 0
272      int ng = ni + get_local_id(1)*(NWI/VWN);
273    #elif STRN == 1
274      int ng = get_local_id(1) + ni*NDIMC;
275    #endif
276    bpm[ni] = blm[kg*(NWG/VWN) + ng];
277  }
278}
279#endif
280
281// =================================================================================================
282
283// Merges the results in Cpm with the global array in Cgm
284inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM) {
285  #pragma unroll
286  for (int ni=0; ni<NWI; ++ni) {
287    #pragma unroll
288    for (int mi=0; mi<MWI/VWM; ++mi) {
289      #if STRM == 0
290        int mg = mi + get_local_id(0)*(MWI/VWM);
291      #elif STRM == 1
292        int mg = get_local_id(0) + mi*MDIMC;
293      #endif
294      #if STRN == 0
295        int ng = ni + get_local_id(1)*NWI;
296      #elif STRN == 1
297        int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
298      #endif
299      int idm = mg + get_group_id(0)*(MWG/VWM);
300      int idn = ng + get_group_id(1)*NWG;
301      int index = idn*(kSizeM/VWM) + idm;
302      cgm[index] = cpm[ni][mi];
303    }
304  }
305}
306
307// =================================================================================================
308
309// The basic scalar multiply-add function
310#if USE_CL_MAD == 1
311  #define MultiplyAdd(cval, aval, bval) (cval = mad(aval, bval, cval))
312#else
313  #define MultiplyAdd(cval, aval, bval) (cval += (aval) * (bval))
314#endif
315
316// The vectorised multiply-add function
317inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) {
318  #if USE_VECTOR_MAD == 1
319    cvec += avec * bval;
320  #else
321    #if VWM == 1
322      MultiplyAdd(cvec,    avec,    bval);
323    #elif VWM == 2
324      MultiplyAdd(cvec.x , avec.x,  bval);
325      MultiplyAdd(cvec.y , avec.y,  bval);
326    #elif VWM == 4
327      MultiplyAdd(cvec.x , avec.x,  bval);
328      MultiplyAdd(cvec.y , avec.y,  bval);
329      MultiplyAdd(cvec.z , avec.z,  bval);
330      MultiplyAdd(cvec.w , avec.w,  bval);
331    #elif VWM == 8
332      MultiplyAdd(cvec.s0, avec.s0, bval);
333      MultiplyAdd(cvec.s1, avec.s1, bval);
334      MultiplyAdd(cvec.s2, avec.s2, bval);
335      MultiplyAdd(cvec.s3, avec.s3, bval);
336      MultiplyAdd(cvec.s4, avec.s4, bval);
337      MultiplyAdd(cvec.s5, avec.s5, bval);
338      MultiplyAdd(cvec.s6, avec.s6, bval);
339      MultiplyAdd(cvec.s7, avec.s7, bval);
340    #elif VWM == 16
341      MultiplyAdd(cvec.s0, avec.s0, bval);
342      MultiplyAdd(cvec.s1, avec.s1, bval);
343      MultiplyAdd(cvec.s2, avec.s2, bval);
344      MultiplyAdd(cvec.s3, avec.s3, bval);
345      MultiplyAdd(cvec.s4, avec.s4, bval);
346      MultiplyAdd(cvec.s5, avec.s5, bval);
347      MultiplyAdd(cvec.s6, avec.s6, bval);
348      MultiplyAdd(cvec.s7, avec.s7, bval);
349      MultiplyAdd(cvec.s8, avec.s8, bval);
350      MultiplyAdd(cvec.s9, avec.s9, bval);
351      MultiplyAdd(cvec.sA, avec.sA, bval);
352      MultiplyAdd(cvec.sB, avec.sB, bval);
353      MultiplyAdd(cvec.sC, avec.sC, bval);
354      MultiplyAdd(cvec.sD, avec.sD, bval);
355      MultiplyAdd(cvec.sE, avec.sE, bval);
356      MultiplyAdd(cvec.sF, avec.sF, bval);
357    #endif
358  #endif
359  return cvec;
360}
361
362// Performs the actual computation: Cpm += Apm * Bpm
363inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
364  #pragma unroll
365  for (int ni=0; ni<NWI/VWN; ++ni) {
366    #pragma unroll
367    for (int mi=0; mi<MWI/VWM; ++mi) {
368      #if VWN == 1
369        cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni]);
370      #elif VWN == 2
371        cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x);
372        cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y);
373      #elif VWN == 4
374        cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x);
375        cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y);
376        cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].z);
377        cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].w);
378      #elif VWN == 8
379        cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].s0);
380        cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].s1);
381        cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].s2);
382        cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].s3);
383        cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], apm[mi], bpm[ni].s4);
384        cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], apm[mi], bpm[ni].s5);
385        cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], apm[mi], bpm[ni].s6);
386        cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], apm[mi], bpm[ni].s7);
387      #elif VWN == 16
388        cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], apm[mi], bpm[ni].s0);
389        cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], apm[mi], bpm[ni].s1);
390        cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], apm[mi], bpm[ni].s2);
391        cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], apm[mi], bpm[ni].s3);
392        cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], apm[mi], bpm[ni].s4);
393        cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], apm[mi], bpm[ni].s5);
394        cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], apm[mi], bpm[ni].s6);
395        cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], apm[mi], bpm[ni].s7);
396        cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], apm[mi], bpm[ni].s8);
397        cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], apm[mi], bpm[ni].s9);
398        cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], apm[mi], bpm[ni].sA);
399        cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], apm[mi], bpm[ni].sB);
400        cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], apm[mi], bpm[ni].sC);
401        cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], apm[mi], bpm[ni].sD);
402        cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], apm[mi], bpm[ni].sE);
403        cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], apm[mi], bpm[ni].sF);
404      #endif
405    }
406  }
407}
408
409// =================================================================================================
410
411// Main entry of the kernel. This function contains the basic skeleton, the functionality is
412// provided by the inlined functions above.
413__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
414__kernel void gemm_fast(const int kSizeM, const int kSizeN, const int kSizeK,
415                        const __global realM* restrict agm,
416                        const __global realN* restrict bgm,
417                        __global realM* cgm) {
418
419  // Combined thread identifier
420  #if SA == 1 || SB == 1
421    volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
422  #endif
423
424  // Allocates workgroup-private memory (local memory)
425  #if SA == 1
426    __local realM alm[KWG * MWG/VWM];
427  #endif
428  #if SB == 1
429    __local realN blm[KWG * NWG/VWN];
430  #endif
431
432  // Allocates workitem-private memory (registers)
433  realM apm[MWI/VWM];
434  realN bpm[NWI/VWN];
435  realM cpm[NWI][MWI/VWM];
436
437  // Initializes the accumulation registers
438  #pragma unroll
439  for (int mi=0; mi<MWI/VWM; ++mi) {
440    #pragma unroll
441    for (int ni=0; ni<NWI; ++ni) {
442      cpm[ni][mi] = (realM)ZERO;
443    }
444  }
445
446  // Loops over all workgroup tiles
447  for (int kwg=0; kwg<kSizeK; kwg+=KWG) {
448
449    // Loads data: off-chip --> local (matrix A)
450    #if SA == 1
451      GlobalToLocalA(agm, alm, kSizeM, tid, kwg);
452    #endif
453    // Loads data: off-chip --> local (matrix B)
454    #if SB == 1
455      GlobalToLocalB(bgm, blm, kSizeN, tid, kwg);
456    #endif
457
458    // Synchronizes all threads in a workgroup
459    #if SA == 1 || SB == 1
460      barrier(CLK_LOCAL_MEM_FENCE);
461    #endif
462
463    // Loops over all workitem tiles, unrolled by a factor KWI
464    for (int pwi=0; pwi<KWG; pwi+=KWI) {
465      #pragma unroll
466      for (int pit=0; pit<KWI; ++pit) {
467        #if SA == 0 || SB == 0
468          int idk = kwg + pwi + pit;
469        #endif
470        #if SA == 1 || SB == 1
471          int kg = pwi+pit;
472        #endif
473
474        // Loads data: local --> private (matrix A)
475        #if SA == 1
476          LocalToPrivateA(alm, apm, kg);
477        // Loads data: off-chip --> private (matrix A)
478        #else
479          GlobalToPrivateA(agm, apm, kSizeM, idk, kwg);
480        #endif
481
482        // Loads data: local --> private (matrix B)
483        #if SB == 1
484          LocalToPrivateB(blm, bpm, kg);
485        // Loads data: off-chip --> private (matrix B)
486        #else
487          GlobalToPrivateB(bgm, bpm, kSizeN, idk);
488        #endif
489
490        // Performs the accumulation (Cpm += Apm * Bpm)
491        MultiplyAccumulate(cpm, apm, bpm);
492      }
493    }
494
495    // Synchronizes all threads in a workgroup
496    #if SA == 1 || SB == 1
497      barrier(CLK_LOCAL_MEM_FENCE);
498    #endif
499  }
500
501  // Stores an MWG * NWG tile of results
502  StoreResults(cgm, cpm, kSizeM);
503}
504
505// =================================================================================================
506