1 /*******************************************************************************
2  * Hand-tuned kernel
3  ******************************************************************************/
4 
5 #ifndef KERNEL_SGEMM_COL_NN_B1_MX096_NX096_KX16_SRC_H
6 #define KERNEL_SGEMM_COL_NN_B1_MX096_NX096_KX16_SRC_H
7 #pragma message("AutoGemm's sgemm_Col_NN_B1_MX096_NX096_KX16_src overriden by user.")
8 
9 #ifndef STRINGIFY
10 #define STRINGIFY(S) STRINGIFY2(S)
11 #define STRINGIFY2(S) #S
12 #endif
13 
14 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_workGroupNumRows = 16;
15 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_workGroupNumCols = 16;
16 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_microTileNumRows = 6;
17 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_microTileNumCols = 6;
18 const unsigned int sgemm_Col_NN_B1_MX096_NX096_KX16_unroll = 16;
19 
20 const char * const sgemm_Col_NN_B1_MX096_NX096_KX16_src = STRINGIFY(
21 
22 #define  M6x6 \
23             rA[0][0] = lA[offA + 0];				  \
24             rA[0][1] = lA[offA + 16];				  \
25             rA[0][2] = lA[offA + 32];				  \
26             rA[0][3] = lA[offA + 48];				  \
27             rA[0][4] = lA[offA + 64];				  \
28             rA[0][5] = lA[offA + 80];				  \
29             rB[0][0] = lB[offB + 0];				  \
30             rB[0][1] = lB[offB + 16];				  \
31             rB[0][2] = lB[offB + 32];				  \
32             rB[0][3] = lB[offB + 48];				  \
33             rB[0][4] = lB[offB + 64];				  \
34             rB[0][5] = lB[offB + 80];				  \
35             offA += 97;								  \
36             offB += 97;								  \
37             rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
38             rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
39             rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \
40             rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \
41             rC[4][0]=mad(rA[0][4],rB[0][0],rC[4][0]); \
42             rC[5][0]=mad(rA[0][5],rB[0][0],rC[5][0]); \
43             rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
44             rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
45             rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \
46             rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \
47             rC[4][1]=mad(rA[0][4],rB[0][1],rC[4][1]); \
48             rC[5][1]=mad(rA[0][5],rB[0][1],rC[5][1]); \
49             rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \
50             rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \
51             rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \
52             rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \
53             rC[4][2]=mad(rA[0][4],rB[0][2],rC[4][2]); \
54             rC[5][2]=mad(rA[0][5],rB[0][2],rC[5][2]); \
55             rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \
56             rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \
57             rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \
58             rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \
59             rC[4][3]=mad(rA[0][4],rB[0][3],rC[4][3]); \
60             rC[5][3]=mad(rA[0][5],rB[0][3],rC[5][3]); \
61             rC[0][4]=mad(rA[0][0],rB[0][4],rC[0][4]); \
62             rC[1][4]=mad(rA[0][1],rB[0][4],rC[1][4]); \
63             rC[2][4]=mad(rA[0][2],rB[0][4],rC[2][4]); \
64             rC[3][4]=mad(rA[0][3],rB[0][4],rC[3][4]); \
65             rC[4][4]=mad(rA[0][4],rB[0][4],rC[4][4]); \
66             rC[5][4]=mad(rA[0][5],rB[0][4],rC[5][4]); \
67             rC[0][5]=mad(rA[0][0],rB[0][5],rC[0][5]); \
68             rC[1][5]=mad(rA[0][1],rB[0][5],rC[1][5]); \
69             rC[2][5]=mad(rA[0][2],rB[0][5],rC[2][5]); \
70             rC[3][5]=mad(rA[0][3],rB[0][5],rC[3][5]); \
71             rC[4][5]=mad(rA[0][4],rB[0][5],rC[4][5]); \
72             rC[5][5]=mad(rA[0][5],rB[0][5],rC[5][5]); \
73 			      mem_fence(CLK_LOCAL_MEM_FENCE);\n
74 
75 __attribute__((reqd_work_group_size(16,16,1)))
76 __kernel void sgemm_Col_NN_B1_MX096_NX096_KX16 (
77   __global float const * restrict A,
78   __global float const * restrict B,
79   __global float * C,
80   float const alpha,
81   float const beta,
82   uint const M,
83   uint const N,
84   uint const K,
85   uint lda,
86   uint ldb,
87   uint ldc,
88   uint offsetA,
89   uint offsetB,
90   uint offsetC)
91 {
92     float rC[6][6]  = { {(float)0} };
93     float rA[1][6];
94     float rB[1][6];
95 
96 
97 
98     A += offsetA;
99     B += offsetB;
100     C+=offsetC;
101 
102     __local float lA[1552];
103     __local float lB[1552];
104 
105     uint gidx = get_group_id(0);
106     uint gidy = get_group_id(1);
107     uint idx = get_local_id(0);
108     uint idy = get_local_id(1);
109 
110     A +=  gidx*96+ idx + idy*lda;
111     B +=  gidy*96*ldb+ idx + idy*ldb;
112 
113 
114     uint block_k = K >> 4;
115     do {
116         __local float* plA = lA + idy*97+idx;
117         __local float* plB = lB + idx*97+idy;
118 		    barrier(CLK_LOCAL_MEM_FENCE);
119         plB[0] = B[0];
120         plB[16] = B[16*ldb];
121         plB[32] = B[32*ldb];
122         plB[48] = B[48*ldb];
123         plB[64] = B[64*ldb];
124         plB[80] = B[80*ldb];
125 
126 	      plA[0] = A[0+0*lda];
127         plA[16] = A[16+0*lda];
128         plA[32] = A[32+0*lda];
129         plA[48] = A[48+0*lda];
130         plA[64] = A[64+0*lda];
131         plA[80] = A[80+0*lda];
132 
133 
134         barrier(CLK_LOCAL_MEM_FENCE);
135         uint offA = idx;
136         uint offB = idy;
137 
138         M6x6
139 	      M6x6
140 	      M6x6
141 	      M6x6
142 	      M6x6
143 	      M6x6
144 	      M6x6
145 	      M6x6
146 	      M6x6
147 	      M6x6
148 	      M6x6
149 	      M6x6
150 	      M6x6
151 	      M6x6
152 	      M6x6
153 	      M6x6
154 
155         A += lda<<4;
156         B += 16;
157 	} while (--block_k > 0);
158 
159     C+= gidx*96+idx;
160     C+= gidy*96*ldc;
161     C+= idy*ldc;
162 
163     C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc];
164     C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc];
165     C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc];
166     C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc];
167     C[64*ldc] = alpha*rC[0][4] + beta*C[64*ldc];
168     C[80*ldc] = alpha*rC[0][5] + beta*C[80*ldc];
169     C+=16;
170     C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc];
171     C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc];
172     C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc];
173     C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc];
174     C[64*ldc] = alpha*rC[1][4] + beta*C[64*ldc];
175     C[80*ldc] = alpha*rC[1][5] + beta*C[80*ldc];
176     C+=16;
177     C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc];
178     C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc];
179     C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc];
180     C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc];
181     C[64*ldc] = alpha*rC[2][4] + beta*C[64*ldc];
182     C[80*ldc] = alpha*rC[2][5] + beta*C[80*ldc];
183     C+=16;
184     C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc];
185     C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc];
186     C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc];
187     C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc];
188     C[64*ldc] = alpha*rC[3][4] + beta*C[64*ldc];
189     C[80*ldc] = alpha*rC[3][5] + beta*C[80*ldc];
190     C+=16;
191     C[0*ldc] = alpha*rC[4][0] + beta*C[0*ldc];
192     C[16*ldc] = alpha*rC[4][1] + beta*C[16*ldc];
193     C[32*ldc] = alpha*rC[4][2] + beta*C[32*ldc];
194     C[48*ldc] = alpha*rC[4][3] + beta*C[48*ldc];
195     C[64*ldc] = alpha*rC[4][4] + beta*C[64*ldc];
196     C[80*ldc] = alpha*rC[4][5] + beta*C[80*ldc];
197     C+=16;
198     C[0*ldc] = alpha*rC[5][0] + beta*C[0*ldc];
199     C[16*ldc] = alpha*rC[5][1] + beta*C[16*ldc];
200     C[32*ldc] = alpha*rC[5][2] + beta*C[32*ldc];
201     C[48*ldc] = alpha*rC[5][3] + beta*C[48*ldc];
202     C[64*ldc] = alpha*rC[5][4] + beta*C[64*ldc];
203     C[80*ldc] = alpha*rC[5][5] + beta*C[80*ldc];
204 
205 }
206 );
207 #endif
208