1 /******************************************************************************* 2 * Hand-tuned kernel 3 ******************************************************************************/ 4 5 #ifndef KERNEL_SGEMM_COL_TN_B1_MX064_NX064_KX16_SRC_H 6 #define KERNEL_SGEMM_COL_TN_B1_MX064_NX064_KX16_SRC_H 7 #pragma message("AutoGemm's sgemm_Col_TN_B1_MX064_NX064_KX16_src overriden by user.") 8 9 #ifndef STRINGIFY 10 #define STRINGIFY(S) STRINGIFY2(S) 11 #define STRINGIFY2(S) #S 12 #endif 13 14 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_workGroupNumRows = 16; 15 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_workGroupNumCols = 16; 16 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_microTileNumRows = 4; 17 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_microTileNumCols = 4; 18 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_unroll = 16; 19 20 const char * const sgemm_Col_TN_B1_MX064_NX064_KX16_src = STRINGIFY( 21 22 #define M4x4 \ 23 rA[0][0] = lA[offA + 0]; \ 24 rA[0][1] = lA[offA + 16]; \ 25 rA[0][2] = lA[offA + 32]; \ 26 rA[0][3] = lA[offA + 48]; \ 27 rB[0][0] = lB[offB + 0]; \ 28 rB[0][1] = lB[offB + 16]; \ 29 rB[0][2] = lB[offB + 32]; \ 30 rB[0][3] = lB[offB + 48]; \ 31 offA += 65; \ 32 offB += 65; \ 33 rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \ 34 rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \ 35 rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \ 36 rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \ 37 rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \ 38 rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \ 39 rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \ 40 rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \ 41 rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \ 42 rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \ 43 rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \ 44 rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \ 45 rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \ 46 rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \ 47 rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \ 48 rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \ 49 mem_fence(CLK_LOCAL_MEM_FENCE);\n 50 51 __attribute__((reqd_work_group_size(16,16,1))) 52 __kernel void sgemm_Col_TN_B1_MX064_NX064_KX16 ( 53 __global float const * restrict A, 54 __global float const * restrict B, 55 __global float * C, 56 float const alpha, 57 float const beta, 58 uint const M, 59 uint const N, 60 uint const K, 61 uint lda, 62 uint ldb, 63 uint ldc, 64 uint offsetA, 65 uint offsetB, 66 uint offsetC) 67 { 68 float rC[4][4] = { {(float)0} }; 69 float rA[1][4]; 70 float rB[1][4]; 71 72 73 A += offsetA; 74 B += offsetB; 75 C+=offsetC; 76 77 __local float lA[1056]; 78 __local float lB[1056]; 79 80 uint gidx = get_group_id(0); 81 uint gidy = get_group_id(1); 82 uint idx = get_local_id(0); 83 uint idy = get_local_id(1); 84 85 uint idt = 16*idy + idx; 86 uint idxT = idt % 16; 87 uint idyT = idt / 16; 88 89 A += gidx*64*lda+ idxT + idyT*lda; 90 B += gidy*64*ldb+ idxT + idyT*ldb; 91 92 93 uint block_k = K >> 4; 94 do 95 { 96 __local float* plA = lA + idxT*65+idyT; 97 __local float* plB = lB + idxT*65+idyT; 98 barrier(CLK_LOCAL_MEM_FENCE); 99 plB[0] = B[0]; 100 plB[16] = B[16*ldb]; 101 plB[32] = B[32*ldb]; 102 plB[48] = B[48*ldb]; 103 104 plA[0] = A[0]; 105 plA[16] = A[16*lda]; 106 plA[32] = A[32*lda]; 107 plA[48] = A[48*lda]; 108 109 110 barrier(CLK_LOCAL_MEM_FENCE); 111 uint offA = idx; 112 uint offB = idy; 113 114 115 M4x4 116 M4x4 117 M4x4 118 M4x4 119 M4x4 120 M4x4 121 M4x4 122 M4x4 123 M4x4 124 M4x4 125 M4x4 126 M4x4 127 M4x4 128 M4x4 129 M4x4 130 M4x4 131 132 A += 16; 133 B += 16; 134 135 } while (--block_k > 0); 136 137 C+= gidx*64+idx; 138 C+= gidy*64*ldc; 139 C+= idy*ldc; 140 141 C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc]; 142 C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc]; 143 C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc]; 144 C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc]; 145 146 C+=16; 147 C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc]; 148 C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc]; 149 C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc]; 150 C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc]; 151 152 C+=16; 153 C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc]; 154 C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc]; 155 C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc]; 156 C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc]; 157 158 C+=16; 159 C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc]; 160 C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc]; 161 C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc]; 162 C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc]; 163 } 164 ); 165 #endif 166