1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <stdio.h>
19 #include "gen_helper_legacy.h"
20 #include "blas_kgen_legacy.h"
21 #include "../gen_helper.h"
22 
23 typedef struct CopyPattern {
24     SubproblemDim dim;
25     const PGranularity *pgran;
26     DataType dtype;
27     DBlockCopyDirection dir;
28     DBlockCopyFlags flags;
29     bool generic;
30     bool zeroing;
31 } CopyPattern;
32 
33 static int
cpyImgGenCallback(struct KgenContext * ctx,const void * pattern)34 cpyImgGenCallback(struct KgenContext *ctx, const void *pattern)
35 {
36     const CopyPattern *pat = (CopyPattern*)pattern;
37     const void *dim = (pat->generic) ? NULL : &pat->dim;
38     if(pat->zeroing) {
39         return f4zeroBlockGen(ctx, dim, pat->pgran, "__local");
40     }
41     else {
42         return copyDataBlockGen(ctx, dim, pat->pgran, pat->dtype, pat->dir,
43                                 pat->flags);
44     }
45 }
46 
47 int
generateImageCopyFuncs(CopyImgFuncs * copyFuncs,struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset)48 generateImageCopyFuncs(
49     CopyImgFuncs *copyFuncs,
50     struct KgenContext *ctx,
51     BlasFunctionID funcID,
52     const BlasGenSettings *gset)
53 {
54     const SubproblemDim *dims = gset->subdims;
55     KernelExtraFlags kflags = gset->kextra->flags;
56     DataType dtype = gset->kextra->dtype;
57     const PGranularity *pgran = gset->pgran;
58     CopyPattern pattern;
59     // mandatory flags for global to local copying
60     DBlockCopyFlags glcpFlags[2] = {0, 0};
61     struct KgenGuard *guard;
62     unsigned int tsize;
63     int ret = 0;
64     bool isTra, areTails, isConjA;
65     bool customize;
66 
67     if (kflags & KEXTRA_NO_COPY_VEC_A) {
68         glcpFlags[0] = DBLOCK_COPY_NOT_VECTORIZE;
69     }
70     if (kflags & KEXTRA_NO_COPY_VEC_B) {
71         glcpFlags[1] = DBLOCK_COPY_NOT_VECTORIZE;
72     }
73 
74     tsize = dtypeSize(dtype);
75     isTra = isMatrixAccessColMaj(funcID, kflags, MATRIX_A);
76     isConjA = isMatrixConj(kflags, MATRIX_A);
77     areTails = (kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N));
78     customize = (funcID == CLBLAS_TRMM);
79 
80     guard = createKgenGuard(ctx, cpyImgGenCallback, sizeof(CopyPattern));
81     if (guard == NULL) {
82         return -ENOMEM;
83     }
84 
85     memset(&pattern, 0, sizeof(pattern));
86 
87     pattern.zeroing = false;
88     pattern.dim = dims[0];
89     pattern.dir = DBLOCK_GLOBAL_TO_IMAGE;
90     pattern.dtype = dtype;
91     pattern.flags = 0;
92     pattern.generic = false;
93     pattern.pgran = pgran;
94 
95     if (!(customize && (isTra || isConjA))) {
96         pattern.dim.x = dims[0].bwidth;
97         pattern.dim.y = dims[0].y;
98         findGenerateFunction(guard, &pattern, copyFuncs->globalToImage[MATRIX_A],
99                              FUNC_NAME_MAXLEN);
100         kgenAddBlankLine(ctx);
101     }
102 
103     pattern.dim.x = dims[0].bwidth;
104     pattern.dim.y = dims[0].x;
105     findGenerateFunction(guard, &pattern, copyFuncs->globalToImage[MATRIX_B],
106                          FUNC_NAME_MAXLEN);
107     kgenAddBlankLine(ctx);
108 
109     pattern.dim.x = dims[0].bwidth;
110     pattern.dim.y = dims[1].y;
111     pattern.dir = DBLOCK_LOCAL_TO_IMAGE;
112     findGenerateFunction(guard, &pattern, copyFuncs->localToImage[MATRIX_A],
113                          FUNC_NAME_MAXLEN);
114     kgenAddBlankLine(ctx);
115 
116     pattern.dim.x = dims[0].bwidth;
117     pattern.dim.y = dims[1].x;
118     pattern.dir = DBLOCK_LOCAL_TO_IMAGE;
119     findGenerateFunction(guard, &pattern, copyFuncs->localToImage[MATRIX_B],
120                          FUNC_NAME_MAXLEN);
121     kgenAddBlankLine(ctx);
122 
123     // Global to local optimized
124     pattern.dir = DBLOCK_GLOBAL_TO_LOCAL;
125     if (customize || isComplexType(dtype)) {
126         pattern.flags = (!customize || isConjA) ? DBLOCK_COPY_CONJUGATE : 0;
127         pattern.flags |= glcpFlags[0];
128         pattern.dim.x = dims[0].bwidth;
129         pattern.dim.y = dims[1].y;
130         findGenerateFunction(guard, &pattern, copyFuncs->globalToLocal[MATRIX_A],
131                              FUNC_NAME_MAXLEN);
132         kgenAddBlankLine(ctx);
133     }
134 
135     if ((funcID == CLBLAS_GEMM) && isComplexType(dtype)) {
136         pattern.flags = DBLOCK_COPY_CONJUGATE | glcpFlags[1];
137         pattern.dim.x = dims[0].bwidth;
138         pattern.dim.y = dims[1].x;
139         findGenerateFunction(guard, &pattern, copyFuncs->globalToLocal[MATRIX_B],
140                              FUNC_NAME_MAXLEN);
141         kgenAddBlankLine(ctx);
142     }
143 
144     // Global to local generic
145     pattern.dim = dims[0];
146     pattern.dir = DBLOCK_GLOBAL_TO_LOCAL;
147     pattern.generic = true;
148     if (!customize || areTails) {
149         pattern.flags = (isConjA) ? DBLOCK_COPY_CONJUGATE : 0;
150         pattern.flags |= glcpFlags[0];
151         findGenerateFunction(guard, &pattern,
152                              copyFuncs->globalToLocalGeneric[MATRIX_A],
153                              FUNC_NAME_MAXLEN);
154         kgenAddBlankLine(ctx);
155     }
156 
157     pattern.flags = (kflags & KEXTRA_CONJUGATE_B) ? DBLOCK_COPY_CONJUGATE : 0;
158     pattern.flags |= glcpFlags[1];
159     findGenerateFunction(guard, &pattern,
160                          copyFuncs->globalToLocalGeneric[MATRIX_B],
161                          FUNC_NAME_MAXLEN);
162     kgenAddBlankLine(ctx);
163 
164     // Global to local transposed functions
165     pattern.dir = DBLOCK_GLOBAL_TO_LOCAL;
166     pattern.flags = (kflags & KEXTRA_NO_COPY_VEC_A) ?
167                     DBLOCK_COPY_NOT_VECTORIZE : 0;
168     pattern.flags |= glcpFlags[0];
169     if (!customize || isTra) {
170         pattern.generic = false;
171         if (isConjA) {
172             pattern.flags |= DBLOCK_COPY_TRANSPOSE | DBLOCK_COPY_CONJUGATE;
173         }
174         else {
175             pattern.flags |= DBLOCK_COPY_TRANSPOSE;
176         }
177         pattern.dim.x = dims[1].y;
178         pattern.dim.y = dims[0].bwidth;
179 
180         findGenerateFunction(guard, &pattern,
181                              copyFuncs->globalToLocalTransposed[MATRIX_A],
182                              FUNC_NAME_MAXLEN);
183         kgenAddBlankLine(ctx);
184     }
185 
186     if (!customize || (isTra && areTails)) {
187         pattern.generic = true;
188         pattern.dim.x = 0;
189         pattern.dim.y = 0;
190         findGenerateFunction(guard, &pattern,
191                          copyFuncs->globalToLocalTransposedGeneric[MATRIX_A],
192                          FUNC_NAME_MAXLEN);
193         kgenAddBlankLine(ctx);
194     }
195 
196     pattern.generic = false;
197     pattern.dim.x = dims[1].x;
198     pattern.dim.y = dims[0].bwidth;
199     if (kflags & KEXTRA_CONJUGATE_B) {
200         pattern.flags = DBLOCK_COPY_TRANSPOSE | DBLOCK_COPY_CONJUGATE;
201     }
202     else {
203         pattern.flags = DBLOCK_COPY_TRANSPOSE;
204     }
205     pattern.flags |= glcpFlags[1];
206     findGenerateFunction(guard, &pattern,
207                          copyFuncs->globalToLocalTransposed[MATRIX_B],
208                          FUNC_NAME_MAXLEN);
209     kgenAddBlankLine(ctx);
210 
211     pattern.generic = true;
212     pattern.dim.x = 0;
213     pattern.dim.y = 0;
214     findGenerateFunction(guard, &pattern,
215                          copyFuncs->globalToLocalTransposedGeneric[MATRIX_B],
216                          FUNC_NAME_MAXLEN);
217     kgenAddBlankLine(ctx);
218 
219     // generate two local zeroing functions for matrix A and matrix B blocks
220     pattern.zeroing = true;
221     pattern.dim = dims[0];
222     pattern.generic = false;
223     pattern.flags = 0;
224     pattern.dim.y = 1;
225     pattern.dim.x = fl4RowWidth(dims[0].bwidth, tsize) * dims[1].y;
226 
227     findGenerateFunction(guard, &pattern,
228                          copyFuncs->zeroBlock[MATRIX_A],
229                          FUNC_NAME_MAXLEN);
230     kgenAddBlankLine(ctx);
231 
232     pattern.dim.x = fl4RowWidth(dims[0].bwidth, tsize) * dims[1].x;
233     findGenerateFunction(guard, &pattern,
234                          copyFuncs->zeroBlock[MATRIX_B],
235                          FUNC_NAME_MAXLEN);
236     ret = kgenAddBlankLine(ctx);
237 
238     destroyKgenGuard(guard);
239     return ret;
240 }
241 
242 int
generateResultUpdateOld(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,const char * optFuncName,const char * genericFuncName)243 generateResultUpdateOld(
244     struct KgenContext *ctx,
245     BlasFunctionID funcID,
246     const BlasGenSettings *gset,
247     const char *optFuncName,
248     const char *genericFuncName)
249 {
250     UpdateResultFlags flags;
251 
252     flags = kextraToUpresFlags(funcID, gset->kextra->flags);
253 
254     return genResultUpdateWithFlagsOld(ctx, funcID, gset, flags,
255                                        optFuncName, genericFuncName, NULL);
256 }
257 
258 int
genResultUpdateWithFlagsOld(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,UpdateResultFlags flags,const char * optFuncName,const char * genericFuncName,const char * cachedName)259 genResultUpdateWithFlagsOld(
260     struct KgenContext *ctx,
261     BlasFunctionID funcID,
262     const BlasGenSettings *gset,
263     UpdateResultFlags flags,
264     const char *optFuncName,
265     const char *genericFuncName,
266     const char *cachedName)
267 {
268     KernelExtraFlags kflags = gset->kextra->flags;
269     UpdateResultOp op;
270     char tmp[1024];
271     int ret = 0;
272     const char *coordY, *coordX;
273     UpresVarNames uvars;
274     const KernelVarNames *kvarNames = &gset->varNames;
275     const SubproblemDim *dim = &gset->subdims[1];
276     bool areTails, useCondition;
277 
278     memset(&uvars, 0, sizeof(uvars));
279 
280     coordX = kvarNames->coordB;
281     coordY = kvarNames->coordA;
282 
283     if (funcHasTriangMatrix(funcID)) {
284         if (flags & UPRES_TRIANG_WRITE_C) {
285             uvars.result = "C";
286         }
287         else {
288             uvars.result = "B";
289         }
290         uvars.ld = "ldb";
291     }
292     else {
293         uvars.result = "C";
294         uvars.ld = "ldc";
295     }
296 
297     uvars.cachedName = cachedName;
298 
299     /* For now, kernels that do not use UPRES_EXCEED_PROBLEM_CONDITION
300      * must return in case problem exceeds more precise lower level conditions
301      * (KEXTRA_TAILS_M_LOWER, KEXTRA_TAILS_N_LOWER) before updating result
302     */
303     areTails = (kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N));
304     useCondition = areTails && ((flags & UPRES_EXCEED_PROBLEM_CONDITION) != 0);
305     if (useCondition) {
306         bool tailM = (kflags & KEXTRA_TAILS_M) != 0;
307         bool tailN = (kflags & KEXTRA_TAILS_N) != 0;
308 
309         if (tailM) {
310             if (tailN) {
311                 sprintf(tmp, "if ((%s < %s) && (%s < %s))",
312                         coordY, kvarNames->sizeM, coordX, kvarNames->sizeN);
313             }
314             else {
315                 sprintf(tmp, "if (%s < %s)", coordY, kvarNames->sizeM);
316             }
317         }
318         else {
319             // here tailN is true
320             sprintf(tmp, "if (%s < %s)", coordX, kvarNames->sizeN);
321         }
322         kgenBeginBranch(ctx, tmp);
323     }
324     else {
325         kgenAddBlankLine(ctx);
326     }
327 
328     if (optFuncName) {
329         const char *betaStr;
330         betaStr = (flags & UPRES_WITH_BETA) ? ", beta" : "";
331 
332         // update with functions invoking
333         if (!(kflags & (KEXTRA_TAILS_M_LOWER | KEXTRA_TAILS_N_LOWER))) {
334             sprintf(tmp, "%s(%s, c, alpha, %s, %s, %s%s);\n",
335                     optFuncName, uvars.result, coordY, coordX,
336                     uvars.ld, betaStr);
337         }
338         else {
339             sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
340                          "uint x = min(%luu, %s - (uint)%s);\n"
341 
342                          "if ((y == %lu) && (x == %lu)) {\n"
343                          "    %s(%s, c, alpha, %s, %s, %s%s);\n"
344                          "}\n"
345                          "else {\n"
346                          "    %s(%s, c, alpha, %s, %s, %s%s, y, x);\n"
347                          "}\n",
348                      dim->y, kvarNames->sizeM, coordY,
349                      dim->x, kvarNames->sizeN, coordX,
350                      dim->y, dim->x,
351                      optFuncName, uvars.result, coordY, coordX, uvars.ld,
352                      betaStr,
353                      genericFuncName, uvars.result, coordY, coordX, uvars.ld,
354                      betaStr);
355         }
356 
357         kgenAddStmt(ctx, tmp);
358     }
359     else {
360         // inline result update
361         flags |= UPRES_INLINE;
362 
363         op = (flags & UPRES_WITH_BETA) ? UPRES_SUM : UPRES_SET;
364 
365         uvars.startRow = coordY;
366         uvars.startCol = coordX;
367         uvars.nrRows = "y";
368         uvars.nrCols = "x";
369 
370         if (!(kflags & (KEXTRA_TAILS_M_LOWER | KEXTRA_TAILS_N_LOWER))) {
371             ret = updateResultGenOld(ctx, gset, op, flags, &uvars);
372         }
373         else {
374             sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
375                          "uint x = min(%luu, %s - (uint)%s);\n",
376                     dim->y, kvarNames->sizeM, coordY,
377                     dim->x, kvarNames->sizeN, coordX);
378             kgenAddStmt(ctx, tmp);
379 
380             sprintf(tmp, "if ((y == %lu) && (x == %lu))",
381                     dim->y, dim->x);
382             kgenBeginBranch(ctx, tmp);
383             // optimized update
384             updateResultGenOld(ctx, gset, op, flags, &uvars);
385             kgenEndBranch(ctx, NULL);
386 
387             flags |= UPRES_GENERIC;
388             kgenBeginBranch(ctx, "else ");
389             // not optimized update
390             updateResultGenOld(ctx, gset, op, flags, &uvars);
391             ret = kgenEndBranch(ctx, NULL);
392         }
393     }
394 
395     if (useCondition) {
396         ret = kgenEndBranch(ctx, NULL);
397     }
398 
399     return (ret) ? -EOVERFLOW : 0;
400 }
401 
402 int
genUpresFuncsWithFlags(struct KgenContext * ctx,const BlasGenSettings * gset,UpdateResultFlags flags,char optFuncName[FUNC_NAME_MAXLEN],char genericFuncName[FUNC_NAME_MAXLEN])403 genUpresFuncsWithFlags(
404     struct KgenContext *ctx,
405     const BlasGenSettings *gset,
406     UpdateResultFlags flags,
407     char optFuncName[FUNC_NAME_MAXLEN],
408     char genericFuncName[FUNC_NAME_MAXLEN])
409 {
410     KernelExtraFlags kflags = gset->kextra->flags;
411     UpdateResultOp op;
412     int ret;
413 
414     op = (flags & UPRES_WITH_BETA) ? UPRES_SUM : UPRES_SET;
415 
416     updateResultGenOld(ctx, gset, op, flags, NULL);
417     ret = kgenAddBlankLine(ctx);
418     if (ret) {
419         return -EOVERFLOW;
420     }
421 
422     kgenGetLastFuncName(optFuncName, FUNC_NAME_MAXLEN, ctx);
423 
424     if (kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N)) {
425         flags |= UPRES_GENERIC;
426         updateResultGenOld(ctx, gset, op, flags, NULL);
427         kgenAddBlankLine(ctx);
428         kgenGetLastFuncName(genericFuncName, FUNC_NAME_MAXLEN, ctx);
429     }
430 
431     return (ret) ? -EOVERFLOW : 0;
432 }
433 
434 int
generateUpresFuncs(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,char optFuncName[FUNC_NAME_MAXLEN],char genericFuncName[FUNC_NAME_MAXLEN])435 generateUpresFuncs(
436     struct KgenContext *ctx,
437     BlasFunctionID funcID,
438     const BlasGenSettings *gset,
439     char optFuncName[FUNC_NAME_MAXLEN],
440     char genericFuncName[FUNC_NAME_MAXLEN])
441 {
442     UpdateResultFlags flags;
443 
444     flags = kextraToUpresFlags(funcID, gset->kextra->flags);
445 
446     return genUpresFuncsWithFlags(ctx, gset, flags,
447                                   optFuncName, genericFuncName);
448 }
449