1 /*******************************************************************************
2  * Hand-tuned kernel
3  * below kernels work with an assumption: after the main matrix being computed by kernels with 64x64 micro tile size, the boundary are of size 32.
4  * Thus, M and N are of mod32 and not necessarily of mod64.
5  ******************************************************************************/
6 
7 #ifndef KERNEL_SGEMM_COL_NT_B1_MX032_NX064_KX16_ROW_SRC_H
8 #define KERNEL_SGEMM_COL_NT_B1_MX032_NX064_KX16_ROW_SRC_H
9 #pragma message("AutoGemm's sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_src (if exists) overriden by user.")
10 
11 #include "UserGemmKernelSourceIncludes.h"
12 
13 #ifndef STRINGIFY
14 #define STRINGIFY(S) STRINGIFY2(S)
15 #define STRINGIFY2(S) #S
16 #endif
17 
18 const unsigned int sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_workGroupNumRows = 16;
19 const unsigned int sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_workGroupNumCols = 16;
20 const unsigned int sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_microTileNumRows = 2;
21 const unsigned int sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_microTileNumCols = 4;
22 const unsigned int sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_unroll = 16;
23 
24 //if precompiled is enabled. All hand tuned kerenls should be precompiled.
25 #ifndef AUTOGEMM_USE_PRE_COMPILED_KERNELS
26 unsigned char *sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_bin = 0;
27 size_t sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_binSize = 0;
28 #endif
29 
30 const char * const sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_src = STRINGIFY(
31 
32 #define  M2x4 \
33             rA[0][0] = lA[offA + 0];				  \
34             rA[0][1] = lA[offA + 16];				  \
35             rB[0][0] = lB[offB + 0];				  \
36             rB[0][1] = lB[offB + 16];				  \
37             rB[0][2] = lB[offB + 32];				  \
38             rB[0][3] = lB[offB + 48];				  \
39             offA += 33;								  \
40             offB += 65;								  \
41             rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
42             rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
43             rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
44             rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
45             rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \
46             rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \
47             rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \
48             rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \
49             mem_fence(CLK_LOCAL_MEM_FENCE);\n
50 
51 __attribute__((reqd_work_group_size(16,16,1)))
52 __kernel void sgemm_Col_NT_B1_MX032_NX064_KX16_ROW (
53   __global float const * restrict A,
54   __global float const * restrict B,
55   __global float * C,
56   float const alpha,
57   float const beta,
58   uint const M,
59   uint const N,
60   uint const K,
61   uint lda,
62   uint ldb,
63   uint ldc,
64   uint offsetA,
65   uint offsetB,
66   uint offsetC)
67 {
68 	float rC[2][4] = { (float)0 };
69 	float rA[1][2];
70 	float rB[1][4];
71 
72 
73 	A += offsetA;
74 	B += offsetB;
75 	C += offsetC;
76 
77 	__local float lA[528];//16*32+16
78 	__local float lB[1040];//16*64+16
79 
80 	uint gidx = M / 64;//get_group_id(0);
81 	uint gidy = get_group_id(1);
82 	uint idx = get_local_id(0);
83 	uint idy = get_local_id(1);
84 
85 
86 	int CurrentOffSetA = gidx * 64 + idx;
87 
88 	A += gidx * 64 + idx + idy*lda;
89 	B += gidy * 64 + idx + idy*ldb;
90 
91 
92 	uint block_k = K >> 4;
93 	do
94 	{
95 		__local float* plA = lA + idy * 33 + idx;
96 		__local float* plB = lB + idy * 65 + idx;
97 		barrier(CLK_LOCAL_MEM_FENCE);
98 
99 		plB[0] = B[0 + 0 * ldb];
100 		plB[16] = B[16 + 0 * ldb];
101 		plB[32] = B[32 + 0 * ldb];
102 		plB[48] = B[48 + 0 * ldb];
103 
104 		//plA[0]  = CurrentOffSetA>=M?0.0:A[0];
105 		//plA[16] = CurrentOffSetA+16>=M?0.0:A[16];
106 		//plA[32] = CurrentOffSetA+32>=M?0.0:A[32];
107 		//plA[48] = CurrentOffSetA+48>=M?0.0:A[48];
108 		plA[0] = A[0];
109 		plA[16] = A[16];
110 
111 
112 		barrier(CLK_LOCAL_MEM_FENCE);
113 		uint offA = idx;
114 		uint offB = idy;
115 
116 
117 		    M2x4
118 			M2x4
119 			M2x4
120 			M2x4
121 			M2x4
122 			M2x4
123 			M2x4
124 			M2x4
125 			M2x4
126 			M2x4
127 			M2x4
128 			M2x4
129 			M2x4
130 			M2x4
131 			M2x4
132 			M2x4
133 
134 			A += lda << 4;
135 		    B += ldb << 4;
136 	} while (--block_k > 0);
137 
138 
139 	int offset_x = gidx * 64 + idx;
140 	int offset_y = gidy * 64 + idy;
141 
142 	//if(offset_x>=M )
143 	//  return;
144 
145 	C += offset_x + offset_y*ldc;
146 
147 	int i = 0;
148 	do
149 	{
150 		C[0] = mad(alpha, rC[i][0], beta*C[0]);
151 		C[16 * ldc] = mad(alpha, rC[i][1], beta*C[16 * ldc]);
152 		C[32 * ldc] = mad(alpha, rC[i][2], beta*C[32 * ldc]);
153 		C[48 * ldc] = mad(alpha, rC[i][3], beta*C[48 * ldc]);
154 		C += 16;
155 		offset_x += 16;
156 		//if(offset_x>=M )
157 		//  return;
158 	} while (++i < 2);
159 }
160 );
161 #endif
162