1 /******************************************************************************* 2 * Hand-tuned kernel 3 ******************************************************************************/ 4 5 #ifndef KERNEL_SGEMM_COL_NT_B1_MX064_NX064_KX16_SRC_H 6 #define KERNEL_SGEMM_COL_NT_B1_MX064_NX064_KX16_SRC_H 7 #pragma message("AutoGemm's sgemm_Col_NT_B1_MX064_NX064_KX16_src overriden by user.") 8 9 #ifndef STRINGIFY 10 #define STRINGIFY(S) STRINGIFY2(S) 11 #define STRINGIFY2(S) #S 12 #endif 13 14 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_workGroupNumRows = 16; 15 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_workGroupNumCols = 16; 16 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_microTileNumRows = 4; 17 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_microTileNumCols = 4; 18 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_unroll = 16; 19 20 const char * const sgemm_Col_NT_B1_MX064_NX064_KX16_src = STRINGIFY( 21 22 #define M4x4 \ 23 rA[0][0] = lA[offA + 0]; \ 24 rA[0][1] = lA[offA + 16]; \ 25 rA[0][2] = lA[offA + 32]; \ 26 rA[0][3] = lA[offA + 48]; \ 27 rB[0][0] = lB[offB + 0]; \ 28 rB[0][1] = lB[offB + 16]; \ 29 rB[0][2] = lB[offB + 32]; \ 30 rB[0][3] = lB[offB + 48]; \ 31 offA += 65; \ 32 offB += 65; \ 33 rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \ 34 rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \ 35 rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \ 36 rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \ 37 rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \ 38 rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \ 39 rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \ 40 rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \ 41 rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \ 42 rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \ 43 rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \ 44 rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \ 45 rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \ 46 rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \ 47 rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \ 48 rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \ 49 mem_fence(CLK_LOCAL_MEM_FENCE);\n 50 51 __attribute__((reqd_work_group_size(16,16,1))) 52 __kernel void sgemm_Col_NT_B1_MX064_NX064_KX16 ( 53 __global float const * restrict A, 54 __global float const * restrict B, 55 __global float * C, 56 float const alpha, 57 float const beta, 58 uint const M, 59 uint const N, 60 uint const K, 61 uint lda, 62 uint ldb, 63 uint ldc, 64 uint offsetA, 65 uint offsetB, 66 uint offsetC) 67 { 68 float rC[4][4] = { {(float)0} }; 69 float rA[1][4]; 70 float rB[1][4]; 71 72 A += offsetA; 73 B += offsetB; 74 C+=offsetC; 75 76 __local float lA[1040]; 77 __local float lB[1040]; 78 79 uint gidx = get_group_id(0); 80 uint gidy = get_group_id(1); 81 uint idx = get_local_id(0); 82 uint idy = get_local_id(1); 83 84 uint idt = 16*idy + idx; 85 uint idxT = idt % 16; 86 uint idyT = idt / 16; 87 88 A += gidx*64+ idxT + idyT*lda; 89 B += gidy*64+ idxT + idyT*ldb; 90 91 92 uint block_k = K >> 4; 93 do 94 { 95 __local float* plA = lA + idyT*65+idxT; 96 __local float* plB = lB + idyT*65+idxT; 97 barrier(CLK_LOCAL_MEM_FENCE); 98 plB[0] = B[0+0*ldb]; 99 plB[16] = B[16+0*ldb]; 100 plB[32] = B[32+0*ldb]; 101 plB[48] = B[48+0*ldb]; 102 103 plA[0] = A[0+0*lda]; 104 plA[16] = A[16+0*lda]; 105 plA[32] = A[32+0*lda]; 106 plA[48] = A[48+0*lda]; 107 108 109 barrier(CLK_LOCAL_MEM_FENCE); 110 uint offA = idx; 111 uint offB = idy; 112 113 M4x4 114 M4x4 115 M4x4 116 M4x4 117 M4x4 118 M4x4 119 M4x4 120 M4x4 121 M4x4 122 M4x4 123 M4x4 124 M4x4 125 M4x4 126 M4x4 127 M4x4 128 M4x4 129 130 A += lda<<4; 131 B += ldb<<4; 132 } while (--block_k > 0); 133 134 C+= gidx*64+idx; 135 C+= gidy*64*ldc; 136 C+= idy*ldc; 137 138 C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc]; 139 C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc]; 140 C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc]; 141 C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc]; 142 C+=16; 143 C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc]; 144 C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc]; 145 C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc]; 146 C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc]; 147 C+=16; 148 C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc]; 149 C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc]; 150 C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc]; 151 C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc]; 152 C+=16; 153 C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc]; 154 C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc]; 155 C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc]; 156 C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc]; 157 158 } 159 ); 160 #endif 161