1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 /*
19 * Cached global buffers based gemm generator
20 */
21
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <kprintf.hpp>
31 #include <gemm.clT>
32 #include <symm_helper.clT>
33 #include <solution_seq.h>
34
35 extern "C" int
36 gemmHasNTail(size_t N, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB);
37
38 extern "C" int
39 gemmHasMTail(size_t M, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB);
40
41
42 //#define DEBUG_GEMM_TAIL
43 static CLBLASMpatExtra mpatExtra;
44
45 static char Prefix[4];
46
47 static ssize_t
48 generator(
49 char *buf,
50 size_t buflen,
51 const struct SubproblemDim *subdims,
52 const struct PGranularity *pgran,
53 void *extra);
54
55 static void
56 assignKargs(KernelArg *args, const void *params, const void *extra);
57
58 static SolverFlags
59 solverFlags(void);
60
61 static void
62 setBuildOpts(
63 char * buildOptStr,
64 const void *args);
65
66 static void
67 calcNrThreads(
68 size_t threads[2],
69 const SubproblemDim *subdims,
70 const PGranularity *pgran,
71 const void *args,
72 const void *extra);
73
74 static SolverOps gemmSops = {
75 generator,
76 assignKargs,
77 NULL,
78 NULL,
79 NULL,
80 calcNrThreads,
81 NULL,
82 solverFlags,
83 NULL,
84 NULL,
85 NULL,
86 setBuildOpts,
87 NULL
88 };
89
90 static void
setBuildOpts(char * buildOptStr,const void * args)91 setBuildOpts(
92 char * buildOptStr,
93 const void *args)
94 {
95 const SolutionStep *step = (const SolutionStep *)args;
96 const CLBlasKargs *kargs = (const CLBlasKargs *)(&step->args);
97 KernelExtraFlags kflags = step->extraFlags;
98
99 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DTAIL_RUN -DM_TAIL_PRESENT -DN_TAIL_PRESENT");
100 if ( kargs->dtype == TYPE_DOUBLE || kargs->dtype == TYPE_COMPLEX_DOUBLE)
101 {
102 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DDOUBLE_PRECISION");
103 #ifdef DEBUG_GEMM_TAIL
104 printf("Setting build options ... Double... for DOUBLE PRECISION support\n");
105 #endif
106 }
107
108 if (isComplexType(kargs->dtype))
109 {
110 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCOMPLEX");
111 }
112
113 if (kflags & KEXTRA_CONJUGATE_A)
114 {
115 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_A");
116 }
117 if (kflags & KEXTRA_CONJUGATE_B)
118 {
119 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_B");
120 }
121
122
123 switch(kargs->pigFuncID)
124 {
125 case CLBLAS_GEMM2:
126 case CLBLAS_GEMM_TAIL:
127 break;
128
129 case CLBLAS_HERK:
130 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK");
131 if(kargs->uplo == clblasLower)
132 {
133 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_LOWER_TRIANGLE");
134 }
135 else if(kargs->uplo == clblasUpper)
136 {
137 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_UPPER_TRIANGLE");
138 }
139 break;
140
141 case CLBLAS_HEMM:
142 case CLBLAS_SYMM_DIAGONAL:
143 case CLBLAS_HEMM_DIAGONAL:
144 case CLBLAS_SYMM:
145 #ifdef DEBUG_GEMM_2
146 printf("GEMM2: setBuildOpts: Setting options for SYMM\n");
147 #endif
148 if (kargs->side == clblasLeft)
149 {
150 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LEFT__");
151 }
152 if (kargs->side == clblasRight)
153 {
154 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_RIGHT__");
155 }
156 if (kargs->uplo == clblasLower)
157 {
158 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LOWER__");
159 }
160 if (kargs->uplo == clblasUpper)
161 {
162 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_UPPER__");
163 }
164 // Define the order for Legacy sake.
165 if (kargs->order == clblasColumnMajor)
166 {
167 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_COLMAJOR__");
168 } else {
169 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_ROWMAJOR__");
170 }
171 if ((kargs->pigFuncID == CLBLAS_SYMM_DIAGONAL) || (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL))
172 {
173 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_DIAGONAL__");
174 }
175 if (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL)
176 {
177 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__HEMM__");
178 }
179 break;
180
181 default:
182 printf("GEMM TAIL: Unknown pigFuncID\n");
183 break;
184 }
185 #ifdef DEBUG_GEMM_TAIL
186 printf("GEMMTAIL: Build options = %s\n", buildOptStr);
187 #endif
188 }
189
190 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * extra)191 calcNrThreads(
192 size_t threads[2],
193 const SubproblemDim *subdims,
194 const PGranularity *pgran,
195 const void *args,
196 const void *extra)
197 {
198 int BLOCKSIZE = pgran->wgSize[0]; // 1D Block
199 size_t tailM, tailN, M, N;
200 size_t Y, X;
201 size_t nWorkGroupsAY, nWorkGroupsAX, nWorkGroupsA;
202 size_t nWorkGroupsBY, nWorkGroupsBX, nWorkGroupsB;
203 size_t totalWorkGroups;
204 #ifdef DEBUG_GEMM_TAIL
205 printf("calcNrThreads called from gemm_tail.cpp\n");
206 #endif
207 const CLBlasKargs *kargs = (const CLBlasKargs *)args;
208 const CLBLASKernExtra *kextra = ( CLBLASKernExtra *)extra;
209 KernelExtraFlags kflags = kextra->flags;
210
211 //
212 // RowMajor GEMM can be expressed in terms of Column Major GEMM
213 //
214 if ((kflags & KEXTRA_COLUMN_MAJOR) == 0)
215 {
216 printf("calcNrThreads: FIXME: RowMajor is NOT supported \n");
217 return;
218 }
219
220 if (kextra->vecLenA != 1)
221 {
222 printf("GEMM_TAIL: calcNrThreads(): Vector Length must be 1 for TAIL. Non-one Vector Length Requested\n");
223 return;
224 }
225
226 tailM = kargs->tailStartM;
227 tailN = kargs->tailStartN;
228 M = kargs->M;
229 N = kargs->N;
230
231 Y = 8;
232 if (Y != subdims->y)
233 {
234 Y = subdims->y;
235 }
236 X = BLOCKSIZE/Y;
237 /*
238 LEGACY CODE: Outdated now. TAIL can handle this condition now using MTAIL_PRESENT and NTAIL_PRESENT
239 if (tailN % X)
240 {
241 printf("GEMM_TAIL: calcNrThreads(): WARNING: tailN is not divisible by X. Will produce Wrong results!\n");
242 }
243 */
244
245 //
246 // A Tail Workgroup will process YxX panel
247 //
248 /*
249 ______________
250 | | |
251 | | |
252 | | | B Tail panel
253 |___________| |
254 |___________|__|
255 <--- A -->
256 */
257 if(tailM != M)
258 {
259 #ifdef DEBUG_GEMM_TAIL
260 printf("GEMM_TAIL: M has TAIL\n");
261 #endif
262 nWorkGroupsAY = ((M - tailM -1)/Y + 1);
263 nWorkGroupsAX = ((tailN - 1)/X + 1);
264 nWorkGroupsA = nWorkGroupsAY * nWorkGroupsAX;
265 } else {
266 nWorkGroupsA = 0;
267 }
268
269 if (tailN != N)
270 {
271 #ifdef DEBUG_GEMM_TAIL
272 printf("GEMM_TAIL: N has TAIL\n");
273 #endif
274 nWorkGroupsBY = ((M-1)/Y) + 1;
275 nWorkGroupsBX = ((N-tailN-1)/X) + 1;
276 nWorkGroupsB = nWorkGroupsBY * nWorkGroupsBX;
277 } else {
278 nWorkGroupsB = 0;
279 }
280
281 totalWorkGroups = nWorkGroupsA + nWorkGroupsB;
282
283 threads[0] = totalWorkGroups * BLOCKSIZE;
284 threads[1] = 1;
285 #ifdef DEBUG_GEMM_TAIL
286 printf("GEMM_TAIL: calcNrThreads(): vlen:%d, <tailM:%lu, M:%lu>, <tailN:%lu, N:%lu, nWorkGroupsA<%lu,%lu>, nWorkGroupsB<%lu,%lu>\n",
287 kextra->vecLenA, tailM, M, tailN, N, nWorkGroupsAY, nWorkGroupsAX, nWorkGroupsBY, nWorkGroupsBX);
288 printf("GEMM_TAIL: calcNrThreads(): globalThreads0=%lu, globalThreads1=%lu\n", threads[0], threads[1]);
289 #endif
290 return;
291 }
292
293 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)294 generator(
295 char *buf,
296 size_t buflen,
297 const struct SubproblemDim *subdims,
298 const struct PGranularity *pgran,
299 void *extra)
300 {
301 CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
302 KernelExtraFlags kflags = kextra->flags;
303 DataType dtype = kextra->dtype;
304 char tempTemplate[32*1024];
305 char itemx[10], itemy[10], width[10], itemy_by_width[10], itemx_by_width[10];
306 size_t Y, X, BLOCKSIZE, ITEMX, ITEMY;
307
308 if (buf == NULL)
309 {
310 buflen = 32*1024*sizeof(char);
311 return (ssize_t)buflen;
312 }
313
314 //
315 // PENDING: Add Support for Row Major at the xAPI.c level
316 // Row major calcs can be expressed in terms of column major
317 //
318 if ((kflags & KEXTRA_COLUMN_MAJOR) == 0)
319 {
320 return 0;
321 }
322
323 kprintf kobj(Prefix[dtype], 1, false, false); // Only Scalar Access
324
325 BLOCKSIZE = pgran->wgSize[0];
326 #ifdef DEBUG_GEMM_TAIL
327 printf("GEMM- generator(): Blocksize passed = %lu, subdimy = %lu, subdimx = %lu, veclen = %d \n", BLOCKSIZE, subdims->y, subdims->x, kextra->vecLenA);
328 #endif
329
330 Y = 8;
331 if (Y != subdims->y)
332 {
333 //printf("GEMM_TAIL: generator(): WARNING: subdims->y is un-suitable.\n");
334 Y = subdims->y;
335 }
336 X = BLOCKSIZE/Y;
337 ITEMY = (subdims->y) / Y;
338 ITEMX = (subdims->x) / X;
339 if (ITEMX == 0)
340 {
341 ITEMX = 1;
342 }
343
344 if ((BLOCKSIZE % Y) || ((subdims->y) % Y) || ((subdims->x)%X) || (ITEMY % kextra->vecLenA) || ((X*ITEMX) % kextra->vecLenA))
345 {
346 printf("WARNING: GEMM TAIL - generator: subdim and blocksize in-compatible. This code should never execute!\n");
347 }
348
349 sprintf(width, "%lu", Y);
350 sprintf(itemy, "%lu", ITEMY);
351 sprintf(itemx, "%lu", ITEMX);
352 sprintf(itemy_by_width, "%lu", (size_t) ITEMY/kextra->vecLenA);
353 sprintf(itemx_by_width, "%lu", (size_t) ITEMX/kextra->vecLenA);
354
355 kobj.put("%WIDTH", width);
356 kobj.put("%ITEMX", itemx);
357 kobj.put("%ITEMY", itemy);
358 kobj.put("%ITEMY_BY_V", itemy_by_width);
359 kobj.put("%ITEMX_BY_V", itemx_by_width);
360 kobj.put("%PANEL", "1");
361 kobj.put("%PANEL_BY_V", "1");
362 #ifdef DEBUG_GEMM_TAIL
363 printf("ColMajor GEMM - WIDTH = %s, ITEMX = %s, ITEMY = %s\n", width, itemx, itemy);
364 #endif
365
366 strcpy(tempTemplate, SYMM_HEMM_HELPER);
367 if ((kflags & KEXTRA_TRANS_A) == 0)
368 {
369 if (kflags & KEXTRA_TRANS_B)
370 {
371 #ifdef DEBUG_GEMM_TAIL
372 printf("GEMM_TAIL: Using GEMM_NT_KERNEL\n");
373 #endif
374 strcat(tempTemplate, GEMM_NT_KERNEL);
375 } else {
376 #ifdef DEBUG_GEMM_TAIL
377 printf("GEMM_TAIL: Using GEMM_NN_KERNEL\n");
378 #endif
379 strcat(tempTemplate, GEMM_NN_KERNEL);
380 }
381 } else {
382 //
383 // NOTE: A^T * B Never leaves any tails. This should NEVER be called.
384 // PENDING: A^T * B^T support is PENDING
385 tempTemplate[0] = 0;
386 }
387
388 kobj.spit(buf, tempTemplate);
389 //#ifdef DEBUG_GEMM_TAIL
390 //printf("Kernel = \n%s\n", buf);
391 //#endif
392 size_t tail = strlen(buf) + 1;
393 while(tail < 32*1024)
394 {
395 buf[tail++] = 0;
396 }
397 return 32*1024*sizeof(char);
398 }
399
400 static void
assignKargs(KernelArg * args,const void * params,const void *)401 assignKargs(KernelArg *args, const void *params, const void*)
402 {
403 CLBlasKargs *blasArgs = (CLBlasKargs*)params;
404
405 #ifdef DEBUG_GEMM_TAIL
406 printf("SAlpha=%f, DAlpha=%f, CAlpha =<%f, %f>, DAlpha=<%f, %f>\n",
407 blasArgs->alpha.argFloat, blasArgs->alpha.argDouble, CREAL(blasArgs->alpha.argFloatComplex), CIMAG(blasArgs->alpha.argFloatComplex),
408 CREAL(blasArgs->alpha.argDoubleComplex) , CIMAG(blasArgs->alpha.argDoubleComplex));
409 printf("SBeta=%f, DBeta=%f, CBeta=<%f, %f>, DBeta=<%f, %f>\n",
410 blasArgs->beta.argFloat, blasArgs->beta.argDouble, CREAL(blasArgs->beta.argFloatComplex), CIMAG(blasArgs->beta.argFloatComplex),
411 CREAL(blasArgs->beta.argDoubleComplex) , CIMAG(blasArgs->beta.argDoubleComplex));
412 printf("TailStartM = %lu, TailStartN = %lu\n", blasArgs->tailStartM, blasArgs->tailStartN);
413 #endif
414
415 INIT_KARG(&args[0], blasArgs->A); //A - input matrix - argument
416 INIT_KARG(&args[1], blasArgs->B); //x - result buffer = _xnew argument
417 INIT_KARG(&args[2], blasArgs->C); //y - scratch == _x_vector argument
418 initSizeKarg(&args[3], blasArgs->M);
419 initSizeKarg(&args[4], blasArgs->N);
420 initSizeKarg(&args[5], blasArgs->K);
421 initSizeKarg(&args[6], blasArgs->lda.matrix);
422 initSizeKarg(&args[7], blasArgs->ldb.matrix);
423 initSizeKarg(&args[8], blasArgs->ldc.matrix);
424 initSizeKarg(&args[9], blasArgs->offA);
425 initSizeKarg(&args[10], blasArgs->offBX);
426 initSizeKarg(&args[11], blasArgs->offCY);
427 assignScalarKarg(&args[12], &(blasArgs->alpha), blasArgs->dtype);
428 assignScalarKarg(&args[13], &(blasArgs->beta), blasArgs->dtype);
429 initSizeKarg(&args[14], blasArgs->tailStartM);
430 initSizeKarg(&args[15], blasArgs->tailStartN);
431 return;
432 }
433
434 static SolverFlags
solverFlags(void)435 solverFlags(void)
436 {
437 return (SF_WSPACE_1D);
438 }
439
440 extern "C"
441 void
initGemmV2TailCachedPattern(MemoryPattern * mempat)442 initGemmV2TailCachedPattern(MemoryPattern *mempat)
443 {
444 mempat->name = "Cached global memory based gemm tail";
445 mempat->nrLevels = 2;
446 mempat->cuLevel = 0;
447 mempat->thLevel = 1;
448 mempat->sops = &gemmSops;
449
450 mpatExtra.aMset = CLMEM_LEVEL_L1;
451 mpatExtra.bMset = CLMEM_LEVEL_L1;
452 mpatExtra.mobjA = CLMEM_BUFFER;
453 mpatExtra.mobjB = CLMEM_BUFFER;
454 mempat->extra = &mpatExtra;
455
456
457 Prefix[TYPE_FLOAT] = 'S';
458 Prefix[TYPE_DOUBLE] = 'D';
459 Prefix[TYPE_COMPLEX_FLOAT] = 'C';
460 Prefix[TYPE_COMPLEX_DOUBLE] = 'Z';
461 }
462
463