1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 /*
19 * gemm image based generators
20 */
21
22 #include <string.h>
23 #include <stdio.h>
24 #include <math.h>
25 #include <clBLAS.h>
26 #include <matrix_dims.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <dis_warning.h>
31
32 #include "blas_kgen_legacy.h"
33 #include "../gen_helper.h"
34 #include "gen_helper_legacy.h"
35
36 static CLBLASMpatExtra mpatExtra;
37
38 static const char *prepareImagesGemmDeclA =
39 "void __kernel\n"
40 "%cprepareImageA(\n"
41 " clblasOrder order,\n"
42 " clblasTranspose transA,\n"
43 " uint M,\n"
44 " uint K,\n"
45 " __global %s *A,\n"
46 " uint lda,\n"
47 " __write_only image2d_t imgA,\n"
48 " uint offsetA)\n";
49
50 static const char *prepareImagesGemmDeclB =
51 "void __kernel\n"
52 "%cprepareImageB(\n"
53 " clblasOrder order,\n"
54 " clblasTranspose transB,\n"
55 " uint N,\n"
56 " uint K,\n"
57 " __global %s *B,\n"
58 " uint ldb,\n"
59 " __write_only image2d_t imgB,\n"
60 " uint offsetB)\n";
61
62
63 static const char *imgGemmDecl =
64 "__attribute__((reqd_work_group_size(%lu, %lu, 1)))\n"
65 "void __kernel\n"
66 "%cgemmImg(\n"
67 " const uint M,\n"
68 " const uint N,\n"
69 " const uint K,\n"
70 " const %s alpha,\n"
71 " const __read_only image2d_t A,\n"
72 " const __read_only image2d_t B,\n"
73 " const %s beta,\n"
74 " __global %s *C,\n"
75 " const uint ldc,\n"
76 " const uint offsetC)\n";
77
78 static ssize_t
79 generator(
80 char *buf,
81 size_t buflen,
82 const struct SubproblemDim *subdims,
83 const struct PGranularity *pgran,
84 void *extra);
85
86 static ssize_t
87 preparator(
88 char *buf,
89 size_t buflen,
90 const struct SubproblemDim *subdims,
91 const struct PGranularity *pgran,
92 void *extra);
93
94 static ssize_t
genWrapper(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)95 genWrapper(
96 char *buf,
97 size_t buflen,
98 const struct SubproblemDim *subdims,
99 const struct PGranularity *pgran,
100 void *extra)
101 {
102 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
103 if (kextra->kernType == CLBLAS_COMPUTING_KERNEL) {
104 return generator(buf, buflen, subdims, pgran, extra);
105 }
106 else {
107 return preparator(buf, buflen, subdims, pgran, extra);
108 }
109 }
110
111 static void
112 assignKargs(KernelArg *args, const void *params, const void *extra);
113
114 static bool
115 isFitToLDS(
116 SubproblemDim *dim,
117 DataType dtype,
118 cl_ulong ldsSize,
119 const void *kernelArgs);
120
121 static SolverFlags
122 solverFlags(void);
123
124 static void
125 calcNrThreads(
126 size_t threads[2],
127 const SubproblemDim *subdims,
128 const PGranularity *pgran,
129 const void *args,
130 const void *extra);
131
132 static int
133 imgGetPerf(
134 unsigned int kflags,
135 const void *args);
136
137 static SolverOps imgSops = {
138 genWrapper,
139 assignKargs,
140 isFitToLDS,
141 imgGetPerf,
142 NULL,
143 calcNrThreads,
144 NULL,
145 solverFlags,
146 NULL, //fixupKargs
147 NULL, //getDefaultDecomp
148 NULL, //getDecompList
149 NULL,
150 NULL
151 };
152
153 // Preparation function for images based kernel generator
154 static ssize_t
preparator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)155 preparator(
156 char *buf,
157 size_t buflen,
158 const struct SubproblemDim *subdims,
159 const struct PGranularity *pgran,
160 void *extra)
161 {
162 struct KgenContext *ctx;
163 char tmp[4096], conjStr[1024];
164 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
165 CopyImgFuncs copyImgFuncs;
166 DataType dtype = kextra->dtype;
167 BlasGenSettings gset;
168 unsigned int vecLen;
169 unsigned int tsize;
170 const char *typeName;
171 char fpref;
172 bool b;
173 size_t localBufSize;
174 ssize_t ret;
175 const char *conjCond;
176
177 const char *functionHeadA =
178 "int tra, aligned;\n"
179 "const uint bpr = (K + %lu) / %lu;\n"
180 "uint m = (gid / bpr) * %lu;\n"
181 "uint k = (gid %% bpr) * %lu;\n"
182 "uint x, y;\n"
183 "__local %s temp[%lu];\n"
184 "\n"
185 "A += offsetA;\n"
186 "tra = (!transA && order == clblasColumnMajor) ||\n"
187 " (transA && order == clblasRowMajor);\n"
188 "if (m >= M) {\n"
189 " return;\n"
190 "}\n";
191
192 const char *functionHeadB =
193 "int trb, aligned;\n"
194 "const uint bpr = (K + %lu) / %lu;\n"
195 "const uint n = (gid / bpr) * %lu;\n"
196 "const uint k = (gid %% bpr) * %lu;\n"
197 "uint x, y;\n"
198 "__local %s temp[%lu];\n"
199 "\n"
200 "B += offsetB;\n"
201 "trb = (!transB && order == clblasRowMajor) ||\n"
202 " (transB && order == clblasColumnMajor);\n"
203 "if (n >= N) {\n"
204 " return;\n"
205 "}\n";
206
207 // Distribute blocks across compute units and copy matrix A to image.
208 // Transposition and filling with zeros in unaligned cases is made using
209 // buffer in local memory.
210 const char *copyToImageA =
211 "//copy matrix A block\n"
212 "y = m + %u <= M ? %u : M - m;\n"
213 "x = k + %u <= K ? %u : K - k;\n"
214 "aligned = (x == %u) && (y == %u) && %d;\n"
215 "int atcase = aligned * 10 + tra;\n"
216 "%s" // conjugated check
217 "if (atcase != 10) {\n"
218 " %s((__local float4*)temp);\n"
219 " barrier(CLK_LOCAL_MEM_FENCE);\n"
220 "}\n"
221 "switch(atcase) {\n"
222 "case 10: //aligned, not transposed\n"
223 " %s(imgA, k / %u, m, (GPtr)A, m, k, lda);\n"
224 " break;\n"
225 "%s" // conjugated case
226 "case 1: //not aligned, transposed\n"
227 " // generic transposed global to local\n"
228 " %s((LPtr)temp, (GPtr)A, k, m, x, y, %u, lda);\n"
229 " break;\n"
230 "case 0: //not aligned, not transposed\n"
231 " // generic global to local\n"
232 " %s((LPtr) temp, (GPtr)A, m, k, y, x, %u, lda);\n"
233 " break;\n"
234 "case 11: //aligned, transposed\n"
235 " // optimized transposed global to local\n"
236 " %s((LPtr) temp, (GPtr)A, k, m, lda);\n"
237 " break;\n"
238 "}\n"
239 "if (atcase != 10) {\n"
240 " barrier(CLK_LOCAL_MEM_FENCE);\n"
241 " %s(imgA, k / %u, m, (LPtr) temp);\n"
242 "}\n"
243 "\n";
244
245 const char *copyToImageB =
246 "//copy matrix B block\n"
247 "y = n + %u <= N ? %u : N - n;\n"
248 "x = k + %u <= K ? %u : K - k;\n"
249 "aligned = (x == %u) && (y == %u) && %d;\n"
250 "int atcase = aligned * 10 + trb;\n"
251 "%s" // conjugated check
252 "if (atcase != 10) {\n"
253 " %s((__local float4*)temp);\n"
254 " barrier(CLK_LOCAL_MEM_FENCE);\n"
255 "}\n"
256 "switch (atcase) {\n"
257 "case 10: //aligned, not transposed\n"
258 " %s(imgB, k / %u, n, (GPtr)B, n, k, ldb);\n"
259 " break;\n"
260 "%s" // conjugated case
261 "case 1: //not aligned, transposed\n"
262 " // generic transposed global to local\n"
263 " %s((LPtr)temp, (GPtr)B, k, n, x, y, %u, ldb);\n"
264 " break;\n"
265 "case 0: //not aligned, not transposed\n"
266 " // generic global to local\n"
267 " %s((LPtr)temp, (GPtr)B, n, k, y, x, %u, ldb);\n"
268 " break;\n"
269 "case 11: //transposed, aligned\n"
270 " // optimized transposed global to local\n"
271 " %s((LPtr)temp, (GPtr)B, k, n, ldb);\n"
272 " break;\n"
273 "}\n"
274 "if (atcase != 10) {\n"
275 " barrier(CLK_LOCAL_MEM_FENCE);\n"
276 " %s(imgB, k / %u, n, (LPtr)temp);\n"
277 "}\n"
278 "\n";
279
280 memset(©ImgFuncs, 0, sizeof(copyImgFuncs));
281 memset(&gset, 0, sizeof(gset));
282
283 ctx = createKgenContext(buf, buflen, true);
284 if (ctx == NULL) {
285 return -ENOMEM;
286 }
287
288 tsize = dtypeSize(dtype);
289
290 b = isDoubleBasedType(dtype);
291 kgenDeclareUptrs(ctx, b);
292 declareBlasEnums(ctx);
293
294 memcpy(gset.subdims, subdims, sizeof(gset.subdims));
295 gset.kextra = kextra;
296 gset.pgran = pgran;
297
298 // generate necessary memory to image copying functions
299 generateImageCopyFuncs(©ImgFuncs, ctx, CLBLAS_GEMM, &gset);
300
301 kgenAddBlankLine(ctx);
302 vecLen = sizeof(cl_float4) / dtypeSize(dtype);
303 typeName = dtypeBuiltinType(dtype);
304 fpref = dtypeToBlasPrefix(dtype);
305
306 if (kextra->kernType == CLBLAS_PREP_A_KERNEL) {
307 sprintf(tmp, prepareImagesGemmDeclA, fpref, typeName, typeName);
308 kgenDeclareFunction(ctx, tmp);
309 ret = kgenBeginFuncBody(ctx);
310
311 // same local buffer is used for both matrix A and matrix B blocks
312 localBufSize = subdims[1].y * fl4RowWidth(subdims[1].bwidth, tsize);
313 localBufSize *= vecLen;
314
315 kgenDeclareGroupID(ctx, "gid", pgran);
316 sprintf(tmp, functionHeadA,
317 subdims[1].bwidth - 1, subdims[1].bwidth,
318 subdims[1].y, subdims[1].bwidth,
319 typeName, localBufSize);
320 kgenAddStmt(ctx, tmp);
321
322 if (isComplexType(dtype)) {
323 conjCond = "atcase += ((atcase == 10) && "
324 "(transA == clblasConjTrans)) ? 100 : 0;\n";
325 sprintf(conjStr, "case 110: //conjugated, not transposed, aligned\n"
326 " %s((LPtr)temp, (GPtr)A, m, k, lda);\n"
327 " break;\n",
328 copyImgFuncs.globalToLocal[MATRIX_A]);
329 }
330 else {
331 conjCond = "";
332 strcpy(conjStr, "");
333 }
334
335 sprintf(tmp, copyToImageA,
336 subdims[1].y, subdims[1].y, // y = m + dy <= M ?...
337 subdims[1].bwidth, subdims[1].bwidth, // x = k + bw <= K ?...
338 subdims[1].bwidth, subdims[1].y, // aligned = (x==bw1)&&(y==dy1)
339 (kextra->flags & KEXTRA_NO_COPY_VEC_A) == 0,
340 conjCond,
341 copyImgFuncs.zeroBlock[MATRIX_A],
342 copyImgFuncs.globalToImage[MATRIX_A],
343 vecLen,
344 conjStr,
345 copyImgFuncs.globalToLocalTransposedGeneric[MATRIX_A],
346 subdims[1].bwidth,
347 copyImgFuncs.globalToLocalGeneric[MATRIX_A],
348 subdims[1].bwidth,
349 copyImgFuncs.globalToLocalTransposed[MATRIX_A],
350 copyImgFuncs.localToImage[MATRIX_A],
351 vecLen);
352 kgenAddStmt(ctx, tmp);
353 }
354 else { // PREP_B
355 sprintf(tmp, prepareImagesGemmDeclB, fpref, typeName, typeName);
356 kgenDeclareFunction(ctx, tmp);
357 ret = kgenBeginFuncBody(ctx);
358
359 // same local buffer is used for both matrix A and matrix B blocks
360 localBufSize = subdims[1].x * fl4RowWidth(subdims[1].bwidth, tsize);
361 localBufSize *= vecLen;
362
363 kgenDeclareGroupID(ctx, "gid", pgran);
364 sprintf(tmp, functionHeadB,
365 subdims[1].bwidth - 1, subdims[1].bwidth,
366 subdims[1].x, subdims[1].bwidth,
367 typeName, localBufSize);
368 kgenAddStmt(ctx, tmp);
369
370 if (isComplexType(dtype)) {
371 conjCond = "atcase += ((atcase == 10) && "
372 "(transB == clblasConjTrans)) ? 100 : 0;\n";
373 sprintf(conjStr, "case 110: //conjugated, not transposed, aligned\n"
374 " %s((LPtr)temp, (GPtr)B, n, k, ldb);\n"
375 " break;\n",
376 copyImgFuncs.globalToLocal[MATRIX_B]);
377 }
378 else {
379 conjCond = "";
380 strcpy(conjStr, "");
381 }
382
383 sprintf(tmp, copyToImageB,
384 subdims[1].x, subdims[1].x, // y = n + dy <= N ?...
385 subdims[1].bwidth, subdims[1].bwidth, // x = k + bw <= K ?...
386 subdims[1].bwidth, subdims[1].x, // aligned = (x==bw1)&&(y==dx1)
387 (kextra->flags & KEXTRA_NO_COPY_VEC_B) == 0,
388 conjCond,
389 copyImgFuncs.zeroBlock[MATRIX_B],
390 copyImgFuncs.globalToImage[MATRIX_B],
391 vecLen,
392 conjStr,
393 copyImgFuncs.globalToLocalTransposedGeneric[MATRIX_B],
394 subdims[1].bwidth,
395 copyImgFuncs.globalToLocalGeneric[MATRIX_B],
396 subdims[1].bwidth,
397 copyImgFuncs.globalToLocalTransposed[MATRIX_B],
398 copyImgFuncs.localToImage[MATRIX_B],
399 vecLen);
400 kgenAddStmt(ctx, tmp);
401 }
402
403 kgenEndFuncBody(ctx);
404
405 ret = kgenAddBlankLine(ctx);
406
407 if (!ret) {
408 ret = (ssize_t)kgenSourceSize(ctx) + 1;
409 }
410 destroyKgenContext(ctx);
411
412 return (ret < 0) ? -EOVERFLOW : ret;
413 }
414
415 static void
initKernelVarNames(KernelVarNames * kvars,KernelExtraFlags kflags)416 initKernelVarNames(KernelVarNames *kvars, KernelExtraFlags kflags)
417 {
418 kvars->A = "imgA";
419 kvars->B = "imgB";
420 if (isMatrixAccessColMaj(CLBLAS_GEMM, kflags, MATRIX_A)) {
421 kvars->coordA = "coordA.x";
422 }
423 else {
424 kvars->coordA = "coordA.y";
425 }
426 if (isMatrixAccessColMaj(CLBLAS_GEMM, kflags, MATRIX_B)) {
427 kvars->coordB = "coordB.x";
428 }
429 else {
430 kvars->coordB = "coordB.y";
431 }
432 kvars->sizeM = "M";
433 kvars->sizeN = "N";
434 kvars->sizeK = "K";
435 }
436
437 // global memory based kernel generator
438 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)439 generator(
440 char *buf,
441 size_t buflen,
442 const struct SubproblemDim *subdims,
443 const struct PGranularity *pgran,
444 void *extra)
445 {
446 struct KgenContext *ctx;
447 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
448 char tmp[4096], tmp1[4096];
449 char *p;
450 // is the iteration over N, N at the top level
451 const char *typeName;
452 char fpref;
453 DataType dtype = kextra->dtype;
454 ssize_t ret;
455 BlasGenSettings gset;
456 BlkMulOpts mulOpts;
457 unsigned int tsize;
458 unsigned int vecLen, outVecLen;
459 bool b;
460 const char *outTypeName;
461 unsigned int i;
462 unsigned int nrRegs, regPitch;
463 int tra, trb;
464 char vect[2] = {'y', 'x'};
465
466 const char *coordConstants =
467 "const uint workItemM = get_global_id(0) * %lu;\n"
468 "const uint workItemN = get_global_id(1) * %lu;\n"
469 "const int2 skewRow = (int2)(0, get_local_id(0) %% %lu);\n"
470 "uint vectK = (K + %u) / %u;\n";
471
472 /*
473 * template for image based gemm preparation part
474 * for two dimensional work space
475 */
476 const char *localVariables =
477 "uint k0;\n"
478 "int2 coordA = (int2)(0, workItemM);\n"
479 "int2 coordB = (int2)(0, workItemN);\n"
480 "%s c[%u];\n\n";
481
482 tsize = dtypeSize(dtype);
483 vecLen = sizeof(cl_float4) / dtypeSize(dtype);
484 if (isComplexType(dtype)) {
485 regPitch = (unsigned int)subdims[1].x;
486 }
487 else {
488 regPitch = (unsigned int) fl4RowWidth(subdims[1].x, tsize) *
489 sizeof(cl_float4) / tsize;
490 }
491
492 memset(&gset, 0, sizeof(gset));
493 memcpy(gset.subdims, subdims, sizeof(gset.subdims));
494 gset.kextra = kextra;
495 gset.pgran = pgran;
496 initKernelVarNames(&gset.varNames, kextra->flags);
497
498 ctx = createKgenContext(buf, buflen, true);
499 if (ctx == NULL) {
500 return -ENOMEM;
501 }
502
503 // at first, generate needed declarations and auxiliary functions
504 b = isDoubleBasedType(dtype);
505 kgenDeclareUptrs(ctx, b);
506
507 typeName = dtypeBuiltinType(dtype);
508 fpref = dtypeToBlasPrefix(dtype);
509
510 // now, generate the kernel
511
512 sprintf(tmp, imgGemmDecl, pgran->wgSize[0], pgran->wgSize[1], fpref,
513 typeName, typeName, typeName);
514 kgenDeclareFunction(ctx, tmp);
515 ret = kgenBeginFuncBody(ctx);
516
517 // constants
518 sprintf(tmp, coordConstants,
519 subdims[1].y, subdims[1].x, subdims[1].y,
520 vecLen - 1, vecLen);
521 kgenAddStmt(ctx, tmp);
522
523 /*
524 * Calculate local buffer pitches, and then declare local
525 * variables
526 */
527 getResultGPRsInfo(dtype, &subdims[1], vecLen, &nrRegs, &outTypeName);
528
529 sprintf(tmp, localVariables, outTypeName, nrRegs);
530 kgenAddStmt(ctx, tmp);
531
532 // check if offset exceeds matrix
533 kgenAddStmt(ctx, "if ((workItemM >= M) ||"
534 "(workItemN >= N)) {\n"
535 " return;\n"
536 "}\n");
537
538 kgenAddStmt(ctx, "C += offsetC;\n");
539
540 // zero C block
541 sprintf(tmp, "for (k0 = 0; k0 < %u; k0++) {\n"
542 " c[k0] = 0;\n"
543 "}\n\n",
544 nrRegs);
545 kgenAddStmt(ctx, tmp);
546
547 // block multiplication inlined function
548 sprintf(tmp, "for (k0 = 0; k0 < vectK; k0 += %lu)",
549 subdims[1].bwidth / vecLen);
550 kgenBeginBranch(ctx, tmp);
551
552 mulOpts.aMobj = CLMEM_IMAGE;
553 mulOpts.bMobj = CLMEM_IMAGE;
554 mulOpts.flags = BLKMUL_OUTPUT_PRIVATE | BLKMUL_SKEW_ROW | BLKMUL_INLINE;
555 if (isComplexType(dtype)) {
556 mulOpts.core = BLKMUL_SEPARATE_MULADD;
557 }
558 else {
559 mulOpts.core = BLKMUL_MAD;
560 }
561 mulOpts.argNames.coordA = "coordA";
562 mulOpts.argNames.coordB = "coordB";
563 mulOpts.argNames.skewCol = "skewCol";
564 mulOpts.argNames.skewRow = "skewRow";
565 mulOpts.argNames.k = "k0";
566 mulOpts.argNames.vectBoundK = "vectK";
567 ret = blkMulGen(ctx, subdims, dtype, &mulOpts);
568 if (ret) {
569 destroyKgenContext(ctx);
570 return -EOVERFLOW;
571 }
572
573 // update image coordinates
574 sprintf(tmp, "\ncoordA.x += %lu;\n"
575 "coordB.x += %lu;\n",
576 subdims[1].bwidth / vecLen, subdims[1].bwidth / vecLen);
577 kgenAddStmt(ctx, tmp);
578
579 kgenEndBranch(ctx, NULL);
580
581 // reorder the given solution
582 outVecLen = isComplexType(dtype) ? 1 : vecLen;
583 p = tmp1;
584 for (i = 0; i < regPitch / outVecLen; i++) {
585 unsigned int k = (unsigned int)(subdims[1].y - 1) *
586 regPitch / outVecLen + i;
587
588 sprintf(p, "\n"
589 " tmp = c[%u];\n"
590 " for (j = %lu; j >= 0; j--) {\n"
591 " c[(j+1) * %u + %u] = c[j * %u + %u];\n"
592 " }\n"
593 " c[%u] = tmp;\n",
594 k, subdims[1].y - 2, regPitch / outVecLen,
595 i, regPitch / outVecLen, i, i);
596 p += strlen(p);
597 }
598 sprintf(tmp, "\n"
599 "for (k0 = 0; k0 < skewRow.y; k0++) {\n"
600 " int j;\n"
601 " %s tmp;\n"
602 "%s"
603 "}\n"
604 "\n",
605 outTypeName, tmp1);
606 kgenAddStmt(ctx, tmp);
607
608 tra = isMatrixAccessColMaj(CLBLAS_GEMM, kextra->flags, MATRIX_A);
609 trb = isMatrixAccessColMaj(CLBLAS_GEMM, kextra->flags, MATRIX_B);
610 sprintf(tmp, "coordA.%c = workItemM;\n"
611 "coordB.%c = workItemN;\n\n",
612 vect[tra], vect[trb]);
613 kgenAddStmt(ctx, tmp);
614
615 // write back the tile evaluated
616 generateResultUpdateOld(ctx, CLBLAS_GEMM, &gset, NULL, NULL);
617
618 kgenEndFuncBody(ctx);
619 ret = kgenAddBlankLine(ctx);
620
621 if (!ret) {
622 ret = (ssize_t)kgenSourceSize(ctx) + 1;
623 }
624
625 destroyKgenContext(ctx);
626
627 return (ret < 0) ? -EOVERFLOW : ret;
628 }
629
630 static void
assignKargs(KernelArg * args,const void * params,const void * extra)631 assignKargs(KernelArg *args, const void *params, const void *extra)
632 {
633 const CLBlasKargs *blasArgs = (const CLBlasKargs*)params;
634
635 (void)extra;
636
637 switch (blasArgs->kernType) {
638 case CLBLAS_COMPUTING_KERNEL:
639 // arguments for computational kernel
640 initSizeKarg(&args[0], blasArgs->M);
641 initSizeKarg(&args[1], blasArgs->N);
642 initSizeKarg(&args[2], blasArgs->K);
643 assignScalarKarg(&args[3], &(blasArgs->alpha), blasArgs->dtype);
644 INIT_KARG(&args[4], blasArgs->scimage[0]);
645 INIT_KARG(&args[5], blasArgs->scimage[1]);
646 assignScalarKarg(&args[6], &(blasArgs->beta), blasArgs->dtype);
647 initMemobjKarg(&args[7], blasArgs->C, NULL, 0, 0);
648 initSizeKarg(&args[8], blasArgs->ldc.matrix);
649 initSizeKarg(&args[9], blasArgs->offCY);
650 break;
651 case CLBLAS_PREP_A_KERNEL:
652 INIT_KARG(&args[0], blasArgs->order);
653 INIT_KARG(&args[1], blasArgs->transA);
654 initSizeKarg(&args[2], blasArgs->M);
655 initSizeKarg(&args[3], blasArgs->K);
656 initMemobjKarg(&args[4], blasArgs->A, NULL, 0, 0);
657 initSizeKarg(&args[5], blasArgs->lda.matrix);
658 INIT_KARG(&args[6], blasArgs->scimage[0]);
659 initSizeKarg(&args[7], blasArgs->offA);
660 break;
661 case CLBLAS_PREP_B_KERNEL:
662 INIT_KARG(&args[0], blasArgs->order);
663 INIT_KARG(&args[1], blasArgs->transB);
664 initSizeKarg(&args[2], blasArgs->N);
665 initSizeKarg(&args[3], blasArgs->K);
666 initMemobjKarg(&args[4], blasArgs->B, NULL, 0, 0);
667 initSizeKarg(&args[5], blasArgs->ldb.matrix);
668 INIT_KARG(&args[6], blasArgs->scimage[1]);
669 initSizeKarg(&args[7], blasArgs->offBX);
670 break;
671 default:
672 //this should not happen
673 break;
674 }
675 }
676
677 static bool
isFitToLDS(SubproblemDim * dim,DataType dtype,cl_ulong ldsSize,const void * kernelArgs)678 isFitToLDS(
679 SubproblemDim *dim,
680 DataType dtype,
681 cl_ulong ldsSize,
682 const void *kernelArgs)
683 {
684 cl_ulong size;
685 const CLBlasKargs *kargs = (const CLBlasKargs*)kernelArgs;
686 size = matrBlockSize(&dim[1], MATRIX_C, dtype, kargs->side);
687 return (size * dtypeSize(dtype) <= ldsSize);
688 }
689
690 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * extra)691 calcNrThreads(
692 size_t threads[2],
693 const SubproblemDim *subdims,
694 const PGranularity *pgran,
695 const void *args,
696 const void *extra)
697 {
698 const CLBlasKargs *kargs = args;
699 (void)extra;
700
701 if (kargs->kernType != CLBLAS_COMPUTING_KERNEL) {
702 const size_t *whole, *part;
703 size_t nrGroups;
704
705 // each thread gets one block
706
707 if (kargs->kernType == CLBLAS_PREP_A_KERNEL) {
708 whole = &kargs->M;
709 part = &subdims[0].itemY;
710 }
711 else {
712 whole = &kargs->N;
713 part = &subdims[0].itemX;
714 }
715
716 nrGroups = *whole / *part + (*whole % *part != 0);
717 nrGroups *= (kargs->K / subdims[0].bwidth +
718 (kargs->K % subdims[0].bwidth != 0));
719 threads[0] = pgran->wgSize[0] * nrGroups;
720 threads[1] = pgran->wgSize[1];
721 }
722 else {
723 calcGlobalThreads(threads, &subdims[0], pgran, kargs->M, kargs->N);
724 }
725 }
726
727 static SolverFlags
solverFlags(void)728 solverFlags(void)
729 {
730 return (SF_WSPACE_2D);
731 }
732
733 void
initGemmImgPattern(MemoryPattern * mempat)734 initGemmImgPattern(MemoryPattern *mempat)
735 {
736 mempat->name = "Image based block gemm";
737 mempat->nrLevels = 2;
738 mempat->cuLevel = 0;
739 mempat->thLevel = 1;
740 mempat->sops = &imgSops;
741
742 mpatExtra.aMset = CLMEM_LEVEL_L1 | CLMEM_LEVEL_LDS;
743 mpatExtra.bMset = CLMEM_LEVEL_L1 | CLMEM_LEVEL_LDS;
744 mpatExtra.mobjA = CLMEM_IMAGE;
745 mpatExtra.mobjB = CLMEM_IMAGE;
746 mempat->extra = &mpatExtra;
747 }
748
749 static int
imgGetPerf(unsigned int kflags,const void * args)750 imgGetPerf(
751 unsigned int kflags,
752 const void *args)
753 {
754 (void)args;
755 (void)kflags;
756
757 return PPERF_POOR;
758 }
759