1 /*******************************************************************************
2  * Hand-tuned kernel
3  * below kernels work with an assumption: after the main matrix being computed by kernels with 64x64 micro tile size, the boundary are of size 32.
4  * Thus, M and N are of mod32 and not necessarily of mod64.
5  ******************************************************************************/
6 
7 #ifndef KERNEL_SGEMM_COL_NT_B1_MX064_NX032_KX16_COLUMN_SRC_H
8 #define KERNEL_SGEMM_COL_NT_B1_MX064_NX032_KX16_COLUMN_SRC_H
9 #pragma message("AutoGemm's sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_src (if exists) overriden by user.")
10 #include "UserGemmKernelSourceIncludes.h"
11 
12 #ifndef STRINGIFY
13 #define STRINGIFY(S) STRINGIFY2(S)
14 #define STRINGIFY2(S) #S
15 #endif
16 
17 const unsigned int sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_workGroupNumRows = 16;
18 const unsigned int sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_workGroupNumCols = 16;
19 const unsigned int sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_microTileNumRows = 4;
20 const unsigned int sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_microTileNumCols = 2;
21 const unsigned int sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_unroll = 16;
22 
23 //if precompiled is enabled. All hand tuned kerenls should be precompiled.
24 #ifndef AUTOGEMM_USE_PRE_COMPILED_KERNELS
25 unsigned char *sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_bin = 0;
26 size_t sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_binSize = 0;
27 #endif
28 
29 const char * const sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_src = STRINGIFY(
30 
31 #define  M4x2 \
32             rA[0][0] = lA[offA + 0];				  \
33             rA[0][1] = lA[offA + 16];				  \
34             rA[0][2] = lA[offA + 32];				  \
35             rA[0][3] = lA[offA + 48];				  \
36             rB[0][0] = lB[offB + 0];				  \
37             rB[0][1] = lB[offB + 16];				  \
38             offA += 65;								  \
39             offB += 33;								  \
40             rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
41             rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
42             rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \
43             rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \
44             rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
45             rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
46             rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \
47             rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \
48             mem_fence(CLK_LOCAL_MEM_FENCE);\n
49 
50 __attribute__((reqd_work_group_size(16,16,1)))
51 __kernel void sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN (
52   __global float const * restrict A,
53   __global float const * restrict B,
54   __global float * C,
55   float const alpha,
56   float const beta,
57   uint const M,
58   uint const N,
59   uint const K,
60   uint lda,
61   uint ldb,
62   uint ldc,
63   uint offsetA,
64   uint offsetB,
65   uint offsetC)
66 {
67 	float rC[4][2] = { (float)0 };
68 	float rA[1][4];
69 	float rB[1][2];
70 
71 
72 	A += offsetA;
73 	B += offsetB;
74 	C += offsetC;
75 
76 	__local float lA[1040];//16*64+16
77 	__local float lB[528];//16*32+16
78 
79 	uint gidx = get_group_id(0);
80 	uint gidy = N / 64;//get_group_id(1);
81 	uint idx = get_local_id(0);
82 	uint idy = get_local_id(1);
83 
84 	int CurrentOffSetB = gidy * 64 + idx;
85 
86 	A += gidx * 64 + idx + idy*lda;
87 	B += gidy * 64 + idx + idy*ldb;
88 
89 
90 	uint block_k = K >> 4;
91 	do
92 	{
93 		__local float* plA = lA + idy * 65 + idx;
94 		__local float* plB = lB + idy * 33 + idx;
95 		barrier(CLK_LOCAL_MEM_FENCE);
96 
97 		//plB[0]  = CurrentOffSetB>=N?0.0:B[0];
98 		//plB[16] = CurrentOffSetB+16>=N?0.0:B[16];
99 		//plB[32] = CurrentOffSetB+32>=N?0.0:B[32];
100 		//plB[48] = CurrentOffSetB+48>=N?0.0:B[48];
101 		plB[0] = B[0];
102 		plB[16] = B[16];
103 
104 		plA[0] = A[0];
105 		plA[16] = A[16];
106 		plA[32] = A[32];
107 		plA[48] = A[48];
108 
109 
110 		barrier(CLK_LOCAL_MEM_FENCE);
111 		uint offA = idx;
112 		uint offB = idy;
113 
114 
115 		    M4x2
116 			M4x2
117 			M4x2
118 			M4x2
119 			M4x2
120 			M4x2
121 			M4x2
122 			M4x2
123 			M4x2
124 			M4x2
125 			M4x2
126 			M4x2
127 			M4x2
128 			M4x2
129 			M4x2
130 			M4x2
131 
132 			A += lda << 4;
133 		B += ldb << 4;
134 	} while (--block_k > 0);
135 
136 
137 	int offset_x = gidx * 64 + idx;
138 	int offset_y = gidy * 64 + idy;
139 
140 	//if(offset_y>=N )
141 	// return;
142 
143 	C += offset_x + offset_y*ldc;
144 
145 	int i = 0;
146 	do
147 	{
148 		C[0] = mad(alpha, rC[i][0], beta*C[0]);
149 		C[16 * ldc] = mad(alpha, rC[i][1], beta*C[16 * ldc]);
150 
151 		C += 16;
152 
153 	} while (++i < 4);
154 
155 }
156 );
157 #endif
158