1 /******************************************************************************* 2 * Hand-tuned kernel 3 * 4 * B21 = -inv(A11)*A12*inv(A22) 5 * 6 ******************************************************************************/ 7 8 #ifndef KERNEL_TRIPLE_DGEMM_UPDATE_128_32_PART1_R_SRC_CPP 9 #define KERNEL_TRIPLE_DGEMM_UPDATE_128_32_PART1_R_SRC_CPP 10 #pragma message("#define KERNEL_TRIPLE_DGEMM_UPDATE_128_32_PART1_R_SRC_CPP.") 11 12 #ifndef STRINGIFY 13 #define STRINGIFY2(...) #__VA_ARGS__ 14 #define STRINGIFY(...) STRINGIFY2(__VA_ARGS__) 15 #endif 16 17 unsigned char *triple_dgemm_update_128_32_PART1_R_bin = 0; 18 size_t triple_dgemm_update_128_32_PART1_R_binSize = 0; 19 20 const char * const triple_dgemm_update_128_32_PART1_R_src = STRINGIFY( 21 static void daxpy(\n 22 double alpha, \n 23 __local const double * __restrict__ b, \n 24 double * __restrict__ c)\n 25 { \n 26 c[0] += alpha * b[0]; \n 27 c[1] += alpha * b[1]; \n 28 c[2] += alpha * b[2]; \n 29 c[3] += alpha * b[3]; \n 30 c[4] += alpha * b[4]; \n 31 c[5] += alpha * b[5]; \n 32 c[6] += alpha * b[6]; \n 33 c[7] += alpha * b[7]; \n 34 c[8] += alpha * b[8]; \n 35 c[9] += alpha * b[9]; \n 36 c[10] += alpha * b[10]; \n 37 c[11] += alpha * b[11]; \n 38 c[12] += alpha * b[12]; \n 39 c[13] += alpha * b[13]; \n 40 c[14] += alpha * b[14]; \n 41 c[15] += alpha * b[15]; \n 42 }\n 43 #define NB 128\n 44 #define __mul(i,j) ((i)*(j))\n 45 #define qmod(a, b) ((a)%(b))\n 46 __kernel void TRIPLE_DGEMM_UPDATE_128_32_PART1_R(__global const double *Ain, uint offAin, __global double *d_dinvA, int blk, uint lda, int npages, int na)\n 47 { \n 48 const int bIdy = get_group_id(1) / npages; \n 49 const int page = qmod(get_group_id(1), npages); \n 50 const int inx = get_local_id(0); \n 51 const int iny = get_local_id(1); \n 52 const int ibx = get_group_id(0) * (get_local_size(0)*get_local_size(1)); \n 53 const int iby = bIdy * 16; \n 54 const int id = inx + iny*get_local_size(0); \n 55 __local double bs[16][17]; \n 56 57 Ain = Ain + offAin; \n 58 59 int PagesPerNB = NB / (blk * 2); \n 60 //--------------------------part one---------------------------// 61 { 62 // A12*inv(A22) -> A21 63 // A=A12, B=inv(A22), C=A12(d_dinvA) 64 __global const double *A; \n 65 __global double *B, *C; \n 66 int ldb = NB; \n 67 int ldc = NB; \n 68 69 d_dinvA += NB*NB*(page / PagesPerNB) 70 + (qmod(page, PagesPerNB))*(blk * 2)*NB 71 + (qmod(page, PagesPerNB))*(blk * 2); \n 72 73 int xa = page*blk * 2 + ibx + id; \n 74 int ya = page*blk * 2 + blk; \n 75 int incA = ya * lda + xa; \n 76 77 // maxA will be used to detect overflow on all subsequent accesses on A(xa, ya:ya+???) 78 79 int maxA; \n 80 if (xa < na)\n 81 maxA = lda*na; \n // macro READA will detect overflow on y dimension 82 else\n 83 maxA = 0; \n // there is already an overflow on xa 84 85 #define READA ( (incA < maxA ) ? Ain[incA] : 0 ) \n 86 87 B = d_dinvA + blk*NB + blk; \n 88 C = d_dinvA + blk*NB; \n 89 90 B += inx + __mul(iby + iny, ldb); \n 91 C += ibx + id + __mul(iby, ldc); \n 92 93 __global double *Blast = B + blk; \n 94 95 double c[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; \n 96 97 do {\n 98 double a[4]; \n 99 a[0] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 100 a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 101 a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 102 a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 103 104 bs[inx][iny] = B[0 * ldb]; \n 105 bs[inx][iny + 4] = B[4 * ldb]; \n 106 bs[inx][iny + 8] = B[8 * ldb]; \n 107 bs[inx][iny + 12] = B[12 * ldb]; \n 108 bs[inx + 8][iny] = B[8 + 0 * ldb]; \n 109 bs[inx + 8][iny + 4] = B[8 + 4 * ldb]; \n 110 bs[inx + 8][iny + 8] = B[8 + 8 * ldb]; \n 111 bs[inx + 8][iny + 12] = B[8 + 12 * ldb]; \n 112 //__syncthreads(); 113 barrier(CLK_LOCAL_MEM_FENCE); \n 114 115 daxpy(a[0], &bs[0][0], c); a[0] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 116 daxpy(a[1], &bs[1][0], c); a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 117 daxpy(a[2], &bs[2][0], c); a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 118 daxpy(a[3], &bs[3][0], c); a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 119 \n 120 daxpy(a[0], &bs[4][0], c); a[0] = ( (incA < maxA ) ? Ain[incA] : 0 ) ; incA += lda; \n 121 daxpy(a[1], &bs[5][0], c); a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 122 daxpy(a[2], &bs[6][0], c); a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 123 daxpy(a[3], &bs[7][0], c); a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 124 \n 125 daxpy(a[0], &bs[8][0], c); a[0] = ( (incA < maxA ) ? Ain[incA] : 0 ) ; incA += lda; \n 126 daxpy(a[1], &bs[9][0], c); a[1] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 127 daxpy(a[2], &bs[10][0], c); a[2] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 128 daxpy(a[3], &bs[11][0], c); a[3] = ((incA < maxA) ? Ain[incA] : 0); incA += lda; \n 129 130 daxpy(a[0], &bs[12][0], c);\n 131 daxpy(a[1], &bs[13][0], c);\n 132 daxpy(a[2], &bs[14][0], c);\n 133 daxpy(a[3], &bs[15][0], c);\n 134 135 B += 16; \n 136 //__syncthreads(); 137 barrier(CLK_LOCAL_MEM_FENCE); \n 138 } while (B < Blast); \n 139 140 for (int i = 0; i < 16; i++) {\n 141 C[0] = c[i]; \n 142 C += ldc; \n 143 }\n 144 }\n 145 146 //__syncthreads(); 147 barrier(CLK_LOCAL_MEM_FENCE); \n 148 }\n 149 // end of kernel 150 ); 151 #endif 152