1 /*******************************************************************************
2  * Hand-tuned kernel
3  ******************************************************************************/
4 
5 #ifndef KERNEL_SGEMM_COL_NT_B1_MX064_NX064_KX16_SRC_H
6 #define KERNEL_SGEMM_COL_NT_B1_MX064_NX064_KX16_SRC_H
7 #pragma message("AutoGemm's sgemm_Col_NT_B1_MX064_NX064_KX16_src overriden by user.")
8 
9 #ifndef STRINGIFY
10 #define STRINGIFY(S) STRINGIFY2(S)
11 #define STRINGIFY2(S) #S
12 #endif
13 
14 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_workGroupNumRows = 16;
15 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_workGroupNumCols = 16;
16 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_microTileNumRows = 4;
17 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_microTileNumCols = 4;
18 const unsigned int sgemm_Col_NT_B1_MX064_NX064_KX16_unroll = 16;
19 
20 const char * const sgemm_Col_NT_B1_MX064_NX064_KX16_src = STRINGIFY(
21 
22 #define  M4x4 \
23             rA[0][0] = lA[offA + 0];				  \
24             rA[0][1] = lA[offA + 16];				  \
25             rA[0][2] = lA[offA + 32];				  \
26             rA[0][3] = lA[offA + 48];				  \
27             rB[0][0] = lB[offB + 0];				  \
28             rB[0][1] = lB[offB + 16];				  \
29             rB[0][2] = lB[offB + 32];				  \
30             rB[0][3] = lB[offB + 48];				  \
31             offA += 65;								  \
32             offB += 65;								  \
33             rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
34             rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
35             rC[2][0]=mad(rA[0][2],rB[0][0],rC[2][0]); \
36             rC[3][0]=mad(rA[0][3],rB[0][0],rC[3][0]); \
37             rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
38             rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
39             rC[2][1]=mad(rA[0][2],rB[0][1],rC[2][1]); \
40             rC[3][1]=mad(rA[0][3],rB[0][1],rC[3][1]); \
41             rC[0][2]=mad(rA[0][0],rB[0][2],rC[0][2]); \
42             rC[1][2]=mad(rA[0][1],rB[0][2],rC[1][2]); \
43             rC[2][2]=mad(rA[0][2],rB[0][2],rC[2][2]); \
44             rC[3][2]=mad(rA[0][3],rB[0][2],rC[3][2]); \
45             rC[0][3]=mad(rA[0][0],rB[0][3],rC[0][3]); \
46             rC[1][3]=mad(rA[0][1],rB[0][3],rC[1][3]); \
47             rC[2][3]=mad(rA[0][2],rB[0][3],rC[2][3]); \
48             rC[3][3]=mad(rA[0][3],rB[0][3],rC[3][3]); \
49 			      mem_fence(CLK_LOCAL_MEM_FENCE);\n
50 
51 __attribute__((reqd_work_group_size(16,16,1)))
52 __kernel void sgemm_Col_NT_B1_MX064_NX064_KX16 (
53   __global float const * restrict A,
54   __global float const * restrict B,
55   __global float * C,
56   float const alpha,
57   float const beta,
58   uint const M,
59   uint const N,
60   uint const K,
61   uint lda,
62   uint ldb,
63   uint ldc,
64   uint offsetA,
65   uint offsetB,
66   uint offsetC)
67 {
68     float rC[4][4]  = { {(float)0} };
69     float rA[1][4];
70     float rB[1][4];
71 
72     A += offsetA;
73     B += offsetB;
74     C+=offsetC;
75 
76     __local float lA[1040];
77 	__local float lB[1040];
78 
79     uint gidx = get_group_id(0);
80     uint gidy = get_group_id(1);
81     uint idx = get_local_id(0);
82     uint idy = get_local_id(1);
83 
84     uint idt = 16*idy + idx;
85     uint idxT = idt % 16;
86     uint idyT = idt / 16;
87 
88     A +=  gidx*64+ idxT + idyT*lda;
89     B +=  gidy*64+ idxT + idyT*ldb;
90 
91 
92     uint block_k = K >> 4;
93     do
94 	{
95         __local float* plA = lA + idyT*65+idxT;
96         __local float* plB = lB + idyT*65+idxT;
97         barrier(CLK_LOCAL_MEM_FENCE);
98         plB[0] = B[0+0*ldb];
99         plB[16] = B[16+0*ldb];
100         plB[32] = B[32+0*ldb];
101         plB[48] = B[48+0*ldb];
102 
103 	    plA[0] = A[0+0*lda];
104         plA[16] = A[16+0*lda];
105         plA[32] = A[32+0*lda];
106         plA[48] = A[48+0*lda];
107 
108 
109         barrier(CLK_LOCAL_MEM_FENCE);
110         uint offA = idx;
111         uint offB = idy;
112 
113     M4x4
114 		M4x4
115 		M4x4
116 		M4x4
117 		M4x4
118 		M4x4
119 		M4x4
120 		M4x4
121 		M4x4
122 		M4x4
123 		M4x4
124 		M4x4
125 		M4x4
126 		M4x4
127 		M4x4
128 		M4x4
129 
130         A += lda<<4;
131         B += ldb<<4;
132 	} while (--block_k > 0);
133 
134     C+= gidx*64+idx;
135     C+= gidy*64*ldc;
136     C+= idy*ldc;
137 
138 	C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc];
139     C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc];
140     C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc];
141     C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc];
142     C+=16;
143     C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc];
144     C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc];
145     C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc];
146     C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc];
147     C+=16;
148     C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc];
149     C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc];
150     C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc];
151     C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc];
152     C+=16;
153     C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc];
154     C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc];
155     C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc];
156     C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc];
157 
158 }
159 );
160 #endif
161