1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 #include <stdio.h>
19 #include "gen_helper_legacy.h"
20 #include "blas_kgen_legacy.h"
21 #include "../gen_helper.h"
22
23 typedef struct CopyPattern {
24 SubproblemDim dim;
25 const PGranularity *pgran;
26 DataType dtype;
27 DBlockCopyDirection dir;
28 DBlockCopyFlags flags;
29 bool generic;
30 bool zeroing;
31 } CopyPattern;
32
33 static int
cpyImgGenCallback(struct KgenContext * ctx,const void * pattern)34 cpyImgGenCallback(struct KgenContext *ctx, const void *pattern)
35 {
36 const CopyPattern *pat = (CopyPattern*)pattern;
37 const void *dim = (pat->generic) ? NULL : &pat->dim;
38 if(pat->zeroing) {
39 return f4zeroBlockGen(ctx, dim, pat->pgran, "__local");
40 }
41 else {
42 return copyDataBlockGen(ctx, dim, pat->pgran, pat->dtype, pat->dir,
43 pat->flags);
44 }
45 }
46
47 int
generateImageCopyFuncs(CopyImgFuncs * copyFuncs,struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset)48 generateImageCopyFuncs(
49 CopyImgFuncs *copyFuncs,
50 struct KgenContext *ctx,
51 BlasFunctionID funcID,
52 const BlasGenSettings *gset)
53 {
54 const SubproblemDim *dims = gset->subdims;
55 KernelExtraFlags kflags = gset->kextra->flags;
56 DataType dtype = gset->kextra->dtype;
57 const PGranularity *pgran = gset->pgran;
58 CopyPattern pattern;
59 // mandatory flags for global to local copying
60 DBlockCopyFlags glcpFlags[2] = {0, 0};
61 struct KgenGuard *guard;
62 unsigned int tsize;
63 int ret = 0;
64 bool isTra, areTails, isConjA;
65 bool customize;
66
67 if (kflags & KEXTRA_NO_COPY_VEC_A) {
68 glcpFlags[0] = DBLOCK_COPY_NOT_VECTORIZE;
69 }
70 if (kflags & KEXTRA_NO_COPY_VEC_B) {
71 glcpFlags[1] = DBLOCK_COPY_NOT_VECTORIZE;
72 }
73
74 tsize = dtypeSize(dtype);
75 isTra = isMatrixAccessColMaj(funcID, kflags, MATRIX_A);
76 isConjA = isMatrixConj(kflags, MATRIX_A);
77 areTails = (kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N));
78 customize = (funcID == CLBLAS_TRMM);
79
80 guard = createKgenGuard(ctx, cpyImgGenCallback, sizeof(CopyPattern));
81 if (guard == NULL) {
82 return -ENOMEM;
83 }
84
85 memset(&pattern, 0, sizeof(pattern));
86
87 pattern.zeroing = false;
88 pattern.dim = dims[0];
89 pattern.dir = DBLOCK_GLOBAL_TO_IMAGE;
90 pattern.dtype = dtype;
91 pattern.flags = 0;
92 pattern.generic = false;
93 pattern.pgran = pgran;
94
95 if (!(customize && (isTra || isConjA))) {
96 pattern.dim.x = dims[0].bwidth;
97 pattern.dim.y = dims[0].y;
98 findGenerateFunction(guard, &pattern, copyFuncs->globalToImage[MATRIX_A],
99 FUNC_NAME_MAXLEN);
100 kgenAddBlankLine(ctx);
101 }
102
103 pattern.dim.x = dims[0].bwidth;
104 pattern.dim.y = dims[0].x;
105 findGenerateFunction(guard, &pattern, copyFuncs->globalToImage[MATRIX_B],
106 FUNC_NAME_MAXLEN);
107 kgenAddBlankLine(ctx);
108
109 pattern.dim.x = dims[0].bwidth;
110 pattern.dim.y = dims[1].y;
111 pattern.dir = DBLOCK_LOCAL_TO_IMAGE;
112 findGenerateFunction(guard, &pattern, copyFuncs->localToImage[MATRIX_A],
113 FUNC_NAME_MAXLEN);
114 kgenAddBlankLine(ctx);
115
116 pattern.dim.x = dims[0].bwidth;
117 pattern.dim.y = dims[1].x;
118 pattern.dir = DBLOCK_LOCAL_TO_IMAGE;
119 findGenerateFunction(guard, &pattern, copyFuncs->localToImage[MATRIX_B],
120 FUNC_NAME_MAXLEN);
121 kgenAddBlankLine(ctx);
122
123 // Global to local optimized
124 pattern.dir = DBLOCK_GLOBAL_TO_LOCAL;
125 if (customize || isComplexType(dtype)) {
126 pattern.flags = (!customize || isConjA) ? DBLOCK_COPY_CONJUGATE : 0;
127 pattern.flags |= glcpFlags[0];
128 pattern.dim.x = dims[0].bwidth;
129 pattern.dim.y = dims[1].y;
130 findGenerateFunction(guard, &pattern, copyFuncs->globalToLocal[MATRIX_A],
131 FUNC_NAME_MAXLEN);
132 kgenAddBlankLine(ctx);
133 }
134
135 if ((funcID == CLBLAS_GEMM) && isComplexType(dtype)) {
136 pattern.flags = DBLOCK_COPY_CONJUGATE | glcpFlags[1];
137 pattern.dim.x = dims[0].bwidth;
138 pattern.dim.y = dims[1].x;
139 findGenerateFunction(guard, &pattern, copyFuncs->globalToLocal[MATRIX_B],
140 FUNC_NAME_MAXLEN);
141 kgenAddBlankLine(ctx);
142 }
143
144 // Global to local generic
145 pattern.dim = dims[0];
146 pattern.dir = DBLOCK_GLOBAL_TO_LOCAL;
147 pattern.generic = true;
148 if (!customize || areTails) {
149 pattern.flags = (isConjA) ? DBLOCK_COPY_CONJUGATE : 0;
150 pattern.flags |= glcpFlags[0];
151 findGenerateFunction(guard, &pattern,
152 copyFuncs->globalToLocalGeneric[MATRIX_A],
153 FUNC_NAME_MAXLEN);
154 kgenAddBlankLine(ctx);
155 }
156
157 pattern.flags = (kflags & KEXTRA_CONJUGATE_B) ? DBLOCK_COPY_CONJUGATE : 0;
158 pattern.flags |= glcpFlags[1];
159 findGenerateFunction(guard, &pattern,
160 copyFuncs->globalToLocalGeneric[MATRIX_B],
161 FUNC_NAME_MAXLEN);
162 kgenAddBlankLine(ctx);
163
164 // Global to local transposed functions
165 pattern.dir = DBLOCK_GLOBAL_TO_LOCAL;
166 pattern.flags = (kflags & KEXTRA_NO_COPY_VEC_A) ?
167 DBLOCK_COPY_NOT_VECTORIZE : 0;
168 pattern.flags |= glcpFlags[0];
169 if (!customize || isTra) {
170 pattern.generic = false;
171 if (isConjA) {
172 pattern.flags |= DBLOCK_COPY_TRANSPOSE | DBLOCK_COPY_CONJUGATE;
173 }
174 else {
175 pattern.flags |= DBLOCK_COPY_TRANSPOSE;
176 }
177 pattern.dim.x = dims[1].y;
178 pattern.dim.y = dims[0].bwidth;
179
180 findGenerateFunction(guard, &pattern,
181 copyFuncs->globalToLocalTransposed[MATRIX_A],
182 FUNC_NAME_MAXLEN);
183 kgenAddBlankLine(ctx);
184 }
185
186 if (!customize || (isTra && areTails)) {
187 pattern.generic = true;
188 pattern.dim.x = 0;
189 pattern.dim.y = 0;
190 findGenerateFunction(guard, &pattern,
191 copyFuncs->globalToLocalTransposedGeneric[MATRIX_A],
192 FUNC_NAME_MAXLEN);
193 kgenAddBlankLine(ctx);
194 }
195
196 pattern.generic = false;
197 pattern.dim.x = dims[1].x;
198 pattern.dim.y = dims[0].bwidth;
199 if (kflags & KEXTRA_CONJUGATE_B) {
200 pattern.flags = DBLOCK_COPY_TRANSPOSE | DBLOCK_COPY_CONJUGATE;
201 }
202 else {
203 pattern.flags = DBLOCK_COPY_TRANSPOSE;
204 }
205 pattern.flags |= glcpFlags[1];
206 findGenerateFunction(guard, &pattern,
207 copyFuncs->globalToLocalTransposed[MATRIX_B],
208 FUNC_NAME_MAXLEN);
209 kgenAddBlankLine(ctx);
210
211 pattern.generic = true;
212 pattern.dim.x = 0;
213 pattern.dim.y = 0;
214 findGenerateFunction(guard, &pattern,
215 copyFuncs->globalToLocalTransposedGeneric[MATRIX_B],
216 FUNC_NAME_MAXLEN);
217 kgenAddBlankLine(ctx);
218
219 // generate two local zeroing functions for matrix A and matrix B blocks
220 pattern.zeroing = true;
221 pattern.dim = dims[0];
222 pattern.generic = false;
223 pattern.flags = 0;
224 pattern.dim.y = 1;
225 pattern.dim.x = fl4RowWidth(dims[0].bwidth, tsize) * dims[1].y;
226
227 findGenerateFunction(guard, &pattern,
228 copyFuncs->zeroBlock[MATRIX_A],
229 FUNC_NAME_MAXLEN);
230 kgenAddBlankLine(ctx);
231
232 pattern.dim.x = fl4RowWidth(dims[0].bwidth, tsize) * dims[1].x;
233 findGenerateFunction(guard, &pattern,
234 copyFuncs->zeroBlock[MATRIX_B],
235 FUNC_NAME_MAXLEN);
236 ret = kgenAddBlankLine(ctx);
237
238 destroyKgenGuard(guard);
239 return ret;
240 }
241
242 int
generateResultUpdateOld(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,const char * optFuncName,const char * genericFuncName)243 generateResultUpdateOld(
244 struct KgenContext *ctx,
245 BlasFunctionID funcID,
246 const BlasGenSettings *gset,
247 const char *optFuncName,
248 const char *genericFuncName)
249 {
250 UpdateResultFlags flags;
251
252 flags = kextraToUpresFlags(funcID, gset->kextra->flags);
253
254 return genResultUpdateWithFlagsOld(ctx, funcID, gset, flags,
255 optFuncName, genericFuncName, NULL);
256 }
257
258 int
genResultUpdateWithFlagsOld(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,UpdateResultFlags flags,const char * optFuncName,const char * genericFuncName,const char * cachedName)259 genResultUpdateWithFlagsOld(
260 struct KgenContext *ctx,
261 BlasFunctionID funcID,
262 const BlasGenSettings *gset,
263 UpdateResultFlags flags,
264 const char *optFuncName,
265 const char *genericFuncName,
266 const char *cachedName)
267 {
268 KernelExtraFlags kflags = gset->kextra->flags;
269 UpdateResultOp op;
270 char tmp[1024];
271 int ret = 0;
272 const char *coordY, *coordX;
273 UpresVarNames uvars;
274 const KernelVarNames *kvarNames = &gset->varNames;
275 const SubproblemDim *dim = &gset->subdims[1];
276 bool areTails, useCondition;
277
278 memset(&uvars, 0, sizeof(uvars));
279
280 coordX = kvarNames->coordB;
281 coordY = kvarNames->coordA;
282
283 if (funcHasTriangMatrix(funcID)) {
284 if (flags & UPRES_TRIANG_WRITE_C) {
285 uvars.result = "C";
286 }
287 else {
288 uvars.result = "B";
289 }
290 uvars.ld = "ldb";
291 }
292 else {
293 uvars.result = "C";
294 uvars.ld = "ldc";
295 }
296
297 uvars.cachedName = cachedName;
298
299 /* For now, kernels that do not use UPRES_EXCEED_PROBLEM_CONDITION
300 * must return in case problem exceeds more precise lower level conditions
301 * (KEXTRA_TAILS_M_LOWER, KEXTRA_TAILS_N_LOWER) before updating result
302 */
303 areTails = (kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N));
304 useCondition = areTails && ((flags & UPRES_EXCEED_PROBLEM_CONDITION) != 0);
305 if (useCondition) {
306 bool tailM = (kflags & KEXTRA_TAILS_M) != 0;
307 bool tailN = (kflags & KEXTRA_TAILS_N) != 0;
308
309 if (tailM) {
310 if (tailN) {
311 sprintf(tmp, "if ((%s < %s) && (%s < %s))",
312 coordY, kvarNames->sizeM, coordX, kvarNames->sizeN);
313 }
314 else {
315 sprintf(tmp, "if (%s < %s)", coordY, kvarNames->sizeM);
316 }
317 }
318 else {
319 // here tailN is true
320 sprintf(tmp, "if (%s < %s)", coordX, kvarNames->sizeN);
321 }
322 kgenBeginBranch(ctx, tmp);
323 }
324 else {
325 kgenAddBlankLine(ctx);
326 }
327
328 if (optFuncName) {
329 const char *betaStr;
330 betaStr = (flags & UPRES_WITH_BETA) ? ", beta" : "";
331
332 // update with functions invoking
333 if (!(kflags & (KEXTRA_TAILS_M_LOWER | KEXTRA_TAILS_N_LOWER))) {
334 sprintf(tmp, "%s(%s, c, alpha, %s, %s, %s%s);\n",
335 optFuncName, uvars.result, coordY, coordX,
336 uvars.ld, betaStr);
337 }
338 else {
339 sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
340 "uint x = min(%luu, %s - (uint)%s);\n"
341
342 "if ((y == %lu) && (x == %lu)) {\n"
343 " %s(%s, c, alpha, %s, %s, %s%s);\n"
344 "}\n"
345 "else {\n"
346 " %s(%s, c, alpha, %s, %s, %s%s, y, x);\n"
347 "}\n",
348 dim->y, kvarNames->sizeM, coordY,
349 dim->x, kvarNames->sizeN, coordX,
350 dim->y, dim->x,
351 optFuncName, uvars.result, coordY, coordX, uvars.ld,
352 betaStr,
353 genericFuncName, uvars.result, coordY, coordX, uvars.ld,
354 betaStr);
355 }
356
357 kgenAddStmt(ctx, tmp);
358 }
359 else {
360 // inline result update
361 flags |= UPRES_INLINE;
362
363 op = (flags & UPRES_WITH_BETA) ? UPRES_SUM : UPRES_SET;
364
365 uvars.startRow = coordY;
366 uvars.startCol = coordX;
367 uvars.nrRows = "y";
368 uvars.nrCols = "x";
369
370 if (!(kflags & (KEXTRA_TAILS_M_LOWER | KEXTRA_TAILS_N_LOWER))) {
371 ret = updateResultGenOld(ctx, gset, op, flags, &uvars);
372 }
373 else {
374 sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
375 "uint x = min(%luu, %s - (uint)%s);\n",
376 dim->y, kvarNames->sizeM, coordY,
377 dim->x, kvarNames->sizeN, coordX);
378 kgenAddStmt(ctx, tmp);
379
380 sprintf(tmp, "if ((y == %lu) && (x == %lu))",
381 dim->y, dim->x);
382 kgenBeginBranch(ctx, tmp);
383 // optimized update
384 updateResultGenOld(ctx, gset, op, flags, &uvars);
385 kgenEndBranch(ctx, NULL);
386
387 flags |= UPRES_GENERIC;
388 kgenBeginBranch(ctx, "else ");
389 // not optimized update
390 updateResultGenOld(ctx, gset, op, flags, &uvars);
391 ret = kgenEndBranch(ctx, NULL);
392 }
393 }
394
395 if (useCondition) {
396 ret = kgenEndBranch(ctx, NULL);
397 }
398
399 return (ret) ? -EOVERFLOW : 0;
400 }
401
402 int
genUpresFuncsWithFlags(struct KgenContext * ctx,const BlasGenSettings * gset,UpdateResultFlags flags,char optFuncName[FUNC_NAME_MAXLEN],char genericFuncName[FUNC_NAME_MAXLEN])403 genUpresFuncsWithFlags(
404 struct KgenContext *ctx,
405 const BlasGenSettings *gset,
406 UpdateResultFlags flags,
407 char optFuncName[FUNC_NAME_MAXLEN],
408 char genericFuncName[FUNC_NAME_MAXLEN])
409 {
410 KernelExtraFlags kflags = gset->kextra->flags;
411 UpdateResultOp op;
412 int ret;
413
414 op = (flags & UPRES_WITH_BETA) ? UPRES_SUM : UPRES_SET;
415
416 updateResultGenOld(ctx, gset, op, flags, NULL);
417 ret = kgenAddBlankLine(ctx);
418 if (ret) {
419 return -EOVERFLOW;
420 }
421
422 kgenGetLastFuncName(optFuncName, FUNC_NAME_MAXLEN, ctx);
423
424 if (kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N)) {
425 flags |= UPRES_GENERIC;
426 updateResultGenOld(ctx, gset, op, flags, NULL);
427 kgenAddBlankLine(ctx);
428 kgenGetLastFuncName(genericFuncName, FUNC_NAME_MAXLEN, ctx);
429 }
430
431 return (ret) ? -EOVERFLOW : 0;
432 }
433
434 int
generateUpresFuncs(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,char optFuncName[FUNC_NAME_MAXLEN],char genericFuncName[FUNC_NAME_MAXLEN])435 generateUpresFuncs(
436 struct KgenContext *ctx,
437 BlasFunctionID funcID,
438 const BlasGenSettings *gset,
439 char optFuncName[FUNC_NAME_MAXLEN],
440 char genericFuncName[FUNC_NAME_MAXLEN])
441 {
442 UpdateResultFlags flags;
443
444 flags = kextraToUpresFlags(funcID, gset->kextra->flags);
445
446 return genUpresFuncsWithFlags(ctx, gset, flags,
447 optFuncName, genericFuncName);
448 }
449