1 /*
2  * Copyright © 2018 Red Hat
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Rob Clark (robdclark@gmail.com)
25  */
26 
27 #include "math.h"
28 #include "nir/nir_builtin_builder.h"
29 
30 #include "vtn_private.h"
31 #include "OpenCL.std.h"
32 
33 typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
34                                     enum OpenCLstd_Entrypoints opcode,
35                                     unsigned num_srcs, nir_ssa_def **srcs,
36                                     const struct glsl_type *dest_type);
37 
38 static void
handle_instr(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count,nir_handler handler)39 handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
40              const uint32_t *w, unsigned count, nir_handler handler)
41 {
42    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
43 
44    unsigned num_srcs = count - 5;
45    nir_ssa_def *srcs[3] = { NULL };
46    vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
47    for (unsigned i = 0; i < num_srcs; i++) {
48       srcs[i] = vtn_get_nir_ssa(b, w[i + 5]);
49    }
50 
51    nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type);
52    if (result) {
53       vtn_push_nir_ssa(b, w[2], result);
54    } else {
55       vtn_assert(dest_type == glsl_void_type());
56    }
57 }
58 
59 static nir_op
nir_alu_op_for_opencl_opcode(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode)60 nir_alu_op_for_opencl_opcode(struct vtn_builder *b,
61                              enum OpenCLstd_Entrypoints opcode)
62 {
63    switch (opcode) {
64    case OpenCLstd_Fabs: return nir_op_fabs;
65    case OpenCLstd_SAbs: return nir_op_iabs;
66    case OpenCLstd_SAdd_sat: return nir_op_iadd_sat;
67    case OpenCLstd_UAdd_sat: return nir_op_uadd_sat;
68    case OpenCLstd_Ceil: return nir_op_fceil;
69    case OpenCLstd_Cos: return nir_op_fcos;
70    case OpenCLstd_Exp2: return nir_op_fexp2;
71    case OpenCLstd_Log2: return nir_op_flog2;
72    case OpenCLstd_Floor: return nir_op_ffloor;
73    case OpenCLstd_SHadd: return nir_op_ihadd;
74    case OpenCLstd_UHadd: return nir_op_uhadd;
75    case OpenCLstd_Fma: return nir_op_ffma;
76    case OpenCLstd_Fmax: return nir_op_fmax;
77    case OpenCLstd_SMax: return nir_op_imax;
78    case OpenCLstd_UMax: return nir_op_umax;
79    case OpenCLstd_Fmin: return nir_op_fmin;
80    case OpenCLstd_SMin: return nir_op_imin;
81    case OpenCLstd_UMin: return nir_op_umin;
82    case OpenCLstd_Fmod: return nir_op_fmod;
83    case OpenCLstd_Mix: return nir_op_flrp;
84    case OpenCLstd_Native_cos: return nir_op_fcos;
85    case OpenCLstd_Native_divide: return nir_op_fdiv;
86    case OpenCLstd_Native_exp2: return nir_op_fexp2;
87    case OpenCLstd_Native_log2: return nir_op_flog2;
88    case OpenCLstd_Native_powr: return nir_op_fpow;
89    case OpenCLstd_Native_recip: return nir_op_frcp;
90    case OpenCLstd_Native_rsqrt: return nir_op_frsq;
91    case OpenCLstd_Native_sin: return nir_op_fsin;
92    case OpenCLstd_Native_sqrt: return nir_op_fsqrt;
93    case OpenCLstd_SMul_hi: return nir_op_imul_high;
94    case OpenCLstd_UMul_hi: return nir_op_umul_high;
95    case OpenCLstd_Popcount: return nir_op_bit_count;
96    case OpenCLstd_Pow: return nir_op_fpow;
97    case OpenCLstd_Remainder: return nir_op_frem;
98    case OpenCLstd_SRhadd: return nir_op_irhadd;
99    case OpenCLstd_URhadd: return nir_op_urhadd;
100    case OpenCLstd_Rsqrt: return nir_op_frsq;
101    case OpenCLstd_Sign: return nir_op_fsign;
102    case OpenCLstd_Sin: return nir_op_fsin;
103    case OpenCLstd_Sqrt: return nir_op_fsqrt;
104    case OpenCLstd_SSub_sat: return nir_op_isub_sat;
105    case OpenCLstd_USub_sat: return nir_op_usub_sat;
106    case OpenCLstd_Trunc: return nir_op_ftrunc;
107    case OpenCLstd_Rint: return nir_op_fround_even;
108    /* uhm... */
109    case OpenCLstd_UAbs: return nir_op_mov;
110    default:
111       vtn_fail("No NIR equivalent");
112    }
113 }
114 
115 static nir_ssa_def *
handle_alu(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,unsigned num_srcs,nir_ssa_def ** srcs,const struct glsl_type * dest_type)116 handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
117            unsigned num_srcs, nir_ssa_def **srcs,
118            const struct glsl_type *dest_type)
119 {
120    return nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode),
121                         srcs[0], srcs[1], srcs[2], NULL);
122 }
123 
124 static nir_ssa_def *
handle_special(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,unsigned num_srcs,nir_ssa_def ** srcs,const struct glsl_type * dest_type)125 handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
126                unsigned num_srcs, nir_ssa_def **srcs,
127                const struct glsl_type *dest_type)
128 {
129    nir_builder *nb = &b->nb;
130 
131    switch (opcode) {
132    case OpenCLstd_SAbs_diff:
133       return nir_iabs_diff(nb, srcs[0], srcs[1]);
134    case OpenCLstd_UAbs_diff:
135       return nir_uabs_diff(nb, srcs[0], srcs[1]);
136    case OpenCLstd_Bitselect:
137       return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
138    case OpenCLstd_SMad_hi:
139       return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
140    case OpenCLstd_UMad_hi:
141       return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
142    case OpenCLstd_SMul24:
143       return nir_imul24(nb, srcs[0], srcs[1]);
144    case OpenCLstd_UMul24:
145       return nir_umul24(nb, srcs[0], srcs[1]);
146    case OpenCLstd_SMad24:
147       return nir_imad24(nb, srcs[0], srcs[1], srcs[2]);
148    case OpenCLstd_UMad24:
149       return nir_umad24(nb, srcs[0], srcs[1], srcs[2]);
150    case OpenCLstd_FClamp:
151       return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
152    case OpenCLstd_SClamp:
153       return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
154    case OpenCLstd_UClamp:
155       return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
156    case OpenCLstd_Copysign:
157       return nir_copysign(nb, srcs[0], srcs[1]);
158    case OpenCLstd_Cross:
159       if (glsl_get_components(dest_type) == 4)
160          return nir_cross4(nb, srcs[0], srcs[1]);
161       return nir_cross3(nb, srcs[0], srcs[1]);
162    case OpenCLstd_Degrees:
163       return nir_degrees(nb, srcs[0]);
164    case OpenCLstd_Fdim:
165       return nir_fdim(nb, srcs[0], srcs[1]);
166    case OpenCLstd_Distance:
167       return nir_distance(nb, srcs[0], srcs[1]);
168    case OpenCLstd_Fast_distance:
169       return nir_fast_distance(nb, srcs[0], srcs[1]);
170    case OpenCLstd_Fast_length:
171       return nir_fast_length(nb, srcs[0]);
172    case OpenCLstd_Fast_normalize:
173       return nir_fast_normalize(nb, srcs[0]);
174    case OpenCLstd_Length:
175       return nir_length(nb, srcs[0]);
176    case OpenCLstd_Mad:
177       return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
178    case OpenCLstd_Maxmag:
179       return nir_maxmag(nb, srcs[0], srcs[1]);
180    case OpenCLstd_Minmag:
181       return nir_minmag(nb, srcs[0], srcs[1]);
182    case OpenCLstd_Nan:
183       return nir_nan(nb, srcs[0]);
184    case OpenCLstd_Nextafter:
185       return nir_nextafter(nb, srcs[0], srcs[1]);
186    case OpenCLstd_Normalize:
187       return nir_normalize(nb, srcs[0]);
188    case OpenCLstd_Radians:
189       return nir_radians(nb, srcs[0]);
190    case OpenCLstd_Rotate:
191       return nir_rotate(nb, srcs[0], srcs[1]);
192    case OpenCLstd_Smoothstep:
193       return nir_smoothstep(nb, srcs[0], srcs[1], srcs[2]);
194    case OpenCLstd_Clz:
195       return nir_clz_u(nb, srcs[0]);
196    case OpenCLstd_Select:
197       return nir_select(nb, srcs[0], srcs[1], srcs[2]);
198    case OpenCLstd_Step:
199       return nir_sge(nb, srcs[1], srcs[0]);
200    case OpenCLstd_S_Upsample:
201    case OpenCLstd_U_Upsample:
202       return nir_upsample(nb, srcs[0], srcs[1]);
203    case OpenCLstd_Native_exp:
204       return nir_fexp(nb, srcs[0]);
205    case OpenCLstd_Native_exp10:
206       return nir_fexp2(nb, nir_fmul_imm(nb, srcs[0], log(10) / log(2)));
207    case OpenCLstd_Native_log:
208       return nir_flog(nb, srcs[0]);
209    case OpenCLstd_Native_log10:
210       return nir_fmul_imm(nb, nir_flog2(nb, srcs[0]), log(2) / log(10));
211    case OpenCLstd_Native_tan:
212       return nir_ftan(nb, srcs[0]);
213    default:
214       vtn_fail("No NIR equivalent");
215       return NULL;
216    }
217 }
218 
219 static void
_handle_v_load_store(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count,bool load)220 _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
221                      const uint32_t *w, unsigned count, bool load)
222 {
223    struct vtn_type *type;
224    if (load)
225       type = vtn_get_type(b, w[1]);
226    else
227       type = vtn_get_value_type(b, w[5]);
228    unsigned a = load ? 0 : 1;
229 
230    const struct glsl_type *dest_type = type->type;
231    unsigned components = glsl_get_vector_elements(dest_type);
232 
233    nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
234    struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
235 
236    struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
237    nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS];
238 
239    nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset, components);
240    nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
241 
242    for (int i = 0; i < components; i++) {
243       nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
244       nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);
245 
246       if (load) {
247          comps[i] = vtn_local_load(b, arr_deref, p->type->access);
248          ncomps[i] = comps[i]->def;
249       } else {
250          struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(glsl_get_base_type(dest_type)));
251          struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
252          ssa->def = nir_channel(&b->nb, val->def, i);
253          vtn_local_store(b, ssa, arr_deref, p->type->access);
254       }
255    }
256    if (load) {
257       vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components));
258    }
259 }
260 
261 static void
vtn_handle_opencl_vload(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)262 vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
263                         const uint32_t *w, unsigned count)
264 {
265    _handle_v_load_store(b, opcode, w, count, true);
266 }
267 
268 static void
vtn_handle_opencl_vstore(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)269 vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
270                          const uint32_t *w, unsigned count)
271 {
272    _handle_v_load_store(b, opcode, w, count, false);
273 }
274 
275 static nir_ssa_def *
handle_printf(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,unsigned num_srcs,nir_ssa_def ** srcs,const struct glsl_type * dest_type)276 handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
277               unsigned num_srcs, nir_ssa_def **srcs,
278               const struct glsl_type *dest_type)
279 {
280    /* hahah, yeah, right.. */
281    return nir_imm_int(&b->nb, -1);
282 }
283 
284 static nir_ssa_def *
handle_shuffle(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,unsigned num_srcs,nir_ssa_def ** srcs,const struct glsl_type * dest_type)285 handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
286                nir_ssa_def **srcs, const struct glsl_type *dest_type)
287 {
288    struct nir_ssa_def *input = srcs[0];
289    struct nir_ssa_def *mask = srcs[1];
290 
291    unsigned out_elems = glsl_get_vector_elements(dest_type);
292    nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
293    unsigned in_elems = input->num_components;
294    if (mask->bit_size != 32)
295       mask = nir_u2u32(&b->nb, mask);
296    mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, in_elems - 1, mask->bit_size));
297    for (unsigned i = 0; i < out_elems; i++)
298       outres[i] = nir_vector_extract(&b->nb, input, nir_channel(&b->nb, mask, i));
299 
300    return nir_vec(&b->nb, outres, out_elems);
301 }
302 
303 static nir_ssa_def *
handle_shuffle2(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,unsigned num_srcs,nir_ssa_def ** srcs,const struct glsl_type * dest_type)304 handle_shuffle2(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
305                 nir_ssa_def **srcs, const struct glsl_type *dest_type)
306 {
307    struct nir_ssa_def *input0 = srcs[0];
308    struct nir_ssa_def *input1 = srcs[1];
309    struct nir_ssa_def *mask = srcs[2];
310 
311    unsigned out_elems = glsl_get_vector_elements(dest_type);
312    nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
313    unsigned in_elems = input0->num_components;
314    unsigned total_mask = 2 * in_elems - 1;
315    unsigned half_mask = in_elems - 1;
316    if (mask->bit_size != 32)
317       mask = nir_u2u32(&b->nb, mask);
318    mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, total_mask, mask->bit_size));
319    for (unsigned i = 0; i < out_elems; i++) {
320       nir_ssa_def *this_mask = nir_channel(&b->nb, mask, i);
321       nir_ssa_def *vmask = nir_iand(&b->nb, this_mask, nir_imm_intN_t(&b->nb, half_mask, mask->bit_size));
322       nir_ssa_def *val0 = nir_vector_extract(&b->nb, input0, vmask);
323       nir_ssa_def *val1 = nir_vector_extract(&b->nb, input1, vmask);
324       nir_ssa_def *sel = nir_ilt(&b->nb, this_mask, nir_imm_intN_t(&b->nb, in_elems, mask->bit_size));
325       outres[i] = nir_bcsel(&b->nb, sel, val0, val1);
326    }
327    return nir_vec(&b->nb, outres, out_elems);
328 }
329 
330 bool
vtn_handle_opencl_instruction(struct vtn_builder * b,SpvOp ext_opcode,const uint32_t * w,unsigned count)331 vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
332                               const uint32_t *w, unsigned count)
333 {
334    enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints) ext_opcode;
335 
336    switch (cl_opcode) {
337    case OpenCLstd_Fabs:
338    case OpenCLstd_SAbs:
339    case OpenCLstd_UAbs:
340    case OpenCLstd_SAdd_sat:
341    case OpenCLstd_UAdd_sat:
342    case OpenCLstd_Ceil:
343    case OpenCLstd_Cos:
344    case OpenCLstd_Exp2:
345    case OpenCLstd_Log2:
346    case OpenCLstd_Floor:
347    case OpenCLstd_Fma:
348    case OpenCLstd_Fmax:
349    case OpenCLstd_SHadd:
350    case OpenCLstd_UHadd:
351    case OpenCLstd_SMax:
352    case OpenCLstd_UMax:
353    case OpenCLstd_Fmin:
354    case OpenCLstd_SMin:
355    case OpenCLstd_UMin:
356    case OpenCLstd_Mix:
357    case OpenCLstd_Native_cos:
358    case OpenCLstd_Native_divide:
359    case OpenCLstd_Native_exp2:
360    case OpenCLstd_Native_log2:
361    case OpenCLstd_Native_powr:
362    case OpenCLstd_Native_recip:
363    case OpenCLstd_Native_rsqrt:
364    case OpenCLstd_Native_sin:
365    case OpenCLstd_Native_sqrt:
366    case OpenCLstd_Fmod:
367    case OpenCLstd_SMul_hi:
368    case OpenCLstd_UMul_hi:
369    case OpenCLstd_Popcount:
370    case OpenCLstd_Pow:
371    case OpenCLstd_Remainder:
372    case OpenCLstd_SRhadd:
373    case OpenCLstd_URhadd:
374    case OpenCLstd_Rsqrt:
375    case OpenCLstd_Sign:
376    case OpenCLstd_Sin:
377    case OpenCLstd_Sqrt:
378    case OpenCLstd_SSub_sat:
379    case OpenCLstd_USub_sat:
380    case OpenCLstd_Trunc:
381    case OpenCLstd_Rint:
382       handle_instr(b, cl_opcode, w, count, handle_alu);
383       return true;
384    case OpenCLstd_SAbs_diff:
385    case OpenCLstd_UAbs_diff:
386    case OpenCLstd_SMad_hi:
387    case OpenCLstd_UMad_hi:
388    case OpenCLstd_SMad24:
389    case OpenCLstd_UMad24:
390    case OpenCLstd_SMul24:
391    case OpenCLstd_UMul24:
392    case OpenCLstd_Bitselect:
393    case OpenCLstd_FClamp:
394    case OpenCLstd_SClamp:
395    case OpenCLstd_UClamp:
396    case OpenCLstd_Copysign:
397    case OpenCLstd_Cross:
398    case OpenCLstd_Degrees:
399    case OpenCLstd_Fdim:
400    case OpenCLstd_Distance:
401    case OpenCLstd_Fast_distance:
402    case OpenCLstd_Fast_length:
403    case OpenCLstd_Fast_normalize:
404    case OpenCLstd_Length:
405    case OpenCLstd_Mad:
406    case OpenCLstd_Maxmag:
407    case OpenCLstd_Minmag:
408    case OpenCLstd_Nan:
409    case OpenCLstd_Nextafter:
410    case OpenCLstd_Normalize:
411    case OpenCLstd_Radians:
412    case OpenCLstd_Rotate:
413    case OpenCLstd_Select:
414    case OpenCLstd_Step:
415    case OpenCLstd_Smoothstep:
416    case OpenCLstd_S_Upsample:
417    case OpenCLstd_U_Upsample:
418    case OpenCLstd_Clz:
419    case OpenCLstd_Native_exp:
420    case OpenCLstd_Native_exp10:
421    case OpenCLstd_Native_log:
422    case OpenCLstd_Native_log10:
423    case OpenCLstd_Native_tan:
424       handle_instr(b, cl_opcode, w, count, handle_special);
425       return true;
426    case OpenCLstd_Vloadn:
427       vtn_handle_opencl_vload(b, cl_opcode, w, count);
428       return true;
429    case OpenCLstd_Vstoren:
430       vtn_handle_opencl_vstore(b, cl_opcode, w, count);
431       return true;
432    case OpenCLstd_Shuffle:
433       handle_instr(b, cl_opcode, w, count, handle_shuffle);
434       return true;
435    case OpenCLstd_Shuffle2:
436       handle_instr(b, cl_opcode, w, count, handle_shuffle2);
437       return true;
438    case OpenCLstd_Printf:
439       handle_instr(b, cl_opcode, w, count, handle_printf);
440       return true;
441    case OpenCLstd_Prefetch:
442       /* TODO maybe add a nir instruction for this? */
443       return true;
444    default:
445       vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
446       return false;
447    }
448 }
449