1 /*******************************************************************************
2  * Hand-tuned kernel
3  *
4  * B21 = -inv(A11)*A12*inv(A22)
5  *
6  ******************************************************************************/
7 
8 #ifndef KERNEL_TRIPLE_DGEMM_UPDATE_128_32_PART1_R_SRC_CPP
9 #define KERNEL_TRIPLE_DGEMM_UPDATE_128_32_PART1_R_SRC_CPP
10 #pragma message("#define KERNEL_TRIPLE_DGEMM_UPDATE_128_32_PART1_R_SRC_CPP.")
11 
12 #ifndef STRINGIFY
13 #define STRINGIFY2(...) #__VA_ARGS__
14 #define STRINGIFY(...) STRINGIFY2(__VA_ARGS__)
15 #endif
16 
17 unsigned char *triple_dgemm_update_128_32_PART1_R_bin = 0;
18 size_t triple_dgemm_update_128_32_PART1_R_binSize = 0;
19 
20 const char * const triple_dgemm_update_128_32_PART1_R_src = STRINGIFY(
21 static void daxpy(\n
22 double alpha, \n
23 __local const double * __restrict__ b, \n
24 double * __restrict__ c)\n
25 { \n
26 c[0] += alpha * b[0]; \n
27 c[1] += alpha * b[1]; \n
28 c[2] += alpha * b[2]; \n
29 c[3] += alpha * b[3]; \n
30 c[4] += alpha * b[4]; \n
31 c[5] += alpha * b[5]; \n
32 c[6] += alpha * b[6]; \n
33 c[7] += alpha * b[7]; \n
34 c[8] += alpha * b[8]; \n
35 c[9] += alpha * b[9]; \n
36 c[10] += alpha * b[10]; \n
37 c[11] += alpha * b[11]; \n
38 c[12] += alpha * b[12]; \n
39 c[13] += alpha * b[13]; \n
40 c[14] += alpha * b[14]; \n
41 c[15] += alpha * b[15]; \n
42 }\n
43 #define NB 128\n
44 #define __mul(i,j) ((i)*(j))\n
45 #define qmod(a, b) ((a)%(b))\n
46 	__kernel void TRIPLE_DGEMM_UPDATE_128_32_PART1_R(__global const double *Ain, uint offAin, __global double *d_dinvA, int blk, uint lda, int npages, int na)\n
47 { \n
48 const int bIdy = get_group_id(1) / npages; \n
49 const int page = qmod(get_group_id(1), npages); \n
50 const int inx = get_local_id(0); \n
51 const int iny = get_local_id(1); \n
52 const int ibx = get_group_id(0) * (get_local_size(0)*get_local_size(1)); \n
53 const int iby = bIdy * 16; \n
54 const int id = inx + iny*get_local_size(0); \n
55 __local double bs[16][17]; \n
56 
57 Ain = Ain + offAin; \n
58 
59 int PagesPerNB = NB / (blk * 2); \n
60 //--------------------------part one---------------------------//
61 {
62 	// A12*inv(A22) -> A21
63 	// A=A12, B=inv(A22), C=A12(d_dinvA)
64 	__global const double *A; \n
65 	__global double *B, *C; \n
66 	int ldb = NB; \n
67 	int ldc = NB; \n
68 
69 	d_dinvA += NB*NB*(page / PagesPerNB)
70 		+ (qmod(page, PagesPerNB))*(blk * 2)*NB
71 		+ (qmod(page, PagesPerNB))*(blk * 2); \n
72 
73 	int xa = page*blk * 2 + ibx + id; \n
74 	int ya = page*blk * 2 + blk; \n
75 	int incA = ya * lda + xa; \n
76 
77 	// maxA will be used to detect overflow on all subsequent accesses on A(xa, ya:ya+???)
78 
79 	int maxA; \n
80 	if (xa < na)\n
81 		maxA = lda*na; \n  // macro READA will detect overflow on y dimension
82 	else\n
83 	    maxA = 0; \n // there is already an overflow on xa
84 
85 #define READA ( (incA < maxA ) ? Ain[incA] : 0 )  \n
86 
87 	B = d_dinvA + blk*NB + blk; \n
88 	C = d_dinvA + blk*NB; \n
89 
90 	B += inx + __mul(iby + iny, ldb); \n
91 	C += ibx + id + __mul(iby, ldc); \n
92 
93 	__global double *Blast = B + blk; \n
94 
95 	double c[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; \n
96 
97 	do {\n
98 		double a[4]; \n
99 		a[0] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
100 		a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
101 		a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
102 		a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
103 
104 		bs[inx][iny] = B[0 * ldb]; \n
105 		bs[inx][iny + 4] = B[4 * ldb]; \n
106 		bs[inx][iny + 8] = B[8 * ldb]; \n
107 		bs[inx][iny + 12] = B[12 * ldb]; \n
108 		bs[inx + 8][iny] = B[8 + 0 * ldb]; \n
109 		bs[inx + 8][iny + 4] = B[8 + 4 * ldb]; \n
110 		bs[inx + 8][iny + 8] = B[8 + 8 * ldb]; \n
111 		bs[inx + 8][iny + 12] = B[8 + 12 * ldb]; \n
112 		//__syncthreads();
113 		barrier(CLK_LOCAL_MEM_FENCE); \n
114 
115 		daxpy(a[0], &bs[0][0], c);  a[0] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
116 		daxpy(a[1], &bs[1][0], c);  a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
117 		daxpy(a[2], &bs[2][0], c);  a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
118 		daxpy(a[3], &bs[3][0], c);  a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
119 															   \n
120 		daxpy(a[0], &bs[4][0], c);  a[0] = ( (incA < maxA ) ? Ain[incA] : 0 ) ; incA += lda; \n
121 		daxpy(a[1], &bs[5][0], c);  a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
122 		daxpy(a[2], &bs[6][0], c);  a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
123 		daxpy(a[3], &bs[7][0], c);  a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
124 															   \n
125 		daxpy(a[0], &bs[8][0], c);  a[0] = ( (incA < maxA ) ? Ain[incA] : 0 ) ; incA += lda; \n
126 		daxpy(a[1], &bs[9][0], c);  a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
127 		daxpy(a[2], &bs[10][0], c);  a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
128 		daxpy(a[3], &bs[11][0], c);  a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n
129 
130 		daxpy(a[0], &bs[12][0], c);\n
131 		daxpy(a[1], &bs[13][0], c);\n
132 		daxpy(a[2], &bs[14][0], c);\n
133 		daxpy(a[3], &bs[15][0], c);\n
134 
135 		B += 16; \n
136 		//__syncthreads();
137 		barrier(CLK_LOCAL_MEM_FENCE); \n
138 	} while (B < Blast); \n
139 
140 	for (int i = 0; i < 16; i++) {\n
141 		C[0] = c[i]; \n
142 		C += ldc; \n
143 	}\n
144 }\n
145 
146 //__syncthreads();
147 barrier(CLK_LOCAL_MEM_FENCE); \n
148 }\n
149 // end of kernel
150 );
151 #endif
152