1 /*******************************************************************************
2  * Hand-tuned kernel
3  ******************************************************************************/
4 
5 #ifndef KERNEL_SGEMM_COL_TN_B1_MX064_NX064_KX16_SRC_H
6 #define KERNEL_SGEMM_COL_TN_B1_MX064_NX064_KX16_SRC_H
7 #pragma message("AutoGemm's sgemm_Col_TN_B1_MX064_NX064_KX16_src overriden by user.")
8 
9 #ifndef STRINGIFY
10 #define STRINGIFY(S) STRINGIFY2(S)
11 #define STRINGIFY2(S) #S
12 #endif
13 
14 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_workGroupNumRows = 16;
15 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_workGroupNumCols = 16;
16 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_microTileNumRows = 4;
17 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_microTileNumCols = 4;
18 const unsigned int sgemm_Col_TN_B1_MX064_NX064_KX16_unroll = 16;
19 
20 const char * const sgemm_Col_TN_B1_MX064_NX064_KX16_src = STRINGIFY(
21 
22 #define  M4x4 \
23             rA[0][0] = lA[offA + 0];				  \
24             rA[0][1] = lA[offA + 16];				  \
25             rA[0][2] = lA[offA + 32];				  \
26             rA[0][3] = lA[offA + 48];				  \
27             rB[0][0] = lB[offB + 0];				  \
28             rB[0][1] = lB[offB + 16];				  \
29             rB[0][2] = lB[offB + 32];				  \
30             rB[0][3] = lB[offB + 48];				  \
31             offA += 65;								  \
32             offB += 65;								  \
33             rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
34             rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
35             rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \
36             rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \
37             rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
38             rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
39             rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \
40             rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \
41             rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \
42             rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \
43             rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \
44             rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \
45             rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \
46             rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \
47             rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \
48             rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \
49 			      mem_fence(CLK_LOCAL_MEM_FENCE);\n
50 
51 __attribute__((reqd_work_group_size(16,16,1)))
52 __kernel void sgemm_Col_TN_B1_MX064_NX064_KX16 (
53   __global float const * restrict A,
54   __global float const * restrict B,
55   __global float * C,
56   float const alpha,
57   float const beta,
58   uint const M,
59   uint const N,
60   uint const K,
61   uint lda,
62   uint ldb,
63   uint ldc,
64   uint offsetA,
65   uint offsetB,
66   uint offsetC)
67 {
68   float rC[4][4]  = { {(float)0} };
69   float rA[1][4];
70   float rB[1][4];
71 
72 
73   A += offsetA;
74   B += offsetB;
75   C+=offsetC;
76 
77   __local float lA[1056];
78   __local float lB[1056];
79 
80   uint gidx = get_group_id(0);
81   uint gidy = get_group_id(1);
82   uint idx = get_local_id(0);
83   uint idy = get_local_id(1);
84 
85   uint idt = 16*idy + idx;
86   uint idxT = idt % 16;
87   uint idyT = idt / 16;
88 
89   A +=  gidx*64*lda+ idxT + idyT*lda;
90   B +=  gidy*64*ldb+ idxT + idyT*ldb;
91 
92 
93   uint block_k = K >> 4;
94   do
95   {
96     __local float* plA = lA + idxT*65+idyT;
97     __local float* plB = lB + idxT*65+idyT;
98     barrier(CLK_LOCAL_MEM_FENCE);
99     plB[0] = B[0];
100     plB[16] = B[16*ldb];
101     plB[32] = B[32*ldb];
102     plB[48] = B[48*ldb];
103 
104     plA[0] = A[0];
105     plA[16] = A[16*lda];
106     plA[32] = A[32*lda];
107     plA[48] = A[48*lda];
108 
109 
110     barrier(CLK_LOCAL_MEM_FENCE);
111     uint offA = idx;
112     uint offB = idy;
113 
114 
115     M4x4
116     M4x4
117     M4x4
118     M4x4
119     M4x4
120     M4x4
121     M4x4
122     M4x4
123     M4x4
124     M4x4
125     M4x4
126     M4x4
127     M4x4
128     M4x4
129     M4x4
130     M4x4
131 
132     A += 16;
133     B += 16;
134 
135   } while (--block_k > 0);
136 
137   C+= gidx*64+idx;
138   C+= gidy*64*ldc;
139   C+= idy*ldc;
140 
141   C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc];
142   C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc];
143   C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc];
144   C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc];
145 
146   C+=16;
147   C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc];
148   C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc];
149   C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc];
150   C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc];
151 
152   C+=16;
153   C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc];
154   C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc];
155   C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc];
156   C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc];
157 
158   C+=16;
159   C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc];
160   C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc];
161   C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc];
162   C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc];
163 }
164 );
165 #endif
166