1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 /*
19 * Cached global buffers based gemm generator
20 */
21
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <kprintf.hpp>
31 #include <gemm.clT>
32 #include <gemm_helper.clT>
33 #include <symm_helper.clT>
34 #include <solution_seq.h>
35 #include "tuned_numbers.h"
36
37 //#define DEBUG_GEMM_2
38 static CLBLASMpatExtra mpatExtra;
39
40 static char Prefix[4];
41
42 /* Function, finding default decomposition */
43 static int
44 getDefaultDecomposition(
45 PGranularity *pgran,
46 SubproblemDim *subdims,
47 unsigned int subdimsNum,
48 void *pArgs);
49
50 static ssize_t
51 generator(
52 char *buf,
53 size_t buflen,
54 const struct SubproblemDim *subdims,
55 const struct PGranularity *pgran,
56 void *extra);
57
58 static void
59 assignKargs(KernelArg *args, const void *params, const void *extra);
60
61 static SolverFlags
62 solverFlags(void);
63
64 static void
65 setBuildOpts(
66 char * buildOptStr,
67 const void *args);
68
69 static void
70 calcNrThreads(
71 size_t threads[2],
72 const SubproblemDim *subdims,
73 const PGranularity *pgran,
74 const void *args,
75 const void *extra);
76
77 static SolverOps gemmSops = {
78 generator,
79 assignKargs,
80 NULL,
81 NULL,
82 NULL,
83 calcNrThreads,
84 NULL,
85 solverFlags,
86 NULL,
87 getDefaultDecomposition,
88 NULL,
89 setBuildOpts,
90 NULL
91 };
92
93 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * extra)94 calcNrThreads(
95 size_t threads[2],
96 const SubproblemDim *subdims,
97 const PGranularity *pgran,
98 const void *args,
99 const void *extra)
100 {
101 const CLBlasKargs *kargs = (const CLBlasKargs *)args;
102 //const CLBLASKernExtra *kextra = ( CLBLASKernExtra *)extra;
103 //KernelExtraFlags kflags = kextra->flags;
104 size_t M, N;
105
106 M = kargs->M;
107 N = kargs->N;
108
109 threads[1] = 1;
110
111 if ((subdims->x != SUBDIM_UNUSED) &&
112 (subdims->y != SUBDIM_UNUSED)) {
113
114 size_t groupWorkX, groupWorkY;
115 size_t nrGroupsX, nrGroupsY;
116 int nrDims;
117
118 groupWorkX = subdims->x;
119 groupWorkY = subdims->y;
120
121 nrGroupsX = N / groupWorkX;
122 if (N % groupWorkX) {
123 nrGroupsX++;
124 }
125
126 nrGroupsY = M / groupWorkY;
127 if (M % groupWorkY) {
128 nrGroupsY++;
129 }
130 nrDims = (pgran == NULL) ? 1 : pgran->wgDim;
131 threads[0] = nrGroupsX * nrGroupsY;
132
133 if(kargs->pigFuncID == CLBLAS_HERK)
134 {
135 threads[0] = (nrGroupsY * (nrGroupsY + 1)) / 2;
136 }
137
138 }
139
140 if (pgran != NULL) {
141 threads[0] *= pgran->wgSize[0];
142 threads[1] *= pgran->wgSize[1];
143 }
144 }
145
146 static void
setBuildOpts(char * buildOptStr,const void * args)147 setBuildOpts(
148 char * buildOptStr,
149 const void *args)
150 {
151 SolutionStep *step = (SolutionStep *)args;
152 const CLBlasKargs *kargs = (const CLBlasKargs *)(&step->args);
153 const SubproblemDim *dims = step->subdims;
154 //size_t vecLen = sizeof(cl_float4)/dtypeSize(kargs->dtype);
155 KernelExtraFlags kflags = step->extraFlags;
156
157 blockSizes bestSize = bestBlockSizeForDevice( step );
158
159 if ( kargs->dtype == TYPE_DOUBLE || kargs->dtype == TYPE_COMPLEX_DOUBLE)
160 {
161 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DDOUBLE_PRECISION");
162 }
163
164 if (isComplexType(kargs->dtype))
165 {
166 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCOMPLEX");
167 }
168
169 if ((bestSize.useBarrier) == 1)
170 {
171 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DGEMM_NEEDS_BARRIER");
172 }
173
174 if (kargs->M % dims->y)
175 {
176 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DM_TAIL_PRESENT");
177 }
178
179 if (kargs->N % dims->x)
180 {
181 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DN_TAIL_PRESENT");
182 }
183
184 if (kflags & KEXTRA_CONJUGATE_A)
185 {
186 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_A");
187 }
188 if (kflags & KEXTRA_CONJUGATE_B)
189 {
190 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_B");
191 }
192
193 switch(kargs->pigFuncID)
194 {
195 case CLBLAS_HEMM:
196 case CLBLAS_SYMM:
197 case CLBLAS_SYMM_DIAGONAL:
198 case CLBLAS_HEMM_DIAGONAL:
199 #ifdef DEBUG_GEMM_2
200 printf("GEMM2: setBuildOpts: Setting options for SYMM\n");
201 #endif
202 if (kargs->side == clblasLeft)
203 {
204 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LEFT__");
205 }
206 if (kargs->side == clblasRight)
207 {
208 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_RIGHT__");
209 }
210 if (kargs->uplo == clblasLower)
211 {
212 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LOWER__");
213 }
214 if (kargs->uplo == clblasUpper)
215 {
216 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_UPPER__");
217 }
218 // Define the order for Legacy sake.
219 if (kargs->order == clblasColumnMajor)
220 {
221 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_COLMAJOR__");
222 } else {
223 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_ROWMAJOR__");
224 }
225 if ((kargs->pigFuncID == CLBLAS_SYMM_DIAGONAL) || (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL))
226 {
227 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_DIAGONAL__");
228 }
229 if (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL)
230 {
231 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__HEMM__");
232 }
233 break;
234
235 case CLBLAS_HERK:
236 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK");
237 if(kargs->uplo == clblasLower)
238 {
239 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_LOWER_TRIANGLE");
240 }
241 else if(kargs->uplo == clblasUpper)
242 {
243 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_UPPER_TRIANGLE");
244 }
245 break;
246
247 default:
248 break;
249 }
250
251 #ifdef DEBUG_GEMM_2
252 printf("buildStr: %s\n", buildOptStr);
253 #endif
254 return;
255 }
256
257 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)258 generator(
259 char *buf,
260 size_t buflen,
261 const struct SubproblemDim *subdims,
262 const struct PGranularity *pgran,
263 void *extra)
264 {
265 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
266 KernelExtraFlags kflags = kextra->flags;
267 DataType dtype = kextra->dtype;
268 char tempTemplate[64*1024]; //PENDING: Is it safe to have 64K in stack for threadSafety?
269 char itemx[10], itemy[10], width[10], itemy_by_width[10], itemx_by_width[10];
270 char bwidth[10], panel_by_v[10];
271 size_t Y, X, BLOCKSIZE, ITEMX, ITEMY;
272 bool doVLOAD = false;
273 unsigned int veclen;
274
275 if (buf == NULL)
276 {
277 buflen = 64*1024*sizeof(char);
278 return (ssize_t)buflen;
279 }
280
281 //
282 // PENDING: Add Support for Row Major
283 //
284 if ((kflags & KEXTRA_COLUMN_MAJOR) == 0)
285 {
286 return 0;
287 }
288
289 if ((kflags & KEXTRA_NO_COPY_VEC_A) || (kflags & KEXTRA_NO_COPY_VEC_B) || (kflags & KEXTRA_NO_COPY_VEC_C))
290 {
291 #ifdef DEBUG_GEMM_2
292 printf("GEMM2: Doing un-aligned access\n");
293 #endif
294 doVLOAD= true;
295 } else {
296 #ifdef DEBUG_GEMM_2
297 printf("GEMM2: Doing Aligned access\n");
298 #endif
299 }
300
301
302 BLOCKSIZE = pgran->wgSize[0];
303 #ifdef DEBUG_GEMM_2
304 printf("GEMM2- generator(): Blocksize passed = %lu, subdimy = %lu, subdimx = %lu, veclen = %d \n",
305 BLOCKSIZE, subdims->y, subdims->x, kextra->vecLen);
306 #endif
307
308 veclen = kextra->vecLen;
309
310 ITEMY = subdims->itemY;
311 ITEMX = subdims->itemX;
312 Y = subdims->y / ITEMY;
313 X = subdims->x / ITEMX;
314
315 //
316 // Handle in-compatible subdims and workgroup sizes
317 // We will use "veclen" of 1 as our shield against these in-compatible
318 // geometries.
319 //
320 if ( (ITEMY % kextra->vecLen) || ((ITEMX % kextra->vecLen) && (kflags & KEXTRA_TRANS_B)) )
321 {
322 //
323 // FIXME:
324 // This kernel must be stored against vecLen of 1 in Kernel Cache.
325 // This needs change in EXTRA structure. However, this is against the API.
326 // We are going against the API by changing fields in EXTRA structure.
327 // One alternate FIX is to return an error.
328 //
329 kextra->vecLen = kextra->vecLenA = kextra->vecLenB = kextra->vecLenC = 1;
330
331 doVLOAD = true;
332 veclen = 1;
333 }
334
335 //
336 // PENDING: Selective Vectorization for A, B and C access has to be added
337 // in KPRINTF module (VLOADA, VLOADB, VLOADC, VSTOREC)
338 //
339 kprintf kobj(Prefix[dtype], veclen, doVLOAD, doVLOAD); // Only Vectored Access
340 sprintf(width, "%lu", Y);
341 sprintf(itemy, "%lu", ITEMY);
342 sprintf(itemx, "%lu", ITEMX);
343 sprintf(itemy_by_width, "%lu", (size_t) ITEMY/veclen);
344 sprintf(itemx_by_width, "%lu", (size_t) ITEMX/veclen);
345 //sprintf(bwidth, "%lu", subdims->bwidth);
346 //sprintf(panel_by_v, "%lu", (subdims->bwidth / veclen));
347 sprintf(bwidth, "%lu", (size_t) veclen);
348 sprintf(panel_by_v, "%lu", (size_t) 1);
349
350 kobj.put("%WIDTH", width);
351 kobj.put("%ITEMX", itemx);
352 kobj.put("%ITEMY", itemy);
353 kobj.put("%ITEMY_BY_V", itemy_by_width);
354 kobj.put("%ITEMX_BY_V", itemx_by_width);
355 kobj.put("%PANEL", bwidth);
356 kobj.put("%PANEL_BY_V", panel_by_v);
357 #ifdef DEBUG_GEMM_2
358 printf("ColMajor GEMM - WIDTH = %s, PANEL = %lu, ITEMX = %s, ITEMY = %s, Veclen = %lu\n", width, subdims->bwidth, itemx, itemy, veclen);
359 #endif
360
361 strcpy(tempTemplate, SYMM_HEMM_HELPER);
362 if ((kflags & KEXTRA_TRANS_A) == 0)
363 {
364 if (kflags & KEXTRA_TRANS_B)
365 {
366 #ifdef DEBUG_GEMM_2
367 printf("Using GEMM_NT_KERNEL\n");
368 #endif
369 strcat(tempTemplate, GEMM_HELPER);
370 strcat(tempTemplate, GEMM_NT_KERNEL);
371 } else {
372 #ifdef DEBUG_GEMM_2
373 printf("Using GEMM_NN_KERNEL\n");
374 #endif
375 strcat(tempTemplate, GEMM_HELPER);
376 strcat(tempTemplate, GEMM_NN_KERNEL);
377 }
378 } else {
379 // PENDING:
380 if (kflags & KEXTRA_TRANS_B)
381 {
382 tempTemplate[0] = 0;
383 } else {
384 #ifdef DEBUG_GEMM_2
385 printf("Using GEMM_TN_KERNEL\n");
386 #endif
387 strcat(tempTemplate, GEMM_HELPER);
388 strcat(tempTemplate, GEMM_TN_KERNEL);
389 }
390 }
391 kobj.spit(buf, tempTemplate);
392 #ifdef DEBUG_GEMM_KPRINTF
393 printf("Kernel = \n%s\n", buf);
394 #endif
395 size_t tail = strlen(buf) + 1;
396 while(tail < 64*1024)
397 {
398 buf[tail++] = 0;
399 }
400 return 64*1024*sizeof(char);
401 }
402
403 static void
assignKargs(KernelArg * args,const void * params,const void *)404 assignKargs(KernelArg *args, const void *params, const void*)
405 {
406 CLBlasKargs *blasArgs = (CLBlasKargs*)params;
407
408 #ifdef DEBUG_GEMM_2
409 printf("SAlpha=%f, DAlpha=%f, CAlpha =<%f, %f>, DAlpha=<%f, %f>\n",
410 blasArgs->alpha.argFloat, blasArgs->alpha.argDouble, CREAL(blasArgs->alpha.argFloatComplex), CIMAG(blasArgs->alpha.argFloatComplex),
411 CREAL(blasArgs->alpha.argDoubleComplex) , CIMAG(blasArgs->alpha.argDoubleComplex));
412 printf("SBeta=%f, DBeta=%f, CBeta=<%f, %f>, DBeta=<%f, %f>\n",
413 blasArgs->beta.argFloat, blasArgs->beta.argDouble, CREAL(blasArgs->beta.argFloatComplex), CIMAG(blasArgs->beta.argFloatComplex),
414 CREAL(blasArgs->beta.argDoubleComplex) , CIMAG(blasArgs->beta.argDoubleComplex));
415 #endif
416
417 INIT_KARG(&args[0], blasArgs->A); //A - input matrix - argument
418 INIT_KARG(&args[1], blasArgs->B); //x - result buffer = _xnew argument
419 INIT_KARG(&args[2], blasArgs->C); //y - scratch == _x_vector argument
420 initSizeKarg(&args[3], blasArgs->M);
421 initSizeKarg(&args[4], blasArgs->N);
422 initSizeKarg(&args[5], blasArgs->K);
423 initSizeKarg(&args[6], blasArgs->lda.matrix);
424 initSizeKarg(&args[7], blasArgs->ldb.matrix);
425 initSizeKarg(&args[8], blasArgs->ldc.matrix);
426 initSizeKarg(&args[9], blasArgs->offA);
427 initSizeKarg(&args[10], blasArgs->offBX);
428 initSizeKarg(&args[11], blasArgs->offCY);
429 assignScalarKarg(&args[12], &(blasArgs->alpha), blasArgs->dtype);
430 assignScalarKarg(&args[13], &(blasArgs->beta), blasArgs->dtype);
431 return;
432 }
433
434 static SolverFlags
solverFlags(void)435 solverFlags(void)
436 {
437 return (SF_WSPACE_1D);
438 }
439
440 extern "C"
441 void
initGemmV2CachedPattern(MemoryPattern * mempat)442 initGemmV2CachedPattern(MemoryPattern *mempat)
443 {
444 mempat->name = "Cached global memory based block gemm";
445 mempat->nrLevels = 2;
446 mempat->cuLevel = 0;
447 mempat->thLevel = 1;
448 mempat->sops = &gemmSops;
449
450 mpatExtra.aMset = CLMEM_LEVEL_L1;
451 mpatExtra.bMset = CLMEM_LEVEL_L1;
452 mpatExtra.mobjA = CLMEM_BUFFER;
453 mpatExtra.mobjB = CLMEM_BUFFER;
454 mempat->extra = &mpatExtra;
455
456
457 Prefix[TYPE_FLOAT] = 'S';
458 Prefix[TYPE_DOUBLE] = 'D';
459 Prefix[TYPE_COMPLEX_FLOAT] = 'C';
460 Prefix[TYPE_COMPLEX_DOUBLE] = 'Z';
461 }
462
463 static int
getDefaultDecomposition(PGranularity * pgran,SubproblemDim * subdims,unsigned int subdimsNum,void * pArgs)464 getDefaultDecomposition(
465 PGranularity *pgran,
466 SubproblemDim *subdims,
467 unsigned int subdimsNum,
468 void *pArgs)
469 {
470
471 DUMMY_ARG_USAGE(pArgs);
472 //
473 // FIXME: container_of() - Counts on the fact that "getDefaultDecomposition" is called
474 // with step->pgran, step->subdims
475 //
476 SolutionStep *step = container_of( pgran , pgran, SolutionStep);
477
478 blockSizes bestSize = bestBlockSizeForDevice( step );
479
480 pgran->wgSize[0] = bestSize.TY * bestSize.TX;
481 pgran->wgSize[1] = 1;
482 pgran->wgDim = 1;
483
484 if (subdimsNum >= 1)
485 {
486 subdims[0].y = bestSize.TY * bestSize.ITEMY;
487 subdims[0].x = bestSize.TX * bestSize.ITEMX;
488 subdims[0].itemY = bestSize.ITEMY;
489 subdims[0].itemX = bestSize.ITEMX;
490 subdims[0].bwidth = 4;
491 }
492 if (subdimsNum >= 2)
493 {
494 subdims[1].y = bestSize.TY * bestSize.ITEMY;
495 subdims[1].x = bestSize.TX * bestSize.ITEMX;
496 subdims[1].itemY = bestSize.ITEMY;
497 subdims[1].itemX = bestSize.ITEMX;
498 subdims[1].bwidth = 4;
499 }
500
501 return 0;
502 }
503
504