1/******************************************************************************* 2 * Notes: 3 * for column major, id(0) is row so C data is coalesced 4 * for row major, id(0) is col 5 ******************************************************************************/ 6 7static const char * zgemm_NT_64_32_8_16x16_2x4__ALPHABETA = " 8 9 10// convert preprocs to ints for comparison 11#define _S_ 1 12#define _D_ 2 13#define _C_ 3 14#define _Z_ 4 15 16/******************************************************************************* 17 * Pre-Processor "Strings" 18 ******************************************************************************/ 19#define COLUMN_MAJOR_STR ColMajor 20#define ROW_MAJOR_STR RowMajor 21 22/******************************************************************************* 23 * Kernel PreProcessor Definitions 24 ******************************************************************************/ 25#define WG_NUM_ROWS 16 26#define WG_NUM_COLS 16 27#define MICRO_TILE_NUM_ROWS 2 28#define MICRO_TILE_NUM_COLS 4 29#define NUM_UNROLL_ITER 8 30#define ORDER ColMajor 31#define TRANSPOSE_A N 32#define TRANSPOSE_B T 33#define DATA_TYPE _Z_ 34 35#define MACRO_TILE_NUM_ROWS 32 36#define MACRO_TILE_NUM_COLS 64 37// each row lengthened by this ammount 38#define LOCAL_ROW_PAD 1 39// each col lengthened by this ammount 40#define LOCAL_COL_PAD 1 41 42 43/******************************************************************************* 44 * Global Memory Indices 45 * Note: (a==b)==(c==d) means if both are true or neither is true 46 ******************************************************************************/ 47 48/* col-major non-transposed 49 * row-major transposed */ 50#define GET_GLOBAL_INDEX_N(ROW,COL,STRIDE) ((COL)*(STRIDE)+(ROW)) 51 52/* col-major transposed 53 * row-major non-transposed */ 54#define GET_GLOBAL_INDEX_T(ROW,COL,STRIDE) ((ROW)*(STRIDE)+(COL)) 55 56// global A 57#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_A==N) 58#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(lda)) 59#else 60#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(lda)) 61#endif 62 63// global B 64#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_B==N) 65#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldb)) 66#else 67#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldb)) 68#endif 69 70// global C 71#if (ORDER==COLUMN_MAJOR_STR) 72#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldc)) 73#else 74#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldc)) 75#endif 76 77/******************************************************************************* 78 * Local Memory Indices 79 ******************************************************************************/ 80 81// localA - rotated 90 degrees from B but use same accessor unless slow 82#define GET_LOCAL_INDEX_A(ROW,COL) (ROW + COL*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) ) 83#define GET_LOCAL_STEP_A ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \ 84 * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS)) 85 86// localB 87#define GET_LOCAL_INDEX_B(ROW,COL) ((COL) + (ROW)*((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) ) 88#define GET_LOCAL_STEP_B ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \ 89 * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS)) 90 91/******************************************************************************* 92 * Data Types 93 ******************************************************************************/ 94 95// single precision 96#if DATA_TYPE==_S_ 97#define DATA_TYPE_STR float 98#define DATA_TYPE_CHAR s 99#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST); 100#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST); 101 102// double precision 103#elif DATA_TYPE==_D_ 104#define DATA_TYPE_STR double 105#define DATA_TYPE_CHAR d 106#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST); 107#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST); 108 109// complex single precision 110#elif DATA_TYPE==_C_ 111#define DATA_TYPE_STR float2 112#define DATA_TYPE_CHAR c 113#define TYPE_MAD(MUL0,MUL1,DST) \ 114 DST.s0 = mad( MUL0.s0, MUL1.s0, DST.s0 ); \ 115 DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \ 116 DST.s1 = mad( MUL0.s0, MUL1.s1, DST.s1 ); \ 117 DST.s1 = mad( MUL0.s1, MUL1.s0, DST.s1 ); 118#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \ 119 /* (1) */ \ 120 type_mad2_tmp = REG.s0; \ 121 REG.s0 *= ALPHA.s0; \ 122 REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \ 123 REG.s1 *= ALPHA.s0; \ 124 REG.s1 = mad( ALPHA.s1, type_mad2_tmp, REG.s1 ); \ 125 /* (2) */ \ 126 REG.s0 = mad( BETA.s0, DST.s0, REG.s0 ); \ 127 REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \ 128 REG.s1 = mad( BETA.s1, DST.s0, REG.s1 ); \ 129 REG.s1 = mad( BETA.s0, DST.s1, REG.s1 ); \ 130 /* (3) */ \ 131 DST = REG; 132 133// complex double precision 134#else 135#define DATA_TYPE_STR double2 136#define DATA_TYPE_CHAR z 137#define TYPE_MAD(MUL0,MUL1,DST) \ 138 DST.s0 = mad( MUL0.s0, MUL1.s0, DST.s0 ); \ 139 DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \ 140 DST.s1 = mad( MUL0.s0, MUL1.s1, DST.s1 ); \ 141 DST.s1 = mad( MUL0.s1, MUL1.s0, DST.s1 ); 142#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \ 143 /* (1) */ \ 144 type_mad2_tmp = REG.s0; \ 145 REG.s0 *= ALPHA.s0; \ 146 REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \ 147 REG.s1 *= ALPHA.s0; \ 148 REG.s1 = mad( ALPHA.s1, type_mad2_tmp, REG.s1 ); \ 149 /* (2) */ \ 150 REG.s0 = mad( BETA.s0, DST.s0, REG.s0 ); \ 151 REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \ 152 REG.s1 = mad( BETA.s1, DST.s0, REG.s1 ); \ 153 REG.s1 = mad( BETA.s0, DST.s1, REG.s1 ); \ 154 /* (3) */ \ 155 DST = REG; 156 157#endif 158 159/******************************************************************************* 160 * 2x4 micro tile 161 ******************************************************************************/ 162#define MAD2x4 \ 163 rA[0] = localA[offA + 0*WG_NUM_ROWS]; \ 164 rA[1] = localA[offA + 1*WG_NUM_ROWS]; \ 165 rB[0] = localB[offB + 0*WG_NUM_COLS]; \ 166 rB[1] = localB[offB + 1*WG_NUM_COLS]; \ 167 rB[2] = localB[offB + 2*WG_NUM_COLS]; \ 168 rB[3] = localB[offB + 3*WG_NUM_COLS]; \ 169 offA += (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD); \ 170 offB += (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD); \ 171 TYPE_MAD(rA[0],rB[0],rC[0][0]); \ 172 TYPE_MAD(rA[1],rB[0],rC[1][0]); \ 173 TYPE_MAD(rA[0],rB[1],rC[0][1]); \ 174 TYPE_MAD(rA[1],rB[1],rC[1][1]); \ 175 TYPE_MAD(rA[0],rB[2],rC[0][2]); \ 176 TYPE_MAD(rA[1],rB[2],rC[1][2]); \ 177 TYPE_MAD(rA[0],rB[3],rC[0][3]); \ 178 TYPE_MAD(rA[1],rB[3],rC[1][3]); \ 179 mem_fence(CLK_LOCAL_MEM_FENCE); 180 181// concatenate kernel name 182// zgemm_NT_64_32_8_16x16_2x4__ALPHABETA 183#define CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) \ 184 DT ## gemm_ ## TA ## TB ## _ ## TILE_COLS ## _ ## TILE_ROWS ## _ ## NUI ## _ ## WGR ## x ## WGC ## _ ## MTR ## x ## MTC ## __ALPHABETA 185#define KERNEL_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) 186 187/******************************************************************************* 188 * Kernel 189 ******************************************************************************/ 190__attribute__((reqd_work_group_size(WG_NUM_COLS,WG_NUM_ROWS,1))) 191__kernel void KERNEL_NAME(DATA_TYPE_CHAR,TRANSPOSE_A,TRANSPOSE_B,MACRO_TILE_NUM_COLS,MACRO_TILE_NUM_ROWS,NUM_UNROLL_ITER,WG_NUM_ROWS,WG_NUM_COLS,MICRO_TILE_NUM_ROWS,MICRO_TILE_NUM_COLS) ( 192 uint const M, 193 uint const N, 194 uint const K, 195 DATA_TYPE_STR const alpha, 196 DATA_TYPE_STR const beta, 197 __global DATA_TYPE_STR const * restrict A, 198 __global DATA_TYPE_STR const * restrict B, 199 __global DATA_TYPE_STR * C, 200 uint const lda, 201 uint const ldb, 202 uint const ldc, 203 uint const offsetA, 204 uint const offsetB, 205 uint const offsetC ) 206{ 207 // apply offsets 208 A += offsetA; 209 B += offsetB; 210 C += offsetC; 211 212 // registers 213 DATA_TYPE_STR rC[MICRO_TILE_NUM_ROWS][MICRO_TILE_NUM_COLS] = { {0} }; 214 DATA_TYPE_STR rA[MICRO_TILE_NUM_ROWS]; 215 DATA_TYPE_STR rB[MICRO_TILE_NUM_COLS]; 216 217 // local memory 218 __local DATA_TYPE_STR localA[NUM_UNROLL_ITER*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD)]; 219 __local DATA_TYPE_STR localB[NUM_UNROLL_ITER*(MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD)]; 220 221/* 222 * for coalesced C writing 223 * if column major, id(0) is row 224 * if row major, id(0) is col 225 */ 226 uint groupRow = get_group_id(0); 227 uint groupCol = get_group_id(1); 228 uint localRow = get_local_id(0); 229 uint localCol = get_local_id(1); 230 uint localSerial = localRow + localCol*WG_NUM_ROWS; 231 232 /***************************************************************************** 233 * global indices being loaded 234 ****************************************************************************/ 235 // which gAij is this thread responsible for loading? 236#define globalARow (groupRow*MACRO_TILE_NUM_ROWS + localSerial%MACRO_TILE_NUM_ROWS) 237#define globalACol (localSerial/MACRO_TILE_NUM_ROWS) 238#define globalAIdx (GET_GLOBAL_INDEX_A( globalARow, globalACol ) ) 239 A += globalAIdx; 240 // which gBij is this thread responsible for loading? 241#define globalBRow (localSerial/MACRO_TILE_NUM_COLS) 242#define globalBCol (groupCol*MACRO_TILE_NUM_COLS + localSerial%MACRO_TILE_NUM_COLS) 243#define globalBIdx (GET_GLOBAL_INDEX_B( globalBRow, globalBCol ) ) 244 B += globalBIdx; 245 246 uint block_k = K / NUM_UNROLL_ITER; 247#pragma nounroll 248 do { 249 250 /*************************************************************************** 251 * local indices being written 252 **************************************************************************/ 253 // which lAij is this thread responsible for writing? 254#define localARow (localSerial % MACRO_TILE_NUM_ROWS) 255#define localACol (localSerial / MACRO_TILE_NUM_ROWS) 256#define localAStride ( (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) ) 257#define globalAStride ( GET_GLOBAL_INDEX_A(0, (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) ) ) 258#define localAIdx ( GET_LOCAL_INDEX_A(localARow, localACol) ) 259 __local DATA_TYPE_STR *lA = localA + localAIdx; 260 // which lBij is this thread responsible for writing? 261#define localBRow ( localSerial / MACRO_TILE_NUM_COLS ) 262#define localBCol ( localSerial % MACRO_TILE_NUM_COLS ) 263#define localBIdx ( GET_LOCAL_INDEX_B(localBRow, localBCol) ) 264#define localBStride ( (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS) ) 265#define globalBStride ( GET_GLOBAL_INDEX_B( (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS), 0 ) ) 266 __local DATA_TYPE_STR *lB = localB + localBIdx; 267 barrier(CLK_LOCAL_MEM_FENCE); 268 269 /*************************************************************************** 270 * Load global -> local 271 * num loads = num threads / total loads 272 **************************************************************************/ 273 // 2x4 uTile x 8unroll 274 lA[ 0*localAStride ] = A[ 0*globalAStride ]; 275 lB[ 0*localBStride ] = B[ 0*globalBStride ]; 276 lB[ 1*localBStride ] = B[ 1*globalBStride ]; 277 barrier(CLK_LOCAL_MEM_FENCE); 278 279 uint offA = localRow; 280 uint offB = localCol; 281 282 /*************************************************************************** 283 * do mads in registers 284 **************************************************************************/ 285 MAD2x4 286 MAD2x4 287 MAD2x4 288 MAD2x4 289 MAD2x4 290 MAD2x4 291 MAD2x4 292 MAD2x4 293 294 // fully shift 295 A += lda*NUM_UNROLL_ITER; // b/c N 296 B += ldb*NUM_UNROLL_ITER; // b/c T 297 298 } while (--block_k > 0); 299 300 // which global Cij is this thread responsible for computing? 301 uint globalCRow = groupRow * MACRO_TILE_NUM_ROWS + localRow; 302 uint globalCCol = groupCol * MACRO_TILE_NUM_COLS + localCol; 303 304 /*************************************************************************** 305 * write data 306 **************************************************************************/ 307 double type_mad2_tmp; // used in TYPE_MAD2 308 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[0][0], beta ) 309 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[0][1], beta ) 310 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[0][2], beta ) 311 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[0][3], beta ) 312 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[1][0], beta ) 313 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[1][1], beta ) 314 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[1][2], beta ) 315 TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[1][3], beta ) 316 317} 318 319"; 320