1/*******************************************************************************
2 * Notes:
3 * for column major, id(0) is row so C data is coalesced
4 * for row major, id(0) is col
5 ******************************************************************************/
6
7static const char * zgemm_NT_64_32_8_16x16_2x4__ALPHABETA = "
8
9
10// convert preprocs to ints for comparison
11#define _S_ 1
12#define _D_ 2
13#define _C_ 3
14#define _Z_ 4
15
16/*******************************************************************************
17 * Pre-Processor "Strings"
18 ******************************************************************************/
19#define COLUMN_MAJOR_STR      ColMajor
20#define ROW_MAJOR_STR         RowMajor
21
22/*******************************************************************************
23 * Kernel PreProcessor Definitions
24 ******************************************************************************/
25#define WG_NUM_ROWS           16
26#define WG_NUM_COLS           16
27#define MICRO_TILE_NUM_ROWS   2
28#define MICRO_TILE_NUM_COLS   4
29#define NUM_UNROLL_ITER       8
30#define ORDER                 ColMajor
31#define TRANSPOSE_A           N
32#define TRANSPOSE_B           T
33#define DATA_TYPE             _Z_
34
35#define MACRO_TILE_NUM_ROWS   32
36#define MACRO_TILE_NUM_COLS   64
37// each row lengthened by this ammount
38#define LOCAL_ROW_PAD         1
39// each col lengthened by this ammount
40#define LOCAL_COL_PAD         1
41
42
43/*******************************************************************************
44 * Global Memory Indices
45 * Note: (a==b)==(c==d) means if both are true or neither is true
46 ******************************************************************************/
47
48/* col-major non-transposed
49 * row-major transposed */
50#define GET_GLOBAL_INDEX_N(ROW,COL,STRIDE) ((COL)*(STRIDE)+(ROW))
51
52/* col-major transposed
53 * row-major non-transposed */
54#define GET_GLOBAL_INDEX_T(ROW,COL,STRIDE) ((ROW)*(STRIDE)+(COL))
55
56// global A
57#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_A==N)
58#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(lda))
59#else
60#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(lda))
61#endif
62
63// global B
64#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_B==N)
65#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldb))
66#else
67#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldb))
68#endif
69
70// global C
71#if (ORDER==COLUMN_MAJOR_STR)
72#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldc))
73#else
74#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldc))
75#endif
76
77/*******************************************************************************
78 * Local Memory Indices
79 ******************************************************************************/
80
81// localA - rotated 90 degrees from B but use same accessor unless slow
82#define GET_LOCAL_INDEX_A(ROW,COL) (ROW + COL*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) )
83#define GET_LOCAL_STEP_A ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \
84    * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS))
85
86// localB
87#define GET_LOCAL_INDEX_B(ROW,COL) ((COL) + (ROW)*((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) )
88#define GET_LOCAL_STEP_B ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \
89    * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS))
90
91/*******************************************************************************
92 * Data Types
93 ******************************************************************************/
94
95// single precision
96#if DATA_TYPE==_S_
97#define DATA_TYPE_STR         float
98#define DATA_TYPE_CHAR        s
99#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST);
100#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST);
101
102// double precision
103#elif DATA_TYPE==_D_
104#define DATA_TYPE_STR         double
105#define DATA_TYPE_CHAR        d
106#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST);
107#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST);
108
109// complex single precision
110#elif DATA_TYPE==_C_
111#define DATA_TYPE_STR         float2
112#define DATA_TYPE_CHAR        c
113#define TYPE_MAD(MUL0,MUL1,DST) \
114  DST.s0 = mad(  MUL0.s0, MUL1.s0, DST.s0 ); \
115  DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \
116  DST.s1 = mad(  MUL0.s0, MUL1.s1, DST.s1 ); \
117  DST.s1 = mad(  MUL0.s1, MUL1.s0, DST.s1 );
118#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \
119  /* (1) */ \
120  type_mad2_tmp = REG.s0; \
121  REG.s0 *= ALPHA.s0; \
122  REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \
123  REG.s1 *= ALPHA.s0; \
124  REG.s1 = mad(  ALPHA.s1, type_mad2_tmp, REG.s1 ); \
125  /* (2) */ \
126  REG.s0 = mad(  BETA.s0, DST.s0, REG.s0 ); \
127  REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \
128  REG.s1 = mad(  BETA.s1, DST.s0, REG.s1 ); \
129  REG.s1 = mad(  BETA.s0, DST.s1, REG.s1 ); \
130  /* (3) */ \
131  DST = REG;
132
133// complex double precision
134#else
135#define DATA_TYPE_STR         double2
136#define DATA_TYPE_CHAR        z
137#define TYPE_MAD(MUL0,MUL1,DST) \
138  DST.s0 = mad(  MUL0.s0, MUL1.s0, DST.s0 ); \
139  DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \
140  DST.s1 = mad(  MUL0.s0, MUL1.s1, DST.s1 ); \
141  DST.s1 = mad(  MUL0.s1, MUL1.s0, DST.s1 );
142#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \
143  /* (1) */ \
144  type_mad2_tmp = REG.s0; \
145  REG.s0 *= ALPHA.s0; \
146  REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \
147  REG.s1 *= ALPHA.s0; \
148  REG.s1 = mad(  ALPHA.s1, type_mad2_tmp, REG.s1 ); \
149  /* (2) */ \
150  REG.s0 = mad(  BETA.s0, DST.s0, REG.s0 ); \
151  REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \
152  REG.s1 = mad(  BETA.s1, DST.s0, REG.s1 ); \
153  REG.s1 = mad(  BETA.s0, DST.s1, REG.s1 ); \
154  /* (3) */ \
155  DST = REG;
156
157#endif
158
159/*******************************************************************************
160 * 2x4 micro tile
161 ******************************************************************************/
162#define MAD2x4 \
163  rA[0] = localA[offA + 0*WG_NUM_ROWS]; \
164  rA[1] = localA[offA + 1*WG_NUM_ROWS]; \
165  rB[0] = localB[offB + 0*WG_NUM_COLS]; \
166  rB[1] = localB[offB + 1*WG_NUM_COLS]; \
167  rB[2] = localB[offB + 2*WG_NUM_COLS]; \
168  rB[3] = localB[offB + 3*WG_NUM_COLS]; \
169  offA += (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD); \
170  offB += (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD); \
171  TYPE_MAD(rA[0],rB[0],rC[0][0]); \
172  TYPE_MAD(rA[1],rB[0],rC[1][0]); \
173  TYPE_MAD(rA[0],rB[1],rC[0][1]); \
174  TYPE_MAD(rA[1],rB[1],rC[1][1]); \
175  TYPE_MAD(rA[0],rB[2],rC[0][2]); \
176  TYPE_MAD(rA[1],rB[2],rC[1][2]); \
177  TYPE_MAD(rA[0],rB[3],rC[0][3]); \
178  TYPE_MAD(rA[1],rB[3],rC[1][3]); \
179  mem_fence(CLK_LOCAL_MEM_FENCE);
180
181// concatenate kernel name
182// zgemm_NT_64_32_8_16x16_2x4__ALPHABETA
183#define CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) \
184  DT ## gemm_ ## TA ## TB ## _ ## TILE_COLS ## _ ## TILE_ROWS ## _ ## NUI ## _ ## WGR ## x ## WGC ## _ ## MTR ## x ## MTC ## __ALPHABETA
185#define KERNEL_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC)
186
187/*******************************************************************************
188 * Kernel
189 ******************************************************************************/
190__attribute__((reqd_work_group_size(WG_NUM_COLS,WG_NUM_ROWS,1)))
191__kernel void KERNEL_NAME(DATA_TYPE_CHAR,TRANSPOSE_A,TRANSPOSE_B,MACRO_TILE_NUM_COLS,MACRO_TILE_NUM_ROWS,NUM_UNROLL_ITER,WG_NUM_ROWS,WG_NUM_COLS,MICRO_TILE_NUM_ROWS,MICRO_TILE_NUM_COLS) (
192  uint const M,
193  uint const N,
194  uint const K,
195  DATA_TYPE_STR const alpha,
196  DATA_TYPE_STR const beta,
197  __global DATA_TYPE_STR const * restrict A,
198  __global DATA_TYPE_STR const * restrict B,
199  __global DATA_TYPE_STR       *          C,
200  uint const lda,
201  uint const ldb,
202  uint const ldc,
203  uint const offsetA,
204  uint const offsetB,
205  uint const offsetC )
206{
207  // apply offsets
208  A += offsetA;
209  B += offsetB;
210  C += offsetC;
211
212  // registers
213  DATA_TYPE_STR rC[MICRO_TILE_NUM_ROWS][MICRO_TILE_NUM_COLS]  = { {0} };
214  DATA_TYPE_STR rA[MICRO_TILE_NUM_ROWS];
215  DATA_TYPE_STR rB[MICRO_TILE_NUM_COLS];
216
217  // local memory
218  __local DATA_TYPE_STR localA[NUM_UNROLL_ITER*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD)];
219  __local DATA_TYPE_STR localB[NUM_UNROLL_ITER*(MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD)];
220
221/*
222 * for coalesced C writing
223 * if column major, id(0) is row
224 * if row major, id(0) is col
225 */
226  uint groupRow = get_group_id(0);
227  uint groupCol = get_group_id(1);
228  uint localRow = get_local_id(0);
229  uint localCol = get_local_id(1);
230  uint localSerial = localRow + localCol*WG_NUM_ROWS;
231
232  /*****************************************************************************
233   * global indices being loaded
234   ****************************************************************************/
235  // which gAij is this thread responsible for loading?
236#define globalARow (groupRow*MACRO_TILE_NUM_ROWS + localSerial%MACRO_TILE_NUM_ROWS)
237#define globalACol (localSerial/MACRO_TILE_NUM_ROWS)
238#define globalAIdx (GET_GLOBAL_INDEX_A( globalARow, globalACol ) )
239  A += globalAIdx;
240  // which gBij is this thread responsible for loading?
241#define globalBRow (localSerial/MACRO_TILE_NUM_COLS)
242#define globalBCol (groupCol*MACRO_TILE_NUM_COLS + localSerial%MACRO_TILE_NUM_COLS)
243#define globalBIdx (GET_GLOBAL_INDEX_B( globalBRow, globalBCol ) )
244  B += globalBIdx;
245
246  uint block_k = K / NUM_UNROLL_ITER;
247#pragma nounroll
248  do {
249
250    /***************************************************************************
251     * local indices being written
252     **************************************************************************/
253    // which lAij is this thread responsible for writing?
254#define localARow (localSerial % MACRO_TILE_NUM_ROWS)
255#define localACol (localSerial / MACRO_TILE_NUM_ROWS)
256#define localAStride ( (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) )
257#define globalAStride ( GET_GLOBAL_INDEX_A(0, (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) ) )
258#define localAIdx ( GET_LOCAL_INDEX_A(localARow, localACol) )
259    __local DATA_TYPE_STR *lA = localA + localAIdx;
260    // which lBij is this thread responsible for writing?
261#define localBRow ( localSerial / MACRO_TILE_NUM_COLS )
262#define localBCol ( localSerial % MACRO_TILE_NUM_COLS )
263#define localBIdx ( GET_LOCAL_INDEX_B(localBRow, localBCol) )
264#define localBStride  ( (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS) )
265#define globalBStride ( GET_GLOBAL_INDEX_B( (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS), 0 ) )
266    __local DATA_TYPE_STR *lB = localB + localBIdx;
267    barrier(CLK_LOCAL_MEM_FENCE);
268
269    /***************************************************************************
270     * Load global -> local
271     * num loads = num threads / total loads
272     **************************************************************************/
273    // 2x4 uTile x 8unroll
274    lA[ 0*localAStride ] = A[ 0*globalAStride ];
275    lB[ 0*localBStride ] = B[ 0*globalBStride ];
276    lB[ 1*localBStride ] = B[ 1*globalBStride ];
277    barrier(CLK_LOCAL_MEM_FENCE);
278
279    uint offA = localRow;
280    uint offB = localCol;
281
282    /***************************************************************************
283     * do mads in registers
284     **************************************************************************/
285    MAD2x4
286    MAD2x4
287    MAD2x4
288    MAD2x4
289    MAD2x4
290    MAD2x4
291    MAD2x4
292    MAD2x4
293
294    // fully shift
295    A += lda*NUM_UNROLL_ITER; // b/c N
296    B += ldb*NUM_UNROLL_ITER; // b/c T
297
298  } while (--block_k > 0);
299
300  // which global Cij is this thread responsible for computing?
301  uint globalCRow = groupRow * MACRO_TILE_NUM_ROWS + localRow;
302  uint globalCCol = groupCol * MACRO_TILE_NUM_COLS + localCol;
303
304  /***************************************************************************
305   * write data
306   **************************************************************************/
307  double type_mad2_tmp; // used in TYPE_MAD2
308  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[0][0], beta )
309  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[0][1], beta )
310  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[0][2], beta )
311  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[0][3], beta )
312  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[1][0], beta )
313  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[1][1], beta )
314  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[1][2], beta )
315  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[1][3], beta )
316
317}
318
319";
320