1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * TRSM generator with support of cached reads from the global memory
20  */
21 
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <stdlib.h>
26 #include <clblas_stddef.h>
27 #include <clBLAS.h>
28 #include <blas_mempat.h>
29 #include <clkern.h>
30 #include <clblas-internal.h>
31 #include <matrix_props.h>
32 #include <matrix_dims.h>
33 
34 #include "dblock_kgen.h"
35 #include "kerngen.h"
36 #include "blas_kgen.h"
37 #include "gen_helper.h"
38 #include "trxm_common.h"
39 #include "trsm_kgen.h"
40 #include "legacy/blas_kgen_legacy.h"
41 
42 typedef enum LdsUseFlags {
43     LDS_NO_USE = 0,
44     LDS_USE_LARGE = 0x1,
45     LDS_USE_DIAGONAL = 0x2
46 } LdsUseFlags;
47 
48 typedef struct TrsmExtraParams {
49     int unrollingFactor;
50     unsigned int unrolledTail;
51     LdsUseFlags ldsUse;
52 } TrsmExtraParams;
53 
54 enum TrsmStage {
55     BLOCK_UPDATE,
56     TILE_UPDATE
57 };
58 
59 static CLBLASMpatExtra mpatExtra;
60 
61 static ssize_t
62 generator(
63    char *buf,
64    size_t buflen,
65    const struct SubproblemDim *subdims,
66    const struct PGranularity *pgran,
67    void *extra);
68 
69 static bool
70 isFitToLDS(
71     SubproblemDim *dim,
72     DataType dtype,
73     cl_ulong ldsSize,
74     const void *kernelArgs);
75 
76 static SolverFlags
77 solverFlags(void);
78 
79 static void
80 assignKargs(KernelArg *args, const void *params, const void *extra);
81 
82 static void
83 fixupArgs(void *args, SubproblemDim *subdims, void *extra);
84 
85 static bool
86 checkCalcDecompDedicated(
87     PGranularity *pgran,
88     SubproblemDim *subdims,
89     unsigned int subdimsNum,
90     DataType dtype,
91     int check);
92 
93 #if 0
94 static int
95 getDefaultDecomp(
96     PGranularity *pgran,
97     SubproblemDim *subdims,
98     unsigned int subdimsNum,
99     void * pArgs);
100 #endif
101 
102 static SolverOps trsmSops = {
103     generator,
104     assignKargs,
105     isFitToLDS,
106     NULL,
107     NULL,
108     NULL,
109     NULL,
110     solverFlags,
111     fixupArgs,
112     NULL,//getDefaultDecomp
113     checkCalcDecompDedicated,
114     NULL,
115     NULL
116 };
117 
118 // The struct for storage tails
119 typedef struct TileSet
120 {
121     Tile rectA;     // The rectangular tile A for the update loop at stage 1
122     Tile squareA;   // The square tile for the stage 2
123     Tile origB;     // The rectangular tile B for the update loop at the stage 1
124     Tile bStage2;   // The rectangular tile B for the update loop at thestage 2
125     Tile bAsSqA;    // Descriptor for holding square tile A in the storage of B
126     Tile bAsC;      // Descriptor for holding tile C in the storage of B
127     // the entire tile A matching the storage declared in the kernel
128     Tile A;
129     // the entire tile B matching the storage declared in the kernel
130     Tile B;
131 } TileSet;
132 
133 
134 static bool
useSkewedFetchB(const BlasGenSettings * gset)135 useSkewedFetchB(const BlasGenSettings *gset)
136 {
137     KernelExtraFlags kflags = gset->kextra->flags;
138     TrsmExtraParams *extraParams = (TrsmExtraParams*)gset->kextra->solverPriv;
139     bool ret = false;
140 
141     if (extraParams->ldsUse & LDS_USE_LARGE) {
142         ret = !isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B);
143     }
144 
145     return ret;
146 }
147 
148 static void
restoreTile(Tile * dst,const Tile * src)149 restoreTile(Tile* dst, const Tile* src)
150 {
151     dst->baseName = src->baseName;
152     dst->vecLen = src->vecLen;
153     dst->storType = src->storType;
154 }
155 
156 static Tile
substituteTile(Tile * dst,const Tile * src)157 substituteTile(Tile* dst, const Tile* src)
158 {
159     Tile tmp;
160 
161     restoreTile(&tmp, dst);
162     restoreTile(dst, src);
163 
164     return tmp;
165 }
166 
167 static void
sprintfInvertedElement(Kstring * elem,const Tile * tile,unsigned int row,unsigned int col,unsigned int len,bool isU)168 sprintfInvertedElement(
169     Kstring *elem,
170     const Tile *tile,
171     unsigned int row,
172     unsigned int col,
173     unsigned int len,
174     bool isU)
175 {
176     if (isU) {
177         row = tile->nrRows - row - 1;
178         col = tile->nrCols - col - len;
179     }
180 
181     sprintfTileElement(elem, tile, row, col, len);
182 }
183 
184 static void
genTileInverting(struct KgenContext * ctx,const BlasGenSettings * gset,const TileSet * tileSet)185 genTileInverting(
186     struct KgenContext *ctx,
187     const BlasGenSettings *gset,
188     const TileSet *tileSet)
189 {
190     char tmp[1024];
191     const CLBLASKernExtra *kextra = gset->kextra;
192     KernelExtraFlags kflags = kextra->flags;
193     DataType dtype = kextra->dtype;
194     const SubproblemDim *dim = &gset->subdims[1];
195     unsigned int accLen;
196     unsigned int i, j, k;
197     Tile srcTile;
198     Tile dstTile;
199     bool isU, isComplex;
200     bool isInlined = gset->flags & BGF_EXPLICIT_INLINE;
201     const char* typeNameA;
202     const char* typeNameB;
203 
204     memcpy(&srcTile, &tileSet->bAsSqA, sizeof(srcTile));
205     memcpy(&dstTile, &tileSet->squareA, sizeof(dstTile));
206 
207     getVectorTypeName(kextra->dtype, dstTile.vecLen, &typeNameA, NULL);
208     getVectorTypeName(kextra->dtype, srcTile.vecLen, &typeNameB, NULL);
209     isU = isMatrixUpper(kflags);
210     isComplex = isComplexType(dtype);
211 
212     if (isComplex || dstTile.trans) {
213         accLen = 1;
214     }
215     else {
216         accLen = umin(srcTile.vecLen, dstTile.vecLen);
217         accLen = umin(accLen, srcTile.nrCols);
218     }
219 
220     if (!isInlined) {
221         dstTile.baseName = "a";
222         srcTile.baseName = "b";
223         sprintf(tmp, "void\n"
224                      "invertTile(%s *a, %s *b)\n",
225                 typeNameA, typeNameB);
226         kgenDeclareFunction(ctx, tmp);
227         kgenBeginFuncBody(ctx);
228     }
229     else {
230         kgenAddStmt(ctx, "// Invert tile\n");
231     }
232 
233     // made destination block unit
234     genZeroTile(ctx, &dstTile);
235     for (i = 0; i < dim->y; i++) {
236         genSetUnitInTile(ctx, &dstTile, i, i);
237     }
238     kgenAddBlankLine(ctx);
239 
240     for (i = 0; i < dim->y; i++) {
241         Kstring src, srcDiag, dst, dstLast;
242 
243         // current source diagonal element
244         sprintfInvertedElement(&srcDiag, &srcTile, i, i, 1, isU);
245         for (j = i; j < dim->y; j++) {
246             // current source non diagonal element
247             if (i) {
248                 sprintfInvertedElement(&src, &srcTile, j, i - 1, 1, isU);
249             }
250 
251             for (k = 0; k < dim->y; k += accLen) {
252                 // current updated vectorized element
253                 sprintfInvertedElement(&dst, &dstTile, j, k, accLen, isU);
254 
255                 // update
256                 if (i) {
257                     // last updated vectorized element
258                     sprintfInvertedElement(&dstLast, &dstTile, i - 1, k,
259                                            accLen, isU);
260                     if (isComplex) {
261                         sprintf(tmp, "%s -= mul(%s, %s);\n",
262                                 dst.buf, dstLast.buf, src.buf);
263                     }
264                     else {
265                         sprintf(tmp, "%s -= %s * %s;\n",
266                                 dst.buf, dstLast.buf, src.buf);
267                     }
268                     kgenAddStmt(ctx, tmp);
269                 }
270 
271                 // divide on the diagonal element
272                 if (j == i) {
273                     if (isComplex) {
274                         sprintf(tmp, "%s = div(%s, %s);\n",
275                                 dst.buf, dst.buf, srcDiag.buf);
276                     }
277                     else {
278                         sprintf(tmp, "%s /= %s;\n", dst.buf, srcDiag.buf);
279                     }
280                     kgenAddStmt(ctx, tmp);
281                 }
282             }
283         }
284         if (i != dim->y - 1) {
285             kgenAddBlankLine(ctx);
286         }
287     }
288 
289     if (!isInlined) {
290         kgenEndFuncBody(ctx);
291     }
292     kgenAddBlankLine(ctx);
293 
294 }
295 
296 static void
declareLocalVariables(struct KgenContext * ctx,const BlasGenSettings * gset,Tile * parTile,TrsmExtraParams * extraParams)297 declareLocalVariables(
298     struct KgenContext *ctx,
299     const BlasGenSettings *gset,
300     Tile* parTile,
301     TrsmExtraParams * extraParams)
302 {
303     char tmp[1024];
304     const SubproblemDim *dims = gset->subdims;
305     const char* parTileTypeName = NULL;
306     bool trb = isMatrixAccessColMaj(CLBLAS_TRSM, gset->kextra->flags,
307                                    MATRIX_B);
308     unsigned int locWidth;
309     unsigned int tsize;
310     unsigned int parTileSize;
311     unsigned int l1Pans;
312     unsigned int step;
313 
314     kgenAddStmt(ctx,
315                  "const int lid = get_local_id(0);\n"
316                  "const int gid = get_group_id(0);\n"
317                  "GPtr uA, uB;\n"
318                  "uint coordA, coordB;\n"
319                  "uint m0 = 0, k0, m1;\n");
320 
321     if (isMatrixUpper(gset->kextra->flags)) {
322         sprintf(tmp, "uint currM = (M - 1) / %lu * %lu;\n",
323                 dims[0].y, dims[0].y);
324         kgenAddStmt(ctx, tmp);
325     }
326 
327     /*
328      * Declare private blocks.
329      * The region 'b' stores in different time tiles of both
330      * the input matrices and the result
331      */
332 
333     declareTileStorages(ctx, gset);
334 
335     *parTile = gset->tileBX;
336 
337     if (extraParams->ldsUse) {
338         tsize = dtypeSize(gset->kextra->dtype);
339         l1Pans = (unsigned int)(dims[0].x / dims[1].x);
340 
341         parTile->vecLen = (trb) ? (unsigned int)dims[1].x
342                                 : (unsigned int)dims[1].bwidth;
343         parTile->vecLen = umin(parTile->vecLen, sizeof(cl_float4) / tsize);
344         parTile->trans = trb;
345 
346        /*
347         * Allocate enough space in the local area to fit several tiles
348         * at the stage1 (according to the unrolled factor) and one tile
349         * at the stage2
350         */
351 
352         locWidth = (unsigned int)dims[1].bwidth * extraParams->unrollingFactor;
353         if (extraParams->ldsUse & LDS_USE_DIAGONAL) {
354             locWidth = umax(locWidth, (unsigned int)dims[1].y);
355         }
356         if (trb) {
357             parTile->nrRows = locWidth;
358             parTile->nrCols = (unsigned int)dims[0].x;
359             step = (unsigned int)dims[1].x / parTile->vecLen;
360         }
361         else {
362             parTile->nrRows = (unsigned int)dims[0].x;
363             parTile->nrCols = locWidth;
364             step = (unsigned int)dims[1].x * locWidth / parTile->vecLen;
365         }
366 
367         parTileSize = tileVectorsNum(parTile);
368 
369         getVectorTypeName(gset->kextra->dtype, parTile->vecLen,
370                           &parTileTypeName, NULL);
371 
372         sprintf(tmp, "__local %s tmpB[%i];\n"
373                      "LPtr lB;\n"
374                      "LPtr lBMain = {(__local float*)(tmpB + lid %% %u * %u)};\n",
375                 parTileTypeName, parTileSize, l1Pans, step);
376         kgenAddStmt(ctx, tmp);
377 
378         if (useSkewedFetchB(gset)) {
379             kgenPrintf(ctx, "const uint skewX = lid %% %u %% %lu;\n",
380                        l1Pans, gset->subdims[1].x);
381         }
382     }
383 
384     kgenAddBlankLine(ctx);
385 }
386 
387 /*
388  * Generate cyclical tile shifting so as to convert the skewed
389  * storing to "one-to-one", i. e. the first element in the tile
390  * matches to the first element of the respective tile in the
391  * output matrix.
392  */
393 static void
genTileCyclicalShift(struct KgenContext * ctx,BlasGenSettings * gset)394 genTileCyclicalShift(struct KgenContext *ctx, BlasGenSettings *gset)
395 {
396     const char *tname;
397     Kstring k1, k2, *src, *dst, *ktmp;
398     unsigned int row, col;
399     unsigned int seglen;
400     Tile *tileC = &gset->tileCY;
401 
402     seglen = tileLineSegmentLen(tileC);
403     getVectorTypeName(gset->kextra->dtype, seglen, &tname, NULL);
404 
405     kgenAddStmt(ctx, "\n// deliver from skewing in the result\n");
406     kgenBeginBranch(ctx, "for (uint i = 0; i < skewX; i++)");
407     kgenPrintf(ctx, "%s tmp;\n\n", tname);
408 
409     src = &k1;
410     dst = &k2;
411 
412     // Skewing may be used only in case of transposed C
413     for (row = 0; row < tileC->nrRows; row += seglen) {
414         sprintfTileElement(dst, tileC, row, tileC->nrCols - 1, seglen);
415         kgenPrintf(ctx, "tmp = %s;\n", dst->buf);
416         for (col = tileC->nrCols - 1; col > 0; col--) {
417             sprintfTileElement(src, tileC, row, col - 1, seglen);
418             kgenPrintf(ctx, "%s = %s;\n", dst->buf, src->buf);
419             // swap pointer
420             ktmp = src;
421             src = dst;
422             dst = ktmp;
423         }
424         kgenPrintf(ctx, "%s = tmp;\n", dst->buf);
425     }
426 
427     kgenEndBranch(ctx, NULL);
428     kgenAddBlankLine(ctx);
429 }
430 
431 /*
432  * Setup coordinates before beginning a trsm stage
433  * A caller must ensure the strict stage sequence:
434  * BLOCK_UPDATE -> TILE_UPDATE
435  */
436 static void
genSetupCoords(struct KgenContext * ctx,const BlasGenSettings * gset,enum TrsmStage stage)437 genSetupCoords(
438     struct KgenContext *ctx,
439     const BlasGenSettings *gset,
440     enum TrsmStage stage)
441 {
442     char tmp[1024];
443     KernelExtraFlags kflags = gset->kextra->flags;
444     const SubproblemDim *dims = gset->subdims;
445     unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
446     const char *s;
447 
448     s = isMatrixUpper(kflags) ? "currM" : "m0";
449     sprintf(tmp, "coordA = %s + (lid / %u * %lu);\n",
450             s, l1Pans, dims[1].y);
451     kgenAddStmt(ctx, tmp);
452 
453     switch (stage) {
454     case BLOCK_UPDATE:
455         if (isMatrixUpper(kflags)) {
456             sprintf(tmp, "k0 = currM + %lu;\n", dims[0].y);
457         }
458         else {
459             sprintf(tmp, "k0 = 0;\n");
460         }
461         break;
462     case TILE_UPDATE:
463         if (isMatrixUpper(kflags)) {
464             sprintf(tmp, "k0 = currM + %lu - m1 * %lu;\n",
465                     dims[0].y - dims[1].y, dims[1].y);
466         }
467         else {
468             sprintf(tmp, "k0 = m0 + m1 * %lu;\n", dims[1].y);
469         }
470         break;
471     }
472 
473     kgenAddStmt(ctx, tmp);
474 
475     sprintf(tmp, "coordB = gid * %lu + (lid %% %u * %lu);\n",
476             dims[0].x, l1Pans, dims[1].x);
477 
478     kgenAddStmt(ctx, tmp);
479     kgenAddBlankLine(ctx);
480 }
481 
482 // Generate control block of the loop over K
483 static void
genInternalLoopCtl(struct KgenContext * ctx,const SubproblemDim * dim,KernelExtraFlags kflags,size_t stepK,size_t boundAlign)484 genInternalLoopCtl(
485     struct KgenContext *ctx,
486     const SubproblemDim *dim,
487     KernelExtraFlags kflags,
488     size_t stepK,
489     size_t boundAlign)
490 {
491     char tmp[1024];
492 
493     if (isMatrixUpper(kflags)) {
494         if (kflags & KEXTRA_TAILS_M) {
495             sprintf(tmp, "for (k0 = currM + %lu; k0 < M / %lu * %lu; "
496                                "k0 += %lu)",
497                     dim[0].y, boundAlign, boundAlign, stepK);
498         }
499         else {
500             sprintf(tmp, "for (k0 = currM + %lu; k0 < M; k0 += %lu)",
501                     dim[0].y, stepK);
502         }
503     }
504     else {
505         sprintf(tmp, "for (k0 = 0; k0 < m0; k0 += %lu)",
506                 stepK);
507     }
508 
509     kgenBeginBranch(ctx, tmp);
510 }
511 
512 static void
initKernelVarNames(KernelVarNames * kvars)513 initKernelVarNames(KernelVarNames *kvars)
514 {
515     kvars->A = "uA";
516     kvars->B = "uB";
517     kvars->C = "B";
518     kvars->coordA = "coordA";
519     kvars->coordB = "coordB";
520     kvars->k = "k0";
521     kvars->sizeM = "M";
522     kvars->sizeN = "N";
523     kvars->sizeK = "M";
524     kvars->lda = "lda";
525     kvars->ldb = "ldb";
526     kvars->ldc = "ldb";
527     kvars->alpha = "alpha";
528     kvars->beta = "beta";
529 }
530 
531 static void
setFetchHandler(TileMulOpts * mulOpts,const BlasGenSettings * gset,int handler (struct KgenContext * ctx,MatrixRole mrole,void * priv),TilePostFetchPrivate * priv)532 setFetchHandler(
533     TileMulOpts *mulOpts,
534     const BlasGenSettings *gset,
535     int handler(struct KgenContext *ctx, MatrixRole mrole, void *priv),
536     TilePostFetchPrivate *priv)
537 {
538     int i, nrPrivs;
539     const char *regName = NULL;
540 
541     if (handler == defaultTilePostFetch) {
542         nrPrivs = 1;
543     }
544     else {
545         nrPrivs = 2;
546         regName = "b";
547     }
548 
549     for (i = 0; i < nrPrivs; i++) {
550         priv[i].fetchNumA = 0;
551         priv[i].wholeA = 1;
552         priv[i].funcID = CLBLAS_TRSM;
553         priv[i].gset = gset;
554         priv[i].regName = regName;
555         mulOpts->postFetch = handler;
556         mulOpts->postFetchPriv = priv;
557     }
558 }
559 
560 static void
genCheckShiftTailB(struct KgenContext * ctx,const BlasGenSettings * gset,int adjustRestore,TailStatus * tailStatus)561 genCheckShiftTailB(
562     struct KgenContext *ctx,
563     const BlasGenSettings *gset,
564     int adjustRestore,
565     TailStatus *tailStatus)
566 {
567     BlasGenSettings gsetNew;
568     CLBLASKernExtra kextraNew;
569 
570     memcpy(&gsetNew, gset, sizeof(gsetNew));
571     memcpy(&kextraNew, gset->kextra, sizeof(kextraNew));
572     // avoid tail shift for the matrix A
573     kextraNew.flags &= ~(KEXTRA_TAILS_M | KEXTRA_TAILS_M_LOWER);
574     gsetNew.kextra = &kextraNew;
575 
576     if (adjustRestore) {
577         checkGenRestoreTailCoords(ctx, &gsetNew, *tailStatus);
578     }
579     else {
580         *tailStatus = checkGenAdjustTailCoords(ctx, CLBLAS_TRSM, &gsetNew,
581                                                NULL);
582     }
583 }
584 
585 static void
sprintfHitMatrixCond(char * buf,MatrixRole mrole,const char * prefix,const char * suffix)586 sprintfHitMatrixCond(
587     char *buf,
588     MatrixRole mrole,
589     const char *prefix,
590     const char *suffix)
591 {
592     const char *coordName;
593     char bound;
594 
595     coordName = (mrole == MATRIX_A) ? "coordA" : "coordB";
596     bound = (mrole == MATRIX_A) ? 'M' : 'N';
597     if (suffix == NULL) {
598         suffix = "";
599     }
600     sprintf(buf, "%s%s < %c%s", prefix, coordName, bound, suffix);
601 }
602 
603 /*
604  * 'mulUpd' arguments mean what action is being done: multiplication on
605  * an inverted tile or subsequent update
606  */
607 static void
sprintfStage2Condition(char * buf,const BlasGenSettings * gset,int mulUpd)608 sprintfStage2Condition(
609     char *buf,
610     const BlasGenSettings *gset,
611     int mulUpd)
612 {
613     KernelExtraFlags kflags = gset->kextra->flags;
614     char hitCond[1024];
615     char *p;
616     unsigned int xPans, yPans;
617 
618 
619     hitCond[0] = '\0';
620     xPans = (unsigned int)(gset->subdims[0].x / gset->subdims[1].x);
621     yPans = (unsigned int)(gset->subdims[0].y / gset->subdims[1].y);
622     if (kflags & KEXTRA_TAILS_M) {
623         sprintfHitMatrixCond(hitCond, MATRIX_A, " && ", NULL);
624     }
625     p = hitCond + strlen(hitCond);
626     if (kflags & KEXTRA_TAILS_N) {
627         sprintfHitMatrixCond(p, MATRIX_B, " && ", NULL);
628     }
629 
630     if (!mulUpd) {
631         if (isMatrixUpper(kflags)) {
632             sprintf(buf, "if (lid / %u + m1 == %u%s)",
633                     xPans, yPans - 1, hitCond);
634         }
635         else {
636             sprintf(buf, "if (lid / %u == m1%s)", xPans, hitCond);
637         }
638     }
639     else {
640         if (isMatrixUpper(kflags)) {
641             sprintf(buf, "if (lid / %u + m1 < %u%s)",
642                     xPans, yPans - 1, hitCond);
643         }
644         else {
645             sprintf(buf, "if (lid / %u > m1%s)", xPans, hitCond);
646         }
647     }
648 }
649 
650 static void
genZeroTileTrash(struct KgenContext * ctx,const BlasGenSettings * gset,MatrixRole mrole,Tile * tile)651 genZeroTileTrash(
652     struct KgenContext *ctx,
653     const BlasGenSettings *gset,
654     MatrixRole mrole,
655     Tile* tile)
656 {
657     char tmp[1024];
658     const SubproblemDim *dim = &gset->subdims[1];
659     const CLBLASKernExtra *kextra = gset->kextra;
660     unsigned int i, j;
661     unsigned int step;
662     Kstring elem;
663 
664     if (mrole == MATRIX_A) {
665         kgenAddBlankLine(ctx);
666     }
667     else {
668         kgenBeginBranch(ctx, NULL);
669     }
670 
671     sprintf(tmp, "const int bound = (coordA + %lu > M) ? (M - coordA) : %lu;\n",
672             dim->y, dim->y);
673     kgenAddStmt(ctx, tmp);
674 
675     step = tileLineSegmentLen(tile);
676     step = (tile->trans) ? 1 : step;
677 
678     for (j = 0; j < tile->nrRows; ++j) {
679         for (i = 0; i < tile->nrCols; i+=step) {
680             sprintfTileElement(&elem, tile, j, i, step);
681             sprintf(tmp, "%s = (bound <= %u) ? 0 : %s;\n", elem.buf, j, elem.buf);
682             kgenAddStmt(ctx, tmp);
683         }
684     }
685 
686     // Set units in the trash diagonal elements for a tile of A
687     if (mrole == MATRIX_A) {
688         for (i = 0; i < (unsigned int)dim->y; i++) {
689             sprintfTileElement(&elem, tile, i, i, 1);
690             sprintf(tmp, "%s = (bound <= %d) ? %s : %s;\n",
691                     elem.buf, (int)i, strOne(kextra->dtype), elem.buf);
692             kgenAddStmt(ctx, tmp);
693         }
694     }
695 
696     if (mrole == MATRIX_A) {
697         kgenAddBlankLine(ctx);
698     }
699     else {
700         kgenEndBranch(ctx, NULL);
701     }
702 }
703 
704 /*
705  * NOTE: Before invoking this function 'tileA' must be initialized accordingly
706  *       so as it stores a square tile of the matrix A.
707  */
708 static void
genMulOnDiagonalTile(struct KgenContext * ctx,BlasGenSettings * gset,TileSet * tileSet,const TileMulOpts * mulOpts)709 genMulOnDiagonalTile(
710     struct KgenContext *ctx,
711     BlasGenSettings *gset,
712     TileSet *tileSet,
713     const TileMulOpts *mulOpts)
714 {
715     char tmp[1024];
716     FetchOpts fetchOpts;
717     const SubproblemDim *dim = &gset->subdims[1];
718     TilePostFetchPrivate pfPriv[2];
719     TileMulOpts optsNew;
720     const CLBLASKernExtra *extra = gset->kextra;
721     CLBLASKernExtra extraNew;
722     KernelExtraFlags kflags = extra->flags;
723     Tile t;
724     bool isTail;
725 
726     memset(&fetchOpts, 0, sizeof(fetchOpts));
727     fetchOpts.regName = "b";
728     fetchOpts.mrole = MATRIX_A;
729     fetchOpts.lineOffset = 0;
730     fetchOpts.linesNum = (unsigned int)dim->y;
731 
732     // setup options to multiply on the inverted tile
733     memcpy(&optsNew, mulOpts, sizeof(TileMulOpts));
734     optsNew.flags &= ~TILEMUL_TRB;
735 
736     kgenAddStmt(ctx, "// Fetch and invert the square tile located on the "
737                      "diagonal\n");
738 
739     // The matrix B play the role of A
740     t = substituteTile(&gset->tileA, &tileSet->bAsSqA);
741 
742     isTail = ((kflags & KEXTRA_TAILS_M) != 0);
743     genFetchInputTile(ctx, mulOpts->fctx, gset, &fetchOpts);
744     setFetchHandler(&optsNew, gset, genTrxmPostFetchZero, pfPriv);
745 
746     /*
747      * There is no needs in zeroing tail along K in case of the lower
748      * triangular matrix because it is in the "other" triangle which is
749      * never accessed
750      */
751     if (isTail && !isMatrixUpper(kflags)) {
752         memcpy(&extraNew, extra, sizeof(extraNew));
753         extraNew.flags &= ~KEXTRA_TAILS_K_LOWER;
754         gset->kextra = &extraNew;
755     }
756     genTrxmPostFetchZero(ctx, MATRIX_A, pfPriv);
757 
758     /*
759      * One must zero the tail part of a fetched square tile
760      * in order to avoid influence of the trailing trash on the resulting
761      * inverted tile (evaluating proceeds from the bottom towards the top
762      *                of the tile)
763      */
764     if (isTail) {
765         genZeroTileTrash(ctx, gset, MATRIX_A, &gset->tileA);
766     }
767 
768     restoreTile(&gset->tileA, &t);
769 
770     if(gset->flags & BGF_EXPLICIT_INLINE) {
771         genTileInverting(ctx, gset, tileSet);
772     }
773     else {
774         sprintf(tmp, "invertTile(%s, %s);\n\n",
775                 tileSet->squareA.baseName, tileSet->bAsSqA.baseName);
776         kgenAddStmt(ctx, tmp);
777     }
778 
779     gset->tileBX = tileSet->bAsC;
780     genTileCopy(ctx, &gset->tileBX, &gset->tileCY, TILECOPY_ASSIGN);
781 
782     /*
783      * For the lower diagonal not integrally decomposed matrix A
784      * it's enough to zero the tail part of the result in order to
785      * clear trash accumulated over the update loop
786      */
787     if (isTail && !isMatrixUpper(kflags)) {
788         genZeroTileTrash(ctx, gset, MATRIX_B, &gset->tileBX);
789     }
790 
791     genZeroTile(ctx, &gset->tileCY);
792 
793     genMulTiles(ctx, gset, &optsNew);
794     kgenAddBlankLine(ctx);
795 
796     // restore original extra
797     gset->kextra = extra;
798 }
799 
800 static void
genUpdateIntermResult(struct KgenContext * ctx,const BlasGenSettings * gset,bool withMhitCond,UpdateResultFlags flags)801 genUpdateIntermResult(
802     struct KgenContext *ctx,
803     const BlasGenSettings *gset,
804     bool withMhitCond,
805     UpdateResultFlags flags)
806 {
807     char tmp[1024];
808     const char *coordY, *coordX;
809     char *revAlp, *alp;
810     DataType dtype = gset->kextra->dtype;
811     KernelExtraFlags kflags = gset->kextra->flags;
812     const SubproblemDim *dim = &gset->subdims[1];
813     const KernelVarNames *kvarNames = &gset->varNames;
814     UpdateResultOp op;
815     UpresVarNames uvars;
816     const char* ctype;
817 
818     memset(&uvars, 0, sizeof(uvars));
819 
820     op = (flags & UPRES_WITH_BETA) ? UPRES_SUM : UPRES_SET;
821 
822     uvars.startRow = kvarNames->coordA;
823     uvars.startCol = kvarNames->coordB;
824     uvars.nrRows = "y";
825     uvars.nrCols = "x";
826     uvars.result = "B";
827     uvars.ld = "ldb";
828 
829     ctype = dtypeBuiltinType(dtype);
830     if (isComplexType(dtype)) {
831         if (dtype == TYPE_COMPLEX_FLOAT) {
832             revAlp = "div((float2)(-1.f, 0), alpha)";
833             alp = "(float2)(1.f, 0)";
834         }
835         else {
836             revAlp = "div((double2)(-1., 0), alpha)";
837             alp = "(double2)(1., 0)";
838         }
839     }
840     else {
841         revAlp = "-1. / alpha";
842         alp = "1.";
843     }
844 
845     // inline result update
846     flags |= UPRES_INLINE;
847 
848     coordY = kvarNames->coordA;
849     coordX = kvarNames->coordB;
850 
851     /*
852      * We should be careful here.
853      *
854      * The non tailed case of updateResult() is rewritted.
855      * Now update result for tailed and non tailed cases have a bit
856      * different semantics.
857      *
858      * The first one produces expressions like
859      * 'dst = dst * beta + src * alpha'.
860      *
861      * Here 'dst' and 'src' may be private result stored in registers or
862      * result to be updated in the global memory. Let the first one to be
863      * designated as tileC and the second one as matC.
864      *
865      * The non tailed case produces expressions like
866      * 'dst = matC * beta + tileC * alpha'.
867      *
868      * The second variant is more clear and native for the new implementation.
869      * But as the difference is not eliminated, both the variants are
870      * maintained here.
871      */
872 
873     if (!(kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N))) {
874         kgenBeginBranch(ctx, "");
875 
876         sprintf(tmp, "%s %s = %s;\n"
877                      "%s alpha = beta;\n",
878                 ctype, "beta", revAlp, ctype);
879         kgenAddStmt(ctx, tmp);
880 
881         updateResultGen(ctx,
882             gset,
883             CLBLAS_TRSM,
884             op,
885             flags & ~UPRES_WITH_BETA,
886             &uvars);
887 
888         kgenEndBranch(ctx, NULL);
889     }
890     else {
891         if (withMhitCond) {
892             sprintf(tmp, "if ((%s < %s) && (%s < %s))",
893                     coordY, kvarNames->sizeM, coordX, kvarNames->sizeN);
894             kgenBeginBranch(ctx, tmp);
895         }
896         else {
897             /* for x, y variables scope */
898             kgenBeginBranch(ctx, NULL);
899         }
900 
901         sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
902                      "uint x = min(%luu, %s - (uint)%s);\n",
903                 dim->y, kvarNames->sizeM, coordY,
904                 dim->x, kvarNames->sizeN, coordX);
905         kgenAddStmt(ctx, tmp);
906 
907         sprintf(tmp, "if ((y == %lu) && (x == %lu))",
908                 dim->y, dim->x);
909         kgenBeginBranch(ctx, tmp);
910 
911         sprintf(tmp, "%s %s = %s;\n"
912                      "%s alpha = beta;\n",
913                 ctype, "beta", revAlp, ctype);
914         kgenAddStmt(ctx, tmp);
915 
916         // optimized update
917         updateResultGen(ctx,
918             gset,
919             CLBLAS_TRSM,
920             op,
921             flags & ~UPRES_WITH_BETA,
922             &uvars);
923 
924         kgenEndBranch(ctx, NULL);
925 
926         flags |= UPRES_GENERIC;
927         kgenBeginBranch(ctx, "else ");
928 
929         sprintf(tmp, "%s %s = %s;\n"
930                      "%s %s = %s;\n",
931                 ctype, "beta", revAlp,
932                 ctype, "alpha", alp);
933         kgenAddStmt(ctx, tmp);
934 
935         // not optimized update
936         updateResultGen(ctx,
937             gset,
938             CLBLAS_TRSM,
939             op,
940             flags,
941             &uvars);
942 
943         kgenEndBranch(ctx, NULL);
944         kgenEndBranch(ctx, NULL);
945     }
946 }
947 
948 static void
genPreloadedTileMul(struct KgenContext * ctx,BlasGenSettings * gset,TileMulOpts * mulOpts,const Tile * parTile,const char * copy2LDSFuncName)949 genPreloadedTileMul(
950     struct KgenContext *ctx,
951     BlasGenSettings *gset,
952     TileMulOpts *mulOpts,
953     const Tile *parTile,
954     const char* copy2LDSFuncName)
955 {
956     char tmp[1024];
957     KernelExtraFlags kflags = gset->kextra->flags;
958     unsigned int bwidthOld;
959     const char *oldNameB;
960     const char *ptrName;
961 
962     getVectorTypeName(gset->kextra->dtype, parTile->vecLen, NULL, &ptrName);
963     kgenPrintf(ctx, "lB.%s = tmpB;\n", ptrName);
964     kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
965 
966     if (!isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B)) {
967         sprintf(tmp, "%s(lB, uB, gid * %lu, k0, ldb);\n",
968             copy2LDSFuncName, gset->subdims[0].x);
969     }
970     else {
971         sprintf(tmp, "%s(lB, uB, k0, gid * %lu, ldb);\n",
972             copy2LDSFuncName, gset->subdims[0].x);
973     }
974     kgenAddStmt(ctx, tmp);
975 
976     kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
977     kgenAddBlankLine(ctx);
978 
979     kgenAddStmt(ctx, "lB = lBMain;\n\n");
980 
981     mulOpts->memB = CLMEM_LOCAL_MEMORY;
982     oldNameB = gset->varNames.B;
983     bwidthOld = (unsigned int)gset->subdims[0].bwidth;
984     gset->varNames.B = "lB";
985     gset->subdims[0].bwidth = (parTile->trans) ? parTile->nrRows :
986                                                  parTile->nrCols;
987 
988     tileMulGen(ctx, gset, mulOpts);
989 
990     gset->varNames.B = oldNameB;
991     gset->subdims[0].bwidth = bwidthOld;
992     mulOpts->memB = CLMEM_GLOBAL_MEMORY;
993 }
994 
995 static void
initTiles(BlasGenSettings * gset,TileSet * tileSet,const struct SubproblemDim * subdims,KernelExtraFlags kflags,DataType dtype,PrivateStorageType storType)996 initTiles(
997     BlasGenSettings* gset,
998     TileSet* tileSet,
999     const struct SubproblemDim *subdims,
1000     KernelExtraFlags kflags,
1001     DataType dtype,
1002     PrivateStorageType storType)
1003 {
1004     unsigned int rowsA;
1005     unsigned int rowsB;
1006     unsigned int rowsC;
1007     unsigned int colsA;
1008     unsigned int colsB;
1009     unsigned int colsC;
1010     bool transA;
1011     bool transB;
1012     unsigned int vecLenA;
1013     unsigned int vecLenB;
1014     unsigned int vecLenC;
1015 
1016     rowsA = (unsigned int)subdims[1].y;
1017     colsA = (unsigned int)szmax(subdims[1].y, subdims[1].bwidth);
1018 
1019     rowsB = (unsigned int)szmax(subdims[1].y, subdims[1].bwidth);
1020     colsB = (unsigned int)szmax(subdims[1].x, subdims[1].y);
1021 
1022     rowsC = (unsigned int)subdims[1].y;
1023     colsC = (unsigned int)subdims[1].x;
1024 
1025     transA = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A);
1026     transB = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B);
1027 
1028     vecLenA = (unsigned int)((transA) ? subdims[1].y : subdims[1].bwidth);
1029     vecLenA = umin(vecLenA, MAX_TILE_VECLEN);
1030     vecLenB = (unsigned int)((transB) ? subdims[1].x : subdims[1].bwidth);
1031     vecLenB = umin(vecLenB, MAX_TILE_VECLEN);
1032     vecLenC = (transB) ? vecLenB : vecLenA;
1033 
1034     initTile(&tileSet->rectA, "a", (unsigned int)subdims[1].y,
1035              (unsigned int)subdims[1].bwidth, vecLenA, dtype,
1036              storType, transA, false);
1037 
1038     initTile(&tileSet->squareA, "a", (unsigned int)subdims[1].y,
1039              (unsigned int)subdims[1].y, vecLenA, dtype, storType,
1040              transA, false);
1041 
1042     initTile(&tileSet->origB, "b", (unsigned int)subdims[1].bwidth,
1043              (unsigned int)subdims[1].x, vecLenB, dtype, storType,
1044              !transB, false);
1045 
1046     initTile(&tileSet->bStage2, "b", (unsigned int)subdims[1].y,
1047              (unsigned int)subdims[1].x, vecLenB, dtype, storType,
1048              !transB, false);
1049 
1050     initTile(&tileSet->bAsSqA, "b", (unsigned int)subdims[1].y,
1051              (unsigned int)subdims[1].y, vecLenB, dtype, storType,
1052              transA, false);
1053 
1054     initTile(&tileSet->bAsC, "b", (unsigned int)subdims[1].y,
1055              (unsigned int)subdims[1].x, vecLenB, dtype, storType,
1056              gset->tileCY.trans, false);
1057 
1058     initTile(&gset->tileA, "a", rowsA, colsA,
1059              vecLenA, dtype, storType, transA, false);
1060 
1061     initTile(&gset->tileBX, "b", rowsB, colsB,
1062              vecLenB, dtype, storType, !transB, false);
1063 
1064     initTile(&gset->tileCY, "c", rowsC, colsC,
1065              vecLenC, dtype, storType, !transB, false);
1066 
1067     tileSet->A = gset->tileA;
1068     tileSet->B = gset->tileBX;
1069 }
1070 
1071 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)1072 generator(
1073    char *buf,
1074    size_t buflen,
1075    const struct SubproblemDim *subdims,
1076    const struct PGranularity *pgran,
1077    void *extra)
1078 {
1079     char tmp[1024];
1080     struct KgenContext *ctx;
1081     ssize_t ret;
1082     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
1083     DataType dtype = kextra->dtype;
1084     KernelExtraFlags kflags = kextra->flags;
1085     CLBLASKernExtra extraNew;
1086     BlasGenSettings gset;
1087     TileMulOpts mulOpts;
1088     const char *ptrName;
1089     UpdateResultFlags upFlags = 0;
1090     TilePostFetchPrivate pfPriv;
1091     unsigned int l1Pans;
1092     bool b;
1093     Tile parTile;
1094     TrsmExtraParams *extraParams = (TrsmExtraParams *)kextra->solverPriv;
1095     int ldsLarge, lds_diagonal;
1096     bool isInline;
1097     TileSet tileSet;
1098     char copy2LDSFuncName[FUNC_NAME_MAXLEN];
1099     TailStatus tailStatus = 0;
1100     FetchAddrMode addrMode = 0;
1101     bool tailM = ((kflags & KEXTRA_TAILS_M) != 0);
1102     bool tailN = ((kflags & KEXTRA_TAILS_N) != 0);
1103     size_t alignK;
1104 
1105     if (pgran->wgDim != 1) {
1106         return -EINVAL;
1107     }
1108 
1109     l1Pans = (unsigned int)(subdims[0].x / subdims[1].x);
1110 
1111     memset(&gset, 0, sizeof(gset));
1112     gset.flags = BGF_WHOLE_A | BGF_EXPLICIT_INLINE | BGF_UPTRS;
1113     memcpy(gset.subdims, subdims, sizeof(SubproblemDim) * 2);
1114     // there is not need in block structure along K
1115     gset.subdims[0].bwidth = gset.subdims[1].bwidth;
1116     subdims = gset.subdims;
1117 
1118     /*
1119      * Since tiles are changed dynamically, e. g. in the main tilemul
1120      * loop they are rectangular, but at the second stage both A and B
1121      * tile storages are used for square tiles. One must adjust physical
1122      * vectorization accordindly, so as vector length might not be
1123      * greater than linear size of any tile
1124      */
1125     memcpy(&extraNew, kextra, sizeof(extraNew));
1126     extraNew.vecLenA = umin(kextra->vecLenA, (unsigned int)subdims[1].y);
1127     extraNew.vecLenB = umin(kextra->vecLenB, (unsigned int)subdims[1].y);
1128 
1129     gset.pgran = pgran;
1130     gset.kextra = &extraNew;
1131     initKernelVarNames(&gset.varNames);
1132 
1133     // multiplication options
1134     mulOpts.memA = CLMEM_GLOBAL_MEMORY;
1135     mulOpts.memB = CLMEM_GLOBAL_MEMORY;
1136     mulOpts.core = (kextra->flags & KEXTRA_ENABLE_MAD) ? TILEMUL_MAD :
1137                                                          TILEMUL_MULADD;
1138     mulOpts.postFetch = NULL;
1139     mulOpts.flags = kextraToTilemulFlags(CLBLAS_TRSM, kflags);
1140     mulOpts.flags |= TILEMUL_EXTERN_RDECL | TILEMUL_NOT_INC_K;
1141     mulOpts.fctx = createFetchContext();
1142     if (mulOpts.fctx == NULL) {
1143         return -ENOMEM;
1144     }
1145 
1146     disableFetchOptLevels(mulOpts.fctx, FOPTLEV_TMP_COORD_PRECOMPUTING);
1147 
1148     isInline = (gset.flags & BGF_EXPLICIT_INLINE);
1149 
1150     initTiles(&gset, &tileSet, subdims, kflags, dtype,
1151               PRIV_STORAGE_VARIABLE_SET);
1152 
1153     ctx = createKgenContext(buf, buflen, true);
1154     if (ctx == NULL) {
1155         destroyFetchContext(mulOpts.fctx);
1156         return -ENOMEM;
1157     }
1158 
1159     kgenAddStmt(ctx, "#pragma OPENCL EXTENSION cl_amd_printf : enable\n\n");
1160 
1161     b = isDoubleBasedType(dtype);
1162     kgenDeclareUptrs(ctx, b);
1163     if (isComplexType(dtype)) {
1164         genComplexMathOperators(ctx, dtype);
1165     }
1166     if(!isInline) {
1167         genTileInverting(ctx, &gset, &tileSet);
1168     }
1169 
1170     if ( extraParams->ldsUse != LDS_NO_USE ) {
1171         SubproblemDim sdims;
1172         DBlockCopyFlags flags;
1173         unsigned int vecLen;
1174 
1175         if (!isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B)) {
1176             sdims.x = gset.subdims[1].bwidth * extraParams->unrollingFactor;
1177             sdims.y = gset.subdims[0].x;
1178         }
1179         else {
1180             sdims.x = gset.subdims[0].x;
1181             sdims.y = gset.subdims[1].bwidth * extraParams->unrollingFactor;
1182         }
1183 
1184         vecLen = getVecLen(&gset, CLBLAS_TRSM, MATRIX_B);
1185         flags = (vecLen < 4) ? DBLOCK_COPY_NOT_VECTORIZE : 0;
1186         copyDataBlockGen(ctx, &sdims, gset.pgran, dtype,
1187                          DBLOCK_GLOBAL_TO_LOCAL, flags);
1188         kgenAddBlankLine(ctx);
1189         kgenGetLastFuncName(copy2LDSFuncName, FUNC_NAME_MAXLEN, ctx);
1190     }
1191 
1192     declareTrxmKernel(ctx, dtype, pgran, kflags, CLBLAS_TRSM, "Cached", false,
1193                       true);
1194     kgenBeginFuncBody(ctx);
1195 
1196     declareLocalVariables(ctx, &gset, &parTile, extraParams);
1197     if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
1198         kgenAddStmt(ctx, "A += offA;\n");
1199     }
1200     genTrxmBMatrShift(ctx, kflags, false);
1201 
1202     ptrName = dtypeUPtrField(dtype);
1203 
1204     sprintf(tmp, "uB.%s = B;\n\n", ptrName);
1205     kgenAddStmt(ctx, tmp);
1206 
1207     // external loop
1208     sprintf(tmp, "for (m0 = 0; m0 < M; m0 += %lu)", subdims[0].y);
1209     kgenBeginBranch(ctx, tmp);
1210     genZeroTile(ctx, &gset.tileCY);
1211     genSetupCoords(ctx, &gset, BLOCK_UPDATE);
1212 
1213     kgenAddStmt(ctx, "// Stage 1. Multiply and update with large blocks\n");
1214 
1215     gset.tileA = tileSet.rectA;
1216     gset.tileBX = tileSet.origB;
1217 
1218     if (!isMatrixUpper(kflags) && tailM) {
1219         addrMode |= FETCH_ADDR_A_CYCLICAL;
1220         setFetchAddrMode(mulOpts.fctx, addrMode);
1221     }
1222 
1223     ldsLarge = ((extraParams->ldsUse & LDS_USE_LARGE) != 0);
1224     alignK = subdims[1].bwidth;
1225     if (ldsLarge) {
1226         alignK *= extraParams->unrollingFactor;
1227     }
1228 
1229     if (ldsLarge) {
1230         const char *oldCoordB;
1231         FetchAddrMode bamode = addrMode | FETCH_ADDR_K_RELATIVE;
1232         bool withSkew;
1233 
1234         withSkew = useSkewedFetchB(&gset);
1235         if (!withSkew) {
1236             bamode |= FETCH_ADDR_B_RELATIVE;
1237         }
1238         else {
1239             bamode |= FETCH_ADDR_B_CYCLICAL;
1240         }
1241 
1242         setFetchAddrMode(mulOpts.fctx, bamode);
1243 
1244         if (tailN) {
1245             /*
1246              * Conditional branch for those items which hit into
1247              * matrix B with their matrix coordinates
1248              */
1249             sprintf(tmp, "if ((gid + 1) * %lu < N)", subdims[0].x);
1250             kgenBeginBranch(ctx, tmp);
1251         }
1252 
1253         if (isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A)) {
1254             kgenPrintf(ctx, "uA.%s = A + k0 * lda;\n", ptrName);
1255         }
1256         else {
1257             kgenPrintf(ctx, "uA.%s = A + k0;\n", ptrName);
1258         }
1259 
1260         if (withSkew) {
1261             unsigned int bwidthOld;
1262 
1263             oldCoordB = gset.varNames.coordB;
1264             gset.varNames.coordB = "skewX";
1265             bwidthOld = gset.subdims[0].bwidth;
1266             gset.subdims[0].bwidth = (parTile.trans) ? parTile.nrRows :
1267                                                        parTile.nrCols;
1268             gset.subdims[0].bwidth = bwidthOld;
1269         }
1270 
1271         genInternalLoopCtl(ctx, subdims, kflags, alignK, alignK);
1272         genPreloadedTileMul(ctx, &gset, &mulOpts, &parTile, copy2LDSFuncName);
1273         genInternalLoopEnd(ctx);                             // loop over K
1274 
1275         if (withSkew) {
1276             gset.varNames.coordB = oldCoordB;
1277             setFetchAddrMode(mulOpts.fctx, bamode & ~FETCH_ADDR_B_CYCLICAL);
1278             // deliver from skew in the result before proceed to the next stage
1279             genTileCyclicalShift(ctx, &gset);
1280         }
1281 
1282         if (tailN) {
1283             kgenEndBranch(ctx, NULL);
1284             kgenBeginBranch(ctx, "else");
1285         }
1286 
1287         setFetchAddrMode(mulOpts.fctx, addrMode);
1288     }
1289 
1290     if (!ldsLarge || tailN) {
1291         genCheckShiftTailB(ctx, &gset, 0, &tailStatus);
1292         if ((kflags & KEXTRA_TAILS_N_LOWER) && !tailStatus) {
1293             addrMode |= FETCH_ADDR_B_CYCLICAL;
1294             setFetchAddrMode(mulOpts.fctx, addrMode);
1295         }
1296 
1297         if (tailN) {
1298             sprintfHitMatrixCond(tmp, MATRIX_B, "if (", ")");
1299             kgenBeginBranch(ctx, tmp);
1300         }
1301 
1302         genInternalLoopCtl(ctx, subdims, kflags, subdims[1].bwidth, alignK);
1303         tileMulGen(ctx, &gset, &mulOpts);
1304         genInternalLoopEnd(ctx);                             // loop over K
1305 
1306         if (tailN) {
1307             kgenEndBranch(ctx, NULL);
1308         }
1309 
1310         if (extraParams->ldsUse & LDS_USE_LARGE) {
1311             kgenEndBranch(ctx, NULL);
1312         }
1313     }
1314 
1315     sprintf(tmp, "uA.%s = A;\n\n", ptrName);
1316     kgenAddStmt(ctx, tmp);
1317 
1318     // processing tails along update dimension
1319     if (isMatrixUpper(kflags) &&
1320         ((kflags & KEXTRA_TAILS_K_LOWER) ||
1321           (ldsLarge && extraParams->unrolledTail))) {
1322 
1323         unsigned int tailChunks;
1324 
1325         tailChunks = (extraParams->ldsUse & LDS_USE_LARGE) ?
1326             extraParams->unrolledTail : 1;
1327 
1328         if (tailN) {
1329             char hitCond[1024];
1330 
1331             sprintfHitMatrixCond(hitCond, MATRIX_B, "(", ")");
1332             sprintf(tmp, "if ((currM + %lu < M) && %s)",
1333                     subdims[0].y, hitCond);
1334         }
1335         else {
1336             sprintf(tmp, "if (currM + %lu < M)", subdims[0].y);
1337         }
1338         kgenBeginBranch(ctx, tmp);
1339 
1340         if (kflags & KEXTRA_TAILS_K_LOWER) {
1341             setFetchAddrMode(mulOpts.fctx, addrMode | FETCH_ADDR_K_CYCLICAL);
1342             setFetchHandler(&mulOpts, &gset, defaultTilePostFetch, &pfPriv);
1343         }
1344         if (tailChunks > 1) {
1345             mulOpts.flags &= ~TILEMUL_NOT_INC_K;
1346             sprintf(tmp, "for (uint k1 = 0; k1 < %u; k1++)", tailChunks);
1347             kgenBeginBranch(ctx, tmp);
1348         }
1349 
1350 		addrMode |= FETCH_ADDR_B_CYCLICAL;
1351         setFetchAddrMode(mulOpts.fctx, addrMode);
1352         tileMulGen(ctx, &gset, &mulOpts);
1353         if (tailChunks > 1) {
1354             kgenEndBranch(ctx, NULL);
1355             mulOpts.flags |= TILEMUL_NOT_INC_K;
1356         }
1357 
1358         kgenEndBranch(ctx, NULL);
1359     }
1360 
1361     gset.tileA = tileSet.squareA;
1362 
1363     kgenAddStmt(ctx, "\n/*\n"
1364                      " * Stage 2. A part of work items multiply got result on "
1365                      "a respective\n"
1366                      " * inverted diagonal block, and the remaining ones wait. "
1367                      "Then they perform\n"
1368                      " * one step of further intermediate result evaluation as "
1369                      "multiplying tile by tile.\n"
1370                      " * It continues until the whole panel of the "
1371                      "matrix A is processed\n"
1372                      " */\n");
1373 
1374     // one must deal further with square blocks strictly
1375     gset.subdims[0].bwidth = gset.subdims[1].bwidth = gset.subdims[1].y;
1376 
1377     sprintf(tmp, "for (m1 = 0; m1 < %lu; m1++)", subdims[0].y / subdims[1].y);
1378     kgenBeginBranch(ctx, tmp);
1379 
1380     if (extraParams->ldsUse & LDS_USE_DIAGONAL) {
1381         sprintf(tmp, "const int bid = lid %% %u;\n\n",
1382                 l1Pans);
1383         kgenAddStmt(ctx, tmp);
1384     }
1385 
1386     /*
1387      * Update the intermediate result multiply on the inverted diagonal tile,
1388      * and write back
1389      */
1390     genSetupCoords(ctx, &gset, TILE_UPDATE);
1391 
1392     sprintfStage2Condition(tmp, &gset, 0);
1393     ret = kgenBeginBranch(ctx, tmp);
1394 
1395     upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
1396     upFlags |= tailStatusToUpresFlags(tailStatus);
1397     upFlags |= UPRES_PRIV_DEST | UPRES_WITH_BETA;
1398     genUpdateIntermResult(ctx, &gset, false, upFlags);
1399 
1400     kgenAddBlankLine(ctx);
1401 
1402     lds_diagonal = ((extraParams->ldsUse & LDS_USE_DIAGONAL) &&
1403                     (kflags & (KEXTRA_COLUMN_MAJOR)) == 0 &&
1404                     !(tailM || tailN) &&
1405                     !(upFlags & UPRES_NO_VECTORIZATION) &&
1406                     !isComplexType(kextra->dtype));
1407 
1408     /*
1409      * it's needed now to adjust addressing mode of A so as to don't
1410      * exceed the bound of A
1411      */
1412     if (tailM) {
1413         setFetchAddrMode(mulOpts.fctx,
1414                          addrMode | FETCH_ADDR_A_CYCLICAL |
1415                          FETCH_ADDR_K_CYCLICAL);
1416         extraNew.flags |= KEXTRA_TAILS_K_LOWER;
1417     }
1418 
1419     genMulOnDiagonalTile(ctx, &gset, &tileSet, &mulOpts);
1420     gset.tileBX = tileSet.bStage2;
1421     if (tailM) {
1422         setFetchHandler(&mulOpts, &gset, defaultTilePostFetch, &pfPriv);
1423     }
1424 
1425     kgenAddStmt(ctx, "// Write back the given result\n");
1426 
1427     upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
1428     upFlags |= tailStatusToUpresFlags(tailStatus);
1429 
1430     if (lds_diagonal) {
1431        sprintf(tmp, "tmpB[%%u * %u + bid]", l1Pans);
1432     }
1433 
1434     genResultUpdateWithFlags(ctx, CLBLAS_TRSM, &gset, upFlags,
1435                                  NULL, NULL, lds_diagonal ? tmp : NULL);
1436 
1437     kgenEndBranch(ctx, NULL);   // multiply on the inverted tile path
1438     kgenAddBarrier(ctx, CLK_GLOBAL_MEM_FENCE);
1439 
1440     // continue the tile update
1441     kgenAddBlankLine(ctx);
1442     sprintfStage2Condition(tmp, &gset, 1);
1443     kgenBeginBranch(ctx, tmp);
1444     genCheckShiftTailB(ctx, &gset, 0, &tailStatus);
1445     if (lds_diagonal) {
1446         // TODO: add here storing to LDS as well
1447     }
1448     else {
1449 		addrMode |= FETCH_ADDR_B_CYCLICAL;
1450         setFetchAddrMode(mulOpts.fctx, addrMode);
1451         tileMulGen(ctx, &gset, &mulOpts);
1452     }
1453     kgenEndBranch(ctx, NULL);           // tile update path
1454     kgenAddBarrier(ctx, CLK_GLOBAL_MEM_FENCE);
1455 
1456     kgenEndBranch(ctx, NULL);           // second stage loop
1457 
1458     if (isMatrixUpper(kflags)) {
1459         sprintf(tmp, "currM -= %lu;\n", subdims[0].y);
1460         kgenAddStmt(ctx, tmp);
1461     }
1462 
1463     kgenEndBranch(ctx, NULL);           // loop over M
1464 
1465     ret = kgenEndFuncBody(ctx);
1466 
1467     if (!ret) {
1468         ret = (ssize_t)kgenSourceSize(ctx) + 1;
1469     }
1470 
1471     destroyFetchContext(mulOpts.fctx);
1472     destroyKgenContext(ctx);
1473 
1474     return (ret < 0) ? -EOVERFLOW : ret;
1475 }
1476 
1477 static bool
isFitToLDS(SubproblemDim * dim,DataType dtype,cl_ulong ldsSize,const void * kernelArgs)1478 isFitToLDS(
1479     SubproblemDim *dim,
1480     DataType dtype,
1481     cl_ulong ldsSize,
1482     const void *kernelArgs)
1483 {
1484     (void)dim;
1485     (void)dtype;
1486     (void)ldsSize;
1487     (void)kernelArgs;
1488 
1489     return true;
1490 }
1491 
1492 static SolverFlags
solverFlags(void)1493 solverFlags(void)
1494 {
1495     return (SF_WSPACE_1D | SF_TOP_INPUT_SQUARE_BLOCKS);
1496 }
1497 
1498 static void
assignKargs(KernelArg * args,const void * params,const void * extra)1499 assignKargs(KernelArg *args, const void *params, const void *extra)
1500 {
1501     const CLBlasKargs *blasArgs = (const CLBlasKargs*)params;
1502     KernelExtraFlags kflags = ((const CLBLASKernExtra*)extra)->flags;
1503     int idx = 7;
1504 
1505     initSizeKarg(&args[0], blasArgs->M);
1506     initSizeKarg(&args[1], blasArgs->N);
1507     assignScalarKarg(&args[2], &(blasArgs->alpha), blasArgs->dtype);
1508     initMemobjKarg(&args[3], blasArgs->A, NULL, 0, 0);
1509     initSizeKarg(&args[4], blasArgs->lda.matrix);
1510     initMemobjKarg(&args[5], blasArgs->B, NULL, 0, 0);
1511     initSizeKarg(&args[6], blasArgs->ldb.matrix);
1512     if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
1513         initSizeKarg(&args[idx++], blasArgs->offA);
1514     }
1515     if (kflags & KEXTRA_BX_OFF_NOT_ZERO) {
1516         initSizeKarg(&args[idx], blasArgs->offBX);
1517     }
1518 }
1519 
1520 static void
fixupArgs(void * args,SubproblemDim * subdims,void * extra)1521 fixupArgs(void *args, SubproblemDim *subdims, void *extra)
1522 {
1523     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
1524     CLBlasKargs *kargs = (CLBlasKargs*)args;
1525     TrsmExtraParams *extraParams = (TrsmExtraParams *)kextra->solverPriv;
1526     size_t loadBatch;
1527     unsigned int wgSize;
1528     unsigned int workRatio;
1529     unsigned int ldsUse = LDS_NO_USE;
1530     KernelExtraFlags kflags = kextra->flags;
1531     SubproblemDim globDim;
1532     bool isAmdGPU;
1533 
1534     /*
1535      * Calculate size of the batch loaded from global to local memory
1536      * at each iteration of the stage 1. Choose such unrolling factor
1537      * that allow each work item to load at least 16 bytes that provides
1538      * efficient global memory access
1539      */
1540     loadBatch = subdims[0].x * subdims[1].bwidth * dtypeSize(kargs->dtype);
1541     wgSize = (unsigned int)((subdims[0].x / subdims[1].itemX) *
1542                             (subdims[0].y / subdims[1].itemY));
1543     if (loadBatch < wgSize) {
1544         workRatio = 1;
1545     }
1546     else {
1547         workRatio = 16 / ((unsigned int)loadBatch / wgSize);
1548         if (!workRatio) {
1549             workRatio = 1;
1550         }
1551     }
1552 
1553 #ifndef NDEBUG
1554     {
1555         const char *envImpl = getenv("AMD_CLBLAS_TRSM_LDSUSE");
1556 
1557         if (envImpl != NULL) {
1558             unsigned int w = atoi(envImpl);
1559             ldsUse = w % 10;
1560             w = w / 10;
1561             workRatio = w > 0 ? w : workRatio;
1562         }
1563     }
1564 #endif
1565 
1566     ldsUse = LDS_NO_USE;
1567     isAmdGPU = ((kflags & KEXTRA_VENDOR_AMD) != 0);
1568     if ((isAmdGPU && !(kflags & (KEXTRA_TAILS_K_LOWER | KEXTRA_TAILS_M_LOWER)))
1569         || (!isAmdGPU && !(kflags & KEXTRA_TAILS_M))) {
1570 
1571         ldsUse = LDS_USE_LARGE;
1572     }
1573 
1574     kargsToProbDims(&globDim, CLBLAS_TRSM, args, false);
1575     extraParams->ldsUse = ldsUse;
1576     extraParams->unrollingFactor = workRatio;
1577     extraParams->unrolledTail = (unsigned int)(((globDim.bwidth %
1578              (subdims[1].bwidth * workRatio)) + subdims[1].bwidth - 1) /
1579                                                 subdims[1].bwidth);
1580 
1581     fixupTrxmKargs(kargs);
1582 }
1583 
1584 static bool
checkCalcDecompDedicated(PGranularity * pgran,SubproblemDim * subdims,unsigned int subdimsNum,DataType dtype,int check)1585 checkCalcDecompDedicated(
1586     PGranularity *pgran,
1587     SubproblemDim *subdims,
1588     unsigned int subdimsNum,
1589     DataType dtype,
1590     int check)
1591 {
1592     bool ret = true;
1593 
1594     DUMMY_ARG_USAGE(subdimsNum);
1595 
1596     if (check == PGRAN_CHECK) {
1597         unsigned int minSize, maxSize;
1598 
1599         maxSize = (dtype == TYPE_COMPLEX_DOUBLE) ? 4 : 8;
1600         minSize = (dtype == TYPE_COMPLEX_DOUBLE) ? 1 : 2;
1601         ret = decompSanityCheck(subdims, minSize, maxSize, 24, dtype, true);
1602         ret = ret && (subdims[0].bwidth == subdims[1].bwidth);
1603         ret = ret && (pgran->wgSize[0] == 64);
1604     }
1605     else {
1606         calcPgranDedicated(pgran, subdims, -1, 3);
1607     }
1608 
1609     return ret;
1610 }
1611 
1612 void
initTrsmLdsLessCachedPattern(MemoryPattern * mempat)1613 initTrsmLdsLessCachedPattern(MemoryPattern *mempat)
1614 {
1615     mempat->name = "2-staged cached global memory based block trsm";
1616     mempat->nrLevels = 2;
1617     mempat->cuLevel = 0;
1618     mempat->thLevel = 0;
1619     mempat->sops = &trsmSops;
1620 
1621     mpatExtra.aMset = CLMEM_LEVEL_L1;
1622     mpatExtra.bMset = CLMEM_LEVEL_L1;
1623     mpatExtra.mobjA = CLMEM_BUFFER;
1624     mpatExtra.mobjB = CLMEM_BUFFER;
1625     mempat->extra = &mpatExtra;
1626 }
1627 
1628 #if 0
1629 
1630 static int
1631 getDefaultDecomp(
1632     PGranularity *pgran,
1633     SubproblemDim *subdims,
1634     unsigned int subdimsNum,
1635     void * pArgs)
1636 {
1637     pgran->wgDim = 1;
1638     pgran->wgSize[0] = 64;
1639     pgran->wgSize[1] = 1;
1640 
1641     subdims[0].x = subdims[0].itemX = 32;
1642     subdims[0].y = 64;
1643     subdims[0].itemY = SUBDIM_UNUSED;
1644     subdims[0].bwidth = subdims[1].bwidth = 4;
1645     subdims[1].x = subdims[1].itemX = 8;
1646     subdims[1].y = subdims[1].itemY = 4;
1647 }
1648 
1649 #endif
1650