1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <stdlib.h>
19 #include <stdio.h>
20 #include <string.h>
21 #include <ctype.h>
22 #include <stdarg.h>
23 #include <assert.h>
24 
25 #include <kerngen.h>
26 #include <mempat.h>
27 
28 const char *uptrsFullDeclaration =
29     "#ifdef cl_khr_fp64\n"
30     "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
31     "#else\n"
32     "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
33     "#endif\n"
34     "\n"
35     "typedef union GPtr {\n"
36     "    __global float *f;\n"
37     "    __global double *d;\n"
38     "    __global float2 *f2v;\n"
39     "    __global double2 *d2v;\n"
40     "    __global float4 *f4v;\n"
41     "    __global double4 *d4v;\n"
42     "    __global float8 *f8v;\n"
43     "    __global double8 *d8v;\n"
44     "    __global float16 *f16v;\n"
45     "    __global double16 *d16v;\n"
46     "} GPtr;\n"
47     "\n"
48     "typedef union LPtr {\n"
49     "    __local float *f;\n"
50     "    __local double *d;\n"
51     "    __local float2 *f2v;\n"
52     "    __local double2 *d2v;\n"
53     "    __local float4 *f4v;\n"
54     "    __local double4 *d4v;\n"
55     "    __local float8 *f8v;\n"
56     "    __local double8 *d8v;\n"
57     "    __local float16 *f16v;\n"
58     "    __local double16 *d16v;\n"
59     "} LPtr;\n"
60     "\n"
61     "typedef union PPtr {\n"
62     "    float *f;\n"
63     "    double *d;\n"
64     "    float2 *f2v;\n"
65     "    double2 *d2v;\n"
66     "    float4 *f4v;\n"
67     "    double4 *d4v;\n"
68     "    float8 *f8v;\n"
69     "    double8 *d8v;\n"
70     "    float16 *f16v;\n"
71     "    double16 *d16v;\n"
72     "} PPtr;\n\n";
73 
74 const char *uptrsSingleDeclaration =
75     "typedef union GPtr {\n"
76     "    __global float *f;\n"
77     "    __global float2 *f2v;\n"
78     "    __global float4 *f4v;\n"
79     "    __global float8 *f8v;\n"
80     "    __global float16 *f16v;\n"
81     "} GPtr;\n"
82     "\n"
83     "typedef union LPtr {\n"
84     "    __local float *f;\n"
85     "    __local float2 *f2v;\n"
86     "    __local float4 *f4v;\n"
87     "    __local float8 *f8v;\n"
88     "    __local float16 *f16v;\n"
89     "} LPtr;\n"
90     "\n"
91     "typedef union PPtr {\n"
92     "    float *f;\n"
93     "    float2 *f2v;\n"
94     "    float4 *f4v;\n"
95     "    float8 *f8v;\n"
96     "    float16 *f16v;\n"
97     "} PPtr;\n\n";
98 
99 const char
uptrTypeName(UptrType type)100 *uptrTypeName(UptrType type)
101 {
102     const char *s = NULL;
103 
104     switch(type) {
105     case UPTR_GLOBAL:
106         s = "GPtr";
107         break;
108     case UPTR_LOCAL:
109         s = "LPtr";
110         break;
111     case UPTR_PRIVATE:
112         s = "PPtr";
113         break;
114     }
115 
116     return s;
117 }
118 
119 char
dtypeToPrefix(DataType type)120 dtypeToPrefix(DataType type)
121 {
122     char c;
123 
124     switch (type) {
125     case TYPE_FLOAT:
126         c = 'f';
127         break;
128     case TYPE_DOUBLE:
129         c = 'd';
130         break;
131     case TYPE_COMPLEX_FLOAT:
132         c = 'c';
133         break;
134     case TYPE_COMPLEX_DOUBLE:
135         c = 'z';
136         break;
137     default:
138         c = 0;
139         break;
140     }
141 
142     return c;
143 }
144 
145 const char
dtypeBuiltinType(DataType dtype)146 *dtypeBuiltinType(DataType dtype)
147 {
148     const char *s;
149 
150     switch (dtype) {
151     case TYPE_FLOAT:
152         s = "float";
153         break;
154     case TYPE_DOUBLE:
155         s = "double";
156         break;
157     case TYPE_COMPLEX_FLOAT:
158         s = "float2";
159         break;
160     case TYPE_COMPLEX_DOUBLE:
161         s = "double2";
162         break;
163     default:
164         s = NULL;
165         break;
166     }
167 
168     return s;
169 }
170 
171 const char
dtypeUPtrField(DataType dtype)172 *dtypeUPtrField(DataType dtype)
173 {
174     const char *s;
175 
176     switch (dtype) {
177     case TYPE_FLOAT:
178         s = "f";
179         break;
180     case TYPE_DOUBLE:
181         s = "d";
182         break;
183     case TYPE_COMPLEX_FLOAT:
184         s = "f2v";
185         break;
186     case TYPE_COMPLEX_DOUBLE:
187         s = "d2v";
188         break;
189     default:
190         s = NULL;
191         break;
192     }
193 
194     return s;
195 }
196 
197 const char
strOne(DataType dtype)198 *strOne(DataType dtype)
199 {
200     const char *s;
201 
202     if (isComplexType(dtype)) {
203         if (isDoubleBasedType(dtype)) {
204             s = "(double2)(1, 0)";
205         }
206         else {
207             s = "(float2)(1, 0)";
208         }
209     }
210     else {
211         s = "1";
212     }
213 
214     return s;
215 }
216 
217 void
getVectorTypeName(DataType dtype,unsigned int vecLen,const char ** typeName,const char ** typePtrName)218 getVectorTypeName(
219     DataType dtype,
220     unsigned int vecLen,
221     const char **typeName,
222     const char **typePtrName)
223 {
224     char *tn = "";
225     char *tpn = "";
226 
227     if (isDoubleBasedType(dtype)) {
228         switch (vecLen * dtypeSize(dtype)) {
229         case sizeof(cl_double):
230             tn = "double";
231             tpn = "d";
232             break;
233         case sizeof(cl_double2):
234             tn = "double2";
235             tpn = "d2v";
236             break;
237         case sizeof(cl_double4):
238             tn = "double4";
239             tpn = "d4v";
240             break;
241         case sizeof(cl_double8):
242             tn = "double8";
243             tpn = "d8v";
244             break;
245         case sizeof(cl_double16):
246             tn = "double16";
247             tpn = "d16v";
248             break;
249         };
250     }
251     else {
252         switch (vecLen * dtypeSize(dtype)) {
253         case sizeof(cl_float):
254             tn = "float";
255             tpn = "f";
256             break;
257         case sizeof(cl_float2):
258             tn = "float2";
259             tpn = "f2v";
260             break;
261         case sizeof(cl_float4):
262             tn = "float4";
263             tpn = "f4v";
264             break;
265         case sizeof(cl_float8):
266             tn = "float8";
267             tpn = "f8v";
268             break;
269         case sizeof(cl_float16):
270             tn = "float16";
271             tpn = "f16v";
272             break;
273         };
274     }
275     if (typeName != NULL) {
276         *typeName = tn;
277     }
278     if (typePtrName != NULL) {
279         *typePtrName = tpn;
280     }
281 }
282 
283 int
kgenAddBarrier(struct KgenContext * ctx,CLMemFence fence)284 kgenAddBarrier(
285     struct KgenContext *ctx,
286     CLMemFence fence)
287 {
288     int ret;
289 
290     if (fence == CLK_LOCAL_MEM_FENCE) {
291         ret = kgenAddStmt(ctx, "barrier(CLK_LOCAL_MEM_FENCE);\n");
292     }
293     else {
294         ret = kgenAddStmt(ctx, "barrier(CLK_GLOBAL_MEM_FENCE);\n");
295     }
296     if (ret) {
297         ret = -EOVERFLOW;
298     }
299 
300     return ret;
301 }
302 
303 int
kgenAddMemFence(struct KgenContext * ctx,CLMemFence fence)304 kgenAddMemFence(
305     struct KgenContext *ctx,
306     CLMemFence fence)
307 {
308     int ret;
309 
310     if (fence == CLK_LOCAL_MEM_FENCE) {
311         ret = kgenAddStmt(ctx, "mem_fence(CLK_LOCAL_MEM_FENCE);\n");
312     }
313     else {
314         ret = kgenAddStmt(ctx, "mem_fence(CLK_GLOBAL_MEM_FENCE);\n");
315     }
316     if (ret) {
317         ret = -EOVERFLOW;
318     }
319 
320     return ret;
321 }
322 
323 int
kgenDeclareLocalID(struct KgenContext * ctx,const char * lidName,const PGranularity * pgran)324 kgenDeclareLocalID(
325     struct KgenContext *ctx,
326     const char *lidName,
327     const PGranularity *pgran)
328 {
329     char tmp[128];
330     int r;
331 
332     if (pgran->wgDim == 1) {
333         sprintf(tmp, "const int %s = get_local_id(0);\n", lidName);
334     }
335     else {
336         sprintf(tmp, "const int %s = get_local_id(1) * %u + "
337                      "get_local_id(0);\n",
338                 lidName, pgran->wgSize[0]);
339     }
340 
341     r = kgenAddStmt(ctx, tmp);
342 
343     return (r) ? -EOVERFLOW : 0;
344 }
345 
346 int
kgenDeclareGroupID(struct KgenContext * ctx,const char * gidName,const PGranularity * pgran)347 kgenDeclareGroupID(
348     struct KgenContext *ctx,
349     const char *gidName,
350     const PGranularity *pgran)
351 {
352     char tmp[128];
353     int r;
354 
355     if (pgran->wgDim == 1) {
356         sprintf(tmp, "const int %s = get_global_id(0) / %u;\n",
357                 gidName, pgran->wgSize[0]);
358     }
359     else {
360         sprintf(tmp, "const int %s = (get_global_id(1) / %u) * "
361                      "(get_global_size(0) / %u) + "
362                      "get_global_id(0) / %u;\n",
363                      gidName, pgran->wgSize[1], pgran->wgSize[0],
364                      pgran->wgSize[0]);
365     }
366 
367     r = kgenAddStmt(ctx, tmp);
368 
369     return (r) ? -EOVERFLOW : 0;
370 }
371 
372 int
kgenDeclareUptrs(struct KgenContext * ctx,bool withDouble)373 kgenDeclareUptrs(struct KgenContext *ctx, bool withDouble)
374 {
375     int ret;
376     const char *s;
377 
378     s = (withDouble) ? uptrsFullDeclaration : uptrsSingleDeclaration;
379     ret = kgenAddStmt(ctx, s);
380 
381     return ret ? -EOVERFLOW: 0;
382 }
383 
384 void
kstrcpy(Kstring * kstr,const char * str)385 kstrcpy(Kstring *kstr, const char *str)
386 {
387     const int lastByte = sizeof(kstr->buf) - 1;
388 
389     kstr->buf[lastByte] = '\0';
390     strncpy(kstr->buf, str, sizeof(kstr->buf));
391     assert(kstr->buf[lastByte] == '\0');
392 }
393 
394 void
ksprintf(Kstring * kstr,const char * fmt,...)395 ksprintf(Kstring *kstr, const char *fmt,...)
396 {
397     va_list ap;
398     int len;
399 
400     va_start(ap, fmt);
401     len = vsnprintf(kstr->buf, sizeof(kstr->buf), fmt, ap);
402     va_end(ap);
403 
404     // to mute GCC with its warning regarding set but unused variables
405 #ifdef NDEBUG
406     (void)len;
407 #endif
408 
409     assert((size_t)len < sizeof(kstr->buf));
410 }
411 
412 void
kstrcatf(Kstring * kstr,const char * fmt,...)413 kstrcatf(Kstring *kstr, const char *fmt,...)
414 {
415     va_list ap;
416     int len, maxlen;
417 
418     va_start(ap, fmt);
419     len = (int)strlen(kstr->buf);
420     maxlen = sizeof(kstr->buf) - len;
421     len = vsnprintf(kstr->buf + len, maxlen, fmt, ap);
422     va_end(ap);
423 
424     assert(len < maxlen);
425 }
426 
427 
428