1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * Cached global buffers based gemm generator
20  */
21 
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <kprintf.hpp>
31 #include <gemm.clT>
32 #include <gemm_helper.clT>
33 #include <symm_helper.clT>
34 #include <solution_seq.h>
35 #include "tuned_numbers.h"
36 
37 //#define DEBUG_GEMM_2
38 static CLBLASMpatExtra mpatExtra;
39 
40 static char Prefix[4];
41 
42 /* Function, finding default decomposition */
43 static int
44 getDefaultDecomposition(
45     PGranularity *pgran,
46     SubproblemDim *subdims,
47     unsigned int subdimsNum,
48     void *pArgs);
49 
50 static ssize_t
51 generator(
52    char *buf,
53    size_t buflen,
54    const struct SubproblemDim *subdims,
55    const struct PGranularity *pgran,
56    void *extra);
57 
58 static void
59 assignKargs(KernelArg *args, const void *params, const void *extra);
60 
61 static SolverFlags
62 solverFlags(void);
63 
64 static void
65 setBuildOpts(
66     char * buildOptStr,
67     const void *args);
68 
69 static void
70 calcNrThreads(
71     size_t threads[2],
72     const SubproblemDim *subdims,
73     const PGranularity *pgran,
74     const void *args,
75     const void *extra);
76 
77 static SolverOps gemmSops = {
78     generator,
79     assignKargs,
80     NULL,
81     NULL,
82    	NULL,
83     calcNrThreads,
84     NULL,
85     solverFlags,
86     NULL,
87 	getDefaultDecomposition,
88 	NULL,
89 	setBuildOpts,
90 	NULL
91 };
92 
93 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * extra)94 calcNrThreads(
95     size_t threads[2],
96     const SubproblemDim *subdims,
97     const PGranularity *pgran,
98     const void *args,
99     const void *extra)
100 {
101     const CLBlasKargs *kargs = (const CLBlasKargs *)args;
102     //const CLBLASKernExtra *kextra = ( CLBLASKernExtra *)extra;
103     //KernelExtraFlags kflags = kextra->flags;
104     size_t M, N;
105 
106     M = kargs->M;
107     N = kargs->N;
108 
109     threads[1] = 1;
110 
111     if ((subdims->x != SUBDIM_UNUSED) &&
112         (subdims->y != SUBDIM_UNUSED)) {
113 
114         size_t groupWorkX, groupWorkY;
115         size_t nrGroupsX, nrGroupsY;
116         int nrDims;
117 
118         groupWorkX = subdims->x;
119         groupWorkY = subdims->y;
120 
121         nrGroupsX = N / groupWorkX;
122         if (N % groupWorkX) {
123             nrGroupsX++;
124         }
125 
126         nrGroupsY = M / groupWorkY;
127         if (M % groupWorkY) {
128             nrGroupsY++;
129         }
130         nrDims = (pgran == NULL) ? 1 : pgran->wgDim;
131         threads[0] = nrGroupsX * nrGroupsY;
132 
133         if(kargs->pigFuncID == CLBLAS_HERK)
134         {
135             threads[0] = (nrGroupsY * (nrGroupsY + 1)) / 2;
136         }
137 
138     }
139 
140     if (pgran != NULL) {
141         threads[0] *= pgran->wgSize[0];
142         threads[1] *= pgran->wgSize[1];
143     }
144 }
145 
146 static void
setBuildOpts(char * buildOptStr,const void * args)147 setBuildOpts(
148     char * buildOptStr,
149     const void *args)
150 {
151 	SolutionStep *step = (SolutionStep *)args;
152     const CLBlasKargs *kargs = (const CLBlasKargs *)(&step->args);
153 	const SubproblemDim *dims = step->subdims;
154 	//size_t vecLen = sizeof(cl_float4)/dtypeSize(kargs->dtype);
155     KernelExtraFlags kflags = step->extraFlags;
156 
157     blockSizes bestSize = bestBlockSizeForDevice( step );
158 
159     if ( kargs->dtype == TYPE_DOUBLE || kargs->dtype == TYPE_COMPLEX_DOUBLE)
160     {
161         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DDOUBLE_PRECISION");
162     }
163 
164     if (isComplexType(kargs->dtype))
165     {
166         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCOMPLEX");
167     }
168 
169     if ((bestSize.useBarrier) == 1)
170     {
171 	    addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DGEMM_NEEDS_BARRIER");
172     }
173 
174     if (kargs->M % dims->y)
175 	{
176 		addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DM_TAIL_PRESENT");
177     }
178 
179 	if (kargs->N % dims->x)
180 	{
181 		addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DN_TAIL_PRESENT");
182 	}
183 
184     if (kflags & KEXTRA_CONJUGATE_A)
185     {
186         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_A");
187     }
188     if (kflags & KEXTRA_CONJUGATE_B)
189     {
190         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_B");
191     }
192 
193     switch(kargs->pigFuncID)
194     {
195         case CLBLAS_HEMM:
196         case CLBLAS_SYMM:
197         case CLBLAS_SYMM_DIAGONAL:
198         case CLBLAS_HEMM_DIAGONAL:
199             #ifdef DEBUG_GEMM_2
200             printf("GEMM2: setBuildOpts: Setting options for SYMM\n");
201             #endif
202             if (kargs->side == clblasLeft)
203             {
204                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LEFT__");
205             }
206             if (kargs->side == clblasRight)
207             {
208                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_RIGHT__");
209             }
210             if (kargs->uplo == clblasLower)
211             {
212                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LOWER__");
213             }
214             if (kargs->uplo == clblasUpper)
215             {
216                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_UPPER__");
217             }
218             // Define the order for Legacy sake.
219             if (kargs->order == clblasColumnMajor)
220             {
221                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_COLMAJOR__");
222             } else {
223                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_ROWMAJOR__");
224             }
225             if ((kargs->pigFuncID == CLBLAS_SYMM_DIAGONAL) || (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL))
226             {
227                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_DIAGONAL__");
228             }
229             if (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL)
230             {
231                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__HEMM__");
232             }
233             break;
234 
235          case CLBLAS_HERK:
236             addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK");
237             if(kargs->uplo == clblasLower)
238             {
239                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_LOWER_TRIANGLE");
240             }
241             else if(kargs->uplo == clblasUpper)
242             {
243                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_UPPER_TRIANGLE");
244             }
245             break;
246 
247          default:
248             break;
249     }
250 
251     #ifdef DEBUG_GEMM_2
252 	printf("buildStr: %s\n", buildOptStr);
253     #endif
254 	return;
255 }
256 
257 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)258 generator(
259    char *buf,
260    size_t buflen,
261    const struct SubproblemDim *subdims,
262    const struct PGranularity *pgran,
263    void *extra)
264 {
265     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
266     KernelExtraFlags kflags = kextra->flags;
267     DataType dtype = kextra->dtype;
268     char tempTemplate[64*1024]; //PENDING: Is it safe to have 64K in stack for threadSafety?
269     char itemx[10], itemy[10], width[10], itemy_by_width[10], itemx_by_width[10];
270     char bwidth[10], panel_by_v[10];
271     size_t Y, X, BLOCKSIZE, ITEMX, ITEMY;
272 	bool doVLOAD = false;
273 	unsigned int veclen;
274 
275     if (buf == NULL)
276     {
277         buflen = 64*1024*sizeof(char);
278         return (ssize_t)buflen;
279     }
280 
281     //
282     // PENDING: Add Support for Row Major
283     //
284     if ((kflags & KEXTRA_COLUMN_MAJOR) == 0)
285     {
286         return 0;
287     }
288 
289 	if ((kflags & KEXTRA_NO_COPY_VEC_A) || (kflags & KEXTRA_NO_COPY_VEC_B) || (kflags  & KEXTRA_NO_COPY_VEC_C))
290 	{
291 		#ifdef DEBUG_GEMM_2
292 		printf("GEMM2: Doing un-aligned access\n");
293 		#endif
294 		doVLOAD= true;
295 	} else {
296 		#ifdef DEBUG_GEMM_2
297 		printf("GEMM2: Doing Aligned access\n");
298 		#endif
299 	}
300 
301 
302     BLOCKSIZE = pgran->wgSize[0];
303     #ifdef DEBUG_GEMM_2
304     printf("GEMM2- generator(): Blocksize passed = %lu, subdimy = %lu, subdimx = %lu, veclen = %d \n",
305                                 BLOCKSIZE, subdims->y, subdims->x, kextra->vecLen);
306     #endif
307 
308 	veclen = kextra->vecLen;
309 
310     ITEMY = subdims->itemY;
311     ITEMX = subdims->itemX;
312     Y = subdims->y / ITEMY;
313     X = subdims->x / ITEMX;
314 
315 	//
316 	// Handle in-compatible subdims and workgroup sizes
317 	// We will use "veclen" of 1 as our shield against these in-compatible
318     // geometries.
319 	//
320     if ( (ITEMY % kextra->vecLen) || ((ITEMX % kextra->vecLen) && (kflags & KEXTRA_TRANS_B)) )
321     {
322         //
323         // FIXME:
324         // This kernel must be stored against vecLen of 1 in Kernel Cache.
325         // This needs change in EXTRA structure. However, this is against the API.
326         // We are going against the API by changing fields in EXTRA structure.
327         // One alternate FIX is to return an error.
328         //
329         kextra->vecLen = kextra->vecLenA = kextra->vecLenB = kextra->vecLenC = 1;
330 
331        	doVLOAD = true;
332 		veclen = 1;
333     }
334 
335 	//
336 	// PENDING: Selective Vectorization for A, B and C access has to be added
337 	// 			in KPRINTF module (VLOADA, VLOADB, VLOADC, VSTOREC)
338 	//
339     kprintf kobj(Prefix[dtype], veclen, doVLOAD, doVLOAD); // Only Vectored Access
340     sprintf(width, "%lu", Y);
341     sprintf(itemy, "%lu", ITEMY);
342     sprintf(itemx, "%lu", ITEMX);
343     sprintf(itemy_by_width, "%lu", (size_t) ITEMY/veclen);
344     sprintf(itemx_by_width, "%lu", (size_t) ITEMX/veclen);
345     //sprintf(bwidth, "%lu", subdims->bwidth);
346     //sprintf(panel_by_v, "%lu", (subdims->bwidth / veclen));
347     sprintf(bwidth, "%lu", (size_t) veclen);
348     sprintf(panel_by_v, "%lu", (size_t) 1);
349 
350     kobj.put("%WIDTH", width);
351     kobj.put("%ITEMX", itemx);
352     kobj.put("%ITEMY", itemy);
353     kobj.put("%ITEMY_BY_V", itemy_by_width);
354     kobj.put("%ITEMX_BY_V", itemx_by_width);
355     kobj.put("%PANEL", bwidth);
356     kobj.put("%PANEL_BY_V", panel_by_v);
357     #ifdef DEBUG_GEMM_2
358     printf("ColMajor GEMM - WIDTH = %s, PANEL = %lu, ITEMX = %s, ITEMY = %s, Veclen = %lu\n", width, subdims->bwidth, itemx, itemy, veclen);
359     #endif
360 
361     strcpy(tempTemplate, SYMM_HEMM_HELPER);
362 	if ((kflags & KEXTRA_TRANS_A) == 0)
363 	{
364 		if (kflags & KEXTRA_TRANS_B)
365 		{
366 			#ifdef DEBUG_GEMM_2
367 			printf("Using GEMM_NT_KERNEL\n");
368 			#endif
369     		strcat(tempTemplate, GEMM_HELPER);
370             strcat(tempTemplate, GEMM_NT_KERNEL);
371 		} else {
372 			#ifdef DEBUG_GEMM_2
373 			printf("Using GEMM_NN_KERNEL\n");
374 			#endif
375     		strcat(tempTemplate, GEMM_HELPER);
376     		strcat(tempTemplate, GEMM_NN_KERNEL);
377 		}
378 	} else {
379 		// PENDING:
380 		if (kflags & KEXTRA_TRANS_B)
381 		{
382 		    tempTemplate[0] = 0;
383 		} else {
384 			#ifdef DEBUG_GEMM_2
385 			printf("Using GEMM_TN_KERNEL\n");
386 			#endif
387     		strcat(tempTemplate, GEMM_HELPER);
388     		strcat(tempTemplate, GEMM_TN_KERNEL);
389 	    }
390 	}
391     kobj.spit(buf, tempTemplate);
392     #ifdef DEBUG_GEMM_KPRINTF
393     printf("Kernel = \n%s\n", buf);
394     #endif
395     size_t tail = strlen(buf) + 1;
396     while(tail < 64*1024)
397     {
398         buf[tail++] = 0;
399     }
400     return 64*1024*sizeof(char);
401 }
402 
403 static void
assignKargs(KernelArg * args,const void * params,const void *)404 assignKargs(KernelArg *args, const void *params, const void*)
405 {
406     CLBlasKargs *blasArgs = (CLBlasKargs*)params;
407 
408     #ifdef DEBUG_GEMM_2
409     printf("SAlpha=%f, DAlpha=%f, CAlpha =<%f, %f>, DAlpha=<%f, %f>\n",
410             blasArgs->alpha.argFloat, blasArgs->alpha.argDouble, CREAL(blasArgs->alpha.argFloatComplex), CIMAG(blasArgs->alpha.argFloatComplex),
411             CREAL(blasArgs->alpha.argDoubleComplex) , CIMAG(blasArgs->alpha.argDoubleComplex));
412     printf("SBeta=%f, DBeta=%f, CBeta=<%f, %f>, DBeta=<%f, %f>\n",
413             blasArgs->beta.argFloat, blasArgs->beta.argDouble, CREAL(blasArgs->beta.argFloatComplex), CIMAG(blasArgs->beta.argFloatComplex),
414             CREAL(blasArgs->beta.argDoubleComplex) , CIMAG(blasArgs->beta.argDoubleComplex));
415     #endif
416 
417     INIT_KARG(&args[0], blasArgs->A);   //A - input matrix - argument
418     INIT_KARG(&args[1], blasArgs->B);   //x - result buffer = _xnew argument
419     INIT_KARG(&args[2], blasArgs->C);   //y - scratch == _x_vector argument
420     initSizeKarg(&args[3], blasArgs->M);
421     initSizeKarg(&args[4], blasArgs->N);
422     initSizeKarg(&args[5], blasArgs->K);
423     initSizeKarg(&args[6], blasArgs->lda.matrix);
424     initSizeKarg(&args[7], blasArgs->ldb.matrix);
425     initSizeKarg(&args[8], blasArgs->ldc.matrix);
426     initSizeKarg(&args[9], blasArgs->offA);
427     initSizeKarg(&args[10], blasArgs->offBX);
428     initSizeKarg(&args[11], blasArgs->offCY);
429     assignScalarKarg(&args[12], &(blasArgs->alpha), blasArgs->dtype);
430     assignScalarKarg(&args[13], &(blasArgs->beta), blasArgs->dtype);
431     return;
432 }
433 
434 static SolverFlags
solverFlags(void)435 solverFlags(void)
436 {
437     return (SF_WSPACE_1D);
438 }
439 
440 extern "C"
441 void
initGemmV2CachedPattern(MemoryPattern * mempat)442 initGemmV2CachedPattern(MemoryPattern *mempat)
443 {
444     mempat->name = "Cached global memory based block gemm";
445     mempat->nrLevels = 2;
446     mempat->cuLevel = 0;
447     mempat->thLevel = 1;
448     mempat->sops = &gemmSops;
449 
450     mpatExtra.aMset = CLMEM_LEVEL_L1;
451     mpatExtra.bMset = CLMEM_LEVEL_L1;
452     mpatExtra.mobjA = CLMEM_BUFFER;
453     mpatExtra.mobjB = CLMEM_BUFFER;
454     mempat->extra = &mpatExtra;
455 
456 
457     Prefix[TYPE_FLOAT] = 'S';
458     Prefix[TYPE_DOUBLE] = 'D';
459     Prefix[TYPE_COMPLEX_FLOAT] = 'C';
460     Prefix[TYPE_COMPLEX_DOUBLE] = 'Z';
461 }
462 
463 static int
getDefaultDecomposition(PGranularity * pgran,SubproblemDim * subdims,unsigned int subdimsNum,void * pArgs)464 getDefaultDecomposition(
465     PGranularity *pgran,
466     SubproblemDim *subdims,
467     unsigned int subdimsNum,
468     void *pArgs)
469 {
470 
471     DUMMY_ARG_USAGE(pArgs);
472     //
473     // FIXME:  container_of() - Counts on the fact that "getDefaultDecomposition" is called
474     //          with step->pgran, step->subdims
475     //
476     SolutionStep *step = container_of( pgran , pgran, SolutionStep);
477 
478     blockSizes bestSize = bestBlockSizeForDevice( step );
479 
480     pgran->wgSize[0] = bestSize.TY * bestSize.TX;
481     pgran->wgSize[1] = 1;
482     pgran->wgDim = 1;
483 
484     if (subdimsNum >= 1)
485     {
486         subdims[0].y = bestSize.TY * bestSize.ITEMY;
487         subdims[0].x = bestSize.TX * bestSize.ITEMX;
488         subdims[0].itemY = bestSize.ITEMY;
489         subdims[0].itemX = bestSize.ITEMX;
490         subdims[0].bwidth = 4;
491     }
492     if (subdimsNum >= 2)
493     {
494         subdims[1].y = bestSize.TY * bestSize.ITEMY;
495         subdims[1].x = bestSize.TX * bestSize.ITEMX;
496         subdims[1].itemY = bestSize.ITEMY;
497         subdims[1].itemX = bestSize.ITEMX;
498         subdims[1].bwidth = 4;
499     }
500 
501     return 0;
502 }
503 
504