1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <stdio.h>
19 #include <clblas_stddef.h>
20 #include "xxmv_common.h"
21 
22 static void
genMul(char * buf,size_t val,const char * type,const char * sum,const char * mul)23 genMul(char *buf, size_t val, const char* type, const char* sum, const char* mul)
24 {
25     if (mul == NULL) {
26         if (sum == NULL) {
27             sprintf(buf, "%lu", val);
28         }
29         else {
30             if (val == 0) {
31                 sprintf(buf, "%s", sum); //zero length string
32             }
33             else {
34                 sprintf(buf, "%s + %lu", sum, val);
35             }
36         }
37     }
38     else {
39         if (sum == NULL) {
40             if (val == 0) {
41                 sprintf(buf, "0"); //zero length string
42             }
43             else
44             if (val == 1) {
45                 sprintf(buf, "%s",
46                     mul); //zero length string
47             }
48             else {
49                 sprintf(buf, "mad24((%s)%lu, (%s)%s, (%s)0)",
50                     type, val, type, mul, type);
51                 //sprintf(buf, "%lu * %s", val, mul);
52             }
53         }
54         else {
55             if (val == 0) {
56                 sprintf(buf, "mad24((%s)%s, (%s)%s, (%s)0)",
57                     type, sum, type, mul, type); //zero length string
58                 //sprintf(buf, "%s * %s", sum, mul);
59             }
60             else {
61                 sprintf(buf, "mad24((%s)%s + %lu, (%s)%s, (%s)0)",
62                     type, sum, val, type, mul, type);
63                 //sprintf(buf, "(%s + %lu) * %s", sum, val, mul);
64             }
65         }
66     }
67 }
68 
69 
70 void
genFetchX(struct KgenContext * ctx,Tile * tile,unsigned int vecLen,DataType dtype,const KernelVarNames * varNames,TileMulFlags tflags,KernelExtraFlags kflags)71 genFetchX(
72     struct KgenContext *ctx,
73     Tile *tile,
74     unsigned int vecLen,
75     DataType dtype,
76     const KernelVarNames *varNames,
77     TileMulFlags tflags,
78     KernelExtraFlags kflags)
79 {
80     Kstring kstr[1];
81     Tile memtile;
82     char tmp[1024], strMul[128];
83     unsigned int n;
84     const char *ptrName;
85     bool tailN = (tflags & TILEMUL_SKEW_B) != 0;
86     bool incxOne = ((kflags & KEXTRA_INCX_ONE) != 0);
87     bool elemFetch = ((kflags & KEXTRA_NO_COPY_VEC_B) != 0);
88     unsigned int nfetch = !tailN && incxOne && !elemFetch ? vecLen : 1;
89 
90     (void)dtype;
91     initTile(&memtile, NULL, tile->nrRows, tile->nrCols, nfetch,
92              tile->dtype, tile->storType,  tile->trans, tile->packed);
93     getVectorTypeName(tile->dtype, vecLen, NULL, &ptrName);
94 
95     if (!tailN && incxOne && !elemFetch) {
96         sprintf(tmp, "const uint xk = %s / %u;\n", varNames->k, vecLen);
97         kgenAddStmt(ctx, tmp);
98         for (n = 0; forEachTile(kstr, n, 0, 2, tile, &memtile); n++) {
99             sprintf(tmp,"%s = %s.%s[xk + %u];\n",
100                         kstr[0].buf, varNames->B, ptrName, n);
101             kgenAddStmt(ctx, tmp);
102         }
103     }
104     else {
105         for (n = 0; forEachTile(kstr, n, 0, 2, tile, &memtile); n++) {
106             genMul(strMul, n, "int", "k", incxOne ? NULL : "incx");
107             if (tailN) {
108                 sprintf(tmp,"%s = X[k + %u < %s ? %s : 0];\n",
109                  kstr[0].buf, n, varNames->sizeK, strMul);
110             }
111             else {
112                 sprintf(tmp,"%s = X[%s];\n",kstr[0].buf, strMul);
113             }
114             kgenAddStmt(ctx, tmp);
115         }
116     }
117 
118     if (tailN) {
119         for (n = 0; forEachTile(kstr, n, 0, 2, tile, &memtile); n++) {
120             sprintf(tmp,"%s = k + %u < %s ? %s : 0;\n",
121                         kstr[0].buf, n, varNames->sizeK, kstr[0].buf);
122             kgenAddStmt(ctx, tmp);
123         }
124     }
125 }
126 
127 void
setResultPos(struct KgenContext * ctx,KernelExtraFlags kflags,const char * axVar)128 setResultPos(
129     struct KgenContext *ctx,
130     KernelExtraFlags kflags,
131     const char *axVar)
132 {
133     bool incyOne = ((kflags & KEXTRA_INCY_ONE) != 0);
134 
135     char tmp[2048];
136 
137     if (incyOne) {
138         sprintf(tmp, "Y += %s;\n", axVar);
139     }
140     else {
141         sprintf(tmp, "Y += incy * (int)%s;\n", axVar);
142     }
143     kgenAddStmt(ctx, tmp);
144 }
145 
146 void
updateResultVectorTiled(struct KgenContext * ctx,KernelExtraFlags kflags,unsigned int vecLen,Tile * tile)147 updateResultVectorTiled(
148     struct KgenContext *ctx,
149     KernelExtraFlags kflags,
150     unsigned int vecLen,
151     Tile *tile)
152 {
153     bool beta0 = ((kflags & KEXTRA_BETA_ZERO) != 0);
154     bool incyOne = ((kflags & KEXTRA_INCY_ONE) != 0);
155     bool tailM = ((kflags & KEXTRA_TAILS_M) != 0);
156     bool isComplex = isComplexType(tile->dtype);
157     unsigned int n, i;
158     const char *outTypeName, *outPtrName;
159     Tile result, memtile;
160 
161     char tmp[2048],strMul[256];
162     Kstring kstr[2];
163 
164     if (isComplex) {
165         vecLen = 1;
166     }
167     initTile(&result, "r", tile->nrRows, tile->nrCols, tile->nrRows,
168                     tile->dtype, tile->storType, true, tile->packed);
169     declareOneTileStorage(ctx, &result);
170 
171     memtile = result;
172     memtile.baseName = NULL;
173     memtile.vecLen = !tailM && incyOne ? vecLen : 1;
174     getVectorTypeName(memtile.dtype, memtile.vecLen, &outTypeName, &outPtrName);
175 
176     sprintf(tmp,"GPtr uC;\n"
177                 "uC.f = Y;\n");
178     kgenAddStmt(ctx, tmp);
179 
180     if (!tailM && incyOne) {
181         for (n = 0; forEachTile(kstr, n, 0, 2, &result, &memtile); n++) {
182             sprintf(tmp,"%s = uC.%s[%u];\n",
183                         kstr[0].buf, outPtrName, n);
184             kgenAddStmt(ctx, tmp);
185         }
186     }
187     else {
188         for (n = 0; forEachTile(kstr, n, 0, 2, &result, &memtile); n++) {
189             genMul(strMul, n, "int", NULL, incyOne ? NULL : "incy");
190             if (tailM) {
191                 sprintf(tmp,"%s = Y[coordA + %u >= M ? 0 : %s];\n",
192                         kstr[0].buf, n, strMul);
193             }
194             else {
195                 sprintf(tmp,"%s = Y[%s];\n",
196                         kstr[0].buf, strMul);
197             }
198             kgenAddStmt(ctx, tmp);
199         }
200     }
201 
202     if (isComplex) {
203         const char *complVec =
204                     isDoubleBasedType(tile->dtype) ? "double2" : "float2";
205         Tile onetile = result;
206         onetile.baseName = NULL;
207         onetile.vecLen = 1;
208         for (n = 0; forEachTile(kstr, n, 0, 3, &result, tile, &onetile); n++) {
209             if (beta0) {
210                 sprintf(tmp,
211                        "%s = %s * alpha.x + %s.yx * (%s)(-alpha.y, alpha.y);\n",
212                        kstr[0].buf, kstr[1].buf, kstr[1].buf, complVec);
213             }
214             else {
215                 sprintf(tmp,
216                         "%s = %s * beta.x + %s.yx * (%s)(-beta.y, beta.y) + "
217                         "%s * alpha.x + %s.yx * (%s)(-alpha.y, alpha.y);\n",
218                         kstr[0].buf, kstr[0].buf, kstr[0].buf, complVec,
219                         kstr[1].buf, kstr[1].buf, complVec);
220             }
221             kgenAddStmt(ctx, tmp);
222         }
223     }
224     else {
225         for (n = 0; forEachTile(kstr, n, 0, 2, &result, tile); n++) {
226             if (beta0) {
227                 sprintf(tmp, "%s = alpha * %s;\n", kstr[0].buf, kstr[1].buf);
228             }
229             else {
230                 sprintf(tmp, "%s = beta * %s + alpha * %s;\n",
231                              kstr[0].buf, kstr[0].buf, kstr[1].buf);
232             }
233             kgenAddStmt(ctx, tmp);
234         }
235     }
236 
237     if (!tailM && incyOne) {
238         for (i = 0; forEachTile(kstr, i, 0, 2, &result, &memtile); i++) {
239             sprintf(tmp,"uC.%s[%u] = %s;\n",
240                         outPtrName, i, kstr[0].buf);
241             kgenAddStmt(ctx, tmp);
242         }
243     }
244     else {
245         if (!tailM) {
246             for (i = 0; forEachTile(kstr, i, 0, 2, &result, &memtile); i++) {
247                 sprintf(tmp,"*Y = %s;\n", kstr[0].buf);
248                 //sprintf(tmp,"Y[%u * incy] = %s;\n", i, kstr.buf);
249                 kgenAddStmt(ctx, tmp);
250                 kgenAddStmt(ctx, "Y += incy;\n");
251             }
252         }
253         else {
254             for (n = forEachTile(NULL, 0, 0, 2, &result, &memtile);
255                      n != 0; n--) {
256                 i = n - 1;
257                 forEachTile(kstr, i, 0, 2, &result, &memtile);
258                 genMul(strMul, i, "int", NULL, incyOne ? NULL : "incy");
259                 sprintf(tmp,"Y[coordA + %u >= M ? 0 : %s] = %s;\n",
260                         i, strMul, kstr[0].buf);
261                 kgenAddStmt(ctx, tmp);
262             }
263         }
264     }
265 }
266 
267 void
genIncPointers(struct KgenContext * ctx,KernelExtraFlags kflags)268 genIncPointers(
269     struct KgenContext *ctx,
270     KernelExtraFlags kflags)
271 {
272     bool incxOne = ((kflags & KEXTRA_INCX_ONE) != 0);
273     bool incyOne = ((kflags & KEXTRA_INCY_ONE) != 0);
274 
275     if (kflags & KEXTRA_A_OFF_NOT_ZERO) {
276         kgenAddStmt(ctx, "A += offA;\n");
277     }
278     if (kflags & KEXTRA_BX_OFF_NOT_ZERO) {
279         kgenAddStmt(ctx, "X += offX;\n");
280     }
281     if (kflags & KEXTRA_CY_OFF_NOT_ZERO) {
282         kgenAddStmt(ctx, "Y += offY;\n");
283     }
284 
285     if (!incxOne) {
286         kgenAddStmt(ctx, "X += incx > 0 ? 0 : (N - 1) * abs(incx);\n");
287     }
288     if (!incyOne) {
289         kgenAddStmt(ctx, "Y += incy > 0 ? 0 : (M - 1) * abs(incy);\n");
290     }
291 }
292 
293 void
genStoreLocalResult(struct KgenContext * ctx,Tile * tile,const char * lid)294 genStoreLocalResult(
295     struct KgenContext *ctx,
296     Tile *tile,
297     const char *lid)
298 {
299     Kstring kstr;
300     char tmp[1024];
301     unsigned int i;
302 
303     for (i = 0; forEachTile(&kstr, i, 0, 1, tile); i++) {
304         sprintf(tmp, "localRes[%s][%u] = %s;\n", lid, i, kstr.buf);
305         kgenAddStmt(ctx, tmp);
306     }
307 }
308 
309 void
genAddLocalResult(struct KgenContext * ctx,Tile * tile,const char * lid,unsigned int cLocal,unsigned int bStep)310 genAddLocalResult(
311     struct KgenContext *ctx,
312     Tile *tile,
313     const char *lid,
314     unsigned int cLocal,
315     unsigned int bStep)
316 {
317     Kstring kstr;
318     char tmp[1024];
319     unsigned int i;
320 
321     sprintf(tmp, "for (uint i = 1; i < %u; i++)", cLocal);
322     kgenBeginBranch(ctx, tmp);
323     for (i = 0; forEachTile(&kstr, i, 0, 1, tile); i++) {
324         sprintf(tmp, "%s += localRes[%s + i*%u][%u];\n",
325                      kstr.buf, lid, bStep, i);
326         kgenAddStmt(ctx, tmp);
327     }
328     kgenEndBranch(ctx, NULL);
329 }
330 
331 void
genMergeResults(struct KgenContext * ctx,Tile * result,Tile * source)332 genMergeResults(
333     struct KgenContext *ctx,
334     Tile *result,
335     Tile *source)
336 {
337     unsigned int i;
338     Kstring kstr[2];
339     char tmp[2048];
340 
341     for (i = 0; forEachTile(kstr, i, 0, 2, result, source); i++) {
342         sprintf(tmp, "%s += %s;\n", kstr[0].buf, kstr[1].buf);
343         kgenAddStmt(ctx, tmp);
344     }
345 }
346 
347