1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * gemm image based generators
20  */
21 
22 #include <string.h>
23 #include <stdio.h>
24 #include <math.h>
25 #include <clBLAS.h>
26 #include <matrix_dims.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <dis_warning.h>
31 
32 #include "blas_kgen_legacy.h"
33 #include "../gen_helper.h"
34 #include "gen_helper_legacy.h"
35 
36 static CLBLASMpatExtra mpatExtra;
37 
38 static const char *prepareImagesGemmDeclA =
39     "void __kernel\n"
40     "%cprepareImageA(\n"
41     "    clblasOrder order,\n"
42     "    clblasTranspose transA,\n"
43     "    uint M,\n"
44     "    uint K,\n"
45     "    __global %s *A,\n"
46     "    uint lda,\n"
47     "    __write_only image2d_t imgA,\n"
48     "    uint offsetA)\n";
49 
50 static const char *prepareImagesGemmDeclB =
51     "void __kernel\n"
52     "%cprepareImageB(\n"
53     "    clblasOrder order,\n"
54     "    clblasTranspose transB,\n"
55     "    uint N,\n"
56     "    uint K,\n"
57     "    __global %s *B,\n"
58     "    uint ldb,\n"
59     "    __write_only image2d_t imgB,\n"
60     "    uint offsetB)\n";
61 
62 
63 static const char *imgGemmDecl =
64     "__attribute__((reqd_work_group_size(%lu, %lu, 1)))\n"
65     "void __kernel\n"
66     "%cgemmImg(\n"
67     "    const uint M,\n"
68     "    const uint N,\n"
69     "    const uint K,\n"
70     "    const %s alpha,\n"
71     "    const __read_only image2d_t A,\n"
72     "    const __read_only image2d_t B,\n"
73     "    const %s beta,\n"
74     "    __global %s *C,\n"
75     "    const uint ldc,\n"
76     "    const uint offsetC)\n";
77 
78 static ssize_t
79 generator(
80    char *buf,
81    size_t buflen,
82    const struct SubproblemDim *subdims,
83    const struct PGranularity *pgran,
84    void *extra);
85 
86 static ssize_t
87 preparator(
88    char *buf,
89    size_t buflen,
90    const struct SubproblemDim *subdims,
91    const struct PGranularity *pgran,
92    void *extra);
93 
94 static ssize_t
genWrapper(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)95 genWrapper(
96     char *buf,
97     size_t buflen,
98     const struct SubproblemDim *subdims,
99     const struct PGranularity *pgran,
100     void *extra)
101 {
102     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
103     if (kextra->kernType == CLBLAS_COMPUTING_KERNEL) {
104         return generator(buf, buflen, subdims, pgran, extra);
105     }
106     else {
107         return preparator(buf, buflen, subdims, pgran, extra);
108     }
109 }
110 
111 static void
112 assignKargs(KernelArg *args, const void *params, const void *extra);
113 
114 static bool
115 isFitToLDS(
116     SubproblemDim *dim,
117     DataType dtype,
118     cl_ulong ldsSize,
119     const void *kernelArgs);
120 
121 static SolverFlags
122 solverFlags(void);
123 
124 static void
125 calcNrThreads(
126     size_t threads[2],
127     const SubproblemDim *subdims,
128     const PGranularity *pgran,
129     const void *args,
130     const void *extra);
131 
132 static int
133 imgGetPerf(
134     unsigned int kflags,
135     const void *args);
136 
137 static SolverOps imgSops = {
138     genWrapper,
139     assignKargs,
140     isFitToLDS,
141     imgGetPerf,
142     NULL,
143     calcNrThreads,
144     NULL,
145     solverFlags,
146     NULL, //fixupKargs
147     NULL, //getDefaultDecomp
148     NULL, //getDecompList
149     NULL,
150     NULL
151 };
152 
153 // Preparation function for images based kernel generator
154 static ssize_t
preparator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)155 preparator(
156    char *buf,
157    size_t buflen,
158    const struct SubproblemDim *subdims,
159    const struct PGranularity *pgran,
160    void *extra)
161 {
162     struct KgenContext *ctx;
163     char tmp[4096], conjStr[1024];
164     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
165     CopyImgFuncs copyImgFuncs;
166     DataType dtype = kextra->dtype;
167     BlasGenSettings gset;
168     unsigned int vecLen;
169     unsigned int tsize;
170     const char *typeName;
171     char fpref;
172     bool b;
173     size_t localBufSize;
174     ssize_t ret;
175     const char *conjCond;
176 
177     const char *functionHeadA =
178         "int tra, aligned;\n"
179         "const uint bpr = (K + %lu) / %lu;\n"
180         "uint m = (gid / bpr) * %lu;\n"
181         "uint k = (gid %% bpr) * %lu;\n"
182         "uint x, y;\n"
183         "__local %s temp[%lu];\n"
184         "\n"
185         "A += offsetA;\n"
186         "tra = (!transA && order == clblasColumnMajor) ||\n"
187         "      (transA && order == clblasRowMajor);\n"
188         "if (m >= M) {\n"
189         "     return;\n"
190         "}\n";
191 
192     const char *functionHeadB =
193         "int trb, aligned;\n"
194         "const uint bpr = (K + %lu) / %lu;\n"
195         "const uint n = (gid / bpr) * %lu;\n"
196         "const uint k = (gid %% bpr) * %lu;\n"
197         "uint x, y;\n"
198         "__local %s temp[%lu];\n"
199         "\n"
200         "B += offsetB;\n"
201         "trb = (!transB && order == clblasRowMajor) ||\n"
202         "      (transB && order == clblasColumnMajor);\n"
203         "if (n >= N) {\n"
204         "    return;\n"
205         "}\n";
206 
207     // Distribute blocks across compute units and copy matrix A to image.
208     // Transposition and filling with zeros in unaligned cases is made using
209     // buffer in local memory.
210     const char *copyToImageA =
211         "//copy matrix A block\n"
212         "y = m + %u <= M ? %u : M - m;\n"
213         "x = k + %u <= K ? %u : K - k;\n"
214         "aligned = (x == %u) && (y == %u) && %d;\n"
215         "int atcase = aligned * 10 + tra;\n"
216         "%s" // conjugated check
217         "if (atcase != 10) {\n"
218         "    %s((__local float4*)temp);\n"
219         "    barrier(CLK_LOCAL_MEM_FENCE);\n"
220         "}\n"
221         "switch(atcase) {\n"
222         "case 10: //aligned, not transposed\n"
223         "    %s(imgA, k / %u, m, (GPtr)A, m, k, lda);\n"
224         "    break;\n"
225         "%s" // conjugated case
226         "case 1: //not aligned, transposed\n"
227         "    // generic transposed global to local\n"
228         "    %s((LPtr)temp, (GPtr)A, k, m, x, y, %u, lda);\n"
229         "    break;\n"
230         "case 0: //not aligned, not transposed\n"
231         "    // generic global to local\n"
232         "    %s((LPtr) temp, (GPtr)A, m, k, y, x, %u, lda);\n"
233         "    break;\n"
234         "case 11: //aligned, transposed\n"
235         "    // optimized transposed global to local\n"
236         "    %s((LPtr) temp, (GPtr)A, k, m, lda);\n"
237         "    break;\n"
238         "}\n"
239         "if (atcase != 10) {\n"
240         "    barrier(CLK_LOCAL_MEM_FENCE);\n"
241         "    %s(imgA, k / %u, m, (LPtr) temp);\n"
242         "}\n"
243         "\n";
244 
245     const char *copyToImageB =
246             "//copy matrix B block\n"
247             "y = n + %u <= N ? %u : N - n;\n"
248             "x = k + %u <= K ? %u : K - k;\n"
249             "aligned = (x == %u) && (y == %u) && %d;\n"
250             "int atcase = aligned * 10 + trb;\n"
251             "%s" // conjugated check
252             "if (atcase != 10) {\n"
253             "    %s((__local float4*)temp);\n"
254             "    barrier(CLK_LOCAL_MEM_FENCE);\n"
255             "}\n"
256             "switch (atcase) {\n"
257             "case 10: //aligned, not transposed\n"
258             "    %s(imgB, k / %u, n, (GPtr)B, n, k, ldb);\n"
259             "    break;\n"
260             "%s" // conjugated case
261             "case 1: //not aligned, transposed\n"
262             "    // generic transposed global to local\n"
263             "    %s((LPtr)temp, (GPtr)B, k, n, x, y, %u, ldb);\n"
264             "    break;\n"
265             "case 0: //not aligned, not transposed\n"
266             "    // generic global to local\n"
267             "    %s((LPtr)temp, (GPtr)B, n, k, y, x, %u, ldb);\n"
268             "    break;\n"
269             "case 11: //transposed, aligned\n"
270             "    // optimized transposed global to local\n"
271             "    %s((LPtr)temp, (GPtr)B, k, n, ldb);\n"
272             "    break;\n"
273             "}\n"
274             "if (atcase != 10) {\n"
275             "    barrier(CLK_LOCAL_MEM_FENCE);\n"
276             "    %s(imgB, k / %u, n, (LPtr)temp);\n"
277             "}\n"
278             "\n";
279 
280     memset(&copyImgFuncs, 0, sizeof(copyImgFuncs));
281     memset(&gset, 0, sizeof(gset));
282 
283     ctx = createKgenContext(buf, buflen, true);
284     if (ctx == NULL) {
285         return -ENOMEM;
286     }
287 
288     tsize = dtypeSize(dtype);
289 
290     b = isDoubleBasedType(dtype);
291     kgenDeclareUptrs(ctx, b);
292     declareBlasEnums(ctx);
293 
294     memcpy(gset.subdims, subdims, sizeof(gset.subdims));
295     gset.kextra = kextra;
296     gset.pgran = pgran;
297 
298     // generate necessary memory to image copying functions
299     generateImageCopyFuncs(&copyImgFuncs, ctx, CLBLAS_GEMM, &gset);
300 
301     kgenAddBlankLine(ctx);
302     vecLen = sizeof(cl_float4) / dtypeSize(dtype);
303     typeName = dtypeBuiltinType(dtype);
304     fpref = dtypeToBlasPrefix(dtype);
305 
306     if (kextra->kernType == CLBLAS_PREP_A_KERNEL) {
307         sprintf(tmp, prepareImagesGemmDeclA, fpref, typeName, typeName);
308         kgenDeclareFunction(ctx, tmp);
309         ret = kgenBeginFuncBody(ctx);
310 
311         // same local buffer is used for both matrix A and matrix B blocks
312         localBufSize = subdims[1].y * fl4RowWidth(subdims[1].bwidth, tsize);
313         localBufSize *= vecLen;
314 
315         kgenDeclareGroupID(ctx, "gid", pgran);
316         sprintf(tmp, functionHeadA,
317                 subdims[1].bwidth - 1, subdims[1].bwidth,
318                 subdims[1].y, subdims[1].bwidth,
319                 typeName, localBufSize);
320         kgenAddStmt(ctx, tmp);
321 
322         if (isComplexType(dtype)) {
323             conjCond = "atcase += ((atcase == 10) && "
324                     "(transA == clblasConjTrans)) ? 100 : 0;\n";
325             sprintf(conjStr, "case 110: //conjugated, not transposed, aligned\n"
326                              "    %s((LPtr)temp, (GPtr)A, m, k, lda);\n"
327                              "    break;\n",
328                     copyImgFuncs.globalToLocal[MATRIX_A]);
329         }
330         else {
331             conjCond = "";
332             strcpy(conjStr, "");
333         }
334 
335         sprintf(tmp, copyToImageA,
336                 subdims[1].y, subdims[1].y, // y = m + dy <= M ?...
337                 subdims[1].bwidth, subdims[1].bwidth, // x = k + bw <= K ?...
338                 subdims[1].bwidth, subdims[1].y, // aligned = (x==bw1)&&(y==dy1)
339                 (kextra->flags & KEXTRA_NO_COPY_VEC_A) == 0,
340                 conjCond,
341                 copyImgFuncs.zeroBlock[MATRIX_A],
342                 copyImgFuncs.globalToImage[MATRIX_A],
343                 vecLen,
344                 conjStr,
345                 copyImgFuncs.globalToLocalTransposedGeneric[MATRIX_A],
346                 subdims[1].bwidth,
347                 copyImgFuncs.globalToLocalGeneric[MATRIX_A],
348                 subdims[1].bwidth,
349                 copyImgFuncs.globalToLocalTransposed[MATRIX_A],
350                 copyImgFuncs.localToImage[MATRIX_A],
351                 vecLen);
352         kgenAddStmt(ctx, tmp);
353     }
354     else { // PREP_B
355         sprintf(tmp, prepareImagesGemmDeclB, fpref, typeName, typeName);
356         kgenDeclareFunction(ctx, tmp);
357         ret = kgenBeginFuncBody(ctx);
358 
359         // same local buffer is used for both matrix A and matrix B blocks
360         localBufSize = subdims[1].x * fl4RowWidth(subdims[1].bwidth, tsize);
361         localBufSize *= vecLen;
362 
363         kgenDeclareGroupID(ctx, "gid", pgran);
364         sprintf(tmp, functionHeadB,
365                 subdims[1].bwidth - 1, subdims[1].bwidth,
366                 subdims[1].x, subdims[1].bwidth,
367                 typeName, localBufSize);
368         kgenAddStmt(ctx, tmp);
369 
370         if (isComplexType(dtype)) {
371             conjCond = "atcase += ((atcase == 10) && "
372                     "(transB == clblasConjTrans)) ? 100 : 0;\n";
373             sprintf(conjStr, "case 110: //conjugated, not transposed, aligned\n"
374                              "    %s((LPtr)temp, (GPtr)B, n, k, ldb);\n"
375                              "    break;\n",
376                     copyImgFuncs.globalToLocal[MATRIX_B]);
377         }
378         else {
379             conjCond = "";
380             strcpy(conjStr, "");
381         }
382 
383         sprintf(tmp, copyToImageB,
384                 subdims[1].x, subdims[1].x, // y = n + dy <= N ?...
385                 subdims[1].bwidth, subdims[1].bwidth, // x = k + bw <= K ?...
386                 subdims[1].bwidth, subdims[1].x, // aligned = (x==bw1)&&(y==dx1)
387                 (kextra->flags & KEXTRA_NO_COPY_VEC_B) == 0,
388                 conjCond,
389                 copyImgFuncs.zeroBlock[MATRIX_B],
390                 copyImgFuncs.globalToImage[MATRIX_B],
391                 vecLen,
392                 conjStr,
393                 copyImgFuncs.globalToLocalTransposedGeneric[MATRIX_B],
394                 subdims[1].bwidth,
395                 copyImgFuncs.globalToLocalGeneric[MATRIX_B],
396                 subdims[1].bwidth,
397                 copyImgFuncs.globalToLocalTransposed[MATRIX_B],
398                 copyImgFuncs.localToImage[MATRIX_B],
399                 vecLen);
400         kgenAddStmt(ctx, tmp);
401     }
402 
403     kgenEndFuncBody(ctx);
404 
405     ret = kgenAddBlankLine(ctx);
406 
407     if (!ret) {
408         ret = (ssize_t)kgenSourceSize(ctx) + 1;
409     }
410     destroyKgenContext(ctx);
411 
412     return (ret < 0) ? -EOVERFLOW : ret;
413 }
414 
415 static void
initKernelVarNames(KernelVarNames * kvars,KernelExtraFlags kflags)416 initKernelVarNames(KernelVarNames *kvars, KernelExtraFlags kflags)
417 {
418     kvars->A = "imgA";
419     kvars->B = "imgB";
420     if (isMatrixAccessColMaj(CLBLAS_GEMM, kflags, MATRIX_A)) {
421         kvars->coordA = "coordA.x";
422     }
423     else {
424         kvars->coordA = "coordA.y";
425     }
426     if (isMatrixAccessColMaj(CLBLAS_GEMM, kflags, MATRIX_B)) {
427         kvars->coordB = "coordB.x";
428     }
429     else {
430         kvars->coordB = "coordB.y";
431     }
432     kvars->sizeM = "M";
433     kvars->sizeN = "N";
434     kvars->sizeK = "K";
435 }
436 
437 // global memory based kernel generator
438 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)439 generator(
440    char *buf,
441    size_t buflen,
442    const struct SubproblemDim *subdims,
443    const struct PGranularity *pgran,
444    void *extra)
445 {
446     struct KgenContext *ctx;
447     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
448     char tmp[4096], tmp1[4096];
449     char *p;
450     // is the iteration over N, N at the top level
451     const char *typeName;
452     char fpref;
453     DataType dtype = kextra->dtype;
454     ssize_t ret;
455     BlasGenSettings gset;
456     BlkMulOpts mulOpts;
457     unsigned int tsize;
458     unsigned int vecLen, outVecLen;
459     bool b;
460     const char *outTypeName;
461     unsigned int i;
462     unsigned int nrRegs, regPitch;
463     int tra, trb;
464     char vect[2] = {'y', 'x'};
465 
466     const char *coordConstants =
467         "const uint workItemM = get_global_id(0) * %lu;\n"
468         "const uint workItemN = get_global_id(1) * %lu;\n"
469         "const int2 skewRow = (int2)(0, get_local_id(0) %% %lu);\n"
470         "uint vectK = (K + %u) / %u;\n";
471 
472     /*
473      *  template for image based gemm preparation part
474      *  for two dimensional work space
475      */
476     const char *localVariables =
477         "uint k0;\n"
478         "int2 coordA = (int2)(0, workItemM);\n"
479         "int2 coordB = (int2)(0, workItemN);\n"
480         "%s c[%u];\n\n";
481 
482     tsize = dtypeSize(dtype);
483     vecLen = sizeof(cl_float4) / dtypeSize(dtype);
484     if (isComplexType(dtype)) {
485         regPitch = (unsigned int)subdims[1].x;
486     }
487     else {
488         regPitch = (unsigned int) fl4RowWidth(subdims[1].x, tsize) *
489                     sizeof(cl_float4) / tsize;
490     }
491 
492     memset(&gset, 0, sizeof(gset));
493     memcpy(gset.subdims, subdims, sizeof(gset.subdims));
494     gset.kextra = kextra;
495     gset.pgran = pgran;
496     initKernelVarNames(&gset.varNames, kextra->flags);
497 
498     ctx = createKgenContext(buf, buflen, true);
499     if (ctx == NULL) {
500         return -ENOMEM;
501     }
502 
503     // at first, generate needed declarations and auxiliary functions
504     b = isDoubleBasedType(dtype);
505     kgenDeclareUptrs(ctx, b);
506 
507     typeName = dtypeBuiltinType(dtype);
508     fpref = dtypeToBlasPrefix(dtype);
509 
510     // now, generate the kernel
511 
512     sprintf(tmp, imgGemmDecl, pgran->wgSize[0], pgran->wgSize[1], fpref,
513             typeName, typeName, typeName);
514     kgenDeclareFunction(ctx, tmp);
515     ret = kgenBeginFuncBody(ctx);
516 
517     // constants
518     sprintf(tmp, coordConstants,
519             subdims[1].y, subdims[1].x, subdims[1].y,
520             vecLen - 1, vecLen);
521     kgenAddStmt(ctx, tmp);
522 
523     /*
524      * Calculate local buffer pitches, and then declare local
525      * variables
526      */
527     getResultGPRsInfo(dtype, &subdims[1], vecLen, &nrRegs, &outTypeName);
528 
529     sprintf(tmp, localVariables, outTypeName, nrRegs);
530     kgenAddStmt(ctx, tmp);
531 
532     // check if offset exceeds matrix
533     kgenAddStmt(ctx, "if ((workItemM >= M) ||"
534                          "(workItemN >= N)) {\n"
535                      "    return;\n"
536                      "}\n");
537 
538     kgenAddStmt(ctx, "C += offsetC;\n");
539 
540     // zero C block
541     sprintf(tmp, "for (k0 = 0; k0 < %u; k0++) {\n"
542                  "    c[k0] = 0;\n"
543                  "}\n\n",
544             nrRegs);
545     kgenAddStmt(ctx, tmp);
546 
547     // block multiplication inlined function
548     sprintf(tmp, "for (k0 = 0; k0 < vectK; k0 += %lu)",
549             subdims[1].bwidth / vecLen);
550     kgenBeginBranch(ctx, tmp);
551 
552     mulOpts.aMobj = CLMEM_IMAGE;
553     mulOpts.bMobj = CLMEM_IMAGE;
554     mulOpts.flags = BLKMUL_OUTPUT_PRIVATE | BLKMUL_SKEW_ROW | BLKMUL_INLINE;
555     if (isComplexType(dtype)) {
556         mulOpts.core = BLKMUL_SEPARATE_MULADD;
557     }
558     else {
559         mulOpts.core = BLKMUL_MAD;
560     }
561     mulOpts.argNames.coordA = "coordA";
562     mulOpts.argNames.coordB = "coordB";
563     mulOpts.argNames.skewCol = "skewCol";
564     mulOpts.argNames.skewRow = "skewRow";
565     mulOpts.argNames.k = "k0";
566     mulOpts.argNames.vectBoundK = "vectK";
567     ret = blkMulGen(ctx, subdims, dtype, &mulOpts);
568     if (ret) {
569         destroyKgenContext(ctx);
570         return -EOVERFLOW;
571     }
572 
573     // update image coordinates
574     sprintf(tmp, "\ncoordA.x += %lu;\n"
575                  "coordB.x += %lu;\n",
576             subdims[1].bwidth / vecLen, subdims[1].bwidth / vecLen);
577     kgenAddStmt(ctx, tmp);
578 
579     kgenEndBranch(ctx, NULL);
580 
581     // reorder the given solution
582     outVecLen = isComplexType(dtype) ? 1 : vecLen;
583     p = tmp1;
584     for (i = 0; i < regPitch / outVecLen; i++) {
585         unsigned int k = (unsigned int)(subdims[1].y - 1) *
586                          regPitch / outVecLen + i;
587 
588         sprintf(p,  "\n"
589                     "    tmp = c[%u];\n"
590                     "    for (j = %lu; j >= 0; j--) {\n"
591                     "        c[(j+1) * %u + %u] = c[j * %u + %u];\n"
592                     "    }\n"
593                     "    c[%u] = tmp;\n",
594                 k, subdims[1].y - 2, regPitch / outVecLen,
595                 i, regPitch / outVecLen, i, i);
596         p += strlen(p);
597     }
598     sprintf(tmp, "\n"
599                  "for (k0 = 0; k0 < skewRow.y; k0++) {\n"
600                  "    int j;\n"
601                  "    %s tmp;\n"
602                  "%s"
603                  "}\n"
604                  "\n",
605                  outTypeName, tmp1);
606     kgenAddStmt(ctx, tmp);
607 
608     tra = isMatrixAccessColMaj(CLBLAS_GEMM, kextra->flags, MATRIX_A);
609     trb = isMatrixAccessColMaj(CLBLAS_GEMM, kextra->flags, MATRIX_B);
610     sprintf(tmp, "coordA.%c = workItemM;\n"
611                  "coordB.%c = workItemN;\n\n",
612             vect[tra], vect[trb]);
613     kgenAddStmt(ctx, tmp);
614 
615     // write back the tile evaluated
616     generateResultUpdateOld(ctx, CLBLAS_GEMM, &gset, NULL, NULL);
617 
618     kgenEndFuncBody(ctx);
619     ret = kgenAddBlankLine(ctx);
620 
621     if (!ret) {
622         ret = (ssize_t)kgenSourceSize(ctx) + 1;
623     }
624 
625     destroyKgenContext(ctx);
626 
627     return (ret < 0) ? -EOVERFLOW : ret;
628 }
629 
630 static void
assignKargs(KernelArg * args,const void * params,const void * extra)631 assignKargs(KernelArg *args, const void *params, const void *extra)
632 {
633     const CLBlasKargs *blasArgs = (const CLBlasKargs*)params;
634 
635     (void)extra;
636 
637     switch (blasArgs->kernType) {
638     case CLBLAS_COMPUTING_KERNEL:
639         // arguments for computational kernel
640         initSizeKarg(&args[0], blasArgs->M);
641         initSizeKarg(&args[1], blasArgs->N);
642         initSizeKarg(&args[2], blasArgs->K);
643         assignScalarKarg(&args[3], &(blasArgs->alpha), blasArgs->dtype);
644         INIT_KARG(&args[4], blasArgs->scimage[0]);
645         INIT_KARG(&args[5], blasArgs->scimage[1]);
646         assignScalarKarg(&args[6], &(blasArgs->beta), blasArgs->dtype);
647         initMemobjKarg(&args[7], blasArgs->C, NULL, 0, 0);
648         initSizeKarg(&args[8], blasArgs->ldc.matrix);
649         initSizeKarg(&args[9], blasArgs->offCY);
650         break;
651     case CLBLAS_PREP_A_KERNEL:
652         INIT_KARG(&args[0], blasArgs->order);
653         INIT_KARG(&args[1], blasArgs->transA);
654         initSizeKarg(&args[2], blasArgs->M);
655         initSizeKarg(&args[3], blasArgs->K);
656         initMemobjKarg(&args[4], blasArgs->A, NULL, 0, 0);
657         initSizeKarg(&args[5], blasArgs->lda.matrix);
658         INIT_KARG(&args[6], blasArgs->scimage[0]);
659         initSizeKarg(&args[7], blasArgs->offA);
660         break;
661     case CLBLAS_PREP_B_KERNEL:
662         INIT_KARG(&args[0], blasArgs->order);
663         INIT_KARG(&args[1], blasArgs->transB);
664         initSizeKarg(&args[2], blasArgs->N);
665         initSizeKarg(&args[3], blasArgs->K);
666         initMemobjKarg(&args[4], blasArgs->B, NULL, 0, 0);
667         initSizeKarg(&args[5], blasArgs->ldb.matrix);
668         INIT_KARG(&args[6], blasArgs->scimage[1]);
669         initSizeKarg(&args[7], blasArgs->offBX);
670         break;
671     default:
672         //this should not happen
673         break;
674     }
675 }
676 
677 static bool
isFitToLDS(SubproblemDim * dim,DataType dtype,cl_ulong ldsSize,const void * kernelArgs)678 isFitToLDS(
679     SubproblemDim *dim,
680     DataType dtype,
681     cl_ulong ldsSize,
682     const void *kernelArgs)
683 {
684     cl_ulong size;
685     const CLBlasKargs *kargs = (const CLBlasKargs*)kernelArgs;
686     size = matrBlockSize(&dim[1], MATRIX_C, dtype, kargs->side);
687     return (size * dtypeSize(dtype) <= ldsSize);
688 }
689 
690 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * extra)691 calcNrThreads(
692     size_t threads[2],
693     const SubproblemDim *subdims,
694     const PGranularity *pgran,
695     const void *args,
696     const void *extra)
697 {
698     const CLBlasKargs *kargs = args;
699     (void)extra;
700 
701     if (kargs->kernType != CLBLAS_COMPUTING_KERNEL) {
702         const size_t *whole, *part;
703         size_t nrGroups;
704 
705         // each thread gets one block
706 
707         if (kargs->kernType == CLBLAS_PREP_A_KERNEL) {
708             whole = &kargs->M;
709             part = &subdims[0].itemY;
710         }
711         else {
712             whole = &kargs->N;
713             part = &subdims[0].itemX;
714         }
715 
716         nrGroups = *whole / *part + (*whole % *part != 0);
717         nrGroups *= (kargs->K / subdims[0].bwidth +
718                     (kargs->K % subdims[0].bwidth != 0));
719         threads[0] = pgran->wgSize[0] * nrGroups;
720         threads[1] = pgran->wgSize[1];
721     }
722     else {
723         calcGlobalThreads(threads, &subdims[0], pgran, kargs->M, kargs->N);
724     }
725 }
726 
727 static SolverFlags
solverFlags(void)728 solverFlags(void)
729 {
730     return (SF_WSPACE_2D);
731 }
732 
733 void
initGemmImgPattern(MemoryPattern * mempat)734 initGemmImgPattern(MemoryPattern *mempat)
735 {
736     mempat->name = "Image based block gemm";
737     mempat->nrLevels = 2;
738     mempat->cuLevel = 0;
739     mempat->thLevel = 1;
740     mempat->sops = &imgSops;
741 
742     mpatExtra.aMset = CLMEM_LEVEL_L1 | CLMEM_LEVEL_LDS;
743     mpatExtra.bMset = CLMEM_LEVEL_L1 | CLMEM_LEVEL_LDS;
744     mpatExtra.mobjA = CLMEM_IMAGE;
745     mpatExtra.mobjB = CLMEM_IMAGE;
746     mempat->extra = &mpatExtra;
747 }
748 
749 static int
imgGetPerf(unsigned int kflags,const void * args)750 imgGetPerf(
751     unsigned int kflags,
752     const void *args)
753 {
754     (void)args;
755     (void)kflags;
756 
757     return PPERF_POOR;
758 }
759