1 2// ================================================================================================= 3// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This 4// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- 5// width of 100 characters per line. 6// 7// Author(s): 8// Cedric Nugteren <www.cedricnugteren.nl> 9// 10// This file contains an optimized matrix-multiplication kernel inspired by the paper by Matsumoto 11// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable 12// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It 13// supports different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define. 14// 15// Matrices are accessed as follows: 16// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m) 17// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n) 18// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m) 19// 20// Or as an image (assuming column-major) 21// K 22// o-------o 23// | | 24// N | [B^T] | 25// | | 26// o-------o 27// K N 28// o-------o o-----o 29// M | [A] | M | [C] | 30// | | | | 31// o-------o o-----o 32// 33// 34// This kernel is separated into three files. This is part 1 out of 4. 35// 36// ================================================================================================= 37 38// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string 39// literal). Comment-out this line for syntax-highlighting when developing. 40R"( 41 42// ================================================================================================= 43 44// Parameters set by the tuner or by the database. Here they are given a basic default value in case 45// this kernel file is used outside of the CLBlast library. 46#ifndef MWG 47 #define MWG 8 // Tile-size in dimension M (e.g. 64, 128) 48#endif 49#ifndef NWG 50 #define NWG 8 // Tile-size in dimension N (e.g. 64, 128) 51#endif 52#ifndef KWG 53 #define KWG 8 // Tile-size in dimension K (e.g. 8, 16) 54#endif 55#ifndef MDIMC 56 #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32) 57#endif 58#ifndef NDIMC 59 #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32) 60#endif 61#ifndef MDIMA 62 #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA 63#endif 64#ifndef NDIMB 65 #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB 66#endif 67#ifndef KWI 68 #define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG) 69#endif 70#ifndef VWM 71 #define VWM 1 // Vector width of matrices A and C 72#endif 73#ifndef VWN 74 #define VWN 1 // Vector width of matrix B 75#endif 76#ifndef STRM 77 #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) 78#endif 79#ifndef STRN 80 #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) 81#endif 82#ifndef SA 83 #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) 84#endif 85#ifndef SB 86 #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) 87#endif 88 89// Helper parameters based on the above tuning parameters 90#define MWI (MWG/MDIMC) // Work per work-item (M-dimension) 91#define NWI (NWG/NDIMC) // Work per work-item (N-dimension) 92#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA 93#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB 94#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension) 95#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension) 96#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension) 97#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension) 98 99// Settings 100#ifndef USE_VECTOR_MAD 101 #define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually 102#endif 103#ifndef GLOBAL_MEM_FENCE 104 #define GLOBAL_MEM_FENCE 0 // Global synchronisation barrier for potential better performance 105#endif 106 107// ================================================================================================= 108 109// Data-widths in dimension M 110#ifdef FP16_STORAGE 111 #if VWM == 1 112 typedef real realM; 113 typedef short memM; 114 #elif VWM == 2 115 typedef real2 realM; 116 typedef short2 memM; 117 #elif VWM == 4 118 typedef real4 realM; 119 typedef short4 memM; 120 #elif VWM == 8 121 typedef real8 realM; 122 typedef short8 memM; 123 #elif VWM == 16 124 typedef real16 realM; 125 typedef short16 memM; 126 #endif 127#else 128 #if VWM == 1 129 typedef real realM; 130 typedef real memM; 131 #elif VWM == 2 132 typedef real2 realM; 133 typedef real2 memM; 134 #elif VWM == 4 135 typedef real4 realM; 136 typedef real4 memM; 137 #elif VWM == 8 138 typedef real8 realM; 139 typedef real8 memM; 140 #elif VWM == 16 141 typedef real16 realM; 142 typedef real16 memM; 143 #endif 144#endif 145 146// Data-widths in dimension N 147#ifdef FP16_STORAGE 148 #if VWN == 1 149 typedef real realN; 150 typedef short memN; 151 #elif VWN == 2 152 typedef real2 realN; 153 typedef short2 memN; 154 #elif VWN == 4 155 typedef real4 realN; 156 typedef short4 memN; 157 #elif VWN == 8 158 typedef real8 realN; 159 typedef short8 memN; 160 #elif VWN == 16 161 typedef real16 realN; 162 typedef short16 memN; 163 #endif 164#else 165 #if VWN == 1 166 typedef real realN; 167 typedef real memN; 168 #elif VWN == 2 169 typedef real2 realN; 170 typedef real2 memN; 171 #elif VWN == 4 172 typedef real4 realN; 173 typedef real4 memN; 174 #elif VWN == 8 175 typedef real8 realN; 176 typedef real8 memN; 177 #elif VWN == 16 178 typedef real16 realN; 179 typedef real16 memN; 180 #endif 181#endif 182 183// ================================================================================================= 184 185// Initializes the accumulation registers to zero 186INLINE_FUNC realM InitAccRegisters() { 187 realM result; 188 #if VWM == 1 189 SetToZero(result); 190 #elif VWM == 2 191 SetToZero(result.x); 192 SetToZero(result.y); 193 #elif VWM == 4 194 SetToZero(result.x); 195 SetToZero(result.y); 196 SetToZero(result.z); 197 SetToZero(result.w); 198 #elif VWM == 8 199 SetToZero(result.s0); 200 SetToZero(result.s1); 201 SetToZero(result.s2); 202 SetToZero(result.s3); 203 SetToZero(result.s4); 204 SetToZero(result.s5); 205 SetToZero(result.s6); 206 SetToZero(result.s7); 207 #elif VWM == 16 208 SetToZero(result.s0); 209 SetToZero(result.s1); 210 SetToZero(result.s2); 211 SetToZero(result.s3); 212 SetToZero(result.s4); 213 SetToZero(result.s5); 214 SetToZero(result.s6); 215 SetToZero(result.s7); 216 SetToZero(result.s8); 217 SetToZero(result.s9); 218 SetToZero(result.sA); 219 SetToZero(result.sB); 220 SetToZero(result.sC); 221 SetToZero(result.sD); 222 SetToZero(result.sE); 223 SetToZero(result.sF); 224 #endif 225 return result; 226} 227 228// ================================================================================================= 229 230// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for 231// caching the A input matrix. 232#if SA == 1 233INLINE_FUNC void GlobalToLocalA(const __global memM* restrict agm, LOCAL_PTR memM* alm, 234 const int kSizeM, const int tid, const int kwg) { 235 const int la0 = tid % MDIMA; 236 const int la1 = tid / MDIMA; 237 #pragma unroll 238 for (int _mia = 0; _mia < MWA/VWM; _mia += 1) { 239 #pragma unroll 240 for (int _kia = 0; _kia < KWA; _kia += 1) { 241 242 // Computes the indices based on strided/non-strided access 243 #if STRM == 0 244 int mg = _mia + la0*(MWA/VWM); 245 #elif STRM == 1 246 int mg = la0 + _mia*MDIMA; 247 #endif 248 249 // Computes the indices for the global memory 250 int kg = _kia + la1*KWA; 251 int idm = mg + GetGroupID0() * (MWG/VWM); 252 int idk = kg + kwg; 253 254 // Loads the data from global memory (not transposed) into the local memory 255 alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm]; 256 } 257 } 258} 259#endif 260 261// Same as above, but now for the B input matrix 262#if SB == 1 263INLINE_FUNC void GlobalToLocalB(const __global memN* restrict bgm, LOCAL_PTR memN* blm, 264 const int kSizeN, const int tid, const int kwg) { 265 const int lb0 = tid % NDIMB; 266 const int lb1 = tid / NDIMB; 267 #pragma unroll 268 for (int _kib = 0; _kib < KWB; _kib += 1) { 269 #pragma unroll 270 for (int _nib = 0; _nib < NWB/VWN; _nib += 1) { 271 272 // Computes the indices based on strided/non-strided access 273 #if STRN == 0 274 int ng = _nib + lb0*(NWB/VWN); 275 #elif STRN == 1 276 int ng = lb0 + _nib*NDIMB; 277 #endif 278 279 // Computes the indices for the global memory 280 int kg = _kib + lb1*KWB; 281 int idn = ng + GetGroupID1() * (NWG/VWN); 282 int idk = kg + kwg; 283 284 // Loads the data from global memory (transposed) into the local memory 285 blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn]; 286 } 287 } 288} 289#endif 290 291// ================================================================================================= 292 293// Caches global off-chip memory directly into per-thread private memory (registers). This function 294// is specific for caching the A input matrix. 295#if SA == 0 296INLINE_FUNC realM GlobalToPrivateA(const __global memM* restrict agm, const int _mi, 297 const int kSizeM, const int idk, const int kwg) { 298 // Computes the indices based on strided/non-strided access 299 #if STRM == 0 300 int mg = _mi + get_local_id(0)*(MWI/VWM); 301 #elif STRM == 1 302 int mg = get_local_id(0) + _mi*MDIMC; 303 #endif 304 305 // Computes the indices for the global memory 306 int idm = mg + GetGroupID0() * (MWG/VWM); 307 308 // Loads the data from global memory (not transposed) and stores into registers 309#ifdef FP16_STORAGE 310 #if VWM == 1 311 return vloada_half(idk*(kSizeM/VWM) + idm, (const __global half*)agm); 312 #elif VWM == 2 313 return vloada_half2(idk*(kSizeM/VWM) + idm, (const __global half*)agm); 314 #elif VWM == 4 315 return vloada_half4(idk*(kSizeM/VWM) + idm, (const __global half*)agm); 316 #elif VWM == 8 317 return vloada_half8(idk*(kSizeM/VWM) + idm, (const __global half*)agm); 318 #elif VWM == 16 319 return vloada_half16(idk*(kSizeM/VWM) + idm, (const __global half*)agm); 320 #endif 321#else 322 return agm[idk*(kSizeM/VWM) + idm]; 323#endif 324} 325#endif 326 327// Same as above, but now for the B input matrix 328#if SB == 0 329INLINE_FUNC realN GlobalToPrivateB(const __global memN* restrict bgm, const int _ni, 330 const int kSizeN, const int idk) { 331 // Computes the indices based on strided/non-strided access 332 #if STRN == 0 333 int ng = _ni + get_local_id(1)*(NWI/VWN); 334 #elif STRN == 1 335 int ng = get_local_id(1) + _ni*NDIMC; 336 #endif 337 338 // Computes the indices for the global memory 339 int idn = ng + GetGroupID1() * (NWG/VWN); 340 341 // Loads the data from global memory (transposed) and stores into registers 342#ifdef FP16_STORAGE 343 #if VWN == 1 344 return vloada_half(idk*(kSizeN/VWN) + idn, (const __global half*)bgm); 345 #elif VWN == 2 346 return vloada_half2(idk*(kSizeN/VWN) + idn, (const __global half*)bgm); 347 #elif VWN == 4 348 return vloada_half4(idk*(kSizeN/VWN) + idn, (const __global half*)bgm); 349 #elif VWN == 8 350 return vloada_half8(idk*(kSizeN/VWN) + idn, (const __global half*)bgm); 351 #elif VWN == 16 352 return vloada_half16(idk*(kSizeN/VWN) + idn, (const __global half*)bgm); 353 #endif 354#else 355 return bgm[idk*(kSizeN/VWN) + idn]; 356#endif 357} 358#endif 359 360// ================================================================================================= 361 362// Caches on-chip local memory into per-thread private memory (registers). This function is specific 363// for caching the A input matrix. 364#if SA == 1 365INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR memM* alm, const int _mi, const int kg) { 366 #if STRM == 0 367 int mg = _mi + get_local_id(0)*(MWI/VWM); 368 #elif STRM == 1 369 int mg = get_local_id(0) + _mi*MDIMC; 370 #endif 371#ifdef FP16_STORAGE 372 #if VWM == 1 373 return vloada_half(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm); 374 #elif VWM == 2 375 return vloada_half2(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm); 376 #elif VWM == 4 377 return vloada_half4(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm); 378 #elif VWM == 8 379 return vloada_half8(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm); 380 #elif VWM == 16 381 return vloada_half16(kg*(MWG/VWM) + mg, (LOCAL_PTR half*)alm); 382 #endif 383#else 384 return alm[kg*(MWG/VWM) + mg]; 385#endif 386} 387#endif 388 389// Same as above, but now for the B input matrix 390#if SB == 1 391INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR memN* blm, const int _ni, const int kg) { 392 #if STRN == 0 393 int ng = _ni + get_local_id(1)*(NWI/VWN); 394 #elif STRN == 1 395 int ng = get_local_id(1) + _ni*NDIMC; 396 #endif 397 398#ifdef FP16_STORAGE 399 #if VWN == 1 400 return vloada_half(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm); 401 #elif VWN == 2 402 return vloada_half2(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm); 403 #elif VWN == 4 404 return vloada_half4(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm); 405 #elif VWN == 8 406 return vloada_half8(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm); 407 #elif VWN == 16 408 return vloada_half16(kg*(NWG/VWN) + ng, (LOCAL_PTR half*)blm); 409 #endif 410#else 411 return blm[kg*(NWG/VWN) + ng]; 412#endif 413} 414#endif 415 416// ================================================================================================= 417 418// End of the C++11 raw string literal 419)" 420 421// ================================================================================================= 422