1 /*******************************************************************************
2  * Hand-tuned kernel
3  ******************************************************************************/
4 
5 #ifndef KERNEL_SGEMM_COL_NT_B0_MX064_NX064_KX16_SRC_H
6 #define KERNEL_SGEMM_COL_NT_B0_MX064_NX064_KX16_SRC_H
7 #pragma message("AutoGemm's sgemm_Col_NT_B0_MX064_NX064_KX16_src overriden by user.")
8 
9 #ifndef STRINGIFY
10 #define STRINGIFY(S) STRINGIFY2(S)
11 #define STRINGIFY2(S) #S
12 #endif
13 
14 const unsigned int sgemm_Col_NT_B0_MX064_NX064_KX16_workGroupNumRows = 16;
15 const unsigned int sgemm_Col_NT_B0_MX064_NX064_KX16_workGroupNumCols = 16;
16 const unsigned int sgemm_Col_NT_B0_MX064_NX064_KX16_microTileNumRows = 4;
17 const unsigned int sgemm_Col_NT_B0_MX064_NX064_KX16_microTileNumCols = 4;
18 const unsigned int sgemm_Col_NT_B0_MX064_NX064_KX16_unroll = 16;
19 
20 const char * const sgemm_Col_NT_B0_MX064_NX064_KX16_src = STRINGIFY(
21 
22 #define  M4x4 \
23             rA[0][0] = lA[offA + 0];				  \
24             rA[0][1] = lA[offA + 16];				  \
25             rA[0][2] = lA[offA + 32];				  \
26             rA[0][3] = lA[offA + 48];				  \
27             rB[0][0] = lB[offB + 0];				  \
28             rB[0][1] = lB[offB + 16];				  \
29             rB[0][2] = lB[offB + 32];				  \
30             rB[0][3] = lB[offB + 48];				  \
31             offA += 65;								  \
32             offB += 65;								  \
33             rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
34             rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
35             rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \
36             rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \
37             rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
38             rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
39             rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \
40             rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \
41             rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \
42             rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \
43             rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \
44             rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \
45             rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \
46             rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \
47             rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \
48             rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \
49 			      mem_fence(CLK_LOCAL_MEM_FENCE);\n
50 
51 __attribute__((reqd_work_group_size(16,16,1)))
52 
53 __kernel void sgemm_Col_NT_B0_MX064_NX064_KX16 (
54   __global float const * restrict A,
55   __global float const * restrict B,
56   __global float * C,
57   float const alpha,
58   float const beta,
59   uint const M,
60   uint const N,
61   uint const K,
62   uint lda,
63   uint ldb,
64   uint ldc,
65   uint offsetA,
66   uint offsetB,
67   uint offsetC)
68 {
69     float rC[4][4]  = { {(float)0} };
70     float rA[1][4];
71     float rB[1][4];
72 
73     A += offsetA;
74     B += offsetB;
75     C+=offsetC;
76 
77     __local float lA[1056];
78     __local float lB[1056];
79 
80     uint gidx = get_group_id(0);
81     uint gidy = get_group_id(1);
82     uint idx = get_local_id(0);
83     uint idy = get_local_id(1);
84 
85     uint idt = 16*idy + idx;
86     uint idxT = idt % 16;
87     uint idyT = idt / 16;
88 
89     A +=  gidx*64+ idxT + idyT*lda;
90     B +=  gidy*64+ idxT + idyT*ldb;
91 
92 
93     uint block_k = K >> 4;
94     do
95 	{
96         __local float* plA = lA + idyT*65+idxT;
97         __local float* plB = lB + idyT*65+idxT;
98         barrier(CLK_LOCAL_MEM_FENCE);
99         plB[0] = B[0+0*ldb];
100         plB[16] = B[16+0*ldb];
101         plB[32] = B[32+0*ldb];
102         plB[48] = B[48+0*ldb];
103 
104 	      plA[0] = A[0+0*lda];
105         plA[16] = A[16+0*lda];
106         plA[32] = A[32+0*lda];
107         plA[48] = A[48+0*lda];
108 
109 
110         barrier(CLK_LOCAL_MEM_FENCE);
111         uint offA = idx;
112         uint offB = idy;
113 
114 
115         M4x4
116 		M4x4
117 		M4x4
118 		M4x4
119 		M4x4
120 		M4x4
121 		M4x4
122 		M4x4
123 		M4x4
124 		M4x4
125 		M4x4
126 		M4x4
127 		M4x4
128 		M4x4
129 		M4x4
130 		M4x4
131 
132         A += lda<<4;
133         B += ldb<<4;
134 	} while (--block_k > 0);
135 
136     C+= gidx*64+idx;
137     C+= gidy*64*ldc;
138     C+= idy*ldc;
139 
140 	C[0*ldc] = alpha*rC[0][0] ;
141     C[16*ldc] = alpha*rC[0][1];
142     C[32*ldc] = alpha*rC[0][2];
143     C[48*ldc] = alpha*rC[0][3];
144 
145     C+=16;
146     C[0*ldc] = alpha*rC[1][0] ;
147     C[16*ldc] = alpha*rC[1][1];
148     C[32*ldc] = alpha*rC[1][2];
149     C[48*ldc] = alpha*rC[1][3];
150 
151     C+=16;
152     C[0*ldc] = alpha*rC[2][0] ;
153     C[16*ldc] = alpha*rC[2][1];
154     C[32*ldc] = alpha*rC[2][2];
155     C[48*ldc] = alpha*rC[2][3];
156 
157     C+=16;
158     C[0*ldc] = alpha*rC[3][0] ;
159     C[16*ldc] = alpha*rC[3][1];
160     C[32*ldc] = alpha*rC[3][2];
161     C[48*ldc] = alpha*rC[3][3];
162 
163 }
164 );
165 #endif
166