1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 #include <stdio.h>
19 #include <clblas_stddef.h>
20 #include "xxmv_common.h"
21
22 static void
genMul(char * buf,size_t val,const char * type,const char * sum,const char * mul)23 genMul(char *buf, size_t val, const char* type, const char* sum, const char* mul)
24 {
25 if (mul == NULL) {
26 if (sum == NULL) {
27 sprintf(buf, "%lu", val);
28 }
29 else {
30 if (val == 0) {
31 sprintf(buf, "%s", sum); //zero length string
32 }
33 else {
34 sprintf(buf, "%s + %lu", sum, val);
35 }
36 }
37 }
38 else {
39 if (sum == NULL) {
40 if (val == 0) {
41 sprintf(buf, "0"); //zero length string
42 }
43 else
44 if (val == 1) {
45 sprintf(buf, "%s",
46 mul); //zero length string
47 }
48 else {
49 sprintf(buf, "mad24((%s)%lu, (%s)%s, (%s)0)",
50 type, val, type, mul, type);
51 //sprintf(buf, "%lu * %s", val, mul);
52 }
53 }
54 else {
55 if (val == 0) {
56 sprintf(buf, "mad24((%s)%s, (%s)%s, (%s)0)",
57 type, sum, type, mul, type); //zero length string
58 //sprintf(buf, "%s * %s", sum, mul);
59 }
60 else {
61 sprintf(buf, "mad24((%s)%s + %lu, (%s)%s, (%s)0)",
62 type, sum, val, type, mul, type);
63 //sprintf(buf, "(%s + %lu) * %s", sum, val, mul);
64 }
65 }
66 }
67 }
68
69
70 void
genFetchX(struct KgenContext * ctx,Tile * tile,unsigned int vecLen,DataType dtype,const KernelVarNames * varNames,TileMulFlags tflags,KernelExtraFlags kflags)71 genFetchX(
72 struct KgenContext *ctx,
73 Tile *tile,
74 unsigned int vecLen,
75 DataType dtype,
76 const KernelVarNames *varNames,
77 TileMulFlags tflags,
78 KernelExtraFlags kflags)
79 {
80 Kstring kstr[1];
81 Tile memtile;
82 char tmp[1024], strMul[128];
83 unsigned int n;
84 const char *ptrName;
85 bool tailN = (tflags & TILEMUL_SKEW_B) != 0;
86 bool incxOne = ((kflags & KEXTRA_INCX_ONE) != 0);
87 bool elemFetch = ((kflags & KEXTRA_NO_COPY_VEC_B) != 0);
88 unsigned int nfetch = !tailN && incxOne && !elemFetch ? vecLen : 1;
89
90 (void)dtype;
91 initTile(&memtile, NULL, tile->nrRows, tile->nrCols, nfetch,
92 tile->dtype, tile->storType, tile->trans, tile->packed);
93 getVectorTypeName(tile->dtype, vecLen, NULL, &ptrName);
94
95 if (!tailN && incxOne && !elemFetch) {
96 sprintf(tmp, "const uint xk = %s / %u;\n", varNames->k, vecLen);
97 kgenAddStmt(ctx, tmp);
98 for (n = 0; forEachTile(kstr, n, 0, 2, tile, &memtile); n++) {
99 sprintf(tmp,"%s = %s.%s[xk + %u];\n",
100 kstr[0].buf, varNames->B, ptrName, n);
101 kgenAddStmt(ctx, tmp);
102 }
103 }
104 else {
105 for (n = 0; forEachTile(kstr, n, 0, 2, tile, &memtile); n++) {
106 genMul(strMul, n, "int", "k", incxOne ? NULL : "incx");
107 if (tailN) {
108 sprintf(tmp,"%s = X[k + %u < %s ? %s : 0];\n",
109 kstr[0].buf, n, varNames->sizeK, strMul);
110 }
111 else {
112 sprintf(tmp,"%s = X[%s];\n",kstr[0].buf, strMul);
113 }
114 kgenAddStmt(ctx, tmp);
115 }
116 }
117
118 if (tailN) {
119 for (n = 0; forEachTile(kstr, n, 0, 2, tile, &memtile); n++) {
120 sprintf(tmp,"%s = k + %u < %s ? %s : 0;\n",
121 kstr[0].buf, n, varNames->sizeK, kstr[0].buf);
122 kgenAddStmt(ctx, tmp);
123 }
124 }
125 }
126
127 void
setResultPos(struct KgenContext * ctx,KernelExtraFlags kflags,const char * axVar)128 setResultPos(
129 struct KgenContext *ctx,
130 KernelExtraFlags kflags,
131 const char *axVar)
132 {
133 bool incyOne = ((kflags & KEXTRA_INCY_ONE) != 0);
134
135 char tmp[2048];
136
137 if (incyOne) {
138 sprintf(tmp, "Y += %s;\n", axVar);
139 }
140 else {
141 sprintf(tmp, "Y += incy * (int)%s;\n", axVar);
142 }
143 kgenAddStmt(ctx, tmp);
144 }
145
146 void
updateResultVectorTiled(struct KgenContext * ctx,KernelExtraFlags kflags,unsigned int vecLen,Tile * tile)147 updateResultVectorTiled(
148 struct KgenContext *ctx,
149 KernelExtraFlags kflags,
150 unsigned int vecLen,
151 Tile *tile)
152 {
153 bool beta0 = ((kflags & KEXTRA_BETA_ZERO) != 0);
154 bool incyOne = ((kflags & KEXTRA_INCY_ONE) != 0);
155 bool tailM = ((kflags & KEXTRA_TAILS_M) != 0);
156 bool isComplex = isComplexType(tile->dtype);
157 unsigned int n, i;
158 const char *outTypeName, *outPtrName;
159 Tile result, memtile;
160
161 char tmp[2048],strMul[256];
162 Kstring kstr[2];
163
164 if (isComplex) {
165 vecLen = 1;
166 }
167 initTile(&result, "r", tile->nrRows, tile->nrCols, tile->nrRows,
168 tile->dtype, tile->storType, true, tile->packed);
169 declareOneTileStorage(ctx, &result);
170
171 memtile = result;
172 memtile.baseName = NULL;
173 memtile.vecLen = !tailM && incyOne ? vecLen : 1;
174 getVectorTypeName(memtile.dtype, memtile.vecLen, &outTypeName, &outPtrName);
175
176 sprintf(tmp,"GPtr uC;\n"
177 "uC.f = Y;\n");
178 kgenAddStmt(ctx, tmp);
179
180 if (!tailM && incyOne) {
181 for (n = 0; forEachTile(kstr, n, 0, 2, &result, &memtile); n++) {
182 sprintf(tmp,"%s = uC.%s[%u];\n",
183 kstr[0].buf, outPtrName, n);
184 kgenAddStmt(ctx, tmp);
185 }
186 }
187 else {
188 for (n = 0; forEachTile(kstr, n, 0, 2, &result, &memtile); n++) {
189 genMul(strMul, n, "int", NULL, incyOne ? NULL : "incy");
190 if (tailM) {
191 sprintf(tmp,"%s = Y[coordA + %u >= M ? 0 : %s];\n",
192 kstr[0].buf, n, strMul);
193 }
194 else {
195 sprintf(tmp,"%s = Y[%s];\n",
196 kstr[0].buf, strMul);
197 }
198 kgenAddStmt(ctx, tmp);
199 }
200 }
201
202 if (isComplex) {
203 const char *complVec =
204 isDoubleBasedType(tile->dtype) ? "double2" : "float2";
205 Tile onetile = result;
206 onetile.baseName = NULL;
207 onetile.vecLen = 1;
208 for (n = 0; forEachTile(kstr, n, 0, 3, &result, tile, &onetile); n++) {
209 if (beta0) {
210 sprintf(tmp,
211 "%s = %s * alpha.x + %s.yx * (%s)(-alpha.y, alpha.y);\n",
212 kstr[0].buf, kstr[1].buf, kstr[1].buf, complVec);
213 }
214 else {
215 sprintf(tmp,
216 "%s = %s * beta.x + %s.yx * (%s)(-beta.y, beta.y) + "
217 "%s * alpha.x + %s.yx * (%s)(-alpha.y, alpha.y);\n",
218 kstr[0].buf, kstr[0].buf, kstr[0].buf, complVec,
219 kstr[1].buf, kstr[1].buf, complVec);
220 }
221 kgenAddStmt(ctx, tmp);
222 }
223 }
224 else {
225 for (n = 0; forEachTile(kstr, n, 0, 2, &result, tile); n++) {
226 if (beta0) {
227 sprintf(tmp, "%s = alpha * %s;\n", kstr[0].buf, kstr[1].buf);
228 }
229 else {
230 sprintf(tmp, "%s = beta * %s + alpha * %s;\n",
231 kstr[0].buf, kstr[0].buf, kstr[1].buf);
232 }
233 kgenAddStmt(ctx, tmp);
234 }
235 }
236
237 if (!tailM && incyOne) {
238 for (i = 0; forEachTile(kstr, i, 0, 2, &result, &memtile); i++) {
239 sprintf(tmp,"uC.%s[%u] = %s;\n",
240 outPtrName, i, kstr[0].buf);
241 kgenAddStmt(ctx, tmp);
242 }
243 }
244 else {
245 if (!tailM) {
246 for (i = 0; forEachTile(kstr, i, 0, 2, &result, &memtile); i++) {
247 sprintf(tmp,"*Y = %s;\n", kstr[0].buf);
248 //sprintf(tmp,"Y[%u * incy] = %s;\n", i, kstr.buf);
249 kgenAddStmt(ctx, tmp);
250 kgenAddStmt(ctx, "Y += incy;\n");
251 }
252 }
253 else {
254 for (n = forEachTile(NULL, 0, 0, 2, &result, &memtile);
255 n != 0; n--) {
256 i = n - 1;
257 forEachTile(kstr, i, 0, 2, &result, &memtile);
258 genMul(strMul, i, "int", NULL, incyOne ? NULL : "incy");
259 sprintf(tmp,"Y[coordA + %u >= M ? 0 : %s] = %s;\n",
260 i, strMul, kstr[0].buf);
261 kgenAddStmt(ctx, tmp);
262 }
263 }
264 }
265 }
266
267 void
genIncPointers(struct KgenContext * ctx,KernelExtraFlags kflags)268 genIncPointers(
269 struct KgenContext *ctx,
270 KernelExtraFlags kflags)
271 {
272 bool incxOne = ((kflags & KEXTRA_INCX_ONE) != 0);
273 bool incyOne = ((kflags & KEXTRA_INCY_ONE) != 0);
274
275 if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
276 kgenAddStmt(ctx, "A += offA;\n");
277 }
278 if (kflags & KEXTRA_BX_OFF_NOT_ZERO) {
279 kgenAddStmt(ctx, "X += offX;\n");
280 }
281 if (kflags & KEXTRA_CY_OFF_NOT_ZERO) {
282 kgenAddStmt(ctx, "Y += offY;\n");
283 }
284
285 if (!incxOne) {
286 kgenAddStmt(ctx, "X += incx > 0 ? 0 : (N - 1) * abs(incx);\n");
287 }
288 if (!incyOne) {
289 kgenAddStmt(ctx, "Y += incy > 0 ? 0 : (M - 1) * abs(incy);\n");
290 }
291 }
292
293 void
genStoreLocalResult(struct KgenContext * ctx,Tile * tile,const char * lid)294 genStoreLocalResult(
295 struct KgenContext *ctx,
296 Tile *tile,
297 const char *lid)
298 {
299 Kstring kstr;
300 char tmp[1024];
301 unsigned int i;
302
303 for (i = 0; forEachTile(&kstr, i, 0, 1, tile); i++) {
304 sprintf(tmp, "localRes[%s][%u] = %s;\n", lid, i, kstr.buf);
305 kgenAddStmt(ctx, tmp);
306 }
307 }
308
309 void
genAddLocalResult(struct KgenContext * ctx,Tile * tile,const char * lid,unsigned int cLocal,unsigned int bStep)310 genAddLocalResult(
311 struct KgenContext *ctx,
312 Tile *tile,
313 const char *lid,
314 unsigned int cLocal,
315 unsigned int bStep)
316 {
317 Kstring kstr;
318 char tmp[1024];
319 unsigned int i;
320
321 sprintf(tmp, "for (uint i = 1; i < %u; i++)", cLocal);
322 kgenBeginBranch(ctx, tmp);
323 for (i = 0; forEachTile(&kstr, i, 0, 1, tile); i++) {
324 sprintf(tmp, "%s += localRes[%s + i*%u][%u];\n",
325 kstr.buf, lid, bStep, i);
326 kgenAddStmt(ctx, tmp);
327 }
328 kgenEndBranch(ctx, NULL);
329 }
330
331 void
genMergeResults(struct KgenContext * ctx,Tile * result,Tile * source)332 genMergeResults(
333 struct KgenContext *ctx,
334 Tile *result,
335 Tile *source)
336 {
337 unsigned int i;
338 Kstring kstr[2];
339 char tmp[2048];
340
341 for (i = 0; forEachTile(kstr, i, 0, 2, result, source); i++) {
342 sprintf(tmp, "%s += %s;\n", kstr[0].buf, kstr[1].buf);
343 kgenAddStmt(ctx, tmp);
344 }
345 }
346
347