1 2// ================================================================================================= 3// This file is part of the CLTune project, which loosely follows the Google C++ styleguide and uses 4// a tab-size of two spaces and a max-width of 100 characters per line. 5// 6// Author: cedric.nugteren@surfsara.nl (Cedric Nugteren) 7// 8// This file contains an example OpenCL kernel as part of the gemm.cc example. It is an optimized 9// matrix-multiplication kernel according to the paper by Matsumoto et al. and the tutorial on 10// http://www.cedricnugteren.nl/tutorial.php. It is fully configurable (and tunable!) using more or 11// less the same parameters/naming conventions as in the paper. It supports single and double 12// precision (SGEMM/DGEMM) through a pre-processor define. 13// 14// Note: this kernel requires a compiler compliant to OpenCL 1.1 or higher. 15// 16// ------------------------------------------------------------------------------------------------- 17// 18// Copyright 2014 SURFsara 19// 20// Licensed under the Apache License, Version 2.0 (the "License"); 21// you may not use this file except in compliance with the License. 22// You may obtain a copy of the License at 23// 24// http://www.apache.org/licenses/LICENSE-2.0 25// 26// Unless required by applicable law or agreed to in writing, software 27// distributed under the License is distributed on an "AS IS" BASIS, 28// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 29// See the License for the specific language governing permissions and 30// limitations under the License. 31// 32// ================================================================================================= 33// 34// Matrices are accessed as follows: 35// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m) 36// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n) 37// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m) 38// 39// Or as an image (assuming column-major) 40// K 41// o-------o 42// | | 43// N | [B^T] | 44// | | 45// o-------o 46// K N 47// o-------o o-----o 48// M | [A] | M | [C] | 49// | | | | 50// o-------o o-----o 51// 52// 53// Parameters determined by the tuner 54// MWG : Tile-size in dimension M (e.g. 64, 128) 55// NWG : Tile-size in dimension N (e.g. 64, 128) 56// KWG : Tile-size in dimension K (e.g. 8, 16) 57// MDIMC : Threads per workgroup in M-dimension (e.g. 8, 16, 32) 58// NDIMC : Threads per workgroup in N-dimension (e.g. 8, 16, 32) 59// MDIMA : Re-shaped tile dimension of matrix A: KDIMA * MDIMA 60// NDIMB : Re-shaped tile dimension of matrix B: KDIMB * NDIMB 61// KWI : Unroll factor of the KWG loop (smaller or equal than KWG) 62// VWM : Vector width of matrices A and C (supported 1, 2, 4, and 8) 63// VWN : Vector width of matrix B (supported 1, 2, 4, and 8) 64// STRM : Use strided access within a thread in the M-dimension (1) or not (0) 65// STRN : Use strided access within a thread in the N-dimension (1) or not (0) 66// SA : Use local/shared memory to cache matrix A (1) or not (0) 67// SB : Use local/shared memory to cache matrix B (1) or not (0) 68// PRECISION : Whether to use single (32) or double (64) precision data-types 69// 70// ================================================================================================= 71 72// Helper parameters based on the above tuning parameters 73#define MWI (MWG/MDIMC) // Work per work-item (M-dimension) 74#define NWI (NWG/NDIMC) // Work per work-item (N-dimension) 75#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA 76#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB 77#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension) 78#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension) 79#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension) 80#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension) 81 82// Settings 83#define USE_VECTOR_MAD 1 // Don't unroll the vector MAD computation 84#define USE_CL_MAD 0 // Uses the non-IEEE754 compliant OpenCL mad() (if above is 0) 85 86// ================================================================================================= 87 88// Data-type: single or double precision 89#if PRECISION == 32 90 typedef float real; 91 typedef float2 real2; 92 typedef float4 real4; 93 typedef float8 real8; 94 typedef float16 real16; 95 #define ZERO 0.0f 96#elif PRECISION == 64 97 #if __OPENCL_VERSION__ <= CL_VERSION_1_1 // This the default on OpenCL 1.2 or higher 98 #pragma OPENCL EXTENSION cl_khr_fp64: enable 99 #endif 100 typedef double real; 101 typedef double2 real2; 102 typedef double4 real4; 103 typedef double8 real8; 104 typedef double16 real16; 105 #define ZERO 0.0 106#endif 107 108// ================================================================================================= 109 110// Data-widths in dimension M 111#if VWM == 1 112 typedef real realM; 113#elif VWM == 2 114 typedef real2 realM; 115#elif VWM == 4 116 typedef real4 realM; 117#elif VWM == 8 118 typedef real8 realM; 119#elif VWM == 16 120 typedef real16 realM; 121#endif 122 123// Data-widths in dimension N 124#if VWN == 1 125 typedef real realN; 126#elif VWN == 2 127 typedef real2 realN; 128#elif VWN == 4 129 typedef real4 realN; 130#elif VWN == 8 131 typedef real8 realN; 132#elif VWN == 16 133 typedef real16 realN; 134#endif 135 136// ================================================================================================= 137 138// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for 139// caching the A input matrix. 140#if SA == 1 141inline void GlobalToLocalA(const __global realM* restrict agm, __local realM* alm, 142 const int kSizeM, const int tid, const int kwg) { 143 const int la0 = tid % MDIMA; 144 const int la1 = tid / MDIMA; 145 #pragma unroll 146 for (int mia=0; mia<MWA/VWM; ++mia) { 147 #pragma unroll 148 for (int kia=0; kia<KWA; ++kia) { 149 150 // Computes the indices based on strided/non-strided access 151 #if STRM == 0 152 int mg = mia + la0*(MWA/VWM); 153 #elif STRM == 1 154 int mg = la0 + mia*MDIMA; 155 #endif 156 157 // Computes the indices for the global memory 158 int kg = kia + la1*KWA; 159 int idm = mg + get_group_id(0)*(MWG/VWM); 160 int idk = kg + kwg; 161 162 // Loads the data from global memory (not transposed) into the local memory 163 alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm]; 164 } 165 } 166} 167#endif 168 169// Same as above, but now for the B input matrix 170#if SB == 1 171inline void GlobalToLocalB(const __global realN* restrict bgm, __local realN* blm, 172 const int kSizeN, const int tid, const int kwg) { 173 const int lb0 = tid % NDIMB; 174 const int lb1 = tid / NDIMB; 175 #pragma unroll 176 for (int kib=0; kib<KWB; ++kib) { 177 #pragma unroll 178 for (int nib=0; nib<NWB/VWN; ++nib) { 179 180 // Computes the indices based on strided/non-strided access 181 #if STRN == 0 182 int ng = nib + lb0*(NWB/VWN); 183 #elif STRN == 1 184 int ng = lb0 + nib*NDIMB; 185 #endif 186 187 // Computes the indices for the global memory 188 int kg = kib + lb1*KWB; 189 int idn = ng + get_group_id(1)*(NWG/VWN); 190 int idk = kg + kwg; 191 192 // Loads the data from global memory (transposed) into the local memory 193 blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn]; 194 } 195 } 196} 197#endif 198 199// ================================================================================================= 200 201// Caches global off-chip memory directly into per-thread private memory (registers). This function 202// is specific for caching the A input matrix. 203#if SA == 0 204inline void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/VWM], 205 const int kSizeM, const int idk, const int kwg) { 206 #pragma unroll 207 for (int mi=0; mi<MWI/VWM; ++mi) { 208 209 // Computes the indices based on strided/non-strided access 210 #if STRM == 0 211 int mg = mi + get_local_id(0)*(MWI/VWM); 212 #elif STRM == 1 213 int mg = get_local_id(0) + mi*MDIMC; 214 #endif 215 216 // Computes the indices for the global memory 217 int idm = mg + get_group_id(0)*(MWG/VWM); 218 219 // Loads the data from global memory (not transposed) and stores into registers 220 apm[mi] = agm[idk*(kSizeM/VWM) + idm]; 221 } 222} 223#endif 224 225// Same as above, but now for the B input matrix 226#if SB == 0 227inline void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/VWN], 228 const int kSizeN, const int idk) { 229 #pragma unroll 230 for (int ni=0; ni<NWI/VWN; ++ni) { 231 232 // Computes the indices based on strided/non-strided access 233 #if STRN == 0 234 int ng = ni + get_local_id(1)*(NWI/VWN); 235 #elif STRN == 1 236 int ng = get_local_id(1) + ni*NDIMC; 237 #endif 238 239 // Computes the indices for the global memory 240 int idn = ng + get_group_id(1)*(NWG/VWN); 241 242 // Loads the data from global memory (transposed) and stores into registers 243 bpm[ni] = bgm[idk*(kSizeN/VWN) + idn]; 244 } 245} 246#endif 247 248// ================================================================================================= 249 250// Caches on-chip local memory into per-thread private memory (registers). This function is specific 251// for caching the A input matrix. 252#if SA == 1 253inline void LocalToPrivateA(__local realM* alm, realM apm[MWI/VWM], const int kg) { 254 #pragma unroll 255 for (int mi=0; mi<MWI/VWM; ++mi) { 256 #if STRM == 0 257 int mg = mi + get_local_id(0)*(MWI/VWM); 258 #elif STRM == 1 259 int mg = get_local_id(0) + mi*MDIMC; 260 #endif 261 apm[mi] = alm[kg*(MWG/VWM) + mg]; 262 } 263} 264#endif 265 266// Same as above, but now for the B input matrix 267#if SB == 1 268inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg) { 269 #pragma unroll 270 for (int ni=0; ni<NWI/VWN; ++ni) { 271 #if STRN == 0 272 int ng = ni + get_local_id(1)*(NWI/VWN); 273 #elif STRN == 1 274 int ng = get_local_id(1) + ni*NDIMC; 275 #endif 276 bpm[ni] = blm[kg*(NWG/VWN) + ng]; 277 } 278} 279#endif 280 281// ================================================================================================= 282 283// Merges the results in Cpm with the global array in Cgm 284inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM) { 285 #pragma unroll 286 for (int ni=0; ni<NWI; ++ni) { 287 #pragma unroll 288 for (int mi=0; mi<MWI/VWM; ++mi) { 289 #if STRM == 0 290 int mg = mi + get_local_id(0)*(MWI/VWM); 291 #elif STRM == 1 292 int mg = get_local_id(0) + mi*MDIMC; 293 #endif 294 #if STRN == 0 295 int ng = ni + get_local_id(1)*NWI; 296 #elif STRN == 1 297 int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC; 298 #endif 299 int idm = mg + get_group_id(0)*(MWG/VWM); 300 int idn = ng + get_group_id(1)*NWG; 301 int index = idn*(kSizeM/VWM) + idm; 302 cgm[index] = cpm[ni][mi]; 303 } 304 } 305} 306 307// ================================================================================================= 308 309// The basic scalar multiply-add function 310#if USE_CL_MAD == 1 311 #define MultiplyAdd(cval, aval, bval) (cval = mad(aval, bval, cval)) 312#else 313 #define MultiplyAdd(cval, aval, bval) (cval += (aval) * (bval)) 314#endif 315 316// The vectorised multiply-add function 317inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) { 318 #if USE_VECTOR_MAD == 1 319 cvec += avec * bval; 320 #else 321 #if VWM == 1 322 MultiplyAdd(cvec, avec, bval); 323 #elif VWM == 2 324 MultiplyAdd(cvec.x , avec.x, bval); 325 MultiplyAdd(cvec.y , avec.y, bval); 326 #elif VWM == 4 327 MultiplyAdd(cvec.x , avec.x, bval); 328 MultiplyAdd(cvec.y , avec.y, bval); 329 MultiplyAdd(cvec.z , avec.z, bval); 330 MultiplyAdd(cvec.w , avec.w, bval); 331 #elif VWM == 8 332 MultiplyAdd(cvec.s0, avec.s0, bval); 333 MultiplyAdd(cvec.s1, avec.s1, bval); 334 MultiplyAdd(cvec.s2, avec.s2, bval); 335 MultiplyAdd(cvec.s3, avec.s3, bval); 336 MultiplyAdd(cvec.s4, avec.s4, bval); 337 MultiplyAdd(cvec.s5, avec.s5, bval); 338 MultiplyAdd(cvec.s6, avec.s6, bval); 339 MultiplyAdd(cvec.s7, avec.s7, bval); 340 #elif VWM == 16 341 MultiplyAdd(cvec.s0, avec.s0, bval); 342 MultiplyAdd(cvec.s1, avec.s1, bval); 343 MultiplyAdd(cvec.s2, avec.s2, bval); 344 MultiplyAdd(cvec.s3, avec.s3, bval); 345 MultiplyAdd(cvec.s4, avec.s4, bval); 346 MultiplyAdd(cvec.s5, avec.s5, bval); 347 MultiplyAdd(cvec.s6, avec.s6, bval); 348 MultiplyAdd(cvec.s7, avec.s7, bval); 349 MultiplyAdd(cvec.s8, avec.s8, bval); 350 MultiplyAdd(cvec.s9, avec.s9, bval); 351 MultiplyAdd(cvec.sA, avec.sA, bval); 352 MultiplyAdd(cvec.sB, avec.sB, bval); 353 MultiplyAdd(cvec.sC, avec.sC, bval); 354 MultiplyAdd(cvec.sD, avec.sD, bval); 355 MultiplyAdd(cvec.sE, avec.sE, bval); 356 MultiplyAdd(cvec.sF, avec.sF, bval); 357 #endif 358 #endif 359 return cvec; 360} 361 362// Performs the actual computation: Cpm += Apm * Bpm 363inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) { 364 #pragma unroll 365 for (int ni=0; ni<NWI/VWN; ++ni) { 366 #pragma unroll 367 for (int mi=0; mi<MWI/VWM; ++mi) { 368 #if VWN == 1 369 cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni]); 370 #elif VWN == 2 371 cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x); 372 cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y); 373 #elif VWN == 4 374 cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x); 375 cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y); 376 cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].z); 377 cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].w); 378 #elif VWN == 8 379 cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].s0); 380 cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].s1); 381 cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].s2); 382 cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].s3); 383 cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], apm[mi], bpm[ni].s4); 384 cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], apm[mi], bpm[ni].s5); 385 cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], apm[mi], bpm[ni].s6); 386 cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], apm[mi], bpm[ni].s7); 387 #elif VWN == 16 388 cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], apm[mi], bpm[ni].s0); 389 cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], apm[mi], bpm[ni].s1); 390 cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], apm[mi], bpm[ni].s2); 391 cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], apm[mi], bpm[ni].s3); 392 cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], apm[mi], bpm[ni].s4); 393 cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], apm[mi], bpm[ni].s5); 394 cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], apm[mi], bpm[ni].s6); 395 cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], apm[mi], bpm[ni].s7); 396 cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], apm[mi], bpm[ni].s8); 397 cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], apm[mi], bpm[ni].s9); 398 cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], apm[mi], bpm[ni].sA); 399 cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], apm[mi], bpm[ni].sB); 400 cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], apm[mi], bpm[ni].sC); 401 cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], apm[mi], bpm[ni].sD); 402 cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], apm[mi], bpm[ni].sE); 403 cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], apm[mi], bpm[ni].sF); 404 #endif 405 } 406 } 407} 408 409// ================================================================================================= 410 411// Main entry of the kernel. This function contains the basic skeleton, the functionality is 412// provided by the inlined functions above. 413__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) 414__kernel void gemm_fast(const int kSizeM, const int kSizeN, const int kSizeK, 415 const __global realM* restrict agm, 416 const __global realN* restrict bgm, 417 __global realM* cgm) { 418 419 // Combined thread identifier 420 #if SA == 1 || SB == 1 421 volatile int tid = get_local_id(0) + MDIMC*get_local_id(1); 422 #endif 423 424 // Allocates workgroup-private memory (local memory) 425 #if SA == 1 426 __local realM alm[KWG * MWG/VWM]; 427 #endif 428 #if SB == 1 429 __local realN blm[KWG * NWG/VWN]; 430 #endif 431 432 // Allocates workitem-private memory (registers) 433 realM apm[MWI/VWM]; 434 realN bpm[NWI/VWN]; 435 realM cpm[NWI][MWI/VWM]; 436 437 // Initializes the accumulation registers 438 #pragma unroll 439 for (int mi=0; mi<MWI/VWM; ++mi) { 440 #pragma unroll 441 for (int ni=0; ni<NWI; ++ni) { 442 cpm[ni][mi] = (realM)ZERO; 443 } 444 } 445 446 // Loops over all workgroup tiles 447 for (int kwg=0; kwg<kSizeK; kwg+=KWG) { 448 449 // Loads data: off-chip --> local (matrix A) 450 #if SA == 1 451 GlobalToLocalA(agm, alm, kSizeM, tid, kwg); 452 #endif 453 // Loads data: off-chip --> local (matrix B) 454 #if SB == 1 455 GlobalToLocalB(bgm, blm, kSizeN, tid, kwg); 456 #endif 457 458 // Synchronizes all threads in a workgroup 459 #if SA == 1 || SB == 1 460 barrier(CLK_LOCAL_MEM_FENCE); 461 #endif 462 463 // Loops over all workitem tiles, unrolled by a factor KWI 464 for (int pwi=0; pwi<KWG; pwi+=KWI) { 465 #pragma unroll 466 for (int pit=0; pit<KWI; ++pit) { 467 #if SA == 0 || SB == 0 468 int idk = kwg + pwi + pit; 469 #endif 470 #if SA == 1 || SB == 1 471 int kg = pwi+pit; 472 #endif 473 474 // Loads data: local --> private (matrix A) 475 #if SA == 1 476 LocalToPrivateA(alm, apm, kg); 477 // Loads data: off-chip --> private (matrix A) 478 #else 479 GlobalToPrivateA(agm, apm, kSizeM, idk, kwg); 480 #endif 481 482 // Loads data: local --> private (matrix B) 483 #if SB == 1 484 LocalToPrivateB(blm, bpm, kg); 485 // Loads data: off-chip --> private (matrix B) 486 #else 487 GlobalToPrivateB(bgm, bpm, kSizeN, idk); 488 #endif 489 490 // Performs the accumulation (Cpm += Apm * Bpm) 491 MultiplyAccumulate(cpm, apm, bpm); 492 } 493 } 494 495 // Synchronizes all threads in a workgroup 496 #if SA == 1 || SB == 1 497 barrier(CLK_LOCAL_MEM_FENCE); 498 #endif 499 } 500 501 // Stores an MWG * NWG tile of results 502 StoreResults(cgm, cpm, kSizeM); 503} 504 505// ================================================================================================= 506