1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * TRSM generator with support of cached reads from the global memory
20  */
21 
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <matrix_props.h>
31 #include <matrix_dims.h>
32 
33 #include "../blas_kgen.h"
34 #include "../trxm_common.h"
35 #include "trsm_kgen_legacy.h"
36 #include "gen_helper_legacy.h"
37 #include "../trsm_kgen.h"
38 
39 static const char *readSquareBlock =
40     "y = (currM + %lu <= M) ? %lu : M - currM;\n"
41     "x = (k0 + %lu <= M) ? %lu : M - k0;\n"
42     "if ((y == %lu) && (x == %lu)) {\n"
43     // just read with an optimized function
44     "    %s((LPtr)temp%c, (GPtr)A, currM, k0, lda);\n"
45     "}\n"
46     "else {\n"
47     "    %s((__local float4*)temp%c);\n"           // zeroing
48     "    barrier(CLK_LOCAL_MEM_FENCE);\n"
49     "    %s((LPtr)temp%c, (GPtr)A, currM, k0, y, x, %lu, lda);\n"
50     "}\n\n";
51 
52 static const char *readSquareBlockOpt =
53     // just read with an optimized function
54     "%s((LPtr)temp%c, (GPtr)A, currM, k0, lda);\n";
55 
56 static const char *readSquareBlockTrans =
57     "y = (currM + %lu <= M) ? %lu : M - currM;\n"
58     "x = (k0 + %lu <= M) ? %lu : M - k0;\n"
59     "if ((y == %lu) && (x == %lu)) {\n"
60     // read and transpose with an optimized function
61     "    %s((LPtr)temp%c, (GPtr)A, k0, currM, lda);\n"
62     "}\n"
63     "else {\n"
64     "    %s((__local float4*)temp%c);\n"           // zeroing
65     "    barrier(CLK_LOCAL_MEM_FENCE);\n"
66     // read and transpose with slow function
67     "    %s((LPtr)temp%c, (GPtr)A, k0, currM, x, y, %lu, lda);\n"
68     "}\n\n";
69 
70 static const char *readSquareBlockTransOpt =
71     // read and transpose with an optimized function
72     "%s((LPtr)temp%c, (GPtr)A, k0, currM, lda);\n";
73 
74 static CLBLASMpatExtra mpatExtra;
75 
76 static ssize_t
77 generator(
78    char *buf,
79    size_t buflen,
80    const struct SubproblemDim *subdims,
81    const struct PGranularity *pgran,
82    void *extra);
83 
84 static bool
85 isFitToLDS(
86     SubproblemDim *dim,
87     DataType dtype,
88     cl_ulong ldsSize,
89     const void *kernelArgs);
90 
91 static SolverFlags
92 solverFlags(void);
93 
94 static void
95 assignKargs(KernelArg *args, const void *params, const void *extra);
96 
97 static void
98 fixupArgs(void *args, SubproblemDim *subdims, void *extra);
99 
100 static SolverOps trsmSops = {
101     generator,
102     assignKargs,
103     isFitToLDS,
104     NULL,
105     NULL,
106     NULL,
107     NULL,
108     solverFlags,
109     fixupArgs,
110     NULL, //getDefaultDecomp
111    	NULL, // getDecompList
112    	NULL,
113    	NULL
114 };
115 
116 static TileMulFlags
getCyclicFlags(const SubproblemDim * dim,KernelExtraFlags kflags,bool tailPass,unsigned int vecLen)117 getCyclicFlags(
118     const SubproblemDim *dim,
119     KernelExtraFlags kflags,
120     bool tailPass,
121     unsigned int vecLen)
122 {
123     TileMulFlags mflags = TILEMUL_NO_FLAGS;
124 
125     if (tailPass && !isMatrixUpper(kflags)) {
126         mflags |= TILEMUL_GLOBAL_CYCLIC_A;
127     }
128 
129     if (isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B) &&
130         (kflags & KEXTRA_TAILS_N) && (dim->x > vecLen)) {
131 
132         mflags |= TILEMUL_GLOBAL_CYCLIC_B;
133     }
134 
135     return mflags;
136 }
137 
138 static void
initTiles(BlasGenSettings * gset)139 initTiles(BlasGenSettings *gset)
140 {
141     unsigned int nrRows, nrCols;
142     unsigned int vecLen;
143     const SubproblemDim *dim = &gset->subdims[1];
144     const CLBLASKernExtra *kextra = gset->kextra;
145     DataType dtype = kextra->dtype;
146     bool tra;
147 
148     // the tile A should be able to fit rectangular and square tiles
149     nrCols = (unsigned int)szmax(dim->y, dim->bwidth);
150     tra = isMatrixAccessColMaj(CLBLAS_TRSM, kextra->flags, MATRIX_A);
151     vecLen = getVecLen(gset, CLBLAS_TRSM, MATRIX_A);
152     initTile(&gset->tileA, "a", (unsigned int)dim->y, nrCols, vecLen,
153              dtype, PRIV_STORAGE_ARRAY, tra, false);
154 
155     /*
156      * tile B should be able to fit tiles of the matrix B and of the
157      * intermediate result. That result will be always transposed
158      * from the point of view of tile multiplication
159      */
160     tra = !isMatrixAccessColMaj(CLBLAS_TRSM, kextra->flags, MATRIX_B);
161     if (tra) {
162         nrRows = (unsigned int)szmax(dim->bwidth, dim->y);
163         nrCols = (unsigned int)dim->x;
164     }
165     else {
166         nrRows = (unsigned int)szmax(dim->bwidth, dim->x);
167         nrCols = (unsigned int)szmax(dim->x, dim->y);
168     }
169     vecLen = getVecLen(gset, CLBLAS_TRSM, MATRIX_B);
170     initTile(&gset->tileBX, "b", nrRows, nrCols, vecLen, dtype,
171              PRIV_STORAGE_ARRAY, tra, false);
172 
173     initTile(&gset->tileCY, "c", (unsigned int)dim->y, (unsigned int)dim->x,
174              vecLen, dtype, PRIV_STORAGE_ARRAY, false, false);
175 }
176 
177 static void
prepareTilesForMainLoop(BlasGenSettings * gset)178 prepareTilesForMainLoop(BlasGenSettings *gset)
179 {
180     const SubproblemDim *dim = &gset->subdims[1];
181 
182     gset->tileA.nrCols = (unsigned int)dim->bwidth;
183     gset->tileBX.nrRows = (unsigned int)dim->bwidth;
184     gset->tileBX.nrCols = (unsigned int)dim->x;
185 }
186 
187 static void
declareLocalVariables(struct KgenContext * ctx,const BlasGenSettings * gset)188 declareLocalVariables(
189     struct KgenContext *ctx,
190     const BlasGenSettings *gset)
191 {
192     char tmp[1024];
193     const char *elemType;
194     const SubproblemDim *dims = gset->subdims;
195     DataType dtype = gset->kextra->dtype;
196     size_t pitchAC, heightC;
197 
198     elemType = dtypeBuiltinType(dtype);
199     pitchAC = matrBlockPitch(dims, MATRIX_C, dtype, clblasRight);
200     heightC = szmax(dims[0].y, dims[0].x);
201 
202     declareTileStorages(ctx, gset);
203     sprintf(tmp, "const int lid = get_local_id(0);\n"
204                  "const int gid = get_group_id(0);\n"
205                  "const uint2 skewRow = 0, skewCol = 0;\n\n"
206                  "GPtr uA, uB;\n"
207                  "uint coordA, coordB, k;\n"
208                  "uint x, y;\n"
209                  "__local %s tempA[%lu], tempC[%lu];\n"
210                  "LPtr utmpA, utmpC;\n"
211                  "uint m0 = 0, k0, currM, currN;\n",
212             elemType, pitchAC * dims[0].y, pitchAC * heightC);
213     kgenAddStmt(ctx, tmp);
214 }
215 
216 static void
genReadDiagBlock(struct KgenContext * ctx,const SubproblemDim * dim,DataType dtype,const CopyBufFuncs * copyFuncs,const ZeroFuncs * zeroFuncs,KernelExtraFlags kflags,char c)217 genReadDiagBlock(
218     struct KgenContext *ctx,
219     const SubproblemDim *dim,
220     DataType dtype,
221     const CopyBufFuncs *copyFuncs,
222     const ZeroFuncs *zeroFuncs,
223     KernelExtraFlags kflags,
224     char c)
225 {
226     char tmp[1024];
227     size_t pitch;
228     const char *readBlock;
229     bool tra;
230 
231     tra = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A);
232     pitch = matrBlockPitch(dim, MATRIX_A, dtype, clblasLeft);
233 
234     if (!(kflags & KEXTRA_TAILS_M)) {
235         readBlock = (tra) ? readSquareBlockTransOpt : readSquareBlockOpt;
236         sprintf(tmp, readBlock, copyFuncs->read[MATRIX_A], c);
237     }
238     else {
239         readBlock = (tra) ? readSquareBlockTrans : readSquareBlock;
240         sprintf(tmp, readBlock, dim->y, dim->y, dim->bwidth, dim->bwidth,
241                 dim->y, dim->bwidth, copyFuncs->read[MATRIX_A], c,
242                 zeroFuncs->names[MATRIX_A], c,
243                 copyFuncs->readGeneric[MATRIX_A], c, pitch);
244     }
245     kgenAddStmt(ctx, tmp);
246 }
247 
248 static void
genZeroResult(struct KgenContext * ctx,DataType dtype,const SubproblemDim * dims,unsigned int vecLen)249 genZeroResult(
250     struct KgenContext *ctx,
251     DataType dtype,
252     const SubproblemDim *dims,
253     unsigned int vecLen)
254 {
255     unsigned int n;
256     char tmp[1024];
257 
258     getResultGPRsInfo(dtype, &dims[1], vecLen, &n, NULL);
259 
260     sprintf(tmp, "for (x = 0; x < %u; x++) {\n"
261                  "    c[x] = 0;\n"
262                  "}\n\n", n);
263 
264     kgenAddStmt(ctx, tmp);
265 }
266 
267 static void
genInternalLoopCtl(struct KgenContext * ctx,const SubproblemDim * dim,KernelExtraFlags kflags)268 genInternalLoopCtl(
269     struct KgenContext *ctx,
270     const SubproblemDim *dim,
271     KernelExtraFlags kflags)
272 {
273     char tmp[1024];
274 
275     if (isMatrixUpper(kflags)) {
276         if (kflags & KEXTRA_TAILS_M) {
277             sprintf(tmp, "for (k0 = currM + %lu; k0 < M / %lu * %lu; "
278                                "k0 += %lu)",
279                     dim[0].bwidth, dim[1].bwidth, dim[1].bwidth, dim[1].bwidth);
280         }
281         else {
282             sprintf(tmp, "for (k0 = currM + %lu; k0 < M; k0 += %lu)",
283                     dim[0].bwidth, dim[1].bwidth);
284         }
285     }
286     else {
287         sprintf(tmp, "for (k0 = 0; k0 < currM; k0 += %lu)",
288                 dim[1].bwidth);
289     }
290 
291     kgenBeginBranch(ctx, tmp);
292 }
293 
294 static void
genInitCurrM(struct KgenContext * ctx,const SubproblemDim * dim,KernelExtraFlags kflags)295 genInitCurrM(
296     struct KgenContext *ctx,
297     const SubproblemDim *dim,
298     KernelExtraFlags kflags)
299 {
300     char tmp[1024];
301 
302     if (isMatrixUpper(kflags)) {
303         /* start from the last block */
304         sprintf(tmp, "currM = ((M - 1) / %lu) * %lu;\n", dim->y, dim->y);
305         kgenAddStmt(ctx, tmp);
306     }
307     else {
308         kgenAddStmt(ctx, "currM = 0;\n");
309     }
310 }
311 
312 static void
initKernelVarNames(KernelVarNames * kvars)313 initKernelVarNames(KernelVarNames *kvars)
314 {
315     kvars->A = "uA";
316     kvars->B = "uB";
317     kvars->coordA = "coordA";
318     kvars->coordB = "coordB";
319     kvars->k = "k";
320     kvars->sizeM = "M";
321     kvars->sizeN = "N";
322     kvars->sizeK = "M";
323     kvars->lda = "lda";
324     kvars->ldb = "ldb";
325 }
326 
327 /*
328  * Generate a code copying tile between LDS and private location.
329  */
330 static void
genLdsCopy(struct KgenContext * ctx,const BlasGenSettings * gset)331 genLdsCopy(
332     struct KgenContext *ctx,
333     const BlasGenSettings *gset)
334 {
335     char pitchStr[16];
336     char coordY[128], coordX[128];
337     size_t pitch;
338     UpresVarNames uvars;
339     UpdateResultFlags upFlags = UPRES_INLINE | UPRES_USE_LDS |
340                                 UPRES_WITHOUT_ALPHA | UPRES_COLUMN_MAJOR;
341     const SubproblemDim *dims = gset->subdims;
342     unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
343 
344     memset(&uvars, 0, sizeof(uvars));
345 
346     pitch = matrBlockPitch(dims, MATRIX_C, gset->kextra->dtype, clblasRight);
347     sprintf(pitchStr, "%lu", pitch);
348     sprintf(coordY, "lid / %u * %lu", l1Pans, dims[1].y);
349     sprintf(coordX, "lid %% %u * %lu", l1Pans, dims[1].x);
350     uvars.result = "tempC";
351     uvars.ld = pitchStr;
352     uvars.startRow = coordY;
353     uvars.startCol = coordX;
354     uvars.nrRows = NULL;
355     uvars.nrCols = NULL;
356 
357     kgenBeginBranch(ctx, NULL);
358 
359     updateResultGen(ctx,
360         gset,
361         CLBLAS_TRSM,
362         UPRES_SET,
363         upFlags,
364         &uvars);
365 
366     kgenEndBranch(ctx, NULL);
367 
368     kgenAddBlankLine(ctx);
369 }
370 
371 static void
genZeroResultTrash(struct KgenContext * ctx,const SubproblemDim * dim,const CLBLASKernExtra * kextra)372 genZeroResultTrash(
373     struct KgenContext *ctx,
374     const SubproblemDim *dim,
375     const CLBLASKernExtra *kextra)
376 {
377     char tmp[1024];
378     unsigned int vecLen, pitch;
379     unsigned int i;
380 
381     vecLen = (isComplexType(kextra->dtype)) ? 1 : kextra->vecLen;
382     pitch = (unsigned int)roundUp(dim->x, vecLen);
383     sprintf(tmp, "if (coordA + %lu > M)", dim->y);
384     kgenBeginBranch(ctx, tmp);
385     sprintf(tmp, "int i = (coordA >= M) ? %lu : (%lu - M %% %lu);\n\n",
386             dim->y, dim->y, dim->y);
387     kgenAddStmt(ctx, tmp);
388     sprintf(tmp, "for (; i > 0; i--)");
389     kgenBeginBranch(ctx, tmp);
390 
391     for (i = 0; i < pitch / vecLen; i++) {
392         sprintf(tmp, "c[(%lu - i) * %u + %u] = 0;\n",
393                 dim->y, pitch / vecLen, i);
394         kgenAddStmt(ctx, tmp);
395     }
396 
397     kgenEndBranch(ctx, NULL);
398     kgenEndBranch(ctx, NULL);
399 }
400 
401 static void
setupVdepUpresFlags(KernelExtraFlags kflags,UpdateResultFlags * upFlags)402 setupVdepUpresFlags(KernelExtraFlags kflags, UpdateResultFlags* upFlags)
403 {
404     bool forceBug = false;
405 
406     unsigned int bugFlag1 = KEXTRA_NO_COPY_VEC_A
407                           | KEXTRA_TAILS_K
408                           | KEXTRA_TAILS_M;
409     unsigned int bugFlag2 = bugFlag1
410                           | KEXTRA_UPPER_TRIANG
411                           | KEXTRA_TRANS_A;
412     unsigned int bugFlag3 = bugFlag1
413                           | KEXTRA_SIDE_RIGHT
414                           | KEXTRA_COLUMN_MAJOR;
415     unsigned int bugFlag4 = bugFlag3
416                           | KEXTRA_TRANS_A;
417     unsigned int bugFlag5 = bugFlag3
418                           | KEXTRA_UPPER_TRIANG;
419     unsigned int bugFlag6 = KEXTRA_NO_COPY_VEC_A
420                           | KEXTRA_NO_COPY_VEC_B
421                           | KEXTRA_NO_COPY_VEC_C
422                           | KEXTRA_TAILS_K
423                           | KEXTRA_TAILS_M;
424     unsigned int bugFlag7 = bugFlag6
425                           | KEXTRA_COLUMN_MAJOR;
426     unsigned int bugFlag8 = bugFlag6
427                           | KEXTRA_SIDE_RIGHT
428                           | KEXTRA_UPPER_TRIANG;
429     unsigned int bugFlag9 = bugFlag6
430                           | KEXTRA_UPPER_TRIANG
431                           | KEXTRA_TRANS_A
432                           | KEXTRA_TAILS_N;
433     unsigned int bugFlag10 = bugFlag7
434                            | KEXTRA_SIDE_RIGHT
435                            | KEXTRA_TRANS_A
436                            | KEXTRA_TAILS_N;
437     unsigned int bugFlag11 = bugFlag9
438                            | KEXTRA_UNIT_DIAGONAL;
439     unsigned int bugFlag12 = bugFlag6
440                            | KEXTRA_TAILS_N
441                            | KEXTRA_SIDE_RIGHT
442                            | KEXTRA_UNIT_DIAGONAL
443                            | KEXTRA_COLUMN_MAJOR
444                            | KEXTRA_TRANS_A;
445 
446     /*
447      * WORKAROUND for AMD GPU: Now, we avoid optimizing the case when
448      *                         matrix B is not divided on block size and
449      *                         since it leads to a hang up at code seeming
450      *                         correct.
451      */
452     if (kflags & KEXTRA_VENDOR_AMD) {
453         forceBug = (kflags & KEXTRA_TAILS_N) != 0;
454     }
455     else {
456         forceBug = (kflags != bugFlag1
457             && kflags != bugFlag2 && kflags != bugFlag4 &&  kflags != bugFlag5
458             && kflags != bugFlag7 && kflags != bugFlag8 &&  kflags != bugFlag9
459             && kflags != bugFlag10 && kflags != bugFlag11
460             && kflags != bugFlag12);
461     }
462 
463     if (!forceBug) {
464         *upFlags |= UPRES_INDEXING_WITH_CONSTANTS;
465     }
466 }
467 
468 static void
genSetupCoordinates(struct KgenContext * ctx,const SubproblemDim * dims,KernelExtraFlags kflags)469 genSetupCoordinates(
470     struct KgenContext *ctx,
471     const SubproblemDim *dims,
472     KernelExtraFlags kflags)
473 {
474     char tmp[1024];
475     unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
476 
477     sprintf(tmp, "coordA = currM + lid / %u * %lu;\n", l1Pans, dims[1].y);
478     kgenAddStmt(ctx, tmp);
479     if (isMatrixUpper(kflags)) {
480         sprintf(tmp, "k = currM + %lu;\n", dims[0].y);
481     }
482     else {
483         strcpy(tmp, "k = 0;\n");
484     }
485     kgenAddStmt(ctx, tmp);
486 }
487 
488 static void
genInvertDiagBlock(struct KgenContext * ctx,const BlasGenSettings * gset,const ZeroFuncs * zeroFuncs)489 genInvertDiagBlock(
490     struct KgenContext *ctx,
491     const BlasGenSettings *gset,
492     const ZeroFuncs *zeroFuncs)
493 {
494     char tmp[1024];
495     const CLBLASKernExtra *kextra = gset->kextra;
496     const SubproblemDim *subdims = gset->subdims;
497     size_t pitchA;
498 
499     pitchA = matrBlockPitch(subdims, MATRIX_A, kextra->dtype, clblasLeft);
500 
501     sprintf(tmp, "%s((__local float4*)tempA);\n", zeroFuncs->names[MATRIX_A]);
502     kgenAddStmt(ctx, tmp);
503     kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
504 
505     if (kextra->flags & KEXTRA_UNIT_DIAGONAL) {
506         sprintf(tmp, "if (lid < %lu) {\n"
507                      "    tempC[lid * %lu + lid] = %s;\n"
508                      "}\n",
509                 subdims[0].bwidth, pitchA, strOne(kextra->dtype));
510         kgenAddStmt(ctx, tmp);
511         kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
512         kgenAddBlankLine(ctx);
513     }
514 
515     sprintf(tmp, "if (lid < %lu)", subdims[0].y);
516     kgenBeginBranch(ctx, tmp);
517     sprintf(tmp, "invert(tempC, tempA, lid, (currM + %lu > M) ? "
518                          "M - currM : %lu);\n",
519             subdims[0].y, subdims[0].y);
520     kgenAddStmt(ctx, tmp);
521     kgenEndBranch(ctx, NULL);
522     kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
523     kgenAddBlankLine(ctx);
524 }
525 
526 static void
genMulOnDiagBlock(struct KgenContext * ctx,BlasGenSettings * gset,const TileMulOpts * mulOpts)527 genMulOnDiagBlock(
528     struct KgenContext *ctx,
529     BlasGenSettings *gset,
530     const TileMulOpts *mulOpts)
531 {
532     char tmp[1024];
533     const SubproblemDim *dims = gset->subdims;
534     const CLBLASKernExtra *kextra = gset->kextra;
535     unsigned int l1Pans = (unsigned int)(dims[0].x / dims[1].x);
536     TileMulOpts optsNew;
537     size_t pitchAC;
538     const char *ptrName;
539     Tile *tile;
540     BlasGenSettings gsetNew;
541 
542     pitchAC = matrBlockPitch(dims, MATRIX_C, kextra->dtype, clblasRight);
543     ptrName = dtypeUPtrField(kextra->dtype);
544 
545     memcpy(&optsNew, mulOpts, sizeof(optsNew));
546     optsNew.memA = CLMEM_LOCAL_MEMORY;
547     optsNew.memB = CLMEM_LOCAL_MEMORY;
548     optsNew.flags &= ~(TILEMUL_TRA | TILEMUL_GLOBAL_CYCLIC | TILEMUL_CONJA);
549     optsNew.flags |= TILEMUL_TRB;
550     optsNew.memA = CLMEM_LOCAL_MEMORY;
551     optsNew.memB = CLMEM_LOCAL_MEMORY;
552     gset->varNames.A = "utmpA";
553     gset->varNames.B = "utmpC";
554 
555     sprintf(tmp, "utmpA.%s = tempA + lid / %u * %lu;\n"
556                  "utmpC.%s = tempC + lid %% %u * %lu;\n\n",
557             ptrName, l1Pans, pitchAC * dims[1].y,
558             ptrName, l1Pans, pitchAC * dims[1].x);
559     kgenAddStmt(ctx, tmp);
560 
561     memcpy(&gsetNew, gset, sizeof(gsetNew));
562     gsetNew.subdims[1].bwidth = dims[1].y;
563 
564     // Configure the tile descriptors to deal with tile of needed sizes.
565     tile = &gsetNew.tileA;
566     tile->nrRows = (unsigned int)dims[1].y;
567     tile->nrCols = (unsigned int)dims[1].y;
568     tile->trans = false;
569     tile = &gsetNew.tileBX;
570     tile->nrRows = (unsigned int)dims[1].y;
571     tile->nrCols = (unsigned int)dims[1].x;
572     tile->trans = true;
573     tileMulGen(ctx, &gsetNew, &optsNew);
574 
575     gset->varNames.A = "uA";
576     gset->varNames.B = "uB";
577 }
578 
579 static void
genOneTrsmPass(struct KgenContext * ctx,BlasGenSettings * gset,const char * updateResFnRev,const char * updateResGenericFnRev,CopyBufFuncs * copyFuncs,ZeroFuncs * zeroFuncs,bool isTailPass)580 genOneTrsmPass(
581     struct KgenContext *ctx,
582     BlasGenSettings *gset,
583     const char *updateResFnRev,
584     const char *updateResGenericFnRev,
585     CopyBufFuncs *copyFuncs,
586     ZeroFuncs *zeroFuncs,
587     bool isTailPass)
588 {
589     const CLBLASKernExtra *kextra = gset->kextra;
590     CLBLASKernExtra kextraTmp;
591     KernelExtraFlags kflags = kextra->flags;
592     char tmp[1024];
593     DataType dtype = kextra->dtype;
594     unsigned int vecLen = gset->kextra->vecLen;
595     SubproblemDim *subdims = gset->subdims;
596     int tra, trb;
597     UpdateResultFlags upFlags;
598     TilePostFetchPrivate pfpriv;
599     TileMulOpts mulOpts;
600     TailFetch tf;
601     TailStatus tailStatus = 0;
602 
603     memset(&pfpriv, 0, sizeof(pfpriv));
604 
605     // multiply options
606     mulOpts.memA = CLMEM_GLOBAL_MEMORY;
607     mulOpts.memB = CLMEM_GLOBAL_MEMORY;
608     mulOpts.core = TILEMUL_MAD;//TILEMUL_MULADD;
609     mulOpts.postFetch = NULL;
610     mulOpts.flags = kextraToTilemulFlags(CLBLAS_TRSM, kflags);
611     mulOpts.flags |= TILEMUL_EXTERN_RDECL;
612     mulOpts.flags |= getCyclicFlags(subdims, kflags, isTailPass, vecLen);
613 
614     tra = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_A);
615     trb = isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B);
616 
617     tf = checkForTailFetches(CLBLAS_TRSM, &subdims[1], kextra, MATRIX_B,
618                              false, false);
619     if (trb) {
620         tf &= ~FETCH_TAIL_COL;
621     }
622 
623     /*
624      * For lower triangular matrix we proceed upto the diagonal, so we
625      * can't exceed matrix bound and zeroing is not needed
626      */
627     if (isMatrixUpper(kflags)) {
628         tf |= checkForTailFetches(CLBLAS_TRSM, &subdims[1], kextra,
629                                   MATRIX_A, false, false);
630         if (tra && trb) {
631             tf &= ~FETCH_TAIL_COL;
632         }
633     }
634 
635     if (tf != FETCH_NO_TAILS) {
636         memset(&pfpriv, 0, sizeof(pfpriv));
637         pfpriv.funcID = CLBLAS_TRSM;
638         pfpriv.gset = gset;
639     }
640 
641     // loop over M
642     if (!isTailPass) {
643         sprintf(tmp, "for (m0 = 0; m0 < M / %lu * %lu; m0 += %lu)",
644                 subdims->y, subdims->y, subdims->y);
645         kgenBeginBranch(ctx, tmp);
646     }
647 
648     genSetupCoordinates(ctx, subdims, kflags);
649     genZeroResult(ctx, dtype, subdims, vecLen);
650 
651     if (!isMatrixUpper(kflags) && isTailPass) {
652         // skip update loop is the matrix consist of the single block
653         sprintf(tmp, "if (M > %lu)", subdims->y);
654         kgenBeginBranch(ctx, tmp);
655     }
656 
657     // Avoid tail adjusting along M.
658 
659     memcpy(&kextraTmp, kextra, sizeof(kextraTmp));
660     kextraTmp.flags &= ~(KEXTRA_TAILS_M | KEXTRA_TAILS_M_LOWER);
661 
662     // update loop is not needed for tail of an upper triangular matrix
663     if (!(isTailPass && isMatrixUpper(kflags))) {
664         if (isTailPass || (kflags & KEXTRA_TAILS_N)) {
665             kgenBeginBranch(ctx, "if (coordB < N)");
666         }
667 
668         gset->kextra = &kextraTmp;
669         tailStatus = checkGenAdjustTailCoords(ctx, CLBLAS_TRSM, gset, NULL);
670         gset->kextra = kextra;
671 
672         genInternalLoopCtl(ctx, subdims, kflags);           // loop over K
673 
674         // multiplication for the step-by-step block updating
675         subdims[0].bwidth = subdims[1].bwidth;
676         tileMulGen(ctx, gset, &mulOpts);
677         subdims[0].bwidth = subdims[0].y;
678 
679         genInternalLoopEnd(ctx);                             // loop over K
680         kgenAddBlankLine(ctx);
681 
682         // invoke once again, in order to process tails along K
683         if (isMatrixUpper(kflags) && (tf != FETCH_NO_TAILS)) {
684             subdims[0].bwidth = subdims[1].bwidth;
685 
686             if (!(tra && trb)) {
687                 mulOpts.flags |= TILEMUL_WRAP_AROUND_TAIL;
688             }
689             mulOpts.flags |= TILEMUL_GLOBAL_CYCLIC_K;
690 
691             mulOpts.postFetchPriv = &pfpriv;
692             mulOpts.postFetch = defaultTilePostFetch;
693 
694             subdims[0].bwidth = subdims[1].bwidth;
695             tileMulGen(ctx, gset, &mulOpts);
696             subdims[0].bwidth = subdims[0].y;
697 
698             mulOpts.postFetch = NULL;
699             mulOpts.postFetchPriv = NULL;
700         }
701 
702         gset->kextra = &kextraTmp;
703         checkGenRestoreTailCoords(ctx, gset, tailStatus);
704         gset->kextra = kextra;
705 
706         if (isTailPass || (kflags & KEXTRA_TAILS_N)) {
707             kgenEndBranch(ctx, NULL);
708         }
709     }
710     else if (!trb && (kflags & KEXTRA_TAILS_N)) {
711         tailStatus |= TAIL_B_RAISED;
712     }
713 
714     mulOpts.flags &= ~(TILEMUL_WRAP_AROUND_TAIL | TILEMUL_GLOBAL_CYCLIC_A |
715                        TILEMUL_GLOBAL_CYCLIC_K);
716 
717     if (!isMatrixUpper(kflags) && isTailPass) {
718         /*
719          * end of branch for non single block tail processing of
720          * the lower triangular matrix
721          */
722         kgenEndBranch(ctx, NULL);
723     }
724 
725     /*
726      * Final phase: update the accumulated result, multiply on an inverted
727      *              block and write back the result
728      */
729     if (isMatrixUpper(kflags) || ((kflags & KEXTRA_VENDOR_AMD) != 0)) {
730         kgenAddStmt(ctx, "k0 = currM;\n");
731     }
732     else {
733         kgenAddStmt(ctx, "k0 = m0;\n");
734     }
735 
736     genReadDiagBlock(ctx, subdims, dtype, copyFuncs, zeroFuncs,
737                      kflags, 'C');
738     genInvertDiagBlock(ctx, gset, zeroFuncs);
739 
740     // Avoid generating not executed non optimal path
741     gset->kextra = &kextraTmp;
742     if (isTailPass) {
743         kextraTmp.flags |= (KEXTRA_TAILS_M | KEXTRA_TAILS_M_LOWER);
744     }
745     genUpdateIntermTrsmResult(ctx, gset, updateResFnRev,
746                               updateResGenericFnRev, true);
747     gset->kextra = kextra;
748 
749     /*
750      * Heap to LDS.
751      * Zero unuseful part along columns since it will have an influence
752      * on the result at multiplication on an inverted block
753      */
754     if (isTailPass) {
755         genZeroResultTrash(ctx, &subdims[1], kextra);
756     }
757     genLdsCopy(ctx, gset);
758     kgenAddBarrier(ctx, CLK_LOCAL_MEM_FENCE);
759     genZeroResult(ctx, dtype, subdims, vecLen);
760 
761     genMulOnDiagBlock(ctx, gset, &mulOpts);
762 
763     // write back the tile evaluated
764     upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
765     upFlags |= tailStatusToUpresFlags(tailStatus);
766     upFlags |= UPRES_EXCEED_PROBLEM_CONDITION;
767     setupVdepUpresFlags(kflags, &upFlags);
768 
769     gset->kextra = &kextraTmp;
770 
771     genResultUpdateWithFlags(ctx, CLBLAS_TRSM, gset, upFlags,
772                              NULL, NULL, NULL);
773     gset->kextra = kextra;
774 
775     kgenAddBarrier(ctx, CLK_GLOBAL_MEM_FENCE);
776 
777     if (isMatrixUpper(kflags)) {
778         sprintf(tmp, "currM -= %lu;\n", subdims[0].y);
779     }
780     else {
781         sprintf(tmp, "currM += %lu;\n", subdims[0].y);
782     }
783     kgenAddStmt(ctx, tmp);
784 
785     if (!isTailPass) {
786         kgenEndBranch(ctx, NULL);                       // loop over M
787     }
788 }
789 
790 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)791 generator(
792    char *buf,
793    size_t buflen,
794    const struct SubproblemDim *subdims,
795    const struct PGranularity *pgran,
796    void *extra)
797 {
798     char tmp[1024];
799     struct KgenContext *ctx;
800     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
801     KernelExtraFlags kflags = kextra->flags;
802     DataType dtype = kextra->dtype;
803     BlasGenSettings gset;
804     char updateResFnRev[FUNC_NAME_MAXLEN];
805     char updateResGenericFnRev[FUNC_NAME_MAXLEN];
806     CopyBufFuncs copyFuncs;
807     ZeroFuncs zeroFuncs;
808     UpdateResultFlags upFlags;
809     const char *ptrName;
810     bool b;
811     ssize_t ret;
812     unsigned int l1Pans = (unsigned int)(subdims[0].x / subdims[1].x);
813     bool tailMarker[2] = {false, true};
814     int triang;
815     int i;
816 
817     if (pgran->wgDim != 1) {
818         return -EINVAL;
819     }
820 
821     if (kflags & KEXTRA_TAILS_M) {
822         kflags |= KEXTRA_TAILS_M_LOWER;
823     }
824     if (kflags & KEXTRA_TAILS_N) {
825         kflags |= KEXTRA_TAILS_N_LOWER;
826     }
827     if (kflags & KEXTRA_TAILS_K) {
828         kflags |= KEXTRA_TAILS_K_LOWER;
829     }
830     kextra->flags = kflags;
831 
832     ctx = createKgenContext(buf, buflen, true);
833     if (ctx == NULL) {
834         return -ENOMEM;
835     }
836 
837     triang = isMatrixUpper(kflags);
838 
839     memset(&gset, 0, sizeof(gset));
840     memcpy(gset.subdims, subdims, sizeof(gset.subdims));
841     gset.kextra = kextra;
842     gset.pgran = pgran;
843 
844     initKernelVarNames(&gset.varNames);
845 
846     b = isDoubleBasedType(dtype);
847     kgenDeclareUptrs(ctx, b);
848     if (isComplexType(dtype)) {
849         genComplexMathOperators(ctx, dtype);
850     }
851 
852     /*
853      * For intermediate result after blocks modification.
854      * Take into account tails adjusting
855      */
856     upFlags = kextraToUpresFlags(CLBLAS_TRSM, kflags);
857     upFlags |= UPRES_WITH_BETA | UPRES_PRIV_DEST;
858 
859     if (!isMatrixAccessColMaj(CLBLAS_TRSM, kflags, MATRIX_B) &&
860         (kflags & KEXTRA_TAILS_N)) {
861 
862         upFlags |= UPRES_TAIL_COL;
863     }
864 
865     setupVdepUpresFlags(kflags, &upFlags);
866     initTiles(&gset);
867     genUpresFuncsWithFlags(ctx, &gset, upFlags, updateResFnRev,
868                            updateResGenericFnRev);
869 
870     generateBufCopyFuncs(&copyFuncs, ctx, CLBLAS_TRSM, &gset, BCHF_MATRIX_A);
871     generateZeroingFuncs(&zeroFuncs, ctx, &subdims[0], pgran, dtype,
872                          ZF_MATRIX_A);
873 
874     //matrix inversion function
875     genInvertingBlockFunc(ctx, subdims[0].bwidth, dtype, kflags);
876     kgenAddBlankLine(ctx);
877 
878     // now, generate the kernel
879     declareTrxmKernel(ctx, dtype, pgran, kflags, CLBLAS_TRSM, "Cached", false,
880                       true);
881     ret = kgenBeginFuncBody(ctx);
882 
883     declareLocalVariables(ctx, &gset);
884     prepareTilesForMainLoop(&gset);
885 
886     sprintf(tmp, "currN = gid * %lu;\n", subdims[0].x);
887     kgenAddStmt(ctx, tmp);
888     genInitCurrM(ctx, subdims, kflags);
889 
890     if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
891         kgenAddStmt(ctx, "A += offA;\n");
892     }
893     genTrxmBMatrShift(ctx, kflags, false);
894 
895     ptrName = dtypeUPtrField(dtype);
896     sprintf(tmp, "uA.%s = A;\n"
897                  "uB.%s = B;\n\n",
898             ptrName, ptrName);
899     kgenAddStmt(ctx, tmp);
900 
901     /*
902      * B matrix is divided on panels, each work group
903      * multiply such a panel on the whole matrix A.
904      */
905 
906     sprintf(tmp, "coordB = gid * %lu + lid %% %u * %lu;\n",
907             subdims[0].x, l1Pans, subdims[1].x);
908     kgenAddStmt(ctx, tmp);
909 
910     for (i = 0; i < 2; i++) {
911         b = (i) ? tailMarker[1 - triang] : tailMarker[triang];
912         if (!b || (kflags & KEXTRA_TAILS_M)) {
913             genOneTrsmPass(ctx, &gset, updateResFnRev, updateResGenericFnRev,
914                            &copyFuncs, &zeroFuncs, b);
915         }
916     }
917 
918     kgenEndFuncBody(ctx);
919     ret = kgenAddBlankLine(ctx);
920 
921     if (!ret) {
922         ret = (ssize_t)kgenSourceSize(ctx) + 1;
923     }
924 
925     destroyKgenContext(ctx);
926 
927     return (ret < 0) ? -EOVERFLOW : ret;
928 }
929 
930 static bool
isFitToLDS(SubproblemDim * dim,DataType dtype,cl_ulong ldsSize,const void * kernelArgs)931 isFitToLDS(
932     SubproblemDim *dim,
933     DataType dtype,
934     cl_ulong ldsSize,
935     const void *kernelArgs)
936 {
937     cl_ulong sizeA, sizeC;
938     const CLBlasKargs *kargs = (const CLBlasKargs*)kernelArgs;
939 
940     /*
941      * It's needed one block for matrix A,
942      * and one block of size maximal of this one for
943      * matrix A and matrix C
944      */
945 
946     sizeA = matrBlockSize(dim, MATRIX_A, dtype, kargs->side);
947     sizeC = matrBlockSize(dim, MATRIX_B, dtype, kargs->side);
948     if (sizeA > sizeC) {
949         sizeC = sizeA;
950     }
951 
952     return ((sizeA + sizeC) * dtypeSize(dtype) <= ldsSize);
953 }
954 
955 static SolverFlags
solverFlags(void)956 solverFlags(void)
957 {
958     return (SF_WSPACE_1D | SF_TOP_INPUT_SQUARE_BLOCKS);
959 }
960 
961 static void
assignKargs(KernelArg * args,const void * params,const void * extra)962 assignKargs(KernelArg *args, const void *params, const void *extra)
963 {
964     const CLBlasKargs *blasArgs = (CLBlasKargs*)params;
965     KernelExtraFlags kflags = ((const CLBLASKernExtra*)extra)->flags;
966     int idx = 7;
967 
968     initSizeKarg(&args[0], blasArgs->M);
969     initSizeKarg(&args[1], blasArgs->N);
970     assignScalarKarg(&args[2], &(blasArgs->alpha), blasArgs->dtype);
971     initMemobjKarg(&args[3], blasArgs->A, NULL, 0, 0);
972     initSizeKarg(&args[4], blasArgs->lda.matrix);
973     initMemobjKarg(&args[5], blasArgs->B, NULL, 0, 0);
974     initSizeKarg(&args[6], blasArgs->ldb.matrix);
975     if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
976         initSizeKarg(&args[idx++], blasArgs->offA);
977     }
978     if (kflags & KEXTRA_BX_OFF_NOT_ZERO) {
979         initSizeKarg(&args[idx++], blasArgs->offBX);
980     }
981 }
982 
983 static void
fixupArgs(void * args,SubproblemDim * subdims,void * extra)984 fixupArgs(void *args, SubproblemDim *subdims, void *extra)
985 {
986     (void)extra;
987     (void)subdims;
988 
989     fixupTrxmKargs((CLBlasKargs*)args);
990 }
991 
992 void
initTrsmCachedPattern(MemoryPattern * mempat)993 initTrsmCachedPattern(MemoryPattern *mempat)
994 {
995     mempat->name = "Cached global memory based block trsm";
996     mempat->nrLevels = 2;
997     mempat->cuLevel = 0;
998     mempat->thLevel = 0;
999     mempat->sops = &trsmSops;
1000 
1001     mpatExtra.aMset = CLMEM_LEVEL_L1;
1002     mpatExtra.mobjA = CLMEM_BUFFER;
1003     mpatExtra.mobjB = CLMEM_BUFFER;
1004     mempat->extra = &mpatExtra;
1005 }
1006