1 /******************************************************************************* 2 * Hand-tuned kernel 3 ******************************************************************************/ 4 5 #ifndef KERNEL_SGEMM_COL_NN_B1_MX096_NX096_KX16_SRC_H 6 #define KERNEL_SGEMM_COL_NN_B1_MX096_NX096_KX16_SRC_H 7 #pragma message("AutoGemm's sgemm_Col_NN_B1_MX096_NX096_KX16_src overriden by user.") 8 9 #ifndef STRINGIFY 10 #define STRINGIFY(S) STRINGIFY2(S) 11 #define STRINGIFY2(S) #S 12 #endif 13 14 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_workGroupNumRows = 16; 15 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_workGroupNumCols = 16; 16 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_microTileNumRows = 6; 17 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_microTileNumCols = 6; 18 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_unroll = 16; 19 20 const char * const sgemm_Col_NN_B1_MX096_NX096_KX16_src = STRINGIFY( 21 22 #define M6x6 \ 23 rA[0][0] = lA[offA + 0]; \ 24 rA[0][1] = lA[offA + 16]; \ 25 rA[0][2] = lA[offA + 32]; \ 26 rA[0][3] = lA[offA + 48]; \ 27 rA[0][4] = lA[offA + 64]; \ 28 rA[0][5] = lA[offA + 80]; \ 29 rB[0][0] = lB[offB + 0]; \ 30 rB[0][1] = lB[offB + 16]; \ 31 rB[0][2] = lB[offB + 32]; \ 32 rB[0][3] = lB[offB + 48]; \ 33 rB[0][4] = lB[offB + 64]; \ 34 rB[0][5] = lB[offB + 80]; \ 35 offA += 97; \ 36 offB += 97; \ 37 rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \ 38 rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \ 39 rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \ 40 rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \ 41 rC[4][0]=mad(rA[0][4],rB[0][0],rC[4][0]); \ 42 rC[5][0]=mad(rA[0][5],rB[0][0],rC[5][0]); \ 43 rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \ 44 rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \ 45 rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \ 46 rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \ 47 rC[4][1]=mad(rA[0][4],rB[0][1],rC[4][1]); \ 48 rC[5][1]=mad(rA[0][5],rB[0][1],rC[5][1]); \ 49 rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \ 50 rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \ 51 rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \ 52 rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \ 53 rC[4][2]=mad(rA[0][4],rB[0][2],rC[4][2]); \ 54 rC[5][2]=mad(rA[0][5],rB[0][2],rC[5][2]); \ 55 rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \ 56 rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \ 57 rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \ 58 rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \ 59 rC[4][3]=mad(rA[0][4],rB[0][3],rC[4][3]); \ 60 rC[5][3]=mad(rA[0][5],rB[0][3],rC[5][3]); \ 61 rC[0][4]=mad(rA[0][0],rB[0][4],rC[0][4]); \ 62 rC[1][4]=mad(rA[0][1],rB[0][4],rC[1][4]); \ 63 rC[2][4]=mad(rA[0][2],rB[0][4],rC[2][4]); \ 64 rC[3][4]=mad(rA[0][3],rB[0][4],rC[3][4]); \ 65 rC[4][4]=mad(rA[0][4],rB[0][4],rC[4][4]); \ 66 rC[5][4]=mad(rA[0][5],rB[0][4],rC[5][4]); \ 67 rC[0][5]=mad(rA[0][0],rB[0][5],rC[0][5]); \ 68 rC[1][5]=mad(rA[0][1],rB[0][5],rC[1][5]); \ 69 rC[2][5]=mad(rA[0][2],rB[0][5],rC[2][5]); \ 70 rC[3][5]=mad(rA[0][3],rB[0][5],rC[3][5]); \ 71 rC[4][5]=mad(rA[0][4],rB[0][5],rC[4][5]); \ 72 rC[5][5]=mad(rA[0][5],rB[0][5],rC[5][5]); \ 73 mem_fence(CLK_LOCAL_MEM_FENCE);\n 74 75 __attribute__((reqd_work_group_size(16,16,1))) 76 __kernel void sgemm_Col_NN_B1_MX096_NX096_KX16 ( 77 __global float const * restrict A, 78 __global float const * restrict B, 79 __global float * C, 80 float const alpha, 81 float const beta, 82 uint const M, 83 uint const N, 84 uint const K, 85 uint lda, 86 uint ldb, 87 uint ldc, 88 uint offsetA, 89 uint offsetB, 90 uint offsetC) 91 { 92 float rC[6][6] = { {(float)0} }; 93 float rA[1][6]; 94 float rB[1][6]; 95 96 97 98 A += offsetA; 99 B += offsetB; 100 C+=offsetC; 101 102 __local float lA[1552]; 103 __local float lB[1552]; 104 105 uint gidx = get_group_id(0); 106 uint gidy = get_group_id(1); 107 uint idx = get_local_id(0); 108 uint idy = get_local_id(1); 109 110 A += gidx*96+ idx + idy*lda; 111 B += gidy*96*ldb+ idx + idy*ldb; 112 113 114 uint block_k = K >> 4; 115 do { 116 __local float* plA = lA + idy*97+idx; 117 __local float* plB = lB + idx*97+idy; 118 barrier(CLK_LOCAL_MEM_FENCE); 119 plB[0] = B[0]; 120 plB[16] = B[16*ldb]; 121 plB[32] = B[32*ldb]; 122 plB[48] = B[48*ldb]; 123 plB[64] = B[64*ldb]; 124 plB[80] = B[80*ldb]; 125 126 plA[0] = A[0+0*lda]; 127 plA[16] = A[16+0*lda]; 128 plA[32] = A[32+0*lda]; 129 plA[48] = A[48+0*lda]; 130 plA[64] = A[64+0*lda]; 131 plA[80] = A[80+0*lda]; 132 133 134 barrier(CLK_LOCAL_MEM_FENCE); 135 uint offA = idx; 136 uint offB = idy; 137 138 M6x6 139 M6x6 140 M6x6 141 M6x6 142 M6x6 143 M6x6 144 M6x6 145 M6x6 146 M6x6 147 M6x6 148 M6x6 149 M6x6 150 M6x6 151 M6x6 152 M6x6 153 M6x6 154 155 A += lda<<4; 156 B += 16; 157 } while (--block_k > 0); 158 159 C+= gidx*96+idx; 160 C+= gidy*96*ldc; 161 C+= idy*ldc; 162 163 C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc]; 164 C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc]; 165 C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc]; 166 C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc]; 167 C[64*ldc] = alpha*rC[0][4] + beta*C[64*ldc]; 168 C[80*ldc] = alpha*rC[0][5] + beta*C[80*ldc]; 169 C+=16; 170 C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc]; 171 C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc]; 172 C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc]; 173 C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc]; 174 C[64*ldc] = alpha*rC[1][4] + beta*C[64*ldc]; 175 C[80*ldc] = alpha*rC[1][5] + beta*C[80*ldc]; 176 C+=16; 177 C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc]; 178 C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc]; 179 C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc]; 180 C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc]; 181 C[64*ldc] = alpha*rC[2][4] + beta*C[64*ldc]; 182 C[80*ldc] = alpha*rC[2][5] + beta*C[80*ldc]; 183 C+=16; 184 C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc]; 185 C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc]; 186 C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc]; 187 C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc]; 188 C[64*ldc] = alpha*rC[3][4] + beta*C[64*ldc]; 189 C[80*ldc] = alpha*rC[3][5] + beta*C[80*ldc]; 190 C+=16; 191 C[0*ldc] = alpha*rC[4][0] + beta*C[0*ldc]; 192 C[16*ldc] = alpha*rC[4][1] + beta*C[16*ldc]; 193 C[32*ldc] = alpha*rC[4][2] + beta*C[32*ldc]; 194 C[48*ldc] = alpha*rC[4][3] + beta*C[48*ldc]; 195 C[64*ldc] = alpha*rC[4][4] + beta*C[64*ldc]; 196 C[80*ldc] = alpha*rC[4][5] + beta*C[80*ldc]; 197 C+=16; 198 C[0*ldc] = alpha*rC[5][0] + beta*C[0*ldc]; 199 C[16*ldc] = alpha*rC[5][1] + beta*C[16*ldc]; 200 C[32*ldc] = alpha*rC[5][2] + beta*C[32*ldc]; 201 C[48*ldc] = alpha*rC[5][3] + beta*C[48*ldc]; 202 C[64*ldc] = alpha*rC[5][4] + beta*C[64*ldc]; 203 C[80*ldc] = alpha*rC[5][5] + beta*C[80*ldc]; 204 205 } 206 ); 207 #endif 208