1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 /*
19 * TRSM generator with support of cached reads from the global memory
20 */
21
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <stdlib.h>
26 #include <clblas_stddef.h>
27 #include <clBLAS.h>
28 #include <blas_mempat.h>
29 #include <clkern.h>
30 #include <clblas-internal.h>
31 #include <matrix_props.h>
32 #include <matrix_dims.h>
33
34 #include "dblock_kgen.h"
35 #include "kerngen.h"
36 #include "blas_kgen.h"
37 #include "gen_helper.h"
38 #include "trxm_common.h"
39 #include "trsm_kgen.h"
40 #include "legacy/blas_kgen_legacy.h"
41
42 typedef enum LdsUseFlags {
43 LDS_NO_USE = 0,
44 LDS_USE_LARGE = 0x1,
45 LDS_USE_DIAGONAL = 0x2
46 } LdsUseFlags;
47
48 typedef struct TrsmExtraParams {
49 int unrollingFactor;
50 unsigned int unrolledTail;
51 LdsUseFlags ldsUse;
52 } TrsmExtraParams;
53
54 enum TrsmStage {
55 BLOCK_UPDATE,
56 TILE_UPDATE
57 };
58
59 static CLBLASMpatExtra mpatExtra;
60
61 static ssize_t
62 generator(
63 char *buf,
64 size_t buflen,
65 const struct SubproblemDim *subdims,
66 const struct PGranularity *pgran,
67 void *extra);
68
69 static bool
70 isFitToLDS(
71 SubproblemDim *dim,
72 DataType dtype,
73 cl_ulong ldsSize,
74 const void *kernelArgs);
75
76 static SolverFlags
77 solverFlags(void);
78
79 static void
80 assignKargs(KernelArg *args, const void *params, const void *extra);
81
82 static void
83 fixupArgs(void *args, SubproblemDim *subdims, void *extra);
84
85 static bool
86 checkCalcDecompDedicated(
87 PGranularity *pgran,
88 SubproblemDim *subdims,
89 unsigned int subdimsNum,
90 DataType dtype,
91 int check);
92
93 #if 0
94 static int
95 getDefaultDecomp(
96 PGranularity *pgran,
97 SubproblemDim *subdims,
98 unsigned int subdimsNum,
99 void * pArgs);
100 #endif
101
102 static SolverOps trsmSops = {
103 generator,
104 assignKargs,
105 isFitToLDS,
106 NULL,
107 NULL,
108 NULL,
109 NULL,
110 solverFlags,
111 fixupArgs,
112 NULL,//getDefaultDecomp
113 checkCalcDecompDedicated,
114 NULL,
115 NULL
116 };
117
118 // The struct for storage tails
119 typedef struct TileSet
120 {
121 Tile rectA; // The rectangular tile A for the update loop at stage 1
122 Tile squareA; // The square tile for the stage 2
123 Tile origB; // The rectangular tile B for the update loop at the stage 1
124 Tile bStage2; // The rectangular tile B for the update loop at thestage 2
125 Tile bAsSqA; // Descriptor for holding square tile A in the storage of B
126 Tile bAsC; // Descriptor for holding tile C in the storage of B
127 // the entire tile A matching the storage declared in the kernel
128 Tile A;
129 // the entire tile B matching the storage declared in the kernel
130 Tile B;
131 } TileSet;
132
133
134 static bool
useSkewedFetchB(const BlasGenSettings * gset)135 useSkewedFetchB(const BlasGenSettings *gset)
136 {
137 KernelExtraFlags kflags = gset->kextra->flags;
138 TrsmExtraParams *extraParams = (TrsmExtraParams*)gset->kextra->solverPriv;
139 bool ret = false;
140
141 if (extraParams->ldsUse & LDS_USE_LARGE) {
142 ret = !isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B);
143 }
144
145 return ret;
146 }
147
148 static void
restoreTile(Tile * dst,const Tile * src)149 restoreTile(Tile* dst, const Tile* src)
150 {
151 dst->baseName = src->baseName;
152 dst->vecLen = src->vecLen;
153 dst->storType = src->storType;
154 }
155
156 static Tile
substituteTile(Tile * dst,const Tile * src)157 substituteTile(Tile* dst, const Tile* src)
158 {
159 Tile tmp;
160
161 restoreTile(&tmp, dst);
162 restoreTile(dst, src);
163
164 return tmp;
165 }
166
167 static void
sprintfInvertedElement(Kstring * elem,const Tile * tile,unsigned int row,unsigned int col,unsigned int len,bool isU)168 sprintfInvertedElement(
169 Kstring *elem,
170 const Tile *tile,
171 unsigned int row,
172 unsigned int col,
173 unsigned int len,
174 bool isU)
175 {
176 if (isU) {
177 row = tile->nrRows - row - 1;
178 col = tile->nrCols - col - len;
179 }
180
181 sprintfTileElement(elem, tile, row, col, len);
182 }
183
184 static void
genTileInverting(struct KgenContext * ctx,const BlasGenSettings * gset,const TileSet * tileSet)185 genTileInverting(
186 struct KgenContext *ctx,
187 const BlasGenSettings *gset,
188 const TileSet *tileSet)
189 {
190 char tmp[1024];
191 const CLBLASKernExtra *kextra = gset->kextra;
192 KernelExtraFlags kflags = kextra->flags;
193 DataType dtype = kextra->dtype;
194 const SubproblemDim *dim = &gset->subdims[1];
195 unsigned int accLen;
196 unsigned int i, j, k;
197 Tile srcTile;
198 Tile dstTile;
199 bool isU, isComplex;
200 bool isInlined = gset->flags & BGF_EXPLICIT_INLINE;
201 const char* typeNameA;
202 const char* typeNameB;
203
204 memcpy(&srcTile, &tileSet->bAsSqA, sizeof(srcTile));
205 memcpy(&dstTile, &tileSet->squareA, sizeof(dstTile));
206
207 getVectorTypeName(kextra->dtype, dstTile.vecLen, &typeNameA, NULL);
208 getVectorTypeName(kextra->dtype, srcTile.vecLen, &typeNameB, NULL);
209 isU = isMatrixUpper(kflags);
210 isComplex = isComplexType(dtype);
211
212 if (isComplex || dstTile.trans) {
213 accLen = 1;
214 }
215 else {
216 accLen = umin(srcTile.vecLen, dstTile.vecLen);
217 accLen = umin(accLen, srcTile.nrCols);
218 }
219
220 if (!isInlined) {
221 dstTile.baseName = "a";
222 srcTile.baseName = "b";
223 sprintf(tmp, "void\n"
224 "invertTile(%s *a, %s *b)\n",
225 typeNameA, typeNameB);
226 kgenDeclareFunction(ctx, tmp);
227 kgenBeginFuncBody(ctx);
228 }
229 else {
230 kgenAddStmt(ctx, "// Invert tile\n");
231 }
232
233 // made destination block unit
234 genZeroTile(ctx, &dstTile);
235 for (i = 0; i < dim->y; i++) {
236 genSetUnitInTile(ctx, &dstTile, i, i);
237 }
238 kgenAddBlankLine(ctx);
239
240 for (i = 0; i < dim->y; i++) {
241 Kstring src, srcDiag, dst, dstLast;
242
243 // current source diagonal element
244 sprintfInvertedElement(&srcDiag, &srcTile, i, i, 1, isU);
245 for (j = i; j < dim->y; j++) {
246 // current source non diagonal element
247 if (i) {
248 sprintfInvertedElement(&src, &srcTile, j, i - 1, 1, isU);
249 }
250
251 for (k = 0; k < dim->y; k += accLen) {
252 // current updated vectorized element
253 sprintfInvertedElement(&dst, &dstTile, j, k, accLen, isU);
254
255 // update
256 if (i) {
257 // last updated vectorized element
258 sprintfInvertedElement(&dstLast, &dstTile, i - 1, k,
259 accLen, isU);
260 if (isComplex) {
261 sprintf(tmp, "%s -= mul(%s, %s);\n",
262 dst.buf, dstLast.buf, src.buf);
263 }
264 else {
265 sprintf(tmp, "%s -= %s * %s;\n",
266 dst.buf, dstLast.buf, src.buf);
267 }
268 kgenAddStmt(ctx, tmp);
269 }
270
271 // divide on the diagonal element
272 if (j == i) {
273 if (isComplex) {
274 sprintf(tmp, "%s = div(%s, %s);\n",
275 dst.buf, dst.buf, srcDiag.buf);
276 }
277 else {
278 sprintf(tmp, "%s /= %s;\n", dst.buf, srcDiag.buf);
279 }
280 kgenAddStmt(ctx, tmp);
281 }
282 }
283 }
284 if (i != dim->y - 1) {
285 kgenAddBlankLine(ctx);
286 }
287 }
288
289 if (!isInlined) {
290 kgenEndFuncBody(ctx);
291 }
292 kgenAddBlankLine(ctx);
293
294 }
295
296 static void
declareLocalVariables(struct KgenContext * ctx,const BlasGenSettings * gset,Tile * parTile,TrsmExtraParams * extraParams)297 declareLocalVariables(
298 struct KgenContext *ctx,
299 const BlasGenSettings *gset,
300 Tile* parTile,
301 TrsmExtraParams * extraParams)
302 {
303 char tmp[1024];
304 const SubproblemDim *dims = gset->subdims;
305 const char* parTileTypeName = NULL;
306 bool trb = isMatrixAccessColMaj(CLBLAS_TRSM, gset->kextra->flags,
307 MATRIX_B);
308 unsigned int locWidth;
309 unsigned int tsize;
310 unsigned int parTileSize;
311 unsigned int l1Pans;
312 unsigned int step;
313
314 kgenAddStmt(ctx,
315 "const int lid = get_local_id(0);\n"
316 "const int gid = get_group_id(0);\n"
317 "GPtr uA, uB;\n"
318 "uint coordA, coordB;\n"
319 "uint m0 = 0, k0, m1;\n");
320
321 if (isMatrixUpper(gset->kextra->flags)) {
322 sprintf(tmp, "uint currM = (M - 1) / %lu * %lu;\n",
323 dims[0].y, dims[0].y);
324 kgenAddStmt(ctx, tmp);
325 }
326
327 /*
328 * Declare private blocks.
329 * The region 'b' stores in different time tiles of both
330 * the input matrices and the result
331 */
332
333 declareTileStorages(ctx, gset);
334
335 *parTile = gset->tileBX;
336
337 if (extraParams->ldsUse) {
338 tsize = dtypeSize(gset->kextra->dtype);
339 l1Pans = (unsigned int)(dims[0].x / dims[1].x);
340
341 parTile->vecLen = (trb) ? (unsigned int)dims[1].x
342 : (unsigned int)dims[1].bwidth;
343 parTile->vecLen = umin(parTile->vecLen, sizeof(cl_float4) / tsize);
344 parTile->trans = trb;
345
346 /*
347 * Allocate enough space in the local area to fit several tiles
348 * at the stage1 (according to the unrolled factor) and one tile
349 * at the stage2
350 */
351
352 locWidth = (unsigned int)dims[1].bwidth * extraParams->unrollingFactor;
353 if (extraParams->ldsUse & LDS_USE_DIAGONAL) {
354 locWidth = umax(locWidth, (unsigned int)dims[1].y);
355 }
356 if (trb) {
357 parTile->nrRows = locWidth;
358 parTile->nrCols = (unsigned int)dims[0].x;
359 step = (unsigned int)dims[1].x / parTile->vecLen;
360 }
361 else {
362 parTile->nrRows = (unsigned int)dims[0].x;
363 parTile->nrCols = locWidth;
364 step = (unsigned int)dims[1].x * locWidth / parTile->vecLen;
365 }
366
367 parTileSize = tileVectorsNum(parTile);
368
369 getVectorTypeName(gset->kextra->dtype, parTile->vecLen,
370 &parTileTypeName, NULL);
371
372 sprintf(tmp, "__local %s tmpB[%i];\n"
373 "LPtr lB;\n"
374 "LPtr lBMain = {(__local float*)(tmpB + lid %% %u * %u)};\n",
375 parTileTypeName, parTileSize, l1Pans, step);
376 kgenAddStmt(ctx, tmp);
377
378 if (useSkewedFetchB(gset)) {
379 kgenPrintf(ctx, "const uint skewX = lid %% %u %% %lu;\n",
380 l1Pans, gset->subdims[1].x);
381 }
382 }
383
384 kgenAddBlankLine(ctx);
385 }
386
387 /*
388 * Generate cyclical tile shifting so as to convert the skewed
389 * storing to "one-to-one", i. e. the first element in the tile
390 * matches to the first element of the respective tile in the
391 * output matrix.
392 */
393 static void
genTileCyclicalShift(struct KgenContext * ctx,BlasGenSettings * gset)394 genTileCyclicalShift(struct KgenContext *ctx, BlasGenSettings *gset)
395 {
396 const char *tname;
397 Kstring k1, k2, *src, *dst, *ktmp;
398 unsigned int row, col;
399 unsigned int seglen;
400 Tile *tileC = &gset->tileCY;
401
402 seglen = tileLineSegmentLen(tileC);
403 getVectorTypeName(gset->kextra->dtype, seglen, &tname, NULL);
404
405 kgenAddStmt(ctx, "\n// deliver from skewing in the result\n");
406 kgenBeginBranch(ctx, "for (uint i = 0; i < skewX; i++)");
407 kgenPrintf(ctx, "%s tmp;\n\n", tname);
408
409 src = &k1;
410 dst = &k2;
411
412 // Skewing may be used only in case of transposed C
413 for (row = 0; row < tileC->nrRows; row += seglen) {
414 sprintfTileElement(dst, tileC, row, tileC->nrCols - 1, seglen);
415 kgenPrintf(ctx, "tmp = %s;\n", dst->buf);
416 for (col = tileC->nrCols - 1; col > 0; col--) {
417 sprintfTileElement(src, tileC, row, col - 1, seglen);
418 kgenPrintf(ctx, "%s = %s;\n", dst->buf, src->buf);
419 // swap pointer
420 ktmp = src;
421 src = dst;
422 dst = ktmp;
423 }
424 kgenPrintf(ctx, "%s = tmp;\n", dst->buf);
425 }
426
427 kgenEndBranch(ctx, NULL);
428 kgenAddBlankLine(ctx);
429 }
430
431 /*
432 * Setup coordinates before beginning a trsm stage
433 * A caller must ensure the strict stage sequence:
434 * BLOCK_UPDATE -> TILE_UPDATE
435 */
436 static void
genSetupCoords(struct KgenContext * ctx,const BlasGenSettings * gset,enum TrsmStage stage)437 genSetupCoords(
438 struct KgenContext *ctx,
439 const BlasGenSettings *gset,
440 enum TrsmStage stage)
441 {
442 char tmp[1024];
443 KernelExtraFlags kflags = gset->kextra->flags;
444 const SubproblemDim *dims = gset->subdims;
445 unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
446 const char *s;
447
448 s = isMatrixUpper(kflags) ? "currM" : "m0";
449 sprintf(tmp, "coordA = %s + (lid / %u * %lu);\n",
450 s, l1Pans, dims[1].y);
451 kgenAddStmt(ctx, tmp);
452
453 switch (stage) {
454 case BLOCK_UPDATE:
455 if (isMatrixUpper(kflags)) {
456 sprintf(tmp, "k0 = currM + %lu;\n", dims[0].y);
457 }
458 else {
459 sprintf(tmp, "k0 = 0;\n");
460 }
461 break;
462 case TILE_UPDATE:
463 if (isMatrixUpper(kflags)) {
464 sprintf(tmp, "k0 = currM + %lu - m1 * %lu;\n",
465 dims[0].y - dims[1].y, dims[1].y);
466 }
467 else {
468 sprintf(tmp, "k0 = m0 + m1 * %lu;\n", dims[1].y);
469 }
470 break;
471 }
472
473 kgenAddStmt(ctx, tmp);
474
475 sprintf(tmp, "coordB = gid * %lu + (lid %% %u * %lu);\n",
476 dims[0].x, l1Pans, dims[1].x);
477
478 kgenAddStmt(ctx, tmp);
479 kgenAddBlankLine(ctx);
480 }
481
482 // Generate control block of the loop over K
483 static void
genInternalLoopCtl(struct KgenContext * ctx,const SubproblemDim * dim,KernelExtraFlags kflags,size_t stepK,size_t boundAlign)484 genInternalLoopCtl(
485 struct KgenContext *ctx,
486 const SubproblemDim *dim,
487 KernelExtraFlags kflags,
488 size_t stepK,
489 size_t boundAlign)
490 {
491 char tmp[1024];
492
493 if (isMatrixUpper(kflags)) {
494 if (kflags & KEXTRA_TAILS_M) {
495 sprintf(tmp, "for (k0 = currM + %lu; k0 < M / %lu * %lu; "
496 "k0 += %lu)",
497 dim[0].y, boundAlign, boundAlign, stepK);
498 }
499 else {
500 sprintf(tmp, "for (k0 = currM + %lu; k0 < M; k0 += %lu)",
501 dim[0].y, stepK);
502 }
503 }
504 else {
505 sprintf(tmp, "for (k0 = 0; k0 < m0; k0 += %lu)",
506 stepK);
507 }
508
509 kgenBeginBranch(ctx, tmp);
510 }
511
512 static void
initKernelVarNames(KernelVarNames * kvars)513 initKernelVarNames(KernelVarNames *kvars)
514 {
515 kvars->A = "uA";
516 kvars->B = "uB";
517 kvars->C = "B";
518 kvars->coordA = "coordA";
519 kvars->coordB = "coordB";
520 kvars->k = "k0";
521 kvars->sizeM = "M";
522 kvars->sizeN = "N";
523 kvars->sizeK = "M";
524 kvars->lda = "lda";
525 kvars->ldb = "ldb";
526 kvars->ldc = "ldb";
527 kvars->alpha = "alpha";
528 kvars->beta = "beta";
529 }
530
531 static void
setFetchHandler(TileMulOpts * mulOpts,const BlasGenSettings * gset,int handler (struct KgenContext * ctx,MatrixRole mrole,void * priv),TilePostFetchPrivate * priv)532 setFetchHandler(
533 TileMulOpts *mulOpts,
534 const BlasGenSettings *gset,
535 int handler(struct KgenContext *ctx, MatrixRole mrole, void *priv),
536 TilePostFetchPrivate *priv)
537 {
538 int i, nrPrivs;
539 const char *regName = NULL;
540
541 if (handler == defaultTilePostFetch) {
542 nrPrivs = 1;
543 }
544 else {
545 nrPrivs = 2;
546 regName = "b";
547 }
548
549 for (i = 0; i < nrPrivs; i++) {
550 priv[i].fetchNumA = 0;
551 priv[i].wholeA = 1;
552 priv[i].funcID = CLBLAS_TRSM;
553 priv[i].gset = gset;
554 priv[i].regName = regName;
555 mulOpts->postFetch = handler;
556 mulOpts->postFetchPriv = priv;
557 }
558 }
559
560 static void
genCheckShiftTailB(struct KgenContext * ctx,const BlasGenSettings * gset,int adjustRestore,TailStatus * tailStatus)561 genCheckShiftTailB(
562 struct KgenContext *ctx,
563 const BlasGenSettings *gset,
564 int adjustRestore,
565 TailStatus *tailStatus)
566 {
567 BlasGenSettings gsetNew;
568 CLBLASKernExtra kextraNew;
569
570 memcpy(&gsetNew, gset, sizeof(gsetNew));
571 memcpy(&kextraNew, gset->kextra, sizeof(kextraNew));
572 // avoid tail shift for the matrix A
573 kextraNew.flags &= ~(KEXTRA_TAILS_M | KEXTRA_TAILS_M_LOWER);
574 gsetNew.kextra = &kextraNew;
575
576 if (adjustRestore) {
577 checkGenRestoreTailCoords(ctx, &gsetNew, *tailStatus);
578 }
579 else {
580 *tailStatus = checkGenAdjustTailCoords(ctx, CLBLAS_TRSM, &gsetNew,
581 NULL);
582 }
583 }
584
585 static void
sprintfHitMatrixCond(char * buf,MatrixRole mrole,const char * prefix,const char * suffix)586 sprintfHitMatrixCond(
587 char *buf,
588 MatrixRole mrole,
589 const char *prefix,
590 const char *suffix)
591 {
592 const char *coordName;
593 char bound;
594
595 coordName = (mrole == MATRIX_A) ? "coordA" : "coordB";
596 bound = (mrole == MATRIX_A) ? 'M' : 'N';
597 if (suffix == NULL) {
598 suffix = "";
599 }
600 sprintf(buf, "%s%s < %c%s", prefix, coordName, bound, suffix);
601 }
602
603 /*
604 * 'mulUpd' arguments mean what action is being done: multiplication on
605 * an inverted tile or subsequent update
606 */
607 static void
sprintfStage2Condition(char * buf,const BlasGenSettings * gset,int mulUpd)608 sprintfStage2Condition(
609 char *buf,
610 const BlasGenSettings *gset,
611 int mulUpd)
612 {
613 KernelExtraFlags kflags = gset->kextra->flags;
614 char hitCond[1024];
615 char *p;
616 unsigned int xPans, yPans;
617
618
619 hitCond[0] = '\0';
620 xPans = (unsigned int)(gset->subdims[0].x / gset->subdims[1].x);
621 yPans = (unsigned int)(gset->subdims[0].y / gset->subdims[1].y);
622 if (kflags & KEXTRA_TAILS_M) {
623 sprintfHitMatrixCond(hitCond, MATRIX_A, " && ", NULL);
624 }
625 p = hitCond + strlen(hitCond);
626 if (kflags & KEXTRA_TAILS_N) {
627 sprintfHitMatrixCond(p, MATRIX_B, " && ", NULL);
628 }
629
630 if (!mulUpd) {
631 if (isMatrixUpper(kflags)) {
632 sprintf(buf, "if (lid / %u + m1 == %u%s)",
633 xPans, yPans - 1, hitCond);
634 }
635 else {
636 sprintf(buf, "if (lid / %u == m1%s)", xPans, hitCond);
637 }
638 }
639 else {
640 if (isMatrixUpper(kflags)) {
641 sprintf(buf, "if (lid / %u + m1 < %u%s)",
642 xPans, yPans - 1, hitCond);
643 }
644 else {
645 sprintf(buf, "if (lid / %u > m1%s)", xPans, hitCond);
646 }
647 }
648 }
649
650 static void
genZeroTileTrash(struct KgenContext * ctx,const BlasGenSettings * gset,MatrixRole mrole,Tile * tile)651 genZeroTileTrash(
652 struct KgenContext *ctx,
653 const BlasGenSettings *gset,
654 MatrixRole mrole,
655 Tile* tile)
656 {
657 char tmp[1024];
658 const SubproblemDim *dim = &gset->subdims[1];
659 const CLBLASKernExtra *kextra = gset->kextra;
660 unsigned int i, j;
661 unsigned int step;
662 Kstring elem;
663
664 if (mrole == MATRIX_A) {
665 kgenAddBlankLine(ctx);
666 }
667 else {
668 kgenBeginBranch(ctx, NULL);
669 }
670
671 sprintf(tmp, "const int bound = (coordA + %lu > M) ? (M - coordA) : %lu;\n",
672 dim->y, dim->y);
673 kgenAddStmt(ctx, tmp);
674
675 step = tileLineSegmentLen(tile);
676 step = (tile->trans) ? 1 : step;
677
678 for (j = 0; j < tile->nrRows; ++j) {
679 for (i = 0; i < tile->nrCols; i+=step) {
680 sprintfTileElement(&elem, tile, j, i, step);
681 sprintf(tmp, "%s = (bound <= %u) ? 0 : %s;\n", elem.buf, j, elem.buf);
682 kgenAddStmt(ctx, tmp);
683 }
684 }
685
686 // Set units in the trash diagonal elements for a tile of A
687 if (mrole == MATRIX_A) {
688 for (i = 0; i < (unsigned int)dim->y; i++) {
689 sprintfTileElement(&elem, tile, i, i, 1);
690 sprintf(tmp, "%s = (bound <= %d) ? %s : %s;\n",
691 elem.buf, (int)i, strOne(kextra->dtype), elem.buf);
692 kgenAddStmt(ctx, tmp);
693 }
694 }
695
696 if (mrole == MATRIX_A) {
697 kgenAddBlankLine(ctx);
698 }
699 else {
700 kgenEndBranch(ctx, NULL);
701 }
702 }
703
704 /*
705 * NOTE: Before invoking this function 'tileA' must be initialized accordingly
706 * so as it stores a square tile of the matrix A.
707 */
708 static void
genMulOnDiagonalTile(struct KgenContext * ctx,BlasGenSettings * gset,TileSet * tileSet,const TileMulOpts * mulOpts)709 genMulOnDiagonalTile(
710 struct KgenContext *ctx,
711 BlasGenSettings *gset,
712 TileSet *tileSet,
713 const TileMulOpts *mulOpts)
714 {
715 char tmp[1024];
716 FetchOpts fetchOpts;
717 const SubproblemDim *dim = &gset->subdims[1];
718 TilePostFetchPrivate pfPriv[2];
719 TileMulOpts optsNew;
720 const CLBLASKernExtra *extra = gset->kextra;
721 CLBLASKernExtra extraNew;
722 KernelExtraFlags kflags = extra->flags;
723 Tile t;
724 bool isTail;
725
726 memset(&fetchOpts, 0, sizeof(fetchOpts));
727 fetchOpts.regName = "b";
728 fetchOpts.mrole = MATRIX_A;
729 fetchOpts.lineOffset = 0;
730 fetchOpts.linesNum = (unsigned int)dim->y;
731
732 // setup options to multiply on the inverted tile
733 memcpy(&optsNew, mulOpts, sizeof(TileMulOpts));
734 optsNew.flags &= ~TILEMUL_TRB;
735
736 kgenAddStmt(ctx, "// Fetch and invert the square tile located on the "
737 "diagonal\n");
738
739 // The matrix B play the role of A
740 t = substituteTile(&gset->tileA, &tileSet->bAsSqA);
741
742 isTail = ((kflags & KEXTRA_TAILS_M) != 0);
743 genFetchInputTile(ctx, mulOpts->fctx, gset, &fetchOpts);
744 setFetchHandler(&optsNew, gset, genTrxmPostFetchZero, pfPriv);
745
746 /*
747 * There is no needs in zeroing tail along K in case of the lower
748 * triangular matrix because it is in the "other" triangle which is
749 * never accessed
750 */
751 if (isTail && !isMatrixUpper(kflags)) {
752 memcpy(&extraNew, extra, sizeof(extraNew));
753 extraNew.flags &= ~KEXTRA_TAILS_K_LOWER;
754 gset->kextra = &extraNew;
755 }
756 genTrxmPostFetchZero(ctx, MATRIX_A, pfPriv);
757
758 /*
759 * One must zero the tail part of a fetched square tile
760 * in order to avoid influence of the trailing trash on the resulting
761 * inverted tile (evaluating proceeds from the bottom towards the top
762 * of the tile)
763 */
764 if (isTail) {
765 genZeroTileTrash(ctx, gset, MATRIX_A, &gset->tileA);
766 }
767
768 restoreTile(&gset->tileA, &t);
769
770 if(gset->flags & BGF_EXPLICIT_INLINE) {
771 genTileInverting(ctx, gset, tileSet);
772 }
773 else {
774 sprintf(tmp, "invertTile(%s, %s);\n\n",
775 tileSet->squareA.baseName, tileSet->bAsSqA.baseName);
776 kgenAddStmt(ctx, tmp);
777 }
778
779 gset->tileBX = tileSet->bAsC;
780 genTileCopy(ctx, &gset->tileBX, &gset->tileCY, TILECOPY_ASSIGN);
781
782 /*
783 * For the lower diagonal not integrally decomposed matrix A
784 * it's enough to zero the tail part of the result in order to
785 * clear trash accumulated over the update loop
786 */
787 if (isTail && !isMatrixUpper(kflags)) {
788 genZeroTileTrash(ctx, gset, MATRIX_B, &gset->tileBX);
789 }
790
791 genZeroTile(ctx, &gset->tileCY);
792
793 genMulTiles(ctx, gset, &optsNew);
794 kgenAddBlankLine(ctx);
795
796 // restore original extra
797 gset->kextra = extra;
798 }
799
800 static void
genUpdateIntermResult(struct KgenContext * ctx,const BlasGenSettings * gset,bool withMhitCond,UpdateResultFlags flags)801 genUpdateIntermResult(
802 struct KgenContext *ctx,
803 const BlasGenSettings *gset,
804 bool withMhitCond,
805 UpdateResultFlags flags)
806 {
807 char tmp[1024];
808 const char *coordY, *coordX;
809 char *revAlp, *alp;
810 DataType dtype = gset->kextra->dtype;
811 KernelExtraFlags kflags = gset->kextra->flags;
812 const SubproblemDim *dim = &gset->subdims[1];
813 const KernelVarNames *kvarNames = &gset->varNames;
814 UpdateResultOp op;
815 UpresVarNames uvars;
816 const char* ctype;
817
818 memset(&uvars, 0, sizeof(uvars));
819
820 op = (flags & UPRES_WITH_BETA) ? UPRES_SUM : UPRES_SET;
821
822 uvars.startRow = kvarNames->coordA;
823 uvars.startCol = kvarNames->coordB;
824 uvars.nrRows = "y";
825 uvars.nrCols = "x";
826 uvars.result = "B";
827 uvars.ld = "ldb";
828
829 ctype = dtypeBuiltinType(dtype);
830 if (isComplexType(dtype)) {
831 if (dtype == TYPE_COMPLEX_FLOAT) {
832 revAlp = "div((float2)(-1.f, 0), alpha)";
833 alp = "(float2)(1.f, 0)";
834 }
835 else {
836 revAlp = "div((double2)(-1., 0), alpha)";
837 alp = "(double2)(1., 0)";
838 }
839 }
840 else {
841 revAlp = "-1. / alpha";
842 alp = "1.";
843 }
844
845 // inline result update
846 flags |= UPRES_INLINE;
847
848 coordY = kvarNames->coordA;
849 coordX = kvarNames->coordB;
850
851 /*
852 * We should be careful here.
853 *
854 * The non tailed case of updateResult() is rewritted.
855 * Now update result for tailed and non tailed cases have a bit
856 * different semantics.
857 *
858 * The first one produces expressions like
859 * 'dst = dst * beta + src * alpha'.
860 *
861 * Here 'dst' and 'src' may be private result stored in registers or
862 * result to be updated in the global memory. Let the first one to be
863 * designated as tileC and the second one as matC.
864 *
865 * The non tailed case produces expressions like
866 * 'dst = matC * beta + tileC * alpha'.
867 *
868 * The second variant is more clear and native for the new implementation.
869 * But as the difference is not eliminated, both the variants are
870 * maintained here.
871 */
872
873 if (!(kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N))) {
874 kgenBeginBranch(ctx, "");
875
876 sprintf(tmp, "%s %s = %s;\n"
877 "%s alpha = beta;\n",
878 ctype, "beta", revAlp, ctype);
879 kgenAddStmt(ctx, tmp);
880
881 updateResultGen(ctx,
882 gset,
883 CLBLAS_TRSM,
884 op,
885 flags & ~UPRES_WITH_BETA,
886 &uvars);
887
888 kgenEndBranch(ctx, NULL);
889 }
890 else {
891 if (withMhitCond) {
892 sprintf(tmp, "if ((%s < %s) && (%s < %s))",
893 coordY, kvarNames->sizeM, coordX, kvarNames->sizeN);
894 kgenBeginBranch(ctx, tmp);
895 }
896 else {
897 /* for x, y variables scope */
898 kgenBeginBranch(ctx, NULL);
899 }
900
901 sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
902 "uint x = min(%luu, %s - (uint)%s);\n",
903 dim->y, kvarNames->sizeM, coordY,
904 dim->x, kvarNames->sizeN, coordX);
905 kgenAddStmt(ctx, tmp);
906
907 sprintf(tmp, "if ((y == %lu) && (x == %lu))",
908 dim->y, dim->x);
909 kgenBeginBranch(ctx, tmp);
910
911 sprintf(tmp, "%s %s = %s;\n"
912 "%s alpha = beta;\n",
913 ctype, "beta", revAlp, ctype);
914 kgenAddStmt(ctx, tmp);
915
916 // optimized update
917 updateResultGen(ctx,
918 gset,
919 CLBLAS_TRSM,
920 op,
921 flags & ~UPRES_WITH_BETA,
922 &uvars);
923
924 kgenEndBranch(ctx, NULL);
925
926 flags |= UPRES_GENERIC;
927 kgenBeginBranch(ctx, "else ");
928
929 sprintf(tmp, "%s %s = %s;\n"
930 "%s %s = %s;\n",
931 ctype, "beta", revAlp,
932 ctype, "alpha", alp);
933 kgenAddStmt(ctx, tmp);
934
935 // not optimized update
936 updateResultGen(ctx,
937 gset,
938 CLBLAS_TRSM,
939 op,
940 flags,
941 &uvars);
942
943 kgenEndBranch(ctx, NULL);
944 kgenEndBranch(ctx, NULL);
945 }
946 }
947
948 static void
genPreloadedTileMul(struct KgenContext * ctx,BlasGenSettings * gset,TileMulOpts * mulOpts,const Tile * parTile,const char * copy2LDSFuncName)949 genPreloadedTileMul(
950 struct KgenContext *ctx,
951 BlasGenSettings *gset,
952 TileMulOpts *mulOpts,
953 const Tile *parTile,
954 const char* copy2LDSFuncName)
955 {
956 char tmp[1024];
957 KernelExtraFlags kflags = gset->kextra->flags;
958 unsigned int bwidthOld;
959 const char *oldNameB;
960 const char *ptrName;
961
962 getVectorTypeName(gset->kextra->dtype, parTile->vecLen, NULL, &ptrName);
963 kgenPrintf(ctx, "lB.%s = tmpB;\n", ptrName);
964 kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
965
966 if (!isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B)) {
967 sprintf(tmp, "%s(lB, uB, gid * %lu, k0, ldb);\n",
968 copy2LDSFuncName, gset->subdims[0].x);
969 }
970 else {
971 sprintf(tmp, "%s(lB, uB, k0, gid * %lu, ldb);\n",
972 copy2LDSFuncName, gset->subdims[0].x);
973 }
974 kgenAddStmt(ctx, tmp);
975
976 kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
977 kgenAddBlankLine(ctx);
978
979 kgenAddStmt(ctx, "lB = lBMain;\n\n");
980
981 mulOpts->memB = CLMEM_LOCAL_MEMORY;
982 oldNameB = gset->varNames.B;
983 bwidthOld = (unsigned int)gset->subdims[0].bwidth;
984 gset->varNames.B = "lB";
985 gset->subdims[0].bwidth = (parTile->trans) ? parTile->nrRows :
986 parTile->nrCols;
987
988 tileMulGen(ctx, gset, mulOpts);
989
990 gset->varNames.B = oldNameB;
991 gset->subdims[0].bwidth = bwidthOld;
992 mulOpts->memB = CLMEM_GLOBAL_MEMORY;
993 }
994
995 static void
initTiles(BlasGenSettings * gset,TileSet * tileSet,const struct SubproblemDim * subdims,KernelExtraFlags kflags,DataType dtype,PrivateStorageType storType)996 initTiles(
997 BlasGenSettings* gset,
998 TileSet* tileSet,
999 const struct SubproblemDim *subdims,
1000 KernelExtraFlags kflags,
1001 DataType dtype,
1002 PrivateStorageType storType)
1003 {
1004 unsigned int rowsA;
1005 unsigned int rowsB;
1006 unsigned int rowsC;
1007 unsigned int colsA;
1008 unsigned int colsB;
1009 unsigned int colsC;
1010 bool transA;
1011 bool transB;
1012 unsigned int vecLenA;
1013 unsigned int vecLenB;
1014 unsigned int vecLenC;
1015
1016 rowsA = (unsigned int)subdims[1].y;
1017 colsA = (unsigned int)szmax(subdims[1].y, subdims[1].bwidth);
1018
1019 rowsB = (unsigned int)szmax(subdims[1].y, subdims[1].bwidth);
1020 colsB = (unsigned int)szmax(subdims[1].x, subdims[1].y);
1021
1022 rowsC = (unsigned int)subdims[1].y;
1023 colsC = (unsigned int)subdims[1].x;
1024
1025 transA = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A);
1026 transB = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B);
1027
1028 vecLenA = (unsigned int)((transA) ? subdims[1].y : subdims[1].bwidth);
1029 vecLenA = umin(vecLenA, MAX_TILE_VECLEN);
1030 vecLenB = (unsigned int)((transB) ? subdims[1].x : subdims[1].bwidth);
1031 vecLenB = umin(vecLenB, MAX_TILE_VECLEN);
1032 vecLenC = (transB) ? vecLenB : vecLenA;
1033
1034 initTile(&tileSet->rectA, "a", (unsigned int)subdims[1].y,
1035 (unsigned int)subdims[1].bwidth, vecLenA, dtype,
1036 storType, transA, false);
1037
1038 initTile(&tileSet->squareA, "a", (unsigned int)subdims[1].y,
1039 (unsigned int)subdims[1].y, vecLenA, dtype, storType,
1040 transA, false);
1041
1042 initTile(&tileSet->origB, "b", (unsigned int)subdims[1].bwidth,
1043 (unsigned int)subdims[1].x, vecLenB, dtype, storType,
1044 !transB, false);
1045
1046 initTile(&tileSet->bStage2, "b", (unsigned int)subdims[1].y,
1047 (unsigned int)subdims[1].x, vecLenB, dtype, storType,
1048 !transB, false);
1049
1050 initTile(&tileSet->bAsSqA, "b", (unsigned int)subdims[1].y,
1051 (unsigned int)subdims[1].y, vecLenB, dtype, storType,
1052 transA, false);
1053
1054 initTile(&tileSet->bAsC, "b", (unsigned int)subdims[1].y,
1055 (unsigned int)subdims[1].x, vecLenB, dtype, storType,
1056 gset->tileCY.trans, false);
1057
1058 initTile(&gset->tileA, "a", rowsA, colsA,
1059 vecLenA, dtype, storType, transA, false);
1060
1061 initTile(&gset->tileBX, "b", rowsB, colsB,
1062 vecLenB, dtype, storType, !transB, false);
1063
1064 initTile(&gset->tileCY, "c", rowsC, colsC,
1065 vecLenC, dtype, storType, !transB, false);
1066
1067 tileSet->A = gset->tileA;
1068 tileSet->B = gset->tileBX;
1069 }
1070
1071 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)1072 generator(
1073 char *buf,
1074 size_t buflen,
1075 const struct SubproblemDim *subdims,
1076 const struct PGranularity *pgran,
1077 void *extra)
1078 {
1079 char tmp[1024];
1080 struct KgenContext *ctx;
1081 ssize_t ret;
1082 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
1083 DataType dtype = kextra->dtype;
1084 KernelExtraFlags kflags = kextra->flags;
1085 CLBLASKernExtra extraNew;
1086 BlasGenSettings gset;
1087 TileMulOpts mulOpts;
1088 const char *ptrName;
1089 UpdateResultFlags upFlags = 0;
1090 TilePostFetchPrivate pfPriv;
1091 unsigned int l1Pans;
1092 bool b;
1093 Tile parTile;
1094 TrsmExtraParams *extraParams = (TrsmExtraParams *)kextra->solverPriv;
1095 int ldsLarge, lds_diagonal;
1096 bool isInline;
1097 TileSet tileSet;
1098 char copy2LDSFuncName[FUNC_NAME_MAXLEN];
1099 TailStatus tailStatus = 0;
1100 FetchAddrMode addrMode = 0;
1101 bool tailM = ((kflags & KEXTRA_TAILS_M) != 0);
1102 bool tailN = ((kflags & KEXTRA_TAILS_N) != 0);
1103 size_t alignK;
1104
1105 if (pgran->wgDim != 1) {
1106 return -EINVAL;
1107 }
1108
1109 l1Pans = (unsigned int)(subdims[0].x / subdims[1].x);
1110
1111 memset(&gset, 0, sizeof(gset));
1112 gset.flags = BGF_WHOLE_A | BGF_EXPLICIT_INLINE | BGF_UPTRS;
1113 memcpy(gset.subdims, subdims, sizeof(SubproblemDim) * 2);
1114 // there is not need in block structure along K
1115 gset.subdims[0].bwidth = gset.subdims[1].bwidth;
1116 subdims = gset.subdims;
1117
1118 /*
1119 * Since tiles are changed dynamically, e. g. in the main tilemul
1120 * loop they are rectangular, but at the second stage both A and B
1121 * tile storages are used for square tiles. One must adjust physical
1122 * vectorization accordindly, so as vector length might not be
1123 * greater than linear size of any tile
1124 */
1125 memcpy(&extraNew, kextra, sizeof(extraNew));
1126 extraNew.vecLenA = umin(kextra->vecLenA, (unsigned int)subdims[1].y);
1127 extraNew.vecLenB = umin(kextra->vecLenB, (unsigned int)subdims[1].y);
1128
1129 gset.pgran = pgran;
1130 gset.kextra = &extraNew;
1131 initKernelVarNames(&gset.varNames);
1132
1133 // multiplication options
1134 mulOpts.memA = CLMEM_GLOBAL_MEMORY;
1135 mulOpts.memB = CLMEM_GLOBAL_MEMORY;
1136 mulOpts.core = (kextra->flags & KEXTRA_ENABLE_MAD) ? TILEMUL_MAD :
1137 TILEMUL_MULADD;
1138 mulOpts.postFetch = NULL;
1139 mulOpts.flags = kextraToTilemulFlags(CLBLAS_TRSM, kflags);
1140 mulOpts.flags |= TILEMUL_EXTERN_RDECL | TILEMUL_NOT_INC_K;
1141 mulOpts.fctx = createFetchContext();
1142 if (mulOpts.fctx == NULL) {
1143 return -ENOMEM;
1144 }
1145
1146 disableFetchOptLevels(mulOpts.fctx, FOPTLEV_TMP_COORD_PRECOMPUTING);
1147
1148 isInline = (gset.flags & BGF_EXPLICIT_INLINE);
1149
1150 initTiles(&gset, &tileSet, subdims, kflags, dtype,
1151 PRIV_STORAGE_VARIABLE_SET);
1152
1153 ctx = createKgenContext(buf, buflen, true);
1154 if (ctx == NULL) {
1155 destroyFetchContext(mulOpts.fctx);
1156 return -ENOMEM;
1157 }
1158
1159 kgenAddStmt(ctx, "#pragma OPENCL EXTENSION cl_amd_printf : enable\n\n");
1160
1161 b = isDoubleBasedType(dtype);
1162 kgenDeclareUptrs(ctx, b);
1163 if (isComplexType(dtype)) {
1164 genComplexMathOperators(ctx, dtype);
1165 }
1166 if(!isInline) {
1167 genTileInverting(ctx, &gset, &tileSet);
1168 }
1169
1170 if ( extraParams->ldsUse != LDS_NO_USE ) {
1171 SubproblemDim sdims;
1172 DBlockCopyFlags flags;
1173 unsigned int vecLen;
1174
1175 if (!isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B)) {
1176 sdims.x = gset.subdims[1].bwidth * extraParams->unrollingFactor;
1177 sdims.y = gset.subdims[0].x;
1178 }
1179 else {
1180 sdims.x = gset.subdims[0].x;
1181 sdims.y = gset.subdims[1].bwidth * extraParams->unrollingFactor;
1182 }
1183
1184 vecLen = getVecLen(&gset, CLBLAS_TRSM, MATRIX_B);
1185 flags = (vecLen < 4) ? DBLOCK_COPY_NOT_VECTORIZE : 0;
1186 copyDataBlockGen(ctx, &sdims, gset.pgran, dtype,
1187 DBLOCK_GLOBAL_TO_LOCAL, flags);
1188 kgenAddBlankLine(ctx);
1189 kgenGetLastFuncName(copy2LDSFuncName, FUNC_NAME_MAXLEN, ctx);
1190 }
1191
1192 declareTrxmKernel(ctx, dtype, pgran, kflags, CLBLAS_TRSM, "Cached", false,
1193 true);
1194 kgenBeginFuncBody(ctx);
1195
1196 declareLocalVariables(ctx, &gset, &parTile, extraParams);
1197 if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
1198 kgenAddStmt(ctx, "A += offA;\n");
1199 }
1200 genTrxmBMatrShift(ctx, kflags, false);
1201
1202 ptrName = dtypeUPtrField(dtype);
1203
1204 sprintf(tmp, "uB.%s = B;\n\n", ptrName);
1205 kgenAddStmt(ctx, tmp);
1206
1207 // external loop
1208 sprintf(tmp, "for (m0 = 0; m0 < M; m0 += %lu)", subdims[0].y);
1209 kgenBeginBranch(ctx, tmp);
1210 genZeroTile(ctx, &gset.tileCY);
1211 genSetupCoords(ctx, &gset, BLOCK_UPDATE);
1212
1213 kgenAddStmt(ctx, "// Stage 1. Multiply and update with large blocks\n");
1214
1215 gset.tileA = tileSet.rectA;
1216 gset.tileBX = tileSet.origB;
1217
1218 if (!isMatrixUpper(kflags) && tailM) {
1219 addrMode |= FETCH_ADDR_A_CYCLICAL;
1220 setFetchAddrMode(mulOpts.fctx, addrMode);
1221 }
1222
1223 ldsLarge = ((extraParams->ldsUse & LDS_USE_LARGE) != 0);
1224 alignK = subdims[1].bwidth;
1225 if (ldsLarge) {
1226 alignK *= extraParams->unrollingFactor;
1227 }
1228
1229 if (ldsLarge) {
1230 const char *oldCoordB;
1231 FetchAddrMode bamode = addrMode | FETCH_ADDR_K_RELATIVE;
1232 bool withSkew;
1233
1234 withSkew = useSkewedFetchB(&gset);
1235 if (!withSkew) {
1236 bamode |= FETCH_ADDR_B_RELATIVE;
1237 }
1238 else {
1239 bamode |= FETCH_ADDR_B_CYCLICAL;
1240 }
1241
1242 setFetchAddrMode(mulOpts.fctx, bamode);
1243
1244 if (tailN) {
1245 /*
1246 * Conditional branch for those items which hit into
1247 * matrix B with their matrix coordinates
1248 */
1249 sprintf(tmp, "if ((gid + 1) * %lu < N)", subdims[0].x);
1250 kgenBeginBranch(ctx, tmp);
1251 }
1252
1253 if (isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A)) {
1254 kgenPrintf(ctx, "uA.%s = A + k0 * lda;\n", ptrName);
1255 }
1256 else {
1257 kgenPrintf(ctx, "uA.%s = A + k0;\n", ptrName);
1258 }
1259
1260 if (withSkew) {
1261 unsigned int bwidthOld;
1262
1263 oldCoordB = gset.varNames.coordB;
1264 gset.varNames.coordB = "skewX";
1265 bwidthOld = gset.subdims[0].bwidth;
1266 gset.subdims[0].bwidth = (parTile.trans) ? parTile.nrRows :
1267 parTile.nrCols;
1268 gset.subdims[0].bwidth = bwidthOld;
1269 }
1270
1271 genInternalLoopCtl(ctx, subdims, kflags, alignK, alignK);
1272 genPreloadedTileMul(ctx, &gset, &mulOpts, &parTile, copy2LDSFuncName);
1273 genInternalLoopEnd(ctx); // loop over K
1274
1275 if (withSkew) {
1276 gset.varNames.coordB = oldCoordB;
1277 setFetchAddrMode(mulOpts.fctx, bamode & ~FETCH_ADDR_B_CYCLICAL);
1278 // deliver from skew in the result before proceed to the next stage
1279 genTileCyclicalShift(ctx, &gset);
1280 }
1281
1282 if (tailN) {
1283 kgenEndBranch(ctx, NULL);
1284 kgenBeginBranch(ctx, "else");
1285 }
1286
1287 setFetchAddrMode(mulOpts.fctx, addrMode);
1288 }
1289
1290 if (!ldsLarge || tailN) {
1291 genCheckShiftTailB(ctx, &gset, 0, &tailStatus);
1292 if ((kflags & KEXTRA_TAILS_N_LOWER) && !tailStatus) {
1293 addrMode |= FETCH_ADDR_B_CYCLICAL;
1294 setFetchAddrMode(mulOpts.fctx, addrMode);
1295 }
1296
1297 if (tailN) {
1298 sprintfHitMatrixCond(tmp, MATRIX_B, "if (", ")");
1299 kgenBeginBranch(ctx, tmp);
1300 }
1301
1302 genInternalLoopCtl(ctx, subdims, kflags, subdims[1].bwidth, alignK);
1303 tileMulGen(ctx, &gset, &mulOpts);
1304 genInternalLoopEnd(ctx); // loop over K
1305
1306 if (tailN) {
1307 kgenEndBranch(ctx, NULL);
1308 }
1309
1310 if (extraParams->ldsUse & LDS_USE_LARGE) {
1311 kgenEndBranch(ctx, NULL);
1312 }
1313 }
1314
1315 sprintf(tmp, "uA.%s = A;\n\n", ptrName);
1316 kgenAddStmt(ctx, tmp);
1317
1318 // processing tails along update dimension
1319 if (isMatrixUpper(kflags) &&
1320 ((kflags & KEXTRA_TAILS_K_LOWER) ||
1321 (ldsLarge && extraParams->unrolledTail))) {
1322
1323 unsigned int tailChunks;
1324
1325 tailChunks = (extraParams->ldsUse & LDS_USE_LARGE) ?
1326 extraParams->unrolledTail : 1;
1327
1328 if (tailN) {
1329 char hitCond[1024];
1330
1331 sprintfHitMatrixCond(hitCond, MATRIX_B, "(", ")");
1332 sprintf(tmp, "if ((currM + %lu < M) && %s)",
1333 subdims[0].y, hitCond);
1334 }
1335 else {
1336 sprintf(tmp, "if (currM + %lu < M)", subdims[0].y);
1337 }
1338 kgenBeginBranch(ctx, tmp);
1339
1340 if (kflags & KEXTRA_TAILS_K_LOWER) {
1341 setFetchAddrMode(mulOpts.fctx, addrMode | FETCH_ADDR_K_CYCLICAL);
1342 setFetchHandler(&mulOpts, &gset, defaultTilePostFetch, &pfPriv);
1343 }
1344 if (tailChunks > 1) {
1345 mulOpts.flags &= ~TILEMUL_NOT_INC_K;
1346 sprintf(tmp, "for (uint k1 = 0; k1 < %u; k1++)", tailChunks);
1347 kgenBeginBranch(ctx, tmp);
1348 }
1349
1350 addrMode |= FETCH_ADDR_B_CYCLICAL;
1351 setFetchAddrMode(mulOpts.fctx, addrMode);
1352 tileMulGen(ctx, &gset, &mulOpts);
1353 if (tailChunks > 1) {
1354 kgenEndBranch(ctx, NULL);
1355 mulOpts.flags |= TILEMUL_NOT_INC_K;
1356 }
1357
1358 kgenEndBranch(ctx, NULL);
1359 }
1360
1361 gset.tileA = tileSet.squareA;
1362
1363 kgenAddStmt(ctx, "\n/*\n"
1364 " * Stage 2. A part of work items multiply got result on "
1365 "a respective\n"
1366 " * inverted diagonal block, and the remaining ones wait. "
1367 "Then they perform\n"
1368 " * one step of further intermediate result evaluation as "
1369 "multiplying tile by tile.\n"
1370 " * It continues until the whole panel of the "
1371 "matrix A is processed\n"
1372 " */\n");
1373
1374 // one must deal further with square blocks strictly
1375 gset.subdims[0].bwidth = gset.subdims[1].bwidth = gset.subdims[1].y;
1376
1377 sprintf(tmp, "for (m1 = 0; m1 < %lu; m1++)", subdims[0].y / subdims[1].y);
1378 kgenBeginBranch(ctx, tmp);
1379
1380 if (extraParams->ldsUse & LDS_USE_DIAGONAL) {
1381 sprintf(tmp, "const int bid = lid %% %u;\n\n",
1382 l1Pans);
1383 kgenAddStmt(ctx, tmp);
1384 }
1385
1386 /*
1387 * Update the intermediate result multiply on the inverted diagonal tile,
1388 * and write back
1389 */
1390 genSetupCoords(ctx, &gset, TILE_UPDATE);
1391
1392 sprintfStage2Condition(tmp, &gset, 0);
1393 ret = kgenBeginBranch(ctx, tmp);
1394
1395 upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
1396 upFlags |= tailStatusToUpresFlags(tailStatus);
1397 upFlags |= UPRES_PRIV_DEST | UPRES_WITH_BETA;
1398 genUpdateIntermResult(ctx, &gset, false, upFlags);
1399
1400 kgenAddBlankLine(ctx);
1401
1402 lds_diagonal = ((extraParams->ldsUse & LDS_USE_DIAGONAL) &&
1403 (kflags & (KEXTRA_COLUMN_MAJOR)) == 0 &&
1404 !(tailM || tailN) &&
1405 !(upFlags & UPRES_NO_VECTORIZATION) &&
1406 !isComplexType(kextra->dtype));
1407
1408 /*
1409 * it's needed now to adjust addressing mode of A so as to don't
1410 * exceed the bound of A
1411 */
1412 if (tailM) {
1413 setFetchAddrMode(mulOpts.fctx,
1414 addrMode | FETCH_ADDR_A_CYCLICAL |
1415 FETCH_ADDR_K_CYCLICAL);
1416 extraNew.flags |= KEXTRA_TAILS_K_LOWER;
1417 }
1418
1419 genMulOnDiagonalTile(ctx, &gset, &tileSet, &mulOpts);
1420 gset.tileBX = tileSet.bStage2;
1421 if (tailM) {
1422 setFetchHandler(&mulOpts, &gset, defaultTilePostFetch, &pfPriv);
1423 }
1424
1425 kgenAddStmt(ctx, "// Write back the given result\n");
1426
1427 upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
1428 upFlags |= tailStatusToUpresFlags(tailStatus);
1429
1430 if (lds_diagonal) {
1431 sprintf(tmp, "tmpB[%%u * %u + bid]", l1Pans);
1432 }
1433
1434 genResultUpdateWithFlags(ctx, CLBLAS_TRSM, &gset, upFlags,
1435 NULL, NULL, lds_diagonal ? tmp : NULL);
1436
1437 kgenEndBranch(ctx, NULL); // multiply on the inverted tile path
1438 kgenAddBarrier(ctx, CLK_GLOBAL_MEM_FENCE);
1439
1440 // continue the tile update
1441 kgenAddBlankLine(ctx);
1442 sprintfStage2Condition(tmp, &gset, 1);
1443 kgenBeginBranch(ctx, tmp);
1444 genCheckShiftTailB(ctx, &gset, 0, &tailStatus);
1445 if (lds_diagonal) {
1446 // TODO: add here storing to LDS as well
1447 }
1448 else {
1449 addrMode |= FETCH_ADDR_B_CYCLICAL;
1450 setFetchAddrMode(mulOpts.fctx, addrMode);
1451 tileMulGen(ctx, &gset, &mulOpts);
1452 }
1453 kgenEndBranch(ctx, NULL); // tile update path
1454 kgenAddBarrier(ctx, CLK_GLOBAL_MEM_FENCE);
1455
1456 kgenEndBranch(ctx, NULL); // second stage loop
1457
1458 if (isMatrixUpper(kflags)) {
1459 sprintf(tmp, "currM -= %lu;\n", subdims[0].y);
1460 kgenAddStmt(ctx, tmp);
1461 }
1462
1463 kgenEndBranch(ctx, NULL); // loop over M
1464
1465 ret = kgenEndFuncBody(ctx);
1466
1467 if (!ret) {
1468 ret = (ssize_t)kgenSourceSize(ctx) + 1;
1469 }
1470
1471 destroyFetchContext(mulOpts.fctx);
1472 destroyKgenContext(ctx);
1473
1474 return (ret < 0) ? -EOVERFLOW : ret;
1475 }
1476
1477 static bool
isFitToLDS(SubproblemDim * dim,DataType dtype,cl_ulong ldsSize,const void * kernelArgs)1478 isFitToLDS(
1479 SubproblemDim *dim,
1480 DataType dtype,
1481 cl_ulong ldsSize,
1482 const void *kernelArgs)
1483 {
1484 (void)dim;
1485 (void)dtype;
1486 (void)ldsSize;
1487 (void)kernelArgs;
1488
1489 return true;
1490 }
1491
1492 static SolverFlags
solverFlags(void)1493 solverFlags(void)
1494 {
1495 return (SF_WSPACE_1D | SF_TOP_INPUT_SQUARE_BLOCKS);
1496 }
1497
1498 static void
assignKargs(KernelArg * args,const void * params,const void * extra)1499 assignKargs(KernelArg *args, const void *params, const void *extra)
1500 {
1501 const CLBlasKargs *blasArgs = (const CLBlasKargs*)params;
1502 KernelExtraFlags kflags = ((const CLBLASKernExtra*)extra)->flags;
1503 int idx = 7;
1504
1505 initSizeKarg(&args[0], blasArgs->M);
1506 initSizeKarg(&args[1], blasArgs->N);
1507 assignScalarKarg(&args[2], &(blasArgs->alpha), blasArgs->dtype);
1508 initMemobjKarg(&args[3], blasArgs->A, NULL, 0, 0);
1509 initSizeKarg(&args[4], blasArgs->lda.matrix);
1510 initMemobjKarg(&args[5], blasArgs->B, NULL, 0, 0);
1511 initSizeKarg(&args[6], blasArgs->ldb.matrix);
1512 if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
1513 initSizeKarg(&args[idx++], blasArgs->offA);
1514 }
1515 if (kflags & KEXTRA_BX_OFF_NOT_ZERO) {
1516 initSizeKarg(&args[idx], blasArgs->offBX);
1517 }
1518 }
1519
1520 static void
fixupArgs(void * args,SubproblemDim * subdims,void * extra)1521 fixupArgs(void *args, SubproblemDim *subdims, void *extra)
1522 {
1523 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
1524 CLBlasKargs *kargs = (CLBlasKargs*)args;
1525 TrsmExtraParams *extraParams = (TrsmExtraParams *)kextra->solverPriv;
1526 size_t loadBatch;
1527 unsigned int wgSize;
1528 unsigned int workRatio;
1529 unsigned int ldsUse = LDS_NO_USE;
1530 KernelExtraFlags kflags = kextra->flags;
1531 SubproblemDim globDim;
1532 bool isAmdGPU;
1533
1534 /*
1535 * Calculate size of the batch loaded from global to local memory
1536 * at each iteration of the stage 1. Choose such unrolling factor
1537 * that allow each work item to load at least 16 bytes that provides
1538 * efficient global memory access
1539 */
1540 loadBatch = subdims[0].x * subdims[1].bwidth * dtypeSize(kargs->dtype);
1541 wgSize = (unsigned int)((subdims[0].x / subdims[1].itemX) *
1542 (subdims[0].y / subdims[1].itemY));
1543 if (loadBatch < wgSize) {
1544 workRatio = 1;
1545 }
1546 else {
1547 workRatio = 16 / ((unsigned int)loadBatch / wgSize);
1548 if (!workRatio) {
1549 workRatio = 1;
1550 }
1551 }
1552
1553 #ifndef NDEBUG
1554 {
1555 const char *envImpl = getenv("AMD_CLBLAS_TRSM_LDSUSE");
1556
1557 if (envImpl != NULL) {
1558 unsigned int w = atoi(envImpl);
1559 ldsUse = w % 10;
1560 w = w / 10;
1561 workRatio = w > 0 ? w : workRatio;
1562 }
1563 }
1564 #endif
1565
1566 ldsUse = LDS_NO_USE;
1567 isAmdGPU = ((kflags & KEXTRA_VENDOR_AMD) != 0);
1568 if ((isAmdGPU && !(kflags & (KEXTRA_TAILS_K_LOWER | KEXTRA_TAILS_M_LOWER)))
1569 || (!isAmdGPU && !(kflags & KEXTRA_TAILS_M))) {
1570
1571 ldsUse = LDS_USE_LARGE;
1572 }
1573
1574 kargsToProbDims(&globDim, CLBLAS_TRSM, args, false);
1575 extraParams->ldsUse = ldsUse;
1576 extraParams->unrollingFactor = workRatio;
1577 extraParams->unrolledTail = (unsigned int)(((globDim.bwidth %
1578 (subdims[1].bwidth * workRatio)) + subdims[1].bwidth - 1) /
1579 subdims[1].bwidth);
1580
1581 fixupTrxmKargs(kargs);
1582 }
1583
1584 static bool
checkCalcDecompDedicated(PGranularity * pgran,SubproblemDim * subdims,unsigned int subdimsNum,DataType dtype,int check)1585 checkCalcDecompDedicated(
1586 PGranularity *pgran,
1587 SubproblemDim *subdims,
1588 unsigned int subdimsNum,
1589 DataType dtype,
1590 int check)
1591 {
1592 bool ret = true;
1593
1594 DUMMY_ARG_USAGE(subdimsNum);
1595
1596 if (check == PGRAN_CHECK) {
1597 unsigned int minSize, maxSize;
1598
1599 maxSize = (dtype == TYPE_COMPLEX_DOUBLE) ? 4 : 8;
1600 minSize = (dtype == TYPE_COMPLEX_DOUBLE) ? 1 : 2;
1601 ret = decompSanityCheck(subdims, minSize, maxSize, 24, dtype, true);
1602 ret = ret && (subdims[0].bwidth == subdims[1].bwidth);
1603 ret = ret && (pgran->wgSize[0] == 64);
1604 }
1605 else {
1606 calcPgranDedicated(pgran, subdims, -1, 3);
1607 }
1608
1609 return ret;
1610 }
1611
1612 void
initTrsmLdsLessCachedPattern(MemoryPattern * mempat)1613 initTrsmLdsLessCachedPattern(MemoryPattern *mempat)
1614 {
1615 mempat->name = "2-staged cached global memory based block trsm";
1616 mempat->nrLevels = 2;
1617 mempat->cuLevel = 0;
1618 mempat->thLevel = 0;
1619 mempat->sops = &trsmSops;
1620
1621 mpatExtra.aMset = CLMEM_LEVEL_L1;
1622 mpatExtra.bMset = CLMEM_LEVEL_L1;
1623 mpatExtra.mobjA = CLMEM_BUFFER;
1624 mpatExtra.mobjB = CLMEM_BUFFER;
1625 mempat->extra = &mpatExtra;
1626 }
1627
1628 #if 0
1629
1630 static int
1631 getDefaultDecomp(
1632 PGranularity *pgran,
1633 SubproblemDim *subdims,
1634 unsigned int subdimsNum,
1635 void * pArgs)
1636 {
1637 pgran->wgDim = 1;
1638 pgran->wgSize[0] = 64;
1639 pgran->wgSize[1] = 1;
1640
1641 subdims[0].x = subdims[0].itemX = 32;
1642 subdims[0].y = 64;
1643 subdims[0].itemY = SUBDIM_UNUSED;
1644 subdims[0].bwidth = subdims[1].bwidth = 4;
1645 subdims[1].x = subdims[1].itemX = 8;
1646 subdims[1].y = subdims[1].itemY = 4;
1647 }
1648
1649 #endif
1650