1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_ssa_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28                          nir_intrinsic_op nir_op,
29                          struct vtn_ssa_value *src0,
30                          nir_ssa_def *index,
31                          unsigned const_idx0,
32                          unsigned const_idx1)
33 {
34    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
35     * any integer type.  To make things simpler for drivers, we only support
36     * 32-bit indices.
37     */
38    if (index && index->bit_size != 32)
39       index = nir_u2u32(&b->nb, index);
40 
41    struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42 
43    vtn_assert(dst->type == src0->type);
44    if (!glsl_type_is_vector_or_scalar(dst->type)) {
45       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46          dst->elems[0] =
47             vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48                                      const_idx0, const_idx1);
49       }
50       return dst;
51    }
52 
53    nir_intrinsic_instr *intrin =
54       nir_intrinsic_instr_create(b->nb.shader, nir_op);
55    nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
56                               dst->type, NULL);
57    intrin->num_components = intrin->dest.ssa.num_components;
58 
59    intrin->src[0] = nir_src_for_ssa(src0->def);
60    if (index)
61       intrin->src[1] = nir_src_for_ssa(index);
62 
63    intrin->const_index[0] = const_idx0;
64    intrin->const_index[1] = const_idx1;
65 
66    nir_builder_instr_insert(&b->nb, &intrin->instr);
67 
68    dst->def = &intrin->dest.ssa;
69 
70    return dst;
71 }
72 
73 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)74 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
75                     const uint32_t *w, unsigned count)
76 {
77    struct vtn_type *dest_type = vtn_get_type(b, w[1]);
78 
79    switch (opcode) {
80    case SpvOpGroupNonUniformElect: {
81       vtn_fail_if(dest_type->type != glsl_bool_type(),
82                   "OpGroupNonUniformElect must return a Bool");
83       nir_intrinsic_instr *elect =
84          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
85       nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
86                                  dest_type->type, NULL);
87       nir_builder_instr_insert(&b->nb, &elect->instr);
88       vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
89       break;
90    }
91 
92    case SpvOpGroupNonUniformBallot:
93    case SpvOpSubgroupBallotKHR: {
94       bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
95       vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
96                   "OpGroupNonUniformBallot must return a uvec4");
97       nir_intrinsic_instr *ballot =
98          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
99       ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
100       nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
101       ballot->num_components = 4;
102       nir_builder_instr_insert(&b->nb, &ballot->instr);
103       vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
104       break;
105    }
106 
107    case SpvOpGroupNonUniformInverseBallot: {
108       /* This one is just a BallotBitfieldExtract with subgroup invocation.
109        * We could add a NIR intrinsic but it's easier to just lower it on the
110        * spot.
111        */
112       nir_intrinsic_instr *intrin =
113          nir_intrinsic_instr_create(b->nb.shader,
114                                     nir_intrinsic_ballot_bitfield_extract);
115 
116       intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));
117       intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
118 
119       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
120                                  dest_type->type, NULL);
121       nir_builder_instr_insert(&b->nb, &intrin->instr);
122 
123       vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
124       break;
125    }
126 
127    case SpvOpGroupNonUniformBallotBitExtract:
128    case SpvOpGroupNonUniformBallotBitCount:
129    case SpvOpGroupNonUniformBallotFindLSB:
130    case SpvOpGroupNonUniformBallotFindMSB: {
131       nir_ssa_def *src0, *src1 = NULL;
132       nir_intrinsic_op op;
133       switch (opcode) {
134       case SpvOpGroupNonUniformBallotBitExtract:
135          op = nir_intrinsic_ballot_bitfield_extract;
136          src0 = vtn_get_nir_ssa(b, w[4]);
137          src1 = vtn_get_nir_ssa(b, w[5]);
138          break;
139       case SpvOpGroupNonUniformBallotBitCount:
140          switch ((SpvGroupOperation)w[4]) {
141          case SpvGroupOperationReduce:
142             op = nir_intrinsic_ballot_bit_count_reduce;
143             break;
144          case SpvGroupOperationInclusiveScan:
145             op = nir_intrinsic_ballot_bit_count_inclusive;
146             break;
147          case SpvGroupOperationExclusiveScan:
148             op = nir_intrinsic_ballot_bit_count_exclusive;
149             break;
150          default:
151             unreachable("Invalid group operation");
152          }
153          src0 = vtn_get_nir_ssa(b, w[5]);
154          break;
155       case SpvOpGroupNonUniformBallotFindLSB:
156          op = nir_intrinsic_ballot_find_lsb;
157          src0 = vtn_get_nir_ssa(b, w[4]);
158          break;
159       case SpvOpGroupNonUniformBallotFindMSB:
160          op = nir_intrinsic_ballot_find_msb;
161          src0 = vtn_get_nir_ssa(b, w[4]);
162          break;
163       default:
164          unreachable("Unhandled opcode");
165       }
166 
167       nir_intrinsic_instr *intrin =
168          nir_intrinsic_instr_create(b->nb.shader, op);
169 
170       intrin->src[0] = nir_src_for_ssa(src0);
171       if (src1)
172          intrin->src[1] = nir_src_for_ssa(src1);
173 
174       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
175                                  dest_type->type, NULL);
176       nir_builder_instr_insert(&b->nb, &intrin->instr);
177 
178       vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
179       break;
180    }
181 
182    case SpvOpGroupNonUniformBroadcastFirst:
183    case SpvOpSubgroupFirstInvocationKHR: {
184       bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
185       vtn_push_ssa_value(b, w[2],
186          vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
187                                   vtn_ssa_value(b, w[3 + has_scope]),
188                                   NULL, 0, 0));
189       break;
190    }
191 
192    case SpvOpGroupNonUniformBroadcast:
193    case SpvOpGroupBroadcast:
194    case SpvOpSubgroupReadInvocationKHR: {
195       bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
196       vtn_push_ssa_value(b, w[2],
197          vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
198                                   vtn_ssa_value(b, w[3 + has_scope]),
199                                   vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
200       break;
201    }
202 
203    case SpvOpGroupNonUniformAll:
204    case SpvOpGroupNonUniformAny:
205    case SpvOpGroupNonUniformAllEqual:
206    case SpvOpGroupAll:
207    case SpvOpGroupAny:
208    case SpvOpSubgroupAllKHR:
209    case SpvOpSubgroupAnyKHR:
210    case SpvOpSubgroupAllEqualKHR: {
211       vtn_fail_if(dest_type->type != glsl_bool_type(),
212                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
213       nir_intrinsic_op op;
214       switch (opcode) {
215       case SpvOpGroupNonUniformAll:
216       case SpvOpGroupAll:
217       case SpvOpSubgroupAllKHR:
218          op = nir_intrinsic_vote_all;
219          break;
220       case SpvOpGroupNonUniformAny:
221       case SpvOpGroupAny:
222       case SpvOpSubgroupAnyKHR:
223          op = nir_intrinsic_vote_any;
224          break;
225       case SpvOpSubgroupAllEqualKHR:
226          op = nir_intrinsic_vote_ieq;
227          break;
228       case SpvOpGroupNonUniformAllEqual:
229          switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
230          case GLSL_TYPE_FLOAT:
231          case GLSL_TYPE_FLOAT16:
232          case GLSL_TYPE_DOUBLE:
233             op = nir_intrinsic_vote_feq;
234             break;
235          case GLSL_TYPE_UINT:
236          case GLSL_TYPE_INT:
237          case GLSL_TYPE_UINT8:
238          case GLSL_TYPE_INT8:
239          case GLSL_TYPE_UINT16:
240          case GLSL_TYPE_INT16:
241          case GLSL_TYPE_UINT64:
242          case GLSL_TYPE_INT64:
243          case GLSL_TYPE_BOOL:
244             op = nir_intrinsic_vote_ieq;
245             break;
246          default:
247             unreachable("Unhandled type");
248          }
249          break;
250       default:
251          unreachable("Unhandled opcode");
252       }
253 
254       nir_ssa_def *src0;
255       if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
256           opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
257           opcode == SpvOpGroupNonUniformAllEqual) {
258          src0 = vtn_get_nir_ssa(b, w[4]);
259       } else {
260          src0 = vtn_get_nir_ssa(b, w[3]);
261       }
262       nir_intrinsic_instr *intrin =
263          nir_intrinsic_instr_create(b->nb.shader, op);
264       if (nir_intrinsic_infos[op].src_components[0] == 0)
265          intrin->num_components = src0->num_components;
266       intrin->src[0] = nir_src_for_ssa(src0);
267       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
268                                  dest_type->type, NULL);
269       nir_builder_instr_insert(&b->nb, &intrin->instr);
270 
271       vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
272       break;
273    }
274 
275    case SpvOpGroupNonUniformShuffle:
276    case SpvOpGroupNonUniformShuffleXor:
277    case SpvOpGroupNonUniformShuffleUp:
278    case SpvOpGroupNonUniformShuffleDown: {
279       nir_intrinsic_op op;
280       switch (opcode) {
281       case SpvOpGroupNonUniformShuffle:
282          op = nir_intrinsic_shuffle;
283          break;
284       case SpvOpGroupNonUniformShuffleXor:
285          op = nir_intrinsic_shuffle_xor;
286          break;
287       case SpvOpGroupNonUniformShuffleUp:
288          op = nir_intrinsic_shuffle_up;
289          break;
290       case SpvOpGroupNonUniformShuffleDown:
291          op = nir_intrinsic_shuffle_down;
292          break;
293       default:
294          unreachable("Invalid opcode");
295       }
296       vtn_push_ssa_value(b, w[2],
297          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
298                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
299       break;
300    }
301 
302    case SpvOpGroupNonUniformQuadBroadcast:
303       vtn_push_ssa_value(b, w[2],
304          vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
305                                   vtn_ssa_value(b, w[4]),
306                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
307       break;
308 
309    case SpvOpGroupNonUniformQuadSwap: {
310       unsigned direction = vtn_constant_uint(b, w[5]);
311       nir_intrinsic_op op;
312       switch (direction) {
313       case 0:
314          op = nir_intrinsic_quad_swap_horizontal;
315          break;
316       case 1:
317          op = nir_intrinsic_quad_swap_vertical;
318          break;
319       case 2:
320          op = nir_intrinsic_quad_swap_diagonal;
321          break;
322       default:
323          vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
324       }
325       vtn_push_ssa_value(b, w[2],
326          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
327       break;
328    }
329 
330    case SpvOpGroupNonUniformIAdd:
331    case SpvOpGroupNonUniformFAdd:
332    case SpvOpGroupNonUniformIMul:
333    case SpvOpGroupNonUniformFMul:
334    case SpvOpGroupNonUniformSMin:
335    case SpvOpGroupNonUniformUMin:
336    case SpvOpGroupNonUniformFMin:
337    case SpvOpGroupNonUniformSMax:
338    case SpvOpGroupNonUniformUMax:
339    case SpvOpGroupNonUniformFMax:
340    case SpvOpGroupNonUniformBitwiseAnd:
341    case SpvOpGroupNonUniformBitwiseOr:
342    case SpvOpGroupNonUniformBitwiseXor:
343    case SpvOpGroupNonUniformLogicalAnd:
344    case SpvOpGroupNonUniformLogicalOr:
345    case SpvOpGroupNonUniformLogicalXor:
346    case SpvOpGroupIAdd:
347    case SpvOpGroupFAdd:
348    case SpvOpGroupFMin:
349    case SpvOpGroupUMin:
350    case SpvOpGroupSMin:
351    case SpvOpGroupFMax:
352    case SpvOpGroupUMax:
353    case SpvOpGroupSMax:
354    case SpvOpGroupIAddNonUniformAMD:
355    case SpvOpGroupFAddNonUniformAMD:
356    case SpvOpGroupFMinNonUniformAMD:
357    case SpvOpGroupUMinNonUniformAMD:
358    case SpvOpGroupSMinNonUniformAMD:
359    case SpvOpGroupFMaxNonUniformAMD:
360    case SpvOpGroupUMaxNonUniformAMD:
361    case SpvOpGroupSMaxNonUniformAMD: {
362       nir_op reduction_op;
363       switch (opcode) {
364       case SpvOpGroupNonUniformIAdd:
365       case SpvOpGroupIAdd:
366       case SpvOpGroupIAddNonUniformAMD:
367          reduction_op = nir_op_iadd;
368          break;
369       case SpvOpGroupNonUniformFAdd:
370       case SpvOpGroupFAdd:
371       case SpvOpGroupFAddNonUniformAMD:
372          reduction_op = nir_op_fadd;
373          break;
374       case SpvOpGroupNonUniformIMul:
375          reduction_op = nir_op_imul;
376          break;
377       case SpvOpGroupNonUniformFMul:
378          reduction_op = nir_op_fmul;
379          break;
380       case SpvOpGroupNonUniformSMin:
381       case SpvOpGroupSMin:
382       case SpvOpGroupSMinNonUniformAMD:
383          reduction_op = nir_op_imin;
384          break;
385       case SpvOpGroupNonUniformUMin:
386       case SpvOpGroupUMin:
387       case SpvOpGroupUMinNonUniformAMD:
388          reduction_op = nir_op_umin;
389          break;
390       case SpvOpGroupNonUniformFMin:
391       case SpvOpGroupFMin:
392       case SpvOpGroupFMinNonUniformAMD:
393          reduction_op = nir_op_fmin;
394          break;
395       case SpvOpGroupNonUniformSMax:
396       case SpvOpGroupSMax:
397       case SpvOpGroupSMaxNonUniformAMD:
398          reduction_op = nir_op_imax;
399          break;
400       case SpvOpGroupNonUniformUMax:
401       case SpvOpGroupUMax:
402       case SpvOpGroupUMaxNonUniformAMD:
403          reduction_op = nir_op_umax;
404          break;
405       case SpvOpGroupNonUniformFMax:
406       case SpvOpGroupFMax:
407       case SpvOpGroupFMaxNonUniformAMD:
408          reduction_op = nir_op_fmax;
409          break;
410       case SpvOpGroupNonUniformBitwiseAnd:
411       case SpvOpGroupNonUniformLogicalAnd:
412          reduction_op = nir_op_iand;
413          break;
414       case SpvOpGroupNonUniformBitwiseOr:
415       case SpvOpGroupNonUniformLogicalOr:
416          reduction_op = nir_op_ior;
417          break;
418       case SpvOpGroupNonUniformBitwiseXor:
419       case SpvOpGroupNonUniformLogicalXor:
420          reduction_op = nir_op_ixor;
421          break;
422       default:
423          unreachable("Invalid reduction operation");
424       }
425 
426       nir_intrinsic_op op;
427       unsigned cluster_size = 0;
428       switch ((SpvGroupOperation)w[4]) {
429       case SpvGroupOperationReduce:
430          op = nir_intrinsic_reduce;
431          break;
432       case SpvGroupOperationInclusiveScan:
433          op = nir_intrinsic_inclusive_scan;
434          break;
435       case SpvGroupOperationExclusiveScan:
436          op = nir_intrinsic_exclusive_scan;
437          break;
438       case SpvGroupOperationClusteredReduce:
439          op = nir_intrinsic_reduce;
440          assert(count == 7);
441          cluster_size = vtn_constant_uint(b, w[6]);
442          break;
443       default:
444          unreachable("Invalid group operation");
445       }
446 
447       vtn_push_ssa_value(b, w[2],
448          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
449                                   reduction_op, cluster_size));
450       break;
451    }
452 
453    default:
454       unreachable("Invalid SPIR-V opcode");
455    }
456 }
457