1
2// =================================================================================================
3// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5// width of 100 characters per line.
6//
7// Author(s):
8//   Cedric Nugteren <www.cedricnugteren.nl>
9//
10// This file contains an optimized matrix-multiplication kernel inspired by the paper by Matsumoto
11// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable
12// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It
13// supports different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define.
14//
15// Matrices are accessed as follows:
16// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m)
17// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n)
18// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m)
19//
20// Or as an image (assuming column-major)
21//       K
22//    o-------o
23//    |       |
24//  N | [B^T] |
25//    |       |
26//    o-------o
27//        K               N
28//    o-------o        o-----o
29//  M |  [A]  |      M | [C] |
30//    |       |        |     |
31//    o-------o        o-----o
32//
33//
34// This kernel is separated into three files. This is part 1 out of 4.
35//
36// =================================================================================================
37
38// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
39// literal). Comment-out this line for syntax-highlighting when developing.
40R"(
41
42// =================================================================================================
43
44// Parameters set by the tuner or by the database. Here they are given a basic default value in case
45// this kernel file is used outside of the CLBlast library.
46#ifndef MWG
47  #define MWG 8      // Tile-size in dimension M (e.g. 64, 128)
48#endif
49#ifndef NWG
50  #define NWG 8      // Tile-size in dimension N (e.g. 64, 128)
51#endif
52#ifndef KWG
53  #define KWG 8      // Tile-size in dimension K (e.g. 8, 16)
54#endif
55#ifndef MDIMC
56  #define MDIMC 8    // Threads per workgroup in M-dimension (e.g. 8, 16, 32)
57#endif
58#ifndef NDIMC
59  #define NDIMC 8    // Threads per workgroup in N-dimension (e.g. 8, 16, 32)
60#endif
61#ifndef MDIMA
62  #define MDIMA 8    // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
63#endif
64#ifndef NDIMB
65  #define NDIMB 8    // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
66#endif
67#ifndef KWI
68  #define KWI 1      // Unroll factor of the KWG loop (smaller or equal than KWG)
69#endif
70#ifndef VWM
71  #define VWM 1      // Vector width of matrices A and C
72#endif
73#ifndef VWN
74  #define VWN 1      // Vector width of matrix B
75#endif
76#ifndef STRM
77  #define STRM 0     // Use strided access within a thread in the M-dimension (1) or not (0)
78#endif
79#ifndef STRN
80  #define STRN 0     // Use strided access within a thread in the N-dimension (1) or not (0)
81#endif
82#ifndef SA
83  #define SA 0       // Use local/shared memory to cache matrix A (1) or not (0)
84#endif
85#ifndef SB
86  #define SB 0       // Use local/shared memory to cache matrix B (1) or not (0)
87#endif
88
89// Helper parameters based on the above tuning parameters
90#define MWI (MWG/MDIMC)               // Work per work-item (M-dimension)
91#define NWI (NWG/NDIMC)               // Work per work-item (N-dimension)
92#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
93#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
94#define MWA (MWG/MDIMA)               // Amount of loads-per-thread for matrix A (M-dimension)
95#define KWA (KWG/KDIMA)               // Amount of loads-per-thread for matrix A (K-dimension)
96#define KWB (KWG/KDIMB)               // Amount of loads-per-thread for matrix B (K-dimension)
97#define NWB (NWG/NDIMB)               // Amount of loads-per-thread for matrix B (N-dimension)
98
99// Settings
100#ifndef USE_VECTOR_MAD
101  #define USE_VECTOR_MAD 0      // Unroll (0) or don't (1) unroll the vector MAD manually
102#endif
103#ifndef GLOBAL_MEM_FENCE
104  #define GLOBAL_MEM_FENCE 0    // Global synchronisation barrier for potential better performance
105#endif
106
107// =================================================================================================
108
109// Data-widths in dimension M
110#ifdef FP16_STORAGE
111  #if VWM == 1
112      typedef real realM;
113      typedef short memM;
114  #elif VWM == 2
115      typedef real2 realM;
116      typedef short2 memM;
117  #elif VWM == 4
118      typedef real4 realM;
119      typedef short4 memM;
120  #elif VWM == 8
121      typedef real8 realM;
122      typedef short8 memM;
123  #elif VWM == 16
124      typedef real16 realM;
125      typedef short16 memM;
126  #endif
127#else
128  #if VWM == 1
129      typedef real realM;
130      typedef real memM;
131  #elif VWM == 2
132      typedef real2 realM;
133      typedef real2 memM;
134  #elif VWM == 4
135      typedef real4 realM;
136      typedef real4 memM;
137  #elif VWM == 8
138      typedef real8 realM;
139      typedef real8 memM;
140  #elif VWM == 16
141      typedef real16 realM;
142      typedef real16 memM;
143  #endif
144#endif
145
146// Data-widths in dimension N
147#ifdef FP16_STORAGE
148  #if VWN == 1
149      typedef real realN;
150      typedef short memN;
151  #elif VWN == 2
152      typedef real2 realN;
153      typedef short2 memN;
154  #elif VWN == 4
155      typedef real4 realN;
156      typedef short4 memN;
157  #elif VWN == 8
158      typedef real8 realN;
159      typedef short8 memN;
160  #elif VWN == 16
161      typedef real16 realN;
162      typedef short16 memN;
163  #endif
164#else
165  #if VWN == 1
166      typedef real realN;
167      typedef real memN;
168  #elif VWN == 2
169      typedef real2 realN;
170      typedef real2 memN;
171  #elif VWN == 4
172      typedef real4 realN;
173      typedef real4 memN;
174  #elif VWN == 8
175      typedef real8 realN;
176      typedef real8 memN;
177  #elif VWN == 16
178      typedef real16 realN;
179      typedef real16 memN;
180  #endif
181#endif
182
183// =================================================================================================
184
185// Initializes the accumulation registers to zero
186INLINE_FUNC realM InitAccRegisters() {
187  realM result;
188  #if VWM == 1
189    SetToZero(result);
190  #elif VWM == 2
191    SetToZero(result.x);
192    SetToZero(result.y);
193  #elif VWM == 4
194    SetToZero(result.x);
195    SetToZero(result.y);
196    SetToZero(result.z);
197    SetToZero(result.w);
198  #elif VWM == 8
199    SetToZero(result.s0);
200    SetToZero(result.s1);
201    SetToZero(result.s2);
202    SetToZero(result.s3);
203    SetToZero(result.s4);
204    SetToZero(result.s5);
205    SetToZero(result.s6);
206    SetToZero(result.s7);
207  #elif VWM == 16
208    SetToZero(result.s0);
209    SetToZero(result.s1);
210    SetToZero(result.s2);
211    SetToZero(result.s3);
212    SetToZero(result.s4);
213    SetToZero(result.s5);
214    SetToZero(result.s6);
215    SetToZero(result.s7);
216    SetToZero(result.s8);
217    SetToZero(result.s9);
218    SetToZero(result.sA);
219    SetToZero(result.sB);
220    SetToZero(result.sC);
221    SetToZero(result.sD);
222    SetToZero(result.sE);
223    SetToZero(result.sF);
224  #endif
225  return result;
226}
227
228// =================================================================================================
229
230// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
231// caching the A input matrix.
232#if SA == 1
233INLINE_FUNC void GlobalToLocalA(const __global memM* restrict agm, LOCAL_PTR memM* alm,
234                                const int kSizeM, const int tid, const int kwg) {
235  const int la0 = tid % MDIMA;
236  const int la1 = tid / MDIMA;
237  #pragma unroll
238  for (int _mia = 0; _mia < MWA/VWM; _mia += 1) {
239    #pragma unroll
240    for (int _kia = 0; _kia < KWA; _kia += 1) {
241
242      // Computes the indices based on strided/non-strided access
243      #if STRM == 0
244        int mg = _mia + la0*(MWA/VWM);
245      #elif STRM == 1
246        int mg = la0 + _mia*MDIMA;
247      #endif
248
249      // Computes the indices for the global memory
250      int kg = _kia + la1*KWA;
251      int idm = mg + GetGroupID0() * (MWG/VWM);
252      int idk = kg + kwg;
253
254      // Loads the data from global memory (not transposed) into the local memory
255      alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm];
256    }
257  }
258}
259#endif
260
261// Same as above, but now for the B input matrix
262#if SB == 1
263INLINE_FUNC void GlobalToLocalB(const __global memN* restrict bgm, LOCAL_PTR memN* blm,
264                                const int kSizeN, const int tid, const int kwg) {
265  const int lb0 = tid % NDIMB;
266  const int lb1 = tid / NDIMB;
267  #pragma unroll
268  for (int _kib = 0; _kib < KWB; _kib += 1) {
269    #pragma unroll
270    for (int _nib = 0; _nib < NWB/VWN; _nib += 1) {
271
272      // Computes the indices based on strided/non-strided access
273      #if STRN == 0
274        int ng = _nib + lb0*(NWB/VWN);
275      #elif STRN == 1
276        int ng = lb0 + _nib*NDIMB;
277      #endif
278
279      // Computes the indices for the global memory
280      int kg = _kib + lb1*KWB;
281      int idn = ng + GetGroupID1() * (NWG/VWN);
282      int idk = kg + kwg;
283
284      // Loads the data from global memory (transposed) into the local memory
285      blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn];
286    }
287  }
288}
289#endif
290
291// =================================================================================================
292
293// Caches global off-chip memory directly into per-thread private memory (registers). This function
294// is specific for caching the A input matrix.
295#if SA == 0
296INLINE_FUNC realM GlobalToPrivateA(const __global memM* restrict agm, const int _mi,
297                                   const int kSizeM, const int idk, const int kwg) {
298  // Computes the indices based on strided/non-strided access
299  #if STRM == 0
300    int mg = _mi + get_local_id(0)*(MWI/VWM);
301  #elif STRM == 1
302    int mg = get_local_id(0) + _mi*MDIMC;
303  #endif
304
305  // Computes the indices for the global memory
306  int idm = mg + GetGroupID0() * (MWG/VWM);
307
308  // Loads the data from global memory (not transposed) and stores into registers
309#ifdef FP16_STORAGE
310  #if VWM == 1
311    return vloada_half(idk*(kSizeM/VWM) + idm, (const __global half*)agm);
312  #elif VWM == 2
313    return vloada_half2(idk*(kSizeM/VWM) + idm, (const __global half*)agm);
314  #elif VWM == 4
315    return vloada_half4(idk*(kSizeM/VWM) + idm, (const __global half*)agm);
316  #elif VWM == 8
317    return vloada_half8(idk*(kSizeM/VWM) + idm, (const __global half*)agm);
318  #elif VWM == 16
319    return vloada_half16(idk*(kSizeM/VWM) + idm, (const __global half*)agm);
320  #endif
321#else
322  return agm[idk*(kSizeM/VWM) + idm];
323#endif
324}
325#endif
326
327// Same as above, but now for the B input matrix
328#if SB == 0
329INLINE_FUNC realN GlobalToPrivateB(const __global memN* restrict bgm, const int _ni,
330                                   const int kSizeN, const int idk) {
331  // Computes the indices based on strided/non-strided access
332  #if STRN == 0
333    int ng = _ni + get_local_id(1)*(NWI/VWN);
334  #elif STRN == 1
335    int ng = get_local_id(1) + _ni*NDIMC;
336  #endif
337
338  // Computes the indices for the global memory
339  int idn = ng + GetGroupID1() * (NWG/VWN);
340
341  // Loads the data from global memory (transposed) and stores into registers
342#ifdef FP16_STORAGE
343  #if VWN == 1
344    return vloada_half(idk*(kSizeN/VWN) + idn, (const __global half*)bgm);
345  #elif VWN == 2
346    return vloada_half2(idk*(kSizeN/VWN) + idn, (const __global half*)bgm);
347  #elif VWN == 4
348    return vloada_half4(idk*(kSizeN/VWN) + idn, (const __global half*)bgm);
349  #elif VWN == 8
350    return vloada_half8(idk*(kSizeN/VWN) + idn, (const __global half*)bgm);
351  #elif VWN == 16
352    return vloada_half16(idk*(kSizeN/VWN) + idn, (const __global half*)bgm);
353  #endif
354#else
355  return bgm[idk*(kSizeN/VWN) + idn];
356#endif
357}
358#endif
359
360// =================================================================================================
361
362// Caches on-chip local memory into per-thread private memory (registers). This function is specific
363// for caching the A input matrix.
364#if SA == 1
365INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR memM* alm, const int _mi, const int kg) {
366  #if STRM == 0
367    int mg = _mi + get_local_id(0)*(MWI/VWM);
368  #elif STRM == 1
369    int mg = get_local_id(0) + _mi*MDIMC;
370  #endif
371#ifdef FP16_STORAGE
372  #if VWM == 1
373    return vloada_half(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm);
374  #elif VWM == 2
375    return vloada_half2(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm);
376  #elif VWM == 4
377    return vloada_half4(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm);
378  #elif VWM == 8
379    return vloada_half8(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm);
380  #elif VWM == 16
381    return vloada_half16(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm);
382  #endif
383#else
384  return alm[kg*(MWG/VWM) + mg];
385#endif
386}
387#endif
388
389// Same as above, but now for the B input matrix
390#if SB == 1
391INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR memN* blm, const int _ni, const int kg) {
392  #if STRN == 0
393    int ng = _ni + get_local_id(1)*(NWI/VWN);
394  #elif STRN == 1
395    int ng = get_local_id(1) + _ni*NDIMC;
396  #endif
397
398#ifdef FP16_STORAGE
399  #if VWN == 1
400    return vloada_half(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm);
401  #elif VWN == 2
402    return vloada_half2(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm);
403  #elif VWN == 4
404    return vloada_half4(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm);
405  #elif VWN == 8
406    return vloada_half8(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm);
407  #elif VWN == 16
408    return vloada_half16(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm);
409  #endif
410#else
411  return blm[kg*(NWG/VWN) + ng];
412#endif
413}
414#endif
415
416// =================================================================================================
417
418// End of the C++11 raw string literal
419)"
420
421// =================================================================================================
422