1 /*
2  * Copyright © 2017 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "util/u_math.h"
27 
28 /**
29  * \file nir_opt_intrinsics.c
30  */
31 
32 static nir_intrinsic_instr *
lower_subgroups_64bit_split_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,unsigned int component)33 lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
34                                       unsigned int component)
35 {
36    nir_ssa_def *comp;
37    if (component == 0)
38       comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
39    else
40       comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
41 
42    nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
43    nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL);
44    intr->const_index[0] = intrin->const_index[0];
45    intr->const_index[1] = intrin->const_index[1];
46    intr->src[0] = nir_src_for_ssa(comp);
47    if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
48       nir_src_copy(&intr->src[1], &intrin->src[1]);
49 
50    intr->num_components = 1;
51    nir_builder_instr_insert(b, &intr->instr);
52    return intr;
53 }
54 
55 static nir_ssa_def *
lower_subgroup_op_to_32bit(nir_builder * b,nir_intrinsic_instr * intrin)56 lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
57 {
58    assert(intrin->src[0].ssa->bit_size == 64);
59    nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
60    nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
61    return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa);
62 }
63 
64 static nir_ssa_def *
ballot_type_to_uint(nir_builder * b,nir_ssa_def * value,const nir_lower_subgroups_options * options)65 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value,
66                     const nir_lower_subgroups_options *options)
67 {
68    /* Only the new-style SPIR-V subgroup instructions take a ballot result as
69     * an argument, so we only use this on uvec4 types.
70     */
71    assert(value->num_components == 4 && value->bit_size == 32);
72 
73    return nir_extract_bits(b, &value, 1, 0, options->ballot_components,
74                            options->ballot_bit_size);
75 }
76 
77 static nir_ssa_def *
uint_to_ballot_type(nir_builder * b,nir_ssa_def * value,unsigned num_components,unsigned bit_size)78 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
79                     unsigned num_components, unsigned bit_size)
80 {
81    assert(util_is_power_of_two_nonzero(num_components));
82    assert(util_is_power_of_two_nonzero(value->num_components));
83 
84    unsigned total_bits = bit_size * num_components;
85 
86    /* If the source doesn't have enough bits, zero-pad */
87    if (total_bits > value->bit_size * value->num_components)
88       value = nir_pad_vector_imm_int(b, value, 0, total_bits / value->bit_size);
89 
90    value = nir_bitcast_vector(b, value, bit_size);
91 
92    /* If the source has too many components, truncate.  This can happen if,
93     * for instance, we're implementing GL_ARB_shader_ballot or
94     * VK_EXT_shader_subgroup_ballot which have 64-bit ballot values on an
95     * architecture with a native 128-bit uvec4 ballot.  This comes up in Zink
96     * for OpenGL on Vulkan.  It's the job of the driver calling this lowering
97     * pass to ensure that it's restricted subgroup sizes sufficiently that we
98     * have enough ballot bits.
99     */
100    if (value->num_components > num_components)
101       value = nir_channels(b, value, BITFIELD_MASK(num_components));
102 
103    return value;
104 }
105 
106 static nir_ssa_def *
lower_subgroup_op_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin,bool lower_to_32bit)107 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
108                             bool lower_to_32bit)
109 {
110    /* This is safe to call on scalar things but it would be silly */
111    assert(intrin->dest.ssa.num_components > 1);
112 
113    nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
114                                            intrin->num_components);
115    nir_ssa_def *reads[NIR_MAX_VEC_COMPONENTS];
116 
117    for (unsigned i = 0; i < intrin->num_components; i++) {
118       nir_intrinsic_instr *chan_intrin =
119          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
120       nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
121                         1, intrin->dest.ssa.bit_size, NULL);
122       chan_intrin->num_components = 1;
123 
124       /* value */
125       chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
126       /* invocation */
127       if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
128          assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
129          nir_src_copy(&chan_intrin->src[1], &intrin->src[1]);
130       }
131 
132       chan_intrin->const_index[0] = intrin->const_index[0];
133       chan_intrin->const_index[1] = intrin->const_index[1];
134 
135       if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
136          reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
137       } else {
138          nir_builder_instr_insert(b, &chan_intrin->instr);
139          reads[i] = &chan_intrin->dest.ssa;
140       }
141    }
142 
143    return nir_vec(b, reads, intrin->num_components);
144 }
145 
146 static nir_ssa_def *
lower_vote_eq_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)147 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
148 {
149    assert(intrin->src[0].is_ssa);
150    nir_ssa_def *value = intrin->src[0].ssa;
151 
152    nir_ssa_def *result = NULL;
153    for (unsigned i = 0; i < intrin->num_components; i++) {
154       nir_intrinsic_instr *chan_intrin =
155          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
156       nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
157                         1, intrin->dest.ssa.bit_size, NULL);
158       chan_intrin->num_components = 1;
159       chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
160       nir_builder_instr_insert(b, &chan_intrin->instr);
161 
162       if (result) {
163          result = nir_iand(b, result, &chan_intrin->dest.ssa);
164       } else {
165          result = &chan_intrin->dest.ssa;
166       }
167    }
168 
169    return result;
170 }
171 
172 static nir_ssa_def *
lower_vote_eq(nir_builder * b,nir_intrinsic_instr * intrin)173 lower_vote_eq(nir_builder *b, nir_intrinsic_instr *intrin)
174 {
175    assert(intrin->src[0].is_ssa);
176    nir_ssa_def *value = intrin->src[0].ssa;
177 
178    /* We have to implicitly lower to scalar */
179    nir_ssa_def *all_eq = NULL;
180    for (unsigned i = 0; i < intrin->num_components; i++) {
181       nir_ssa_def *rfi = nir_read_first_invocation(b, nir_channel(b, value, i));
182 
183       nir_ssa_def *is_eq;
184       if (intrin->intrinsic == nir_intrinsic_vote_feq) {
185          is_eq = nir_feq(b, rfi, nir_channel(b, value, i));
186       } else {
187          is_eq = nir_ieq(b, rfi, nir_channel(b, value, i));
188       }
189 
190       if (all_eq == NULL) {
191          all_eq = is_eq;
192       } else {
193          all_eq = nir_iand(b, all_eq, is_eq);
194       }
195    }
196 
197    return nir_vote_all(b, 1, all_eq);
198 }
199 
200 static nir_ssa_def *
lower_shuffle_to_swizzle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)201 lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
202                          const nir_lower_subgroups_options *options)
203 {
204    unsigned mask = nir_src_as_uint(intrin->src[1]);
205 
206    if (mask >= 32)
207       return NULL;
208 
209    nir_intrinsic_instr *swizzle = nir_intrinsic_instr_create(
210       b->shader, nir_intrinsic_masked_swizzle_amd);
211    swizzle->num_components = intrin->num_components;
212    nir_src_copy(&swizzle->src[0], &intrin->src[0]);
213    nir_intrinsic_set_swizzle_mask(swizzle, (mask << 10) | 0x1f);
214    nir_ssa_dest_init(&swizzle->instr, &swizzle->dest,
215                      intrin->dest.ssa.num_components,
216                      intrin->dest.ssa.bit_size, NULL);
217 
218    if (options->lower_to_scalar && swizzle->num_components > 1) {
219       return lower_subgroup_op_to_scalar(b, swizzle, options->lower_shuffle_to_32bit);
220    } else if (options->lower_shuffle_to_32bit && swizzle->src[0].ssa->bit_size == 64) {
221       return lower_subgroup_op_to_32bit(b, swizzle);
222    } else {
223       nir_builder_instr_insert(b, &swizzle->instr);
224       return &swizzle->dest.ssa;
225    }
226 }
227 
228 static nir_ssa_def *
lower_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)229 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
230               const nir_lower_subgroups_options *options)
231 {
232    if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
233        options->lower_shuffle_to_swizzle_amd &&
234        nir_src_is_const(intrin->src[1])) {
235       nir_ssa_def *result =
236          lower_shuffle_to_swizzle(b, intrin, options);
237       if (result)
238          return result;
239    }
240 
241    nir_ssa_def *index = nir_load_subgroup_invocation(b);
242    bool is_shuffle = false;
243    switch (intrin->intrinsic) {
244    case nir_intrinsic_shuffle_xor:
245       assert(intrin->src[1].is_ssa);
246       index = nir_ixor(b, index, intrin->src[1].ssa);
247       is_shuffle = true;
248       break;
249    case nir_intrinsic_shuffle_up:
250       assert(intrin->src[1].is_ssa);
251       index = nir_isub(b, index, intrin->src[1].ssa);
252       is_shuffle = true;
253       break;
254    case nir_intrinsic_shuffle_down:
255       assert(intrin->src[1].is_ssa);
256       index = nir_iadd(b, index, intrin->src[1].ssa);
257       is_shuffle = true;
258       break;
259    case nir_intrinsic_quad_broadcast:
260       assert(intrin->src[1].is_ssa);
261       index = nir_ior(b, nir_iand(b, index, nir_imm_int(b, ~0x3)),
262                          intrin->src[1].ssa);
263       break;
264    case nir_intrinsic_quad_swap_horizontal:
265       /* For Quad operations, subgroups are divided into quads where
266        * (invocation % 4) is the index to a square arranged as follows:
267        *
268        *    +---+---+
269        *    | 0 | 1 |
270        *    +---+---+
271        *    | 2 | 3 |
272        *    +---+---+
273        */
274       index = nir_ixor(b, index, nir_imm_int(b, 0x1));
275       break;
276    case nir_intrinsic_quad_swap_vertical:
277       index = nir_ixor(b, index, nir_imm_int(b, 0x2));
278       break;
279    case nir_intrinsic_quad_swap_diagonal:
280       index = nir_ixor(b, index, nir_imm_int(b, 0x3));
281       break;
282    default:
283       unreachable("Invalid intrinsic");
284    }
285 
286    nir_intrinsic_instr *shuffle =
287       nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle);
288    shuffle->num_components = intrin->num_components;
289    nir_src_copy(&shuffle->src[0], &intrin->src[0]);
290    shuffle->src[1] = nir_src_for_ssa(index);
291    nir_ssa_dest_init(&shuffle->instr, &shuffle->dest,
292                      intrin->dest.ssa.num_components,
293                      intrin->dest.ssa.bit_size, NULL);
294 
295    bool lower_to_32bit = options->lower_shuffle_to_32bit && is_shuffle;
296    if (options->lower_to_scalar && shuffle->num_components > 1) {
297       return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
298    } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
299       return lower_subgroup_op_to_32bit(b, shuffle);
300    } else {
301       nir_builder_instr_insert(b, &shuffle->instr);
302       return &shuffle->dest.ssa;
303    }
304 }
305 
306 static bool
lower_subgroups_filter(const nir_instr * instr,const void * _options)307 lower_subgroups_filter(const nir_instr *instr, const void *_options)
308 {
309    return instr->type == nir_instr_type_intrinsic;
310 }
311 
312 /* Return a ballot-mask-sized value which represents "val" sign-extended and
313  * then shifted left by "shift". Only particular values for "val" are
314  * supported, see below.
315  */
316 static nir_ssa_def *
build_ballot_imm_ishl(nir_builder * b,int64_t val,nir_ssa_def * shift,const nir_lower_subgroups_options * options)317 build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_ssa_def *shift,
318                       const nir_lower_subgroups_options *options)
319 {
320    /* This only works if all the high bits are the same as bit 1. */
321    assert(((val << 62) >> 62) == val);
322 
323    /* First compute the result assuming one ballot component. */
324    nir_ssa_def *result =
325       nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
326 
327    if (options->ballot_components == 1)
328       return result;
329 
330    /* Fix up the result when there is > 1 component. The idea is that nir_ishl
331     * masks out the high bits of the shift value already, so in case there's
332     * more than one component the component which 1 would be shifted into
333     * already has the right value and all we have to do is fixup the other
334     * components. Components below it should always be 0, and components above
335     * it must be either 0 or ~0 because of the assert above. For example, if
336     * the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
337     * we'll feed 33 into ishl, which will mask it off to get 1, so we'll
338     * compute a single-component result of 2, which is correct for the second
339     * component, but the first component needs to be 0, which we get by
340     * comparing the high bits of the shift with 0 and selecting the original
341     * answer or 0 for the first component (and something similar with the
342     * second component). This idea is generalized here for any component count
343     */
344    nir_const_value min_shift[4] = { 0 };
345    for (unsigned i = 0; i < options->ballot_components; i++)
346       min_shift[i].i32 = i * options->ballot_bit_size;
347    nir_ssa_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
348 
349    nir_const_value max_shift[4] = { 0 };
350    for (unsigned i = 0; i < options->ballot_components; i++)
351       max_shift[i].i32 = (i + 1) * options->ballot_bit_size;
352    nir_ssa_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
353 
354    return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
355                     nir_bcsel(b, nir_ult(b, shift, min_shift_val),
356                               nir_imm_intN_t(b, val >> 63, result->bit_size),
357                               result),
358                     nir_imm_intN_t(b, 0, result->bit_size));
359 }
360 
361 static nir_ssa_def *
build_subgroup_eq_mask(nir_builder * b,const nir_lower_subgroups_options * options)362 build_subgroup_eq_mask(nir_builder *b,
363                        const nir_lower_subgroups_options *options)
364 {
365    nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b);
366 
367    return build_ballot_imm_ishl(b, 1, subgroup_idx, options);
368 }
369 
370 static nir_ssa_def *
build_subgroup_ge_mask(nir_builder * b,const nir_lower_subgroups_options * options)371 build_subgroup_ge_mask(nir_builder *b,
372                        const nir_lower_subgroups_options *options)
373 {
374    nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b);
375 
376    return build_ballot_imm_ishl(b, ~0ull, subgroup_idx, options);
377 }
378 
379 static nir_ssa_def *
build_subgroup_gt_mask(nir_builder * b,const nir_lower_subgroups_options * options)380 build_subgroup_gt_mask(nir_builder *b,
381                        const nir_lower_subgroups_options *options)
382 {
383    nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b);
384 
385    return build_ballot_imm_ishl(b, ~1ull, subgroup_idx, options);
386 }
387 
388 /* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
389  * 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
390  * above the subgroup size for the masks, but gt_mask and ge_mask make them 1
391  * so we have to "and" with this mask.
392  */
393 static nir_ssa_def *
build_subgroup_mask(nir_builder * b,const nir_lower_subgroups_options * options)394 build_subgroup_mask(nir_builder *b,
395                     const nir_lower_subgroups_options *options)
396 {
397    nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
398 
399    /* First compute the result assuming one ballot component. */
400    nir_ssa_def *result =
401       nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
402                   nir_isub_imm(b, options->ballot_bit_size,
403                                subgroup_size));
404 
405    /* Since the subgroup size and ballot bitsize are both powers of two, there
406     * are two possible cases to consider:
407     *
408     * (1) The subgroup size is less than the ballot bitsize. We need to return
409     * "result" in the first component and 0 in every other component.
410     * (2) The subgroup size is a multiple of the ballot bitsize. We need to
411     * return ~0 if the subgroup size divided by the ballot bitsize is less
412     * than or equal to the index in the vector and 0 otherwise. For example,
413     * with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
414     * to return { ~0, ~0, 0, 0 }.
415     *
416     * In case (2) it turns out that "result" will be ~0, because
417     * "ballot_bit_size - subgroup_size" is also a multiple of
418     * "ballot_bit_size" and since nir_ushr masks the shift value it will
419     * shifted by 0. This means that the first component can just be "result"
420     * in all cases.  The other components will also get the correct value in
421     * case (1) if we just use the rule in case (2), so we'll get the correct
422     * result if we just follow (2) and then replace the first component with
423     * "result".
424     */
425    nir_const_value min_idx[4] = { 0 };
426    for (unsigned i = 0; i < options->ballot_components; i++)
427       min_idx[i].i32 = i * options->ballot_bit_size;
428    nir_ssa_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
429 
430    nir_ssa_def *result_extended =
431       nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
432 
433    return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
434                     result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
435 }
436 
437 static nir_ssa_def *
vec_bit_count(nir_builder * b,nir_ssa_def * value)438 vec_bit_count(nir_builder *b, nir_ssa_def *value)
439 {
440    nir_ssa_def *vec_result = nir_bit_count(b, value);
441    nir_ssa_def *result = nir_channel(b, vec_result, 0);
442    for (unsigned i = 1; i < value->num_components; i++)
443       result = nir_iadd(b, result, nir_channel(b, vec_result, i));
444    return result;
445 }
446 
447 static nir_ssa_def *
vec_find_lsb(nir_builder * b,nir_ssa_def * value)448 vec_find_lsb(nir_builder *b, nir_ssa_def *value)
449 {
450    nir_ssa_def *vec_result = nir_find_lsb(b, value);
451    nir_ssa_def *result = nir_imm_int(b, -1);
452    for (int i = value->num_components - 1; i >= 0; i--) {
453       nir_ssa_def *channel = nir_channel(b, vec_result, i);
454       /* result = channel >= 0 ? (i * bitsize + channel) : result */
455       result = nir_bcsel(b, nir_ige(b, channel, nir_imm_int(b, 0)),
456                          nir_iadd_imm(b, channel, i * value->bit_size),
457                          result);
458    }
459    return result;
460 }
461 
462 static nir_ssa_def *
vec_find_msb(nir_builder * b,nir_ssa_def * value)463 vec_find_msb(nir_builder *b, nir_ssa_def *value)
464 {
465    nir_ssa_def *vec_result = nir_ufind_msb(b, value);
466    nir_ssa_def *result = nir_imm_int(b, -1);
467    for (unsigned i = 0; i < value->num_components; i++) {
468       nir_ssa_def *channel = nir_channel(b, vec_result, i);
469       /* result = channel >= 0 ? (i * bitsize + channel) : result */
470       result = nir_bcsel(b, nir_ige(b, channel, nir_imm_int(b, 0)),
471                          nir_iadd_imm(b, channel, i * value->bit_size),
472                          result);
473    }
474    return result;
475 }
476 
477 static nir_ssa_def *
lower_dynamic_quad_broadcast(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)478 lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
479                              const nir_lower_subgroups_options *options)
480 {
481    if (!options->lower_quad_broadcast_dynamic_to_const)
482       return lower_shuffle(b, intrin, options);
483 
484    nir_ssa_def *dst = NULL;
485 
486    for (unsigned i = 0; i < 4; ++i) {
487       nir_intrinsic_instr *qbcst =
488          nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
489 
490       qbcst->num_components = intrin->num_components;
491       qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
492       nir_src_copy(&qbcst->src[0], &intrin->src[0]);
493       nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
494                         intrin->dest.ssa.num_components,
495                         intrin->dest.ssa.bit_size, NULL);
496 
497       nir_ssa_def *qbcst_dst = NULL;
498 
499       if (options->lower_to_scalar && qbcst->num_components > 1) {
500          qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
501       } else {
502          nir_builder_instr_insert(b, &qbcst->instr);
503          qbcst_dst = &qbcst->dest.ssa;
504       }
505 
506       if (i)
507          dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
508                                     nir_src_for_ssa(nir_imm_int(b, i)).ssa),
509                          qbcst_dst, dst);
510       else
511          dst = qbcst_dst;
512    }
513 
514    return dst;
515 }
516 
517 static nir_ssa_def *
lower_read_invocation_to_cond(nir_builder * b,nir_intrinsic_instr * intrin)518 lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin)
519 {
520    return nir_read_invocation_cond_ir3(b, intrin->dest.ssa.bit_size,
521                                        intrin->src[0].ssa,
522                                        nir_ieq(b, intrin->src[1].ssa,
523                                                nir_load_subgroup_invocation(b)));
524 }
525 
526 static nir_ssa_def *
lower_subgroups_instr(nir_builder * b,nir_instr * instr,void * _options)527 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
528 {
529    const nir_lower_subgroups_options *options = _options;
530 
531    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
532    switch (intrin->intrinsic) {
533    case nir_intrinsic_vote_any:
534    case nir_intrinsic_vote_all:
535       if (options->lower_vote_trivial)
536          return nir_ssa_for_src(b, intrin->src[0], 1);
537       break;
538 
539    case nir_intrinsic_vote_feq:
540    case nir_intrinsic_vote_ieq:
541       if (options->lower_vote_trivial)
542          return nir_imm_true(b);
543 
544       if (options->lower_vote_eq)
545          return lower_vote_eq(b, intrin);
546 
547       if (options->lower_to_scalar && intrin->num_components > 1)
548          return lower_vote_eq_to_scalar(b, intrin);
549       break;
550 
551    case nir_intrinsic_load_subgroup_size:
552       if (options->subgroup_size)
553          return nir_imm_int(b, options->subgroup_size);
554       break;
555 
556    case nir_intrinsic_read_invocation:
557       if (options->lower_to_scalar && intrin->num_components > 1)
558          return lower_subgroup_op_to_scalar(b, intrin, false);
559 
560       if (options->lower_read_invocation_to_cond)
561          return lower_read_invocation_to_cond(b, intrin);
562 
563       break;
564 
565    case nir_intrinsic_read_first_invocation:
566       if (options->lower_to_scalar && intrin->num_components > 1)
567          return lower_subgroup_op_to_scalar(b, intrin, false);
568       break;
569 
570    case nir_intrinsic_load_subgroup_eq_mask:
571    case nir_intrinsic_load_subgroup_ge_mask:
572    case nir_intrinsic_load_subgroup_gt_mask:
573    case nir_intrinsic_load_subgroup_le_mask:
574    case nir_intrinsic_load_subgroup_lt_mask: {
575       if (!options->lower_subgroup_masks)
576          return NULL;
577 
578       nir_ssa_def *val;
579       switch (intrin->intrinsic) {
580       case nir_intrinsic_load_subgroup_eq_mask:
581          val = build_subgroup_eq_mask(b, options);
582          break;
583       case nir_intrinsic_load_subgroup_ge_mask:
584          val = nir_iand(b, build_subgroup_ge_mask(b, options),
585                            build_subgroup_mask(b, options));
586          break;
587       case nir_intrinsic_load_subgroup_gt_mask:
588          val = nir_iand(b, build_subgroup_gt_mask(b, options),
589                            build_subgroup_mask(b, options));
590          break;
591       case nir_intrinsic_load_subgroup_le_mask:
592          val = nir_inot(b, build_subgroup_gt_mask(b, options));
593          break;
594       case nir_intrinsic_load_subgroup_lt_mask:
595          val = nir_inot(b, build_subgroup_ge_mask(b, options));
596          break;
597       default:
598          unreachable("you seriously can't tell this is unreachable?");
599       }
600 
601       return uint_to_ballot_type(b, val,
602                                  intrin->dest.ssa.num_components,
603                                  intrin->dest.ssa.bit_size);
604    }
605 
606    case nir_intrinsic_ballot: {
607       if (intrin->dest.ssa.num_components == options->ballot_components &&
608           intrin->dest.ssa.bit_size == options->ballot_bit_size)
609          return NULL;
610 
611       nir_ssa_def *ballot =
612          nir_ballot(b, options->ballot_components, options->ballot_bit_size,
613                     intrin->src[0].ssa);
614 
615       return uint_to_ballot_type(b, ballot,
616                                  intrin->dest.ssa.num_components,
617                                  intrin->dest.ssa.bit_size);
618    }
619 
620    case nir_intrinsic_ballot_bitfield_extract:
621    case nir_intrinsic_ballot_bit_count_reduce:
622    case nir_intrinsic_ballot_find_lsb:
623    case nir_intrinsic_ballot_find_msb: {
624       assert(intrin->src[0].is_ssa);
625       nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
626                                                  options);
627 
628       if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
629           intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
630          /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
631           *
632           *    "Find the most significant bit set to 1 in Value, considering
633           *    only the bits in Value required to represent all bits of the
634           *    group’s invocations.  If none of the considered bits is set to
635           *    1, the result is undefined."
636           *
637           * It has similar text for the other three.  This means that, in case
638           * the subgroup size is less than 32, we have to mask off the unused
639           * bits.  If the subgroup size is fixed and greater than or equal to
640           * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
641           * the iand.
642           *
643           * We only have to worry about this for BitCount and FindMSB because
644           * FindLSB counts from the bottom and BitfieldExtract selects
645           * individual bits.  In either case, if run outside the range of
646           * valid bits, we hit the undefined results case and we can return
647           * anything we want.
648           */
649          int_val = nir_iand(b, int_val, build_subgroup_mask(b, options));
650       }
651 
652       switch (intrin->intrinsic) {
653       case nir_intrinsic_ballot_bitfield_extract: {
654          assert(intrin->src[1].is_ssa);
655          nir_ssa_def *idx = intrin->src[1].ssa;
656          if (int_val->num_components > 1) {
657             /* idx will be truncated by nir_ushr, so we just need to select
658              * the right component using the bits of idx that are truncated in
659              * the shift.
660              */
661             int_val =
662                nir_vector_extract(b, int_val,
663                                   nir_udiv_imm(b, idx, int_val->bit_size));
664          }
665 
666          return nir_i2b(b, nir_iand_imm(b, nir_ushr(b, int_val, idx), 1));
667       }
668       case nir_intrinsic_ballot_bit_count_reduce:
669          return vec_bit_count(b, int_val);
670       case nir_intrinsic_ballot_find_lsb:
671          return vec_find_lsb(b, int_val);
672       case nir_intrinsic_ballot_find_msb:
673          return vec_find_msb(b, int_val);
674       default:
675          unreachable("you seriously can't tell this is unreachable?");
676       }
677    }
678 
679    case nir_intrinsic_ballot_bit_count_exclusive:
680    case nir_intrinsic_ballot_bit_count_inclusive: {
681       nir_ssa_def *mask;
682       if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
683          mask = nir_inot(b, build_subgroup_gt_mask(b, options));
684       } else {
685          mask = nir_inot(b, build_subgroup_ge_mask(b, options));
686       }
687 
688       assert(intrin->src[0].is_ssa);
689       nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
690                                                  options);
691 
692       return vec_bit_count(b, nir_iand(b, int_val, mask));
693    }
694 
695    case nir_intrinsic_elect: {
696       if (!options->lower_elect)
697          return NULL;
698 
699       return nir_ieq(b, nir_load_subgroup_invocation(b), nir_first_invocation(b));
700    }
701 
702    case nir_intrinsic_shuffle:
703       if (options->lower_to_scalar && intrin->num_components > 1)
704          return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
705       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
706          return lower_subgroup_op_to_32bit(b, intrin);
707       break;
708    case nir_intrinsic_shuffle_xor:
709    case nir_intrinsic_shuffle_up:
710    case nir_intrinsic_shuffle_down:
711       if (options->lower_shuffle)
712          return lower_shuffle(b, intrin, options);
713       else if (options->lower_to_scalar && intrin->num_components > 1)
714          return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
715       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
716          return lower_subgroup_op_to_32bit(b, intrin);
717       break;
718 
719    case nir_intrinsic_quad_broadcast:
720    case nir_intrinsic_quad_swap_horizontal:
721    case nir_intrinsic_quad_swap_vertical:
722    case nir_intrinsic_quad_swap_diagonal:
723       if (options->lower_quad ||
724           (options->lower_quad_broadcast_dynamic &&
725            intrin->intrinsic == nir_intrinsic_quad_broadcast &&
726            !nir_src_is_const(intrin->src[1])))
727          return lower_dynamic_quad_broadcast(b, intrin, options);
728       else if (options->lower_to_scalar && intrin->num_components > 1)
729          return lower_subgroup_op_to_scalar(b, intrin, false);
730       break;
731 
732    case nir_intrinsic_reduce: {
733       nir_ssa_def *ret = NULL;
734       /* A cluster size greater than the subgroup size is implemention defined */
735       if (options->subgroup_size &&
736           nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
737          nir_intrinsic_set_cluster_size(intrin, 0);
738          ret = NIR_LOWER_INSTR_PROGRESS;
739       }
740       if (options->lower_to_scalar && intrin->num_components > 1)
741          ret = lower_subgroup_op_to_scalar(b, intrin, false);
742       return ret;
743    }
744    case nir_intrinsic_inclusive_scan:
745    case nir_intrinsic_exclusive_scan:
746       if (options->lower_to_scalar && intrin->num_components > 1)
747          return lower_subgroup_op_to_scalar(b, intrin, false);
748       break;
749 
750    default:
751       break;
752    }
753 
754    return NULL;
755 }
756 
757 bool
nir_lower_subgroups(nir_shader * shader,const nir_lower_subgroups_options * options)758 nir_lower_subgroups(nir_shader *shader,
759                     const nir_lower_subgroups_options *options)
760 {
761    return nir_shader_lower_instructions(shader,
762                                         lower_subgroups_filter,
763                                         lower_subgroups_instr,
764                                         (void *)options);
765 }
766