1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * Cached global buffers based gemm generator
20  */
21 
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include <kprintf.hpp>
31 #include <gemm.clT>
32 #include <symm_helper.clT>
33 #include <solution_seq.h>
34 
35 extern "C" int
36 gemmHasNTail(size_t N, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB);
37 
38 extern "C" int
39 gemmHasMTail(size_t M, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB);
40 
41 
42 //#define DEBUG_GEMM_TAIL
43 static CLBLASMpatExtra mpatExtra;
44 
45 static char Prefix[4];
46 
47 static ssize_t
48 generator(
49    char *buf,
50    size_t buflen,
51    const struct SubproblemDim *subdims,
52    const struct PGranularity *pgran,
53    void *extra);
54 
55 static void
56 assignKargs(KernelArg *args, const void *params, const void *extra);
57 
58 static SolverFlags
59 solverFlags(void);
60 
61 static void
62 setBuildOpts(
63     char * buildOptStr,
64     const void *args);
65 
66 static void
67 calcNrThreads(
68     size_t threads[2],
69     const SubproblemDim *subdims,
70     const PGranularity *pgran,
71     const void *args,
72     const void *extra);
73 
74 static SolverOps gemmSops = {
75     generator,
76     assignKargs,
77     NULL,
78     NULL,
79    	NULL,
80     calcNrThreads,
81     NULL,
82     solverFlags,
83     NULL,
84 	NULL,
85 	NULL,
86 	setBuildOpts,
87 	NULL
88 };
89 
90 static void
setBuildOpts(char * buildOptStr,const void * args)91 setBuildOpts(
92     char * buildOptStr,
93     const void *args)
94 {
95 	const SolutionStep *step = (const SolutionStep *)args;
96     const CLBlasKargs *kargs = (const CLBlasKargs *)(&step->args);
97     KernelExtraFlags kflags = step->extraFlags;
98 
99 	addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DTAIL_RUN -DM_TAIL_PRESENT -DN_TAIL_PRESENT");
100     if ( kargs->dtype == TYPE_DOUBLE || kargs->dtype == TYPE_COMPLEX_DOUBLE)
101     {
102         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DDOUBLE_PRECISION");
103         #ifdef DEBUG_GEMM_TAIL
104         printf("Setting build options ... Double... for DOUBLE PRECISION support\n");
105         #endif
106     }
107 
108     if (isComplexType(kargs->dtype))
109     {
110         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCOMPLEX");
111     }
112 
113     if (kflags & KEXTRA_CONJUGATE_A)
114     {
115         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_A");
116     }
117     if (kflags & KEXTRA_CONJUGATE_B)
118     {
119         addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DCONJUGATE_B");
120     }
121 
122 
123     switch(kargs->pigFuncID)
124     {
125         case CLBLAS_GEMM2:
126         case CLBLAS_GEMM_TAIL:
127             break;
128 
129         case CLBLAS_HERK:
130             addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK");
131             if(kargs->uplo == clblasLower)
132             {
133                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_LOWER_TRIANGLE");
134             }
135             else if(kargs->uplo == clblasUpper)
136             {
137                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DHERK_UPPER_TRIANGLE");
138             }
139             break;
140 
141         case CLBLAS_HEMM:
142         case CLBLAS_SYMM_DIAGONAL:
143         case CLBLAS_HEMM_DIAGONAL:
144         case CLBLAS_SYMM:
145             #ifdef DEBUG_GEMM_2
146             printf("GEMM2: setBuildOpts: Setting options for SYMM\n");
147             #endif
148             if (kargs->side == clblasLeft)
149             {
150                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LEFT__");
151             }
152             if (kargs->side == clblasRight)
153             {
154                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_RIGHT__");
155             }
156             if (kargs->uplo == clblasLower)
157             {
158                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_LOWER__");
159             }
160             if (kargs->uplo == clblasUpper)
161             {
162                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_UPPER__");
163             }
164             // Define the order for Legacy sake.
165             if (kargs->order == clblasColumnMajor)
166             {
167                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_COLMAJOR__");
168             } else {
169                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_ROWMAJOR__");
170             }
171             if ((kargs->pigFuncID == CLBLAS_SYMM_DIAGONAL) || (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL))
172             {
173                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__SYMM_DIAGONAL__");
174             }
175             if (kargs->pigFuncID == CLBLAS_HEMM_DIAGONAL)
176             {
177                 addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-D__HEMM__");
178             }
179             break;
180 
181         default:
182             printf("GEMM TAIL: Unknown pigFuncID\n");
183             break;
184     }
185     #ifdef DEBUG_GEMM_TAIL
186     printf("GEMMTAIL: Build options = %s\n", buildOptStr);
187     #endif
188 }
189 
190 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * extra)191 calcNrThreads(
192     size_t threads[2],
193     const SubproblemDim *subdims,
194     const PGranularity *pgran,
195     const void *args,
196     const void *extra)
197 {
198     int BLOCKSIZE = pgran->wgSize[0]; // 1D Block
199 	size_t tailM, tailN, M, N;
200 	size_t Y, X;
201 	size_t nWorkGroupsAY, nWorkGroupsAX, nWorkGroupsA;
202 	size_t nWorkGroupsBY, nWorkGroupsBX, nWorkGroupsB;
203 	size_t totalWorkGroups;
204     #ifdef DEBUG_GEMM_TAIL
205     printf("calcNrThreads called from gemm_tail.cpp\n");
206     #endif
207     const CLBlasKargs *kargs = (const CLBlasKargs *)args;
208     const CLBLASKernExtra *kextra = ( CLBLASKernExtra *)extra;
209 	KernelExtraFlags kflags = kextra->flags;
210 
211 	//
212 	// RowMajor GEMM can be expressed in terms of Column Major GEMM
213 	//
214     if ((kflags & KEXTRA_COLUMN_MAJOR) == 0)
215     {
216     	printf("calcNrThreads: FIXME: RowMajor is NOT supported \n");
217         return;
218     }
219 
220 	if (kextra->vecLenA != 1)
221 	{
222     	printf("GEMM_TAIL: calcNrThreads(): Vector Length must be 1 for TAIL. Non-one Vector Length Requested\n");
223 		return;
224 	}
225 
226 	tailM = kargs->tailStartM;
227 	tailN = kargs->tailStartN;
228 	M = kargs->M;
229 	N = kargs->N;
230 
231     Y = 8;
232     if (Y != subdims->y)
233 	{
234 		Y = subdims->y;
235 	}
236     X = BLOCKSIZE/Y;
237     /*
238     LEGACY CODE: Outdated now. TAIL can handle this condition now using MTAIL_PRESENT and NTAIL_PRESENT
239 	if (tailN % X)
240 	{
241 		printf("GEMM_TAIL: calcNrThreads(): WARNING: tailN is not divisible by X. Will produce Wrong results!\n");
242 	}
243     */
244 
245 	//
246 	// A Tail Workgroup will process YxX panel
247 	//
248 	/*
249 			 ______________
250 			|			|  |
251 			|			|  |
252 			|			|  | B Tail panel
253 			|___________|  |
254 			|___________|__|
255 		    <---  A   -->
256 	 */
257 	if(tailM != M)
258 	{
259 		#ifdef DEBUG_GEMM_TAIL
260 		printf("GEMM_TAIL: M has TAIL\n");
261 		#endif
262 		nWorkGroupsAY = ((M - tailM -1)/Y + 1);
263 		nWorkGroupsAX = ((tailN - 1)/X + 1);
264 		nWorkGroupsA = nWorkGroupsAY * nWorkGroupsAX;
265 	} else {
266 		nWorkGroupsA = 0;
267 	}
268 
269 	if (tailN != N)
270 	{
271 		#ifdef DEBUG_GEMM_TAIL
272 		printf("GEMM_TAIL: N has TAIL\n");
273 		#endif
274 		nWorkGroupsBY = ((M-1)/Y) + 1;
275 		nWorkGroupsBX = ((N-tailN-1)/X) + 1;
276 		nWorkGroupsB = nWorkGroupsBY * nWorkGroupsBX;
277 	} else {
278 		nWorkGroupsB = 0;
279 	}
280 
281 	totalWorkGroups = nWorkGroupsA + nWorkGroupsB;
282 
283 	threads[0] = totalWorkGroups * BLOCKSIZE;
284 	threads[1] = 1;
285 	#ifdef DEBUG_GEMM_TAIL
286 	printf("GEMM_TAIL: calcNrThreads(): vlen:%d, <tailM:%lu, M:%lu>, <tailN:%lu, N:%lu, nWorkGroupsA<%lu,%lu>, nWorkGroupsB<%lu,%lu>\n",
287 			kextra->vecLenA, tailM, M, tailN, N, nWorkGroupsAY, nWorkGroupsAX, nWorkGroupsBY, nWorkGroupsBX);
288 	printf("GEMM_TAIL: calcNrThreads(): globalThreads0=%lu, globalThreads1=%lu\n", threads[0], threads[1]);
289 	#endif
290 	return;
291 }
292 
293 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)294 generator(
295    char *buf,
296    size_t buflen,
297    const struct SubproblemDim *subdims,
298    const struct PGranularity *pgran,
299    void *extra)
300 {
301     CLBLASKernExtra *kextra = (CLBLASKernExtra*)extra;
302     KernelExtraFlags kflags = kextra->flags;
303     DataType dtype = kextra->dtype;
304     char tempTemplate[32*1024];
305     char itemx[10], itemy[10], width[10], itemy_by_width[10], itemx_by_width[10];
306     size_t Y, X, BLOCKSIZE, ITEMX, ITEMY;
307 
308     if (buf == NULL)
309     {
310         buflen = 32*1024*sizeof(char);
311         return (ssize_t)buflen;
312     }
313 
314     //
315     // PENDING: Add Support for Row Major at the xAPI.c level
316 	// Row major calcs can be expressed in terms of column major
317     //
318     if ((kflags & KEXTRA_COLUMN_MAJOR) == 0)
319     {
320         return 0;
321     }
322 
323     kprintf kobj(Prefix[dtype], 1, false, false); // Only Scalar Access
324 
325     BLOCKSIZE = pgran->wgSize[0];
326     #ifdef DEBUG_GEMM_TAIL
327     printf("GEMM- generator(): Blocksize passed = %lu, subdimy = %lu, subdimx = %lu, veclen = %d \n", BLOCKSIZE, subdims->y, subdims->x, kextra->vecLenA);
328     #endif
329 
330     Y = 8;
331     if (Y != subdims->y)
332 	{
333 		//printf("GEMM_TAIL: generator(): WARNING: subdims->y is un-suitable.\n");
334 		Y = subdims->y;
335 	}
336     X = BLOCKSIZE/Y;
337     ITEMY = (subdims->y) / Y;
338     ITEMX = (subdims->x) / X;
339     if (ITEMX == 0)
340     {
341         ITEMX = 1;
342     }
343 
344     if ((BLOCKSIZE % Y) || ((subdims->y) % Y) || ((subdims->x)%X) || (ITEMY % kextra->vecLenA) || ((X*ITEMX) % kextra->vecLenA))
345     {
346         printf("WARNING: GEMM TAIL - generator: subdim and blocksize in-compatible. This code should never execute!\n");
347     }
348 
349     sprintf(width, "%lu", Y);
350     sprintf(itemy, "%lu", ITEMY);
351     sprintf(itemx, "%lu", ITEMX);
352     sprintf(itemy_by_width, "%lu", (size_t) ITEMY/kextra->vecLenA);
353     sprintf(itemx_by_width, "%lu", (size_t) ITEMX/kextra->vecLenA);
354 
355     kobj.put("%WIDTH", width);
356     kobj.put("%ITEMX", itemx);
357     kobj.put("%ITEMY", itemy);
358     kobj.put("%ITEMY_BY_V", itemy_by_width);
359     kobj.put("%ITEMX_BY_V", itemx_by_width);
360     kobj.put("%PANEL", "1");
361     kobj.put("%PANEL_BY_V", "1");
362     #ifdef DEBUG_GEMM_TAIL
363     printf("ColMajor GEMM - WIDTH = %s, ITEMX = %s, ITEMY = %s\n", width, itemx, itemy);
364     #endif
365 
366     strcpy(tempTemplate, SYMM_HEMM_HELPER);
367     if ((kflags & KEXTRA_TRANS_A) == 0)
368     {
369         if (kflags & KEXTRA_TRANS_B)
370         {
371 			#ifdef DEBUG_GEMM_TAIL
372 			printf("GEMM_TAIL: Using GEMM_NT_KERNEL\n");
373 			#endif
374             strcat(tempTemplate, GEMM_NT_KERNEL);
375         } else {
376 			#ifdef DEBUG_GEMM_TAIL
377 			printf("GEMM_TAIL: Using GEMM_NN_KERNEL\n");
378 			#endif
379             strcat(tempTemplate, GEMM_NN_KERNEL);
380 		}
381     } else {
382         //
383         // NOTE: A^T * B Never leaves any tails. This should NEVER be called.
384         // PENDING: A^T * B^T support is PENDING
385         tempTemplate[0] = 0;
386     }
387 
388     kobj.spit(buf, tempTemplate);
389     //#ifdef DEBUG_GEMM_TAIL
390     //printf("Kernel = \n%s\n", buf);
391     //#endif
392     size_t tail = strlen(buf) + 1;
393     while(tail < 32*1024)
394     {
395         buf[tail++] = 0;
396     }
397     return 32*1024*sizeof(char);
398 }
399 
400 static void
assignKargs(KernelArg * args,const void * params,const void *)401 assignKargs(KernelArg *args, const void *params, const void*)
402 {
403     CLBlasKargs *blasArgs = (CLBlasKargs*)params;
404 
405     #ifdef DEBUG_GEMM_TAIL
406     printf("SAlpha=%f, DAlpha=%f, CAlpha =<%f, %f>, DAlpha=<%f, %f>\n",
407             blasArgs->alpha.argFloat, blasArgs->alpha.argDouble, CREAL(blasArgs->alpha.argFloatComplex), CIMAG(blasArgs->alpha.argFloatComplex),
408             CREAL(blasArgs->alpha.argDoubleComplex) , CIMAG(blasArgs->alpha.argDoubleComplex));
409     printf("SBeta=%f, DBeta=%f, CBeta=<%f, %f>, DBeta=<%f, %f>\n",
410             blasArgs->beta.argFloat, blasArgs->beta.argDouble, CREAL(blasArgs->beta.argFloatComplex), CIMAG(blasArgs->beta.argFloatComplex),
411             CREAL(blasArgs->beta.argDoubleComplex) , CIMAG(blasArgs->beta.argDoubleComplex));
412 	printf("TailStartM = %lu, TailStartN = %lu\n", blasArgs->tailStartM, blasArgs->tailStartN);
413     #endif
414 
415     INIT_KARG(&args[0], blasArgs->A);   //A - input matrix - argument
416     INIT_KARG(&args[1], blasArgs->B);   //x - result buffer = _xnew argument
417     INIT_KARG(&args[2], blasArgs->C);   //y - scratch == _x_vector argument
418     initSizeKarg(&args[3], blasArgs->M);
419     initSizeKarg(&args[4], blasArgs->N);
420     initSizeKarg(&args[5], blasArgs->K);
421     initSizeKarg(&args[6], blasArgs->lda.matrix);
422     initSizeKarg(&args[7], blasArgs->ldb.matrix);
423     initSizeKarg(&args[8], blasArgs->ldc.matrix);
424     initSizeKarg(&args[9], blasArgs->offA);
425     initSizeKarg(&args[10], blasArgs->offBX);
426     initSizeKarg(&args[11], blasArgs->offCY);
427     assignScalarKarg(&args[12], &(blasArgs->alpha), blasArgs->dtype);
428     assignScalarKarg(&args[13], &(blasArgs->beta), blasArgs->dtype);
429     initSizeKarg(&args[14], blasArgs->tailStartM);
430     initSizeKarg(&args[15], blasArgs->tailStartN);
431     return;
432 }
433 
434 static SolverFlags
solverFlags(void)435 solverFlags(void)
436 {
437     return (SF_WSPACE_1D);
438 }
439 
440 extern "C"
441 void
initGemmV2TailCachedPattern(MemoryPattern * mempat)442 initGemmV2TailCachedPattern(MemoryPattern *mempat)
443 {
444     mempat->name = "Cached global memory based gemm tail";
445     mempat->nrLevels = 2;
446     mempat->cuLevel = 0;
447     mempat->thLevel = 1;
448     mempat->sops = &gemmSops;
449 
450     mpatExtra.aMset = CLMEM_LEVEL_L1;
451     mpatExtra.bMset = CLMEM_LEVEL_L1;
452     mpatExtra.mobjA = CLMEM_BUFFER;
453     mpatExtra.mobjB = CLMEM_BUFFER;
454     mempat->extra = &mpatExtra;
455 
456 
457     Prefix[TYPE_FLOAT] = 'S';
458     Prefix[TYPE_DOUBLE] = 'D';
459     Prefix[TYPE_COMPLEX_FLOAT] = 'C';
460     Prefix[TYPE_COMPLEX_DOUBLE] = 'Z';
461 }
462 
463