1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 /*
19 * TRSM generator with support of cached reads from the global memory
20 */
21
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <matrix_props.h>
31 #include <matrix_dims.h>
32
33 #include "../blas_kgen.h"
34 #include "../trxm_common.h"
35 #include "trsm_kgen_legacy.h"
36 #include "gen_helper_legacy.h"
37 #include "../trsm_kgen.h"
38
39 static const char *readSquareBlock =
40 "y = (currM + %lu <= M) ? %lu : M - currM;\n"
41 "x = (k0 + %lu <= M) ? %lu : M - k0;\n"
42 "if ((y == %lu) && (x == %lu)) {\n"
43 // just read with an optimized function
44 " %s((LPtr)temp%c, (GPtr)A, currM, k0, lda);\n"
45 "}\n"
46 "else {\n"
47 " %s((__local float4*)temp%c);\n" // zeroing
48 " barrier(CLK_LOCAL_MEM_FENCE);\n"
49 " %s((LPtr)temp%c, (GPtr)A, currM, k0, y, x, %lu, lda);\n"
50 "}\n\n";
51
52 static const char *readSquareBlockOpt =
53 // just read with an optimized function
54 "%s((LPtr)temp%c, (GPtr)A, currM, k0, lda);\n";
55
56 static const char *readSquareBlockTrans =
57 "y = (currM + %lu <= M) ? %lu : M - currM;\n"
58 "x = (k0 + %lu <= M) ? %lu : M - k0;\n"
59 "if ((y == %lu) && (x == %lu)) {\n"
60 // read and transpose with an optimized function
61 " %s((LPtr)temp%c, (GPtr)A, k0, currM, lda);\n"
62 "}\n"
63 "else {\n"
64 " %s((__local float4*)temp%c);\n" // zeroing
65 " barrier(CLK_LOCAL_MEM_FENCE);\n"
66 // read and transpose with slow function
67 " %s((LPtr)temp%c, (GPtr)A, k0, currM, x, y, %lu, lda);\n"
68 "}\n\n";
69
70 static const char *readSquareBlockTransOpt =
71 // read and transpose with an optimized function
72 "%s((LPtr)temp%c, (GPtr)A, k0, currM, lda);\n";
73
74 static CLBLASMpatExtra mpatExtra;
75
76 static ssize_t
77 generator(
78 char *buf,
79 size_t buflen,
80 const struct SubproblemDim *subdims,
81 const struct PGranularity *pgran,
82 void *extra);
83
84 static bool
85 isFitToLDS(
86 SubproblemDim *dim,
87 DataType dtype,
88 cl_ulong ldsSize,
89 const void *kernelArgs);
90
91 static SolverFlags
92 solverFlags(void);
93
94 static void
95 assignKargs(KernelArg *args, const void *params, const void *extra);
96
97 static void
98 fixupArgs(void *args, SubproblemDim *subdims, void *extra);
99
100 static SolverOps trsmSops = {
101 generator,
102 assignKargs,
103 isFitToLDS,
104 NULL,
105 NULL,
106 NULL,
107 NULL,
108 solverFlags,
109 fixupArgs,
110 NULL, //getDefaultDecomp
111 NULL, // getDecompList
112 NULL,
113 NULL
114 };
115
116 static TileMulFlags
getCyclicFlags(const SubproblemDim * dim,KernelExtraFlags kflags,bool tailPass,unsigned int vecLen)117 getCyclicFlags(
118 const SubproblemDim *dim,
119 KernelExtraFlags kflags,
120 bool tailPass,
121 unsigned int vecLen)
122 {
123 TileMulFlags mflags = TILEMUL_NO_FLAGS;
124
125 if (tailPass && !isMatrixUpper(kflags)) {
126 mflags |= TILEMUL_GLOBAL_CYCLIC_A;
127 }
128
129 if (isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B) &&
130 (kflags & KEXTRA_TAILS_N) && (dim->x > vecLen)) {
131
132 mflags |= TILEMUL_GLOBAL_CYCLIC_B;
133 }
134
135 return mflags;
136 }
137
138 static void
initTiles(BlasGenSettings * gset)139 initTiles(BlasGenSettings *gset)
140 {
141 unsigned int nrRows, nrCols;
142 unsigned int vecLen;
143 const SubproblemDim *dim = &gset->subdims[1];
144 const CLBLASKernExtra *kextra = gset->kextra;
145 DataType dtype = kextra->dtype;
146 bool tra;
147
148 // the tile A should be able to fit rectangular and square tiles
149 nrCols = (unsigned int)szmax(dim->y, dim->bwidth);
150 tra = isMatrixAccessColMaj(CLBLAS_TRSM, kextra->flags, MATRIX_A);
151 vecLen = getVecLen(gset, CLBLAS_TRSM, MATRIX_A);
152 initTile(&gset->tileA, "a", (unsigned int)dim->y, nrCols, vecLen,
153 dtype, PRIV_STORAGE_ARRAY, tra, false);
154
155 /*
156 * tile B should be able to fit tiles of the matrix B and of the
157 * intermediate result. That result will be always transposed
158 * from the point of view of tile multiplication
159 */
160 tra = !isMatrixAccessColMaj(CLBLAS_TRSM, kextra->flags, MATRIX_B);
161 if (tra) {
162 nrRows = (unsigned int)szmax(dim->bwidth, dim->y);
163 nrCols = (unsigned int)dim->x;
164 }
165 else {
166 nrRows = (unsigned int)szmax(dim->bwidth, dim->x);
167 nrCols = (unsigned int)szmax(dim->x, dim->y);
168 }
169 vecLen = getVecLen(gset, CLBLAS_TRSM, MATRIX_B);
170 initTile(&gset->tileBX, "b", nrRows, nrCols, vecLen, dtype,
171 PRIV_STORAGE_ARRAY, tra, false);
172
173 initTile(&gset->tileCY, "c", (unsigned int)dim->y, (unsigned int)dim->x,
174 vecLen, dtype, PRIV_STORAGE_ARRAY, false, false);
175 }
176
177 static void
prepareTilesForMainLoop(BlasGenSettings * gset)178 prepareTilesForMainLoop(BlasGenSettings *gset)
179 {
180 const SubproblemDim *dim = &gset->subdims[1];
181
182 gset->tileA.nrCols = (unsigned int)dim->bwidth;
183 gset->tileBX.nrRows = (unsigned int)dim->bwidth;
184 gset->tileBX.nrCols = (unsigned int)dim->x;
185 }
186
187 static void
declareLocalVariables(struct KgenContext * ctx,const BlasGenSettings * gset)188 declareLocalVariables(
189 struct KgenContext *ctx,
190 const BlasGenSettings *gset)
191 {
192 char tmp[1024];
193 const char *elemType;
194 const SubproblemDim *dims = gset->subdims;
195 DataType dtype = gset->kextra->dtype;
196 size_t pitchAC, heightC;
197
198 elemType = dtypeBuiltinType(dtype);
199 pitchAC = matrBlockPitch(dims, MATRIX_C, dtype, clblasRight);
200 heightC = szmax(dims[0].y, dims[0].x);
201
202 declareTileStorages(ctx, gset);
203 sprintf(tmp, "const int lid = get_local_id(0);\n"
204 "const int gid = get_group_id(0);\n"
205 "const uint2 skewRow = 0, skewCol = 0;\n\n"
206 "GPtr uA, uB;\n"
207 "uint coordA, coordB, k;\n"
208 "uint x, y;\n"
209 "__local %s tempA[%lu], tempC[%lu];\n"
210 "LPtr utmpA, utmpC;\n"
211 "uint m0 = 0, k0, currM, currN;\n",
212 elemType, pitchAC * dims[0].y, pitchAC * heightC);
213 kgenAddStmt(ctx, tmp);
214 }
215
216 static void
genReadDiagBlock(struct KgenContext * ctx,const SubproblemDim * dim,DataType dtype,const CopyBufFuncs * copyFuncs,const ZeroFuncs * zeroFuncs,KernelExtraFlags kflags,char c)217 genReadDiagBlock(
218 struct KgenContext *ctx,
219 const SubproblemDim *dim,
220 DataType dtype,
221 const CopyBufFuncs *copyFuncs,
222 const ZeroFuncs *zeroFuncs,
223 KernelExtraFlags kflags,
224 char c)
225 {
226 char tmp[1024];
227 size_t pitch;
228 const char *readBlock;
229 bool tra;
230
231 tra = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A);
232 pitch = matrBlockPitch(dim, MATRIX_A, dtype, clblasLeft);
233
234 if (!(kflags & KEXTRA_TAILS_M)) {
235 readBlock = (tra) ? readSquareBlockTransOpt : readSquareBlockOpt;
236 sprintf(tmp, readBlock, copyFuncs->read[MATRIX_A], c);
237 }
238 else {
239 readBlock = (tra) ? readSquareBlockTrans : readSquareBlock;
240 sprintf(tmp, readBlock, dim->y, dim->y, dim->bwidth, dim->bwidth,
241 dim->y, dim->bwidth, copyFuncs->read[MATRIX_A], c,
242 zeroFuncs->names[MATRIX_A], c,
243 copyFuncs->readGeneric[MATRIX_A], c, pitch);
244 }
245 kgenAddStmt(ctx, tmp);
246 }
247
248 static void
genZeroResult(struct KgenContext * ctx,DataType dtype,const SubproblemDim * dims,unsigned int vecLen)249 genZeroResult(
250 struct KgenContext *ctx,
251 DataType dtype,
252 const SubproblemDim *dims,
253 unsigned int vecLen)
254 {
255 unsigned int n;
256 char tmp[1024];
257
258 getResultGPRsInfo(dtype, &dims[1], vecLen, &n, NULL);
259
260 sprintf(tmp, "for (x = 0; x < %u; x++) {\n"
261 " c[x] = 0;\n"
262 "}\n\n", n);
263
264 kgenAddStmt(ctx, tmp);
265 }
266
267 static void
genInternalLoopCtl(struct KgenContext * ctx,const SubproblemDim * dim,KernelExtraFlags kflags)268 genInternalLoopCtl(
269 struct KgenContext *ctx,
270 const SubproblemDim *dim,
271 KernelExtraFlags kflags)
272 {
273 char tmp[1024];
274
275 if (isMatrixUpper(kflags)) {
276 if (kflags & KEXTRA_TAILS_M) {
277 sprintf(tmp, "for (k0 = currM + %lu; k0 < M / %lu * %lu; "
278 "k0 += %lu)",
279 dim[0].bwidth, dim[1].bwidth, dim[1].bwidth, dim[1].bwidth);
280 }
281 else {
282 sprintf(tmp, "for (k0 = currM + %lu; k0 < M; k0 += %lu)",
283 dim[0].bwidth, dim[1].bwidth);
284 }
285 }
286 else {
287 sprintf(tmp, "for (k0 = 0; k0 < currM; k0 += %lu)",
288 dim[1].bwidth);
289 }
290
291 kgenBeginBranch(ctx, tmp);
292 }
293
294 static void
genInitCurrM(struct KgenContext * ctx,const SubproblemDim * dim,KernelExtraFlags kflags)295 genInitCurrM(
296 struct KgenContext *ctx,
297 const SubproblemDim *dim,
298 KernelExtraFlags kflags)
299 {
300 char tmp[1024];
301
302 if (isMatrixUpper(kflags)) {
303 /* start from the last block */
304 sprintf(tmp, "currM = ((M - 1) / %lu) * %lu;\n", dim->y, dim->y);
305 kgenAddStmt(ctx, tmp);
306 }
307 else {
308 kgenAddStmt(ctx, "currM = 0;\n");
309 }
310 }
311
312 static void
initKernelVarNames(KernelVarNames * kvars)313 initKernelVarNames(KernelVarNames *kvars)
314 {
315 kvars->A = "uA";
316 kvars->B = "uB";
317 kvars->coordA = "coordA";
318 kvars->coordB = "coordB";
319 kvars->k = "k";
320 kvars->sizeM = "M";
321 kvars->sizeN = "N";
322 kvars->sizeK = "M";
323 kvars->lda = "lda";
324 kvars->ldb = "ldb";
325 }
326
327 /*
328 * Generate a code copying tile between LDS and private location.
329 */
330 static void
genLdsCopy(struct KgenContext * ctx,const BlasGenSettings * gset)331 genLdsCopy(
332 struct KgenContext *ctx,
333 const BlasGenSettings *gset)
334 {
335 char pitchStr[16];
336 char coordY[128], coordX[128];
337 size_t pitch;
338 UpresVarNames uvars;
339 UpdateResultFlags upFlags = UPRES_INLINE | UPRES_USE_LDS |
340 UPRES_WITHOUT_ALPHA | UPRES_COLUMN_MAJOR;
341 const SubproblemDim *dims = gset->subdims;
342 unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
343
344 memset(&uvars, 0, sizeof(uvars));
345
346 pitch = matrBlockPitch(dims, MATRIX_C, gset->kextra->dtype, clblasRight);
347 sprintf(pitchStr, "%lu", pitch);
348 sprintf(coordY, "lid / %u * %lu", l1Pans, dims[1].y);
349 sprintf(coordX, "lid %% %u * %lu", l1Pans, dims[1].x);
350 uvars.result = "tempC";
351 uvars.ld = pitchStr;
352 uvars.startRow = coordY;
353 uvars.startCol = coordX;
354 uvars.nrRows = NULL;
355 uvars.nrCols = NULL;
356
357 kgenBeginBranch(ctx, NULL);
358
359 updateResultGen(ctx,
360 gset,
361 CLBLAS_TRSM,
362 UPRES_SET,
363 upFlags,
364 &uvars);
365
366 kgenEndBranch(ctx, NULL);
367
368 kgenAddBlankLine(ctx);
369 }
370
371 static void
genZeroResultTrash(struct KgenContext * ctx,const SubproblemDim * dim,const CLBLASKernExtra * kextra)372 genZeroResultTrash(
373 struct KgenContext *ctx,
374 const SubproblemDim *dim,
375 const CLBLASKernExtra *kextra)
376 {
377 char tmp[1024];
378 unsigned int vecLen, pitch;
379 unsigned int i;
380
381 vecLen = (isComplexType(kextra->dtype)) ? 1 : kextra->vecLen;
382 pitch = (unsigned int)roundUp(dim->x, vecLen);
383 sprintf(tmp, "if (coordA + %lu > M)", dim->y);
384 kgenBeginBranch(ctx, tmp);
385 sprintf(tmp, "int i = (coordA >= M) ? %lu : (%lu - M %% %lu);\n\n",
386 dim->y, dim->y, dim->y);
387 kgenAddStmt(ctx, tmp);
388 sprintf(tmp, "for (; i > 0; i--)");
389 kgenBeginBranch(ctx, tmp);
390
391 for (i = 0; i < pitch / vecLen; i++) {
392 sprintf(tmp, "c[(%lu - i) * %u + %u] = 0;\n",
393 dim->y, pitch / vecLen, i);
394 kgenAddStmt(ctx, tmp);
395 }
396
397 kgenEndBranch(ctx, NULL);
398 kgenEndBranch(ctx, NULL);
399 }
400
401 static void
setupVdepUpresFlags(KernelExtraFlags kflags,UpdateResultFlags * upFlags)402 setupVdepUpresFlags(KernelExtraFlags kflags, UpdateResultFlags* upFlags)
403 {
404 bool forceBug = false;
405
406 unsigned int bugFlag1 = KEXTRA_NO_COPY_VEC_A
407 | KEXTRA_TAILS_K
408 | KEXTRA_TAILS_M;
409 unsigned int bugFlag2 = bugFlag1
410 | KEXTRA_UPPER_TRIANG
411 | KEXTRA_TRANS_A;
412 unsigned int bugFlag3 = bugFlag1
413 | KEXTRA_SIDE_RIGHT
414 | KEXTRA_COLUMN_MAJOR;
415 unsigned int bugFlag4 = bugFlag3
416 | KEXTRA_TRANS_A;
417 unsigned int bugFlag5 = bugFlag3
418 | KEXTRA_UPPER_TRIANG;
419 unsigned int bugFlag6 = KEXTRA_NO_COPY_VEC_A
420 | KEXTRA_NO_COPY_VEC_B
421 | KEXTRA_NO_COPY_VEC_C
422 | KEXTRA_TAILS_K
423 | KEXTRA_TAILS_M;
424 unsigned int bugFlag7 = bugFlag6
425 | KEXTRA_COLUMN_MAJOR;
426 unsigned int bugFlag8 = bugFlag6
427 | KEXTRA_SIDE_RIGHT
428 | KEXTRA_UPPER_TRIANG;
429 unsigned int bugFlag9 = bugFlag6
430 | KEXTRA_UPPER_TRIANG
431 | KEXTRA_TRANS_A
432 | KEXTRA_TAILS_N;
433 unsigned int bugFlag10 = bugFlag7
434 | KEXTRA_SIDE_RIGHT
435 | KEXTRA_TRANS_A
436 | KEXTRA_TAILS_N;
437 unsigned int bugFlag11 = bugFlag9
438 | KEXTRA_UNIT_DIAGONAL;
439 unsigned int bugFlag12 = bugFlag6
440 | KEXTRA_TAILS_N
441 | KEXTRA_SIDE_RIGHT
442 | KEXTRA_UNIT_DIAGONAL
443 | KEXTRA_COLUMN_MAJOR
444 | KEXTRA_TRANS_A;
445
446 /*
447 * WORKAROUND for AMD GPU: Now, we avoid optimizing the case when
448 * matrix B is not divided on block size and
449 * since it leads to a hang up at code seeming
450 * correct.
451 */
452 if (kflags & KEXTRA_VENDOR_AMD) {
453 forceBug = (kflags & KEXTRA_TAILS_N) != 0;
454 }
455 else {
456 forceBug = (kflags != bugFlag1
457 && kflags != bugFlag2 && kflags != bugFlag4 && kflags != bugFlag5
458 && kflags != bugFlag7 && kflags != bugFlag8 && kflags != bugFlag9
459 && kflags != bugFlag10 && kflags != bugFlag11
460 && kflags != bugFlag12);
461 }
462
463 if (!forceBug) {
464 *upFlags |= UPRES_INDEXING_WITH_CONSTANTS;
465 }
466 }
467
468 static void
genSetupCoordinates(struct KgenContext * ctx,const SubproblemDim * dims,KernelExtraFlags kflags)469 genSetupCoordinates(
470 struct KgenContext *ctx,
471 const SubproblemDim *dims,
472 KernelExtraFlags kflags)
473 {
474 char tmp[1024];
475 unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
476
477 sprintf(tmp, "coordA = currM + lid / %u * %lu;\n", l1Pans, dims[1].y);
478 kgenAddStmt(ctx, tmp);
479 if (isMatrixUpper(kflags)) {
480 sprintf(tmp, "k = currM + %lu;\n", dims[0].y);
481 }
482 else {
483 strcpy(tmp, "k = 0;\n");
484 }
485 kgenAddStmt(ctx, tmp);
486 }
487
488 static void
genInvertDiagBlock(struct KgenContext * ctx,const BlasGenSettings * gset,const ZeroFuncs * zeroFuncs)489 genInvertDiagBlock(
490 struct KgenContext *ctx,
491 const BlasGenSettings *gset,
492 const ZeroFuncs *zeroFuncs)
493 {
494 char tmp[1024];
495 const CLBLASKernExtra *kextra = gset->kextra;
496 const SubproblemDim *subdims = gset->subdims;
497 size_t pitchA;
498
499 pitchA = matrBlockPitch(subdims, MATRIX_A, kextra->dtype, clblasLeft);
500
501 sprintf(tmp, "%s((__local float4*)tempA);\n", zeroFuncs->names[MATRIX_A]);
502 kgenAddStmt(ctx, tmp);
503 kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
504
505 if (kextra->flags & KEXTRA_UNIT_DIAGONAL) {
506 sprintf(tmp, "if (lid < %lu) {\n"
507 " tempC[lid * %lu + lid] = %s;\n"
508 "}\n",
509 subdims[0].bwidth, pitchA, strOne(kextra->dtype));
510 kgenAddStmt(ctx, tmp);
511 kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
512 kgenAddBlankLine(ctx);
513 }
514
515 sprintf(tmp, "if (lid < %lu)", subdims[0].y);
516 kgenBeginBranch(ctx, tmp);
517 sprintf(tmp, "invert(tempC, tempA, lid, (currM + %lu > M) ? "
518 "M - currM : %lu);\n",
519 subdims[0].y, subdims[0].y);
520 kgenAddStmt(ctx, tmp);
521 kgenEndBranch(ctx, NULL);
522 kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
523 kgenAddBlankLine(ctx);
524 }
525
526 static void
genMulOnDiagBlock(struct KgenContext * ctx,BlasGenSettings * gset,const TileMulOpts * mulOpts)527 genMulOnDiagBlock(
528 struct KgenContext *ctx,
529 BlasGenSettings *gset,
530 const TileMulOpts *mulOpts)
531 {
532 char tmp[1024];
533 const SubproblemDim *dims = gset->subdims;
534 const CLBLASKernExtra *kextra = gset->kextra;
535 unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
536 TileMulOpts optsNew;
537 size_t pitchAC;
538 const char *ptrName;
539 Tile *tile;
540 BlasGenSettings gsetNew;
541
542 pitchAC = matrBlockPitch(dims, MATRIX_C, kextra->dtype, clblasRight);
543 ptrName = dtypeUPtrField(kextra->dtype);
544
545 memcpy(&optsNew, mulOpts, sizeof(optsNew));
546 optsNew.memA = CLMEM_LOCAL_MEMORY;
547 optsNew.memB = CLMEM_LOCAL_MEMORY;
548 optsNew.flags &= ~(TILEMUL_TRA | TILEMUL_GLOBAL_CYCLIC | TILEMUL_CONJA);
549 optsNew.flags |= TILEMUL_TRB;
550 optsNew.memA = CLMEM_LOCAL_MEMORY;
551 optsNew.memB = CLMEM_LOCAL_MEMORY;
552 gset->varNames.A = "utmpA";
553 gset->varNames.B = "utmpC";
554
555 sprintf(tmp, "utmpA.%s = tempA + lid / %u * %lu;\n"
556 "utmpC.%s = tempC + lid %% %u * %lu;\n\n",
557 ptrName, l1Pans, pitchAC * dims[1].y,
558 ptrName, l1Pans, pitchAC * dims[1].x);
559 kgenAddStmt(ctx, tmp);
560
561 memcpy(&gsetNew, gset, sizeof(gsetNew));
562 gsetNew.subdims[1].bwidth = dims[1].y;
563
564 // Configure the tile descriptors to deal with tile of needed sizes.
565 tile = &gsetNew.tileA;
566 tile->nrRows = (unsigned int)dims[1].y;
567 tile->nrCols = (unsigned int)dims[1].y;
568 tile->trans = false;
569 tile = &gsetNew.tileBX;
570 tile->nrRows = (unsigned int)dims[1].y;
571 tile->nrCols = (unsigned int)dims[1].x;
572 tile->trans = true;
573 tileMulGen(ctx, &gsetNew, &optsNew);
574
575 gset->varNames.A = "uA";
576 gset->varNames.B = "uB";
577 }
578
579 static void
genOneTrsmPass(struct KgenContext * ctx,BlasGenSettings * gset,const char * updateResFnRev,const char * updateResGenericFnRev,CopyBufFuncs * copyFuncs,ZeroFuncs * zeroFuncs,bool isTailPass)580 genOneTrsmPass(
581 struct KgenContext *ctx,
582 BlasGenSettings *gset,
583 const char *updateResFnRev,
584 const char *updateResGenericFnRev,
585 CopyBufFuncs *copyFuncs,
586 ZeroFuncs *zeroFuncs,
587 bool isTailPass)
588 {
589 const CLBLASKernExtra *kextra = gset->kextra;
590 CLBLASKernExtra kextraTmp;
591 KernelExtraFlags kflags = kextra->flags;
592 char tmp[1024];
593 DataType dtype = kextra->dtype;
594 unsigned int vecLen = gset->kextra->vecLen;
595 SubproblemDim *subdims = gset->subdims;
596 int tra, trb;
597 UpdateResultFlags upFlags;
598 TilePostFetchPrivate pfpriv;
599 TileMulOpts mulOpts;
600 TailFetch tf;
601 TailStatus tailStatus = 0;
602
603 memset(&pfpriv, 0, sizeof(pfpriv));
604
605 // multiply options
606 mulOpts.memA = CLMEM_GLOBAL_MEMORY;
607 mulOpts.memB = CLMEM_GLOBAL_MEMORY;
608 mulOpts.core = TILEMUL_MAD;//TILEMUL_MULADD;
609 mulOpts.postFetch = NULL;
610 mulOpts.flags = kextraToTilemulFlags(CLBLAS_TRSM, kflags);
611 mulOpts.flags |= TILEMUL_EXTERN_RDECL;
612 mulOpts.flags |= getCyclicFlags(subdims, kflags, isTailPass, vecLen);
613
614 tra = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A);
615 trb = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B);
616
617 tf = checkForTailFetches(CLBLAS_TRSM, &subdims[1], kextra, MATRIX_B,
618 false, false);
619 if (trb) {
620 tf &= ~FETCH_TAIL_COL;
621 }
622
623 /*
624 * For lower triangular matrix we proceed upto the diagonal, so we
625 * can't exceed matrix bound and zeroing is not needed
626 */
627 if (isMatrixUpper(kflags)) {
628 tf |= checkForTailFetches(CLBLAS_TRSM, &subdims[1], kextra,
629 MATRIX_A, false, false);
630 if (tra && trb) {
631 tf &= ~FETCH_TAIL_COL;
632 }
633 }
634
635 if (tf != FETCH_NO_TAILS) {
636 memset(&pfpriv, 0, sizeof(pfpriv));
637 pfpriv.funcID = CLBLAS_TRSM;
638 pfpriv.gset = gset;
639 }
640
641 // loop over M
642 if (!isTailPass) {
643 sprintf(tmp, "for (m0 = 0; m0 < M / %lu * %lu; m0 += %lu)",
644 subdims->y, subdims->y, subdims->y);
645 kgenBeginBranch(ctx, tmp);
646 }
647
648 genSetupCoordinates(ctx, subdims, kflags);
649 genZeroResult(ctx, dtype, subdims, vecLen);
650
651 if (!isMatrixUpper(kflags) && isTailPass) {
652 // skip update loop is the matrix consist of the single block
653 sprintf(tmp, "if (M > %lu)", subdims->y);
654 kgenBeginBranch(ctx, tmp);
655 }
656
657 // Avoid tail adjusting along M.
658
659 memcpy(&kextraTmp, kextra, sizeof(kextraTmp));
660 kextraTmp.flags &= ~(KEXTRA_TAILS_M | KEXTRA_TAILS_M_LOWER);
661
662 // update loop is not needed for tail of an upper triangular matrix
663 if (!(isTailPass && isMatrixUpper(kflags))) {
664 if (isTailPass || (kflags & KEXTRA_TAILS_N)) {
665 kgenBeginBranch(ctx, "if (coordB < N)");
666 }
667
668 gset->kextra = &kextraTmp;
669 tailStatus = checkGenAdjustTailCoords(ctx, CLBLAS_TRSM, gset, NULL);
670 gset->kextra = kextra;
671
672 genInternalLoopCtl(ctx, subdims, kflags); // loop over K
673
674 // multiplication for the step-by-step block updating
675 subdims[0].bwidth = subdims[1].bwidth;
676 tileMulGen(ctx, gset, &mulOpts);
677 subdims[0].bwidth = subdims[0].y;
678
679 genInternalLoopEnd(ctx); // loop over K
680 kgenAddBlankLine(ctx);
681
682 // invoke once again, in order to process tails along K
683 if (isMatrixUpper(kflags) && (tf != FETCH_NO_TAILS)) {
684 subdims[0].bwidth = subdims[1].bwidth;
685
686 if (!(tra && trb)) {
687 mulOpts.flags |= TILEMUL_WRAP_AROUND_TAIL;
688 }
689 mulOpts.flags |= TILEMUL_GLOBAL_CYCLIC_K;
690
691 mulOpts.postFetchPriv = &pfpriv;
692 mulOpts.postFetch = defaultTilePostFetch;
693
694 subdims[0].bwidth = subdims[1].bwidth;
695 tileMulGen(ctx, gset, &mulOpts);
696 subdims[0].bwidth = subdims[0].y;
697
698 mulOpts.postFetch = NULL;
699 mulOpts.postFetchPriv = NULL;
700 }
701
702 gset->kextra = &kextraTmp;
703 checkGenRestoreTailCoords(ctx, gset, tailStatus);
704 gset->kextra = kextra;
705
706 if (isTailPass || (kflags & KEXTRA_TAILS_N)) {
707 kgenEndBranch(ctx, NULL);
708 }
709 }
710 else if (!trb && (kflags & KEXTRA_TAILS_N)) {
711 tailStatus |= TAIL_B_RAISED;
712 }
713
714 mulOpts.flags &= ~(TILEMUL_WRAP_AROUND_TAIL | TILEMUL_GLOBAL_CYCLIC_A |
715 TILEMUL_GLOBAL_CYCLIC_K);
716
717 if (!isMatrixUpper(kflags) && isTailPass) {
718 /*
719 * end of branch for non single block tail processing of
720 * the lower triangular matrix
721 */
722 kgenEndBranch(ctx, NULL);
723 }
724
725 /*
726 * Final phase: update the accumulated result, multiply on an inverted
727 * block and write back the result
728 */
729 if (isMatrixUpper(kflags) || ((kflags & KEXTRA_VENDOR_AMD) != 0)) {
730 kgenAddStmt(ctx, "k0 = currM;\n");
731 }
732 else {
733 kgenAddStmt(ctx, "k0 = m0;\n");
734 }
735
736 genReadDiagBlock(ctx, subdims, dtype, copyFuncs, zeroFuncs,
737 kflags, 'C');
738 genInvertDiagBlock(ctx, gset, zeroFuncs);
739
740 // Avoid generating not executed non optimal path
741 gset->kextra = &kextraTmp;
742 if (isTailPass) {
743 kextraTmp.flags |= (KEXTRA_TAILS_M | KEXTRA_TAILS_M_LOWER);
744 }
745 genUpdateIntermTrsmResult(ctx, gset, updateResFnRev,
746 updateResGenericFnRev, true);
747 gset->kextra = kextra;
748
749 /*
750 * Heap to LDS.
751 * Zero unuseful part along columns since it will have an influence
752 * on the result at multiplication on an inverted block
753 */
754 if (isTailPass) {
755 genZeroResultTrash(ctx, &subdims[1], kextra);
756 }
757 genLdsCopy(ctx, gset);
758 kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
759 genZeroResult(ctx, dtype, subdims, vecLen);
760
761 genMulOnDiagBlock(ctx, gset, &mulOpts);
762
763 // write back the tile evaluated
764 upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
765 upFlags |= tailStatusToUpresFlags(tailStatus);
766 upFlags |= UPRES_EXCEED_PROBLEM_CONDITION;
767 setupVdepUpresFlags(kflags, &upFlags);
768
769 gset->kextra = &kextraTmp;
770
771 genResultUpdateWithFlags(ctx, CLBLAS_TRSM, gset, upFlags,
772 NULL, NULL, NULL);
773 gset->kextra = kextra;
774
775 kgenAddBarrier(ctx, CLK_GLOBAL_MEM_FENCE);
776
777 if (isMatrixUpper(kflags)) {
778 sprintf(tmp, "currM -= %lu;\n", subdims[0].y);
779 }
780 else {
781 sprintf(tmp, "currM += %lu;\n", subdims[0].y);
782 }
783 kgenAddStmt(ctx, tmp);
784
785 if (!isTailPass) {
786 kgenEndBranch(ctx, NULL); // loop over M
787 }
788 }
789
790 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)791 generator(
792 char *buf,
793 size_t buflen,
794 const struct SubproblemDim *subdims,
795 const struct PGranularity *pgran,
796 void *extra)
797 {
798 char tmp[1024];
799 struct KgenContext *ctx;
800 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
801 KernelExtraFlags kflags = kextra->flags;
802 DataType dtype = kextra->dtype;
803 BlasGenSettings gset;
804 char updateResFnRev[FUNC_NAME_MAXLEN];
805 char updateResGenericFnRev[FUNC_NAME_MAXLEN];
806 CopyBufFuncs copyFuncs;
807 ZeroFuncs zeroFuncs;
808 UpdateResultFlags upFlags;
809 const char *ptrName;
810 bool b;
811 ssize_t ret;
812 unsigned int l1Pans = (unsigned int)(subdims[0].x / subdims[1].x);
813 bool tailMarker[2] = {false, true};
814 int triang;
815 int i;
816
817 if (pgran->wgDim != 1) {
818 return -EINVAL;
819 }
820
821 if (kflags & KEXTRA_TAILS_M) {
822 kflags |= KEXTRA_TAILS_M_LOWER;
823 }
824 if (kflags & KEXTRA_TAILS_N) {
825 kflags |= KEXTRA_TAILS_N_LOWER;
826 }
827 if (kflags & KEXTRA_TAILS_K) {
828 kflags |= KEXTRA_TAILS_K_LOWER;
829 }
830 kextra->flags = kflags;
831
832 ctx = createKgenContext(buf, buflen, true);
833 if (ctx == NULL) {
834 return -ENOMEM;
835 }
836
837 triang = isMatrixUpper(kflags);
838
839 memset(&gset, 0, sizeof(gset));
840 memcpy(gset.subdims, subdims, sizeof(gset.subdims));
841 gset.kextra = kextra;
842 gset.pgran = pgran;
843
844 initKernelVarNames(&gset.varNames);
845
846 b = isDoubleBasedType(dtype);
847 kgenDeclareUptrs(ctx, b);
848 if (isComplexType(dtype)) {
849 genComplexMathOperators(ctx, dtype);
850 }
851
852 /*
853 * For intermediate result after blocks modification.
854 * Take into account tails adjusting
855 */
856 upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
857 upFlags |= UPRES_WITH_BETA | UPRES_PRIV_DEST;
858
859 if (!isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B) &&
860 (kflags & KEXTRA_TAILS_N)) {
861
862 upFlags |= UPRES_TAIL_COL;
863 }
864
865 setupVdepUpresFlags(kflags, &upFlags);
866 initTiles(&gset);
867 genUpresFuncsWithFlags(ctx, &gset, upFlags, updateResFnRev,
868 updateResGenericFnRev);
869
870 generateBufCopyFuncs(©Funcs, ctx, CLBLAS_TRSM, &gset, BCHF_MATRIX_A);
871 generateZeroingFuncs(&zeroFuncs, ctx, &subdims[0], pgran, dtype,
872 ZF_MATRIX_A);
873
874 //matrix inversion function
875 genInvertingBlockFunc(ctx, subdims[0].bwidth, dtype, kflags);
876 kgenAddBlankLine(ctx);
877
878 // now, generate the kernel
879 declareTrxmKernel(ctx, dtype, pgran, kflags, CLBLAS_TRSM, "Cached", false,
880 true);
881 ret = kgenBeginFuncBody(ctx);
882
883 declareLocalVariables(ctx, &gset);
884 prepareTilesForMainLoop(&gset);
885
886 sprintf(tmp, "currN = gid * %lu;\n", subdims[0].x);
887 kgenAddStmt(ctx, tmp);
888 genInitCurrM(ctx, subdims, kflags);
889
890 if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
891 kgenAddStmt(ctx, "A += offA;\n");
892 }
893 genTrxmBMatrShift(ctx, kflags, false);
894
895 ptrName = dtypeUPtrField(dtype);
896 sprintf(tmp, "uA.%s = A;\n"
897 "uB.%s = B;\n\n",
898 ptrName, ptrName);
899 kgenAddStmt(ctx, tmp);
900
901 /*
902 * B matrix is divided on panels, each work group
903 * multiply such a panel on the whole matrix A.
904 */
905
906 sprintf(tmp, "coordB = gid * %lu + lid %% %u * %lu;\n",
907 subdims[0].x, l1Pans, subdims[1].x);
908 kgenAddStmt(ctx, tmp);
909
910 for (i = 0; i < 2; i++) {
911 b = (i) ? tailMarker[1 - triang] : tailMarker[triang];
912 if (!b || (kflags & KEXTRA_TAILS_M)) {
913 genOneTrsmPass(ctx, &gset, updateResFnRev, updateResGenericFnRev,
914 ©Funcs, &zeroFuncs, b);
915 }
916 }
917
918 kgenEndFuncBody(ctx);
919 ret = kgenAddBlankLine(ctx);
920
921 if (!ret) {
922 ret = (ssize_t)kgenSourceSize(ctx) + 1;
923 }
924
925 destroyKgenContext(ctx);
926
927 return (ret < 0) ? -EOVERFLOW : ret;
928 }
929
930 static bool
isFitToLDS(SubproblemDim * dim,DataType dtype,cl_ulong ldsSize,const void * kernelArgs)931 isFitToLDS(
932 SubproblemDim *dim,
933 DataType dtype,
934 cl_ulong ldsSize,
935 const void *kernelArgs)
936 {
937 cl_ulong sizeA, sizeC;
938 const CLBlasKargs *kargs = (const CLBlasKargs*)kernelArgs;
939
940 /*
941 * It's needed one block for matrix A,
942 * and one block of size maximal of this one for
943 * matrix A and matrix C
944 */
945
946 sizeA = matrBlockSize(dim, MATRIX_A, dtype, kargs->side);
947 sizeC = matrBlockSize(dim, MATRIX_B, dtype, kargs->side);
948 if (sizeA > sizeC) {
949 sizeC = sizeA;
950 }
951
952 return ((sizeA + sizeC) * dtypeSize(dtype) <= ldsSize);
953 }
954
955 static SolverFlags
solverFlags(void)956 solverFlags(void)
957 {
958 return (SF_WSPACE_1D | SF_TOP_INPUT_SQUARE_BLOCKS);
959 }
960
961 static void
assignKargs(KernelArg * args,const void * params,const void * extra)962 assignKargs(KernelArg *args, const void *params, const void *extra)
963 {
964 const CLBlasKargs *blasArgs = (CLBlasKargs*)params;
965 KernelExtraFlags kflags = ((const CLBLASKernExtra*)extra)->flags;
966 int idx = 7;
967
968 initSizeKarg(&args[0], blasArgs->M);
969 initSizeKarg(&args[1], blasArgs->N);
970 assignScalarKarg(&args[2], &(blasArgs->alpha), blasArgs->dtype);
971 initMemobjKarg(&args[3], blasArgs->A, NULL, 0, 0);
972 initSizeKarg(&args[4], blasArgs->lda.matrix);
973 initMemobjKarg(&args[5], blasArgs->B, NULL, 0, 0);
974 initSizeKarg(&args[6], blasArgs->ldb.matrix);
975 if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
976 initSizeKarg(&args[idx++], blasArgs->offA);
977 }
978 if (kflags & KEXTRA_BX_OFF_NOT_ZERO) {
979 initSizeKarg(&args[idx++], blasArgs->offBX);
980 }
981 }
982
983 static void
fixupArgs(void * args,SubproblemDim * subdims,void * extra)984 fixupArgs(void *args, SubproblemDim *subdims, void *extra)
985 {
986 (void)extra;
987 (void)subdims;
988
989 fixupTrxmKargs((CLBlasKargs*)args);
990 }
991
992 void
initTrsmCachedPattern(MemoryPattern * mempat)993 initTrsmCachedPattern(MemoryPattern *mempat)
994 {
995 mempat->name = "Cached global memory based block trsm";
996 mempat->nrLevels = 2;
997 mempat->cuLevel = 0;
998 mempat->thLevel = 0;
999 mempat->sops = &trsmSops;
1000
1001 mpatExtra.aMset = CLMEM_LEVEL_L1;
1002 mpatExtra.mobjA = CLMEM_BUFFER;
1003 mpatExtra.mobjB = CLMEM_BUFFER;
1004 mempat->extra = &mpatExtra;
1005 }
1006