1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 #include "u_math.h"
28 #include "u_vector.h"
29 
30 enum {
31    nggc_passflag_used_by_pos = 1,
32    nggc_passflag_used_by_other = 2,
33    nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
34 };
35 
36 typedef struct
37 {
38    nir_ssa_def *ssa;
39    nir_variable *var;
40 } saved_uniform;
41 
42 typedef struct
43 {
44    nir_variable *position_value_var;
45    nir_variable *prim_exp_arg_var;
46    nir_variable *es_accepted_var;
47    nir_variable *gs_accepted_var;
48 
49    struct u_vector saved_uniforms;
50 
51    bool passthrough;
52    bool export_prim_id;
53    bool early_prim_export;
54    bool use_edgeflags;
55    unsigned wave_size;
56    unsigned max_num_waves;
57    unsigned num_vertices_per_primitives;
58    unsigned provoking_vtx_idx;
59    unsigned max_es_num_vertices;
60    unsigned total_lds_bytes;
61 
62    uint64_t inputs_needed_by_pos;
63    uint64_t inputs_needed_by_others;
64    uint32_t instance_rate_inputs;
65 
66    nir_instr *compact_arg_stores[4];
67    nir_intrinsic_instr *overwrite_args;
68 } lower_ngg_nogs_state;
69 
70 typedef struct
71 {
72    /* bitsize of this component (max 32), or 0 if it's never written at all */
73    uint8_t bit_size : 6;
74    /* output stream index  */
75    uint8_t stream : 2;
76 } gs_output_component_info;
77 
78 typedef struct
79 {
80    nir_variable *output_vars[VARYING_SLOT_MAX][4];
81    nir_variable *current_clear_primflag_idx_var;
82    int const_out_vtxcnt[4];
83    int const_out_prmcnt[4];
84    unsigned wave_size;
85    unsigned max_num_waves;
86    unsigned num_vertices_per_primitive;
87    unsigned lds_addr_gs_out_vtx;
88    unsigned lds_addr_gs_scratch;
89    unsigned lds_bytes_per_gs_out_vertex;
90    unsigned lds_offs_primflags;
91    bool found_out_vtxcnt[4];
92    bool output_compile_time_known;
93    bool provoking_vertex_last;
94    gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
95 } lower_ngg_gs_state;
96 
97 typedef struct {
98    nir_variable *pre_cull_position_value_var;
99 } remove_culling_shader_outputs_state;
100 
101 typedef struct {
102    nir_variable *pos_value_replacement;
103 } remove_extra_position_output_state;
104 
105 /* Per-vertex LDS layout of culling shaders */
106 enum {
107    /* Position of the ES vertex (at the beginning for alignment reasons) */
108    lds_es_pos_x = 0,
109    lds_es_pos_y = 4,
110    lds_es_pos_z = 8,
111    lds_es_pos_w = 12,
112 
113    /* 1 when the vertex is accepted, 0 if it should be culled */
114    lds_es_vertex_accepted = 16,
115    /* ID of the thread which will export the current thread's vertex */
116    lds_es_exporter_tid = 17,
117 
118    /* Repacked arguments - also listed separately for VS and TES */
119    lds_es_arg_0 = 20,
120 
121    /* VS arguments which need to be repacked */
122    lds_es_vs_vertex_id = 20,
123    lds_es_vs_instance_id = 24,
124 
125    /* TES arguments which need to be repacked */
126    lds_es_tes_u = 20,
127    lds_es_tes_v = 24,
128    lds_es_tes_rel_patch_id = 28,
129    lds_es_tes_patch_id = 32,
130 };
131 
132 typedef struct {
133    nir_ssa_def *num_repacked_invocations;
134    nir_ssa_def *repacked_invocation_index;
135 } wg_repack_result;
136 
137 /**
138  * Computes a horizontal sum of 8-bit packed values loaded from LDS.
139  *
140  * Each lane N will sum packed bytes 0 to N-1.
141  * We only care about the results from up to wave_id+1 lanes.
142  * (Other lanes are not deactivated but their calculation is not used.)
143  */
144 static nir_ssa_def *
summarize_repack(nir_builder * b,nir_ssa_def * packed_counts,unsigned num_lds_dwords)145 summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords)
146 {
147    /* We'll use shift to filter out the bytes not needed by the current lane.
148     *
149     * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
150     * However, two shifts are needed because one can't go all the way,
151     * so the shift amount is half that (and in bits).
152     *
153     * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
154     * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
155     * therefore v_dot can get rid of the unneeded values.
156     * This sequence is preferable because it better hides the latency of the LDS.
157     *
158     * If the v_dot instruction can't be used, we left-shift the packed bytes.
159     * This will shift out the unneeded bytes and shift in zeroes instead,
160     * then we sum them using v_sad_u8.
161     */
162 
163    nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
164    nir_ssa_def *shift = nir_iadd_imm_nuw(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
165    bool use_dot = b->shader->options->has_dot_4x8;
166 
167    if (num_lds_dwords == 1) {
168       nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
169 
170       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
171       nir_ssa_def *packed = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
172 
173       /* Horizontally add the packed bytes. */
174       if (use_dot) {
175          return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
176       } else {
177          nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
178          return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
179       }
180    } else if (num_lds_dwords == 2) {
181       nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
182 
183       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
184       nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
185       nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
186 
187       /* Horizontally add the packed bytes. */
188       if (use_dot) {
189          nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
190          return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
191       } else {
192          nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
193          nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
194          return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
195       }
196    } else {
197       unreachable("Unimplemented NGG wave count");
198    }
199 }
200 
201 /**
202  * Repacks invocations in the current workgroup to eliminate gaps between them.
203  *
204  * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
205  * Assumes that all invocations in the workgroup are active (exec = -1).
206  */
207 static wg_repack_result
repack_invocations_in_workgroup(nir_builder * b,nir_ssa_def * input_bool,unsigned lds_addr_base,unsigned max_num_waves,unsigned wave_size)208 repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
209                                 unsigned lds_addr_base, unsigned max_num_waves,
210                                 unsigned wave_size)
211 {
212    /* Input boolean: 1 if the current invocation should survive the repack. */
213    assert(input_bool->bit_size == 1);
214 
215    /* STEP 1. Count surviving invocations in the current wave.
216     *
217     * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
218     */
219 
220    nir_ssa_def *input_mask = nir_build_ballot(b, 1, wave_size, input_bool);
221    nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
222 
223    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
224    if (max_num_waves == 1) {
225       wg_repack_result r = {
226          .num_repacked_invocations = surviving_invocations_in_current_wave,
227          .repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
228       };
229       return r;
230    }
231 
232    /* STEP 2. Waves tell each other their number of surviving invocations.
233     *
234     * Each wave activates only its first lane (exec = 1), which stores the number of surviving
235     * invocations in that wave into the LDS, then reads the numbers from every wave.
236     *
237     * The workgroup size of NGG shaders is at most 256, which means
238     * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
239     * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
240     */
241 
242    const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
243    assert(num_lds_dwords <= 2);
244 
245    nir_ssa_def *wave_id = nir_build_load_subgroup_id(b);
246    nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);
247    nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
248 
249    nir_build_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base, .align_mul = 1u, .write_mask = 0x1u);
250 
251    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
252                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
253 
254    nir_ssa_def *packed_counts = nir_build_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);
255 
256    nir_pop_if(b, if_first_lane);
257 
258    packed_counts = nir_if_phi(b, packed_counts, dont_care);
259 
260    /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
261     *
262     * By now, every wave knows the number of surviving invocations in all waves.
263     * Each number is 1 byte, and they are packed into up to 2 dwords.
264     *
265     * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
266     * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
267     * (Other lanes are not deactivated but their calculation is not used.)
268     *
269     * - We read the sum from the lane whose id is the current wave's id.
270     *   Add the masked bitcount to this, and we get the repacked invocation index.
271     * - We read the sum from the lane whose id is the number of waves in the workgroup.
272     *   This is the total number of surviving invocations in the workgroup.
273     */
274 
275    nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
276    nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
277 
278    nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
279    nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
280    nir_ssa_def *wg_repacked_index = nir_build_mbcnt_amd(b, input_mask, wg_repacked_index_base);
281 
282    wg_repack_result r = {
283       .num_repacked_invocations = wg_num_repacked_invocations,
284       .repacked_invocation_index = wg_repacked_index,
285    };
286 
287    return r;
288 }
289 
290 static nir_ssa_def *
pervertex_lds_addr(nir_builder * b,nir_ssa_def * vertex_idx,unsigned per_vtx_bytes)291 pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes)
292 {
293    return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
294 }
295 
296 static nir_ssa_def *
emit_pack_ngg_prim_exp_arg(nir_builder * b,unsigned num_vertices_per_primitives,nir_ssa_def * vertex_indices[3],nir_ssa_def * is_null_prim,bool use_edgeflags)297 emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
298                            nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim,
299                            bool use_edgeflags)
300 {
301    nir_ssa_def *arg = use_edgeflags
302                       ? nir_build_load_initial_edgeflags_amd(b)
303                       : nir_imm_int(b, 0);
304 
305    for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
306       assert(vertex_indices[i]);
307       arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i)));
308    }
309 
310    if (is_null_prim) {
311       if (is_null_prim->bit_size == 1)
312          is_null_prim = nir_b2i32(b, is_null_prim);
313       assert(is_null_prim->bit_size == 32);
314       arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u)));
315    }
316 
317    return arg;
318 }
319 
320 static nir_ssa_def *
ngg_input_primitive_vertex_index(nir_builder * b,unsigned vertex)321 ngg_input_primitive_vertex_index(nir_builder *b, unsigned vertex)
322 {
323    return nir_ubfe(b, nir_build_load_gs_vertex_offset_amd(b, .base = vertex / 2u),
324                       nir_imm_int(b, (vertex & 1u) * 16u), nir_imm_int(b, 16u));
325 }
326 
327 static nir_ssa_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * st)328 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
329 {
330    if (st->passthrough) {
331       assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX);
332       return nir_build_load_packed_passthrough_primitive_amd(b);
333    } else {
334       nir_ssa_def *vtx_idx[3] = {0};
335 
336       vtx_idx[0] = ngg_input_primitive_vertex_index(b, 0);
337       vtx_idx[1] = st->num_vertices_per_primitives >= 2
338                ? ngg_input_primitive_vertex_index(b, 1)
339                : nir_imm_zero(b, 1, 32);
340       vtx_idx[2] = st->num_vertices_per_primitives >= 3
341                ? ngg_input_primitive_vertex_index(b, 2)
342                : nir_imm_zero(b, 1, 32);
343 
344       return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL, st->use_edgeflags);
345    }
346 }
347 
348 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * st,nir_ssa_def * arg)349 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)
350 {
351    nir_ssa_def *gs_thread = st->gs_accepted_var
352                             ? nir_load_var(b, st->gs_accepted_var)
353                             : nir_build_has_input_primitive_amd(b);
354 
355    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
356    {
357       if (!arg)
358          arg = emit_ngg_nogs_prim_exp_arg(b, st);
359 
360       if (st->export_prim_id && b->shader->info.stage == MESA_SHADER_VERTEX) {
361          /* Copy Primitive IDs from GS threads to the LDS address corresponding to the ES thread of the provoking vertex. */
362          nir_ssa_def *prim_id = nir_build_load_primitive_id(b);
363          nir_ssa_def *provoking_vtx_idx = ngg_input_primitive_vertex_index(b, st->provoking_vtx_idx);
364          nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u);
365 
366          nir_build_store_shared(b,  prim_id, addr, .write_mask = 1u, .align_mul = 4u);
367       }
368 
369       nir_build_export_primitive_amd(b, arg);
370    }
371    nir_pop_if(b, if_gs_thread);
372 }
373 
374 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b)375 emit_store_ngg_nogs_es_primitive_id(nir_builder *b)
376 {
377    nir_ssa_def *prim_id = NULL;
378 
379    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
380       /* Workgroup barrier - wait for GS threads to store primitive ID in LDS. */
381       nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP,
382                             .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
383 
384       /* LDS address where the primitive ID is stored */
385       nir_ssa_def *thread_id_in_threadgroup = nir_build_load_local_invocation_index(b);
386       nir_ssa_def *addr =  pervertex_lds_addr(b, thread_id_in_threadgroup, 4u);
387 
388       /* Load primitive ID from LDS */
389       prim_id = nir_build_load_shared(b, 1, 32, addr, .align_mul = 4u);
390    } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
391       /* Just use tess eval primitive ID, which is the same as the patch ID. */
392       prim_id = nir_build_load_primitive_id(b);
393    }
394 
395    nir_io_semantics io_sem = {
396       .location = VARYING_SLOT_PRIMITIVE_ID,
397       .num_slots = 1,
398    };
399 
400    nir_build_store_output(b, prim_id, nir_imm_zero(b, 1, 32),
401                           .base = io_sem.location,
402                           .write_mask = 1u, .src_type = nir_type_uint32, .io_semantics = io_sem);
403 }
404 
405 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)406 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
407 {
408    remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;
409 
410    if (instr->type != nir_instr_type_intrinsic)
411       return false;
412 
413    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
414 
415    /* These are not allowed in VS / TES */
416    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
417           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
418 
419    /* We are only interested in output stores now */
420    if (intrin->intrinsic != nir_intrinsic_store_output)
421       return false;
422 
423    b->cursor = nir_before_instr(instr);
424 
425    /* Position output - store the value to a variable, remove output store */
426    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
427    if (io_sem.location == VARYING_SLOT_POS) {
428       /* TODO: check if it's indirect, etc? */
429       unsigned writemask = nir_intrinsic_write_mask(intrin);
430       nir_ssa_def *store_val = intrin->src[0].ssa;
431       nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);
432    }
433 
434    /* Remove all output stores */
435    nir_instr_remove(instr);
436    return true;
437 }
438 
439 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * nogs_state,nir_variable * pre_cull_position_value_var)440 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)
441 {
442    remove_culling_shader_outputs_state s = {
443       .pre_cull_position_value_var = pre_cull_position_value_var,
444    };
445 
446    nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
447                                 nir_metadata_block_index | nir_metadata_dominance, &s);
448 
449    /* Remove dead code resulting from the deleted outputs. */
450    bool progress;
451    do {
452       progress = false;
453       NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
454       NIR_PASS(progress, culling_shader, nir_opt_dce);
455       NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
456    } while (progress);
457 }
458 
459 static void
rewrite_uses_to_var(nir_builder * b,nir_ssa_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)460 rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
461 {
462    if (old_def->parent_instr->type == nir_instr_type_load_const)
463       return;
464 
465    b->cursor = nir_after_instr(old_def->parent_instr);
466    if (b->cursor.instr->type == nir_instr_type_phi)
467       b->cursor = nir_after_phis(old_def->parent_instr->block);
468 
469    nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var);
470    nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
471 
472    if (old_def->num_components > 1) {
473       /* old_def uses a swizzled vector component.
474        * There is no way to replace the uses of just a single vector component,
475        * so instead create a new vector and replace all uses of the old vector.
476        */
477       nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
478       for (unsigned j = 0; j < old_def->num_components; ++j)
479          old_def_elements[j] = nir_channel(b, old_def, j);
480       replacement = nir_vec(b, old_def_elements, old_def->num_components);
481    }
482 
483    nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
484 }
485 
486 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)487 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
488 {
489    remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;
490 
491    if (instr->type != nir_instr_type_intrinsic)
492       return false;
493 
494    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
495 
496    /* These are not allowed in VS / TES */
497    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
498           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
499 
500    /* We are only interested in output stores now */
501    if (intrin->intrinsic != nir_intrinsic_store_output)
502       return false;
503 
504    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
505    if (io_sem.location != VARYING_SLOT_POS)
506       return false;
507 
508    b->cursor = nir_before_instr(instr);
509 
510    /* In case other outputs use what we calculated for pos,
511     * try to avoid calculating it again by rewriting the usages
512     * of the store components here.
513     */
514    nir_ssa_def *store_val = intrin->src[0].ssa;
515    unsigned store_pos_component = nir_intrinsic_component(intrin);
516 
517    nir_instr_remove(instr);
518 
519    if (store_val->parent_instr->type == nir_instr_type_alu) {
520       nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
521       if (nir_op_is_vec(alu->op)) {
522          /* Output store uses a vector, we can easily rewrite uses of each vector element. */
523 
524          unsigned num_vec_src = 0;
525          if (alu->op == nir_op_mov)
526             num_vec_src = 1;
527          else if (alu->op == nir_op_vec2)
528             num_vec_src = 2;
529          else if (alu->op == nir_op_vec3)
530             num_vec_src = 3;
531          else if (alu->op == nir_op_vec4)
532             num_vec_src = 4;
533          assert(num_vec_src);
534 
535          /* Remember the current components whose uses we wish to replace.
536           * This is needed because rewriting one source can affect the others too.
537           */
538          nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
539          for (unsigned i = 0; i < num_vec_src; i++)
540             vec_comps[i] = alu->src[i].src.ssa;
541 
542          for (unsigned i = 0; i < num_vec_src; i++)
543             rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i);
544       } else {
545          rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
546       }
547    } else {
548       rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
549    }
550 
551    return true;
552 }
553 
554 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * nogs_state)555 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
556 {
557    remove_extra_position_output_state s = {
558       .pos_value_replacement = nogs_state->position_value_var,
559    };
560 
561    nir_shader_instructions_pass(shader, remove_extra_pos_output,
562                                 nir_metadata_block_index | nir_metadata_dominance, &s);
563 }
564 
565 static bool
remove_compacted_arg(lower_ngg_nogs_state * state,nir_builder * b,unsigned idx)566 remove_compacted_arg(lower_ngg_nogs_state *state, nir_builder *b, unsigned idx)
567 {
568    nir_instr *store_instr = state->compact_arg_stores[idx];
569    if (!store_instr)
570       return false;
571 
572    /* Simply remove the store. */
573    nir_instr_remove(store_instr);
574 
575    /* Find the intrinsic that overwrites the shader arguments,
576     * and change its corresponding source.
577     * This will cause NIR's DCE to recognize the load and its phis as dead.
578     */
579    b->cursor = nir_before_instr(&state->overwrite_args->instr);
580    nir_ssa_def *undef_arg = nir_ssa_undef(b, 1, 32);
581    nir_ssa_def_rewrite_uses(state->overwrite_args->src[idx].ssa, undef_arg);
582 
583    state->compact_arg_stores[idx] = NULL;
584    return true;
585 }
586 
587 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * state)588 cleanup_culling_shader_after_dce(nir_shader *shader,
589                                  nir_function_impl *function_impl,
590                                  lower_ngg_nogs_state *state)
591 {
592    bool uses_vs_vertex_id = false;
593    bool uses_vs_instance_id = false;
594    bool uses_tes_u = false;
595    bool uses_tes_v = false;
596    bool uses_tes_rel_patch_id = false;
597    bool uses_tes_patch_id = false;
598 
599    bool progress = false;
600    nir_builder b;
601    nir_builder_init(&b, function_impl);
602 
603    nir_foreach_block_reverse_safe(block, function_impl) {
604       nir_foreach_instr_reverse_safe(instr, block) {
605          if (instr->type != nir_instr_type_intrinsic)
606             continue;
607 
608          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
609 
610          switch (intrin->intrinsic) {
611          case nir_intrinsic_alloc_vertices_and_primitives_amd:
612             goto cleanup_culling_shader_after_dce_done;
613          case nir_intrinsic_load_vertex_id:
614          case nir_intrinsic_load_vertex_id_zero_base:
615             uses_vs_vertex_id = true;
616             break;
617          case nir_intrinsic_load_instance_id:
618             uses_vs_instance_id = true;
619             break;
620          case nir_intrinsic_load_input:
621             if (state->instance_rate_inputs &
622                 (1 << (nir_intrinsic_base(intrin) - VERT_ATTRIB_GENERIC0)))
623                uses_vs_instance_id = true;
624             else
625                uses_vs_vertex_id = true;
626             break;
627          case nir_intrinsic_load_tess_coord:
628             uses_tes_u = uses_tes_v = true;
629             break;
630          case nir_intrinsic_load_tess_rel_patch_id_amd:
631             uses_tes_rel_patch_id = true;
632             break;
633          case nir_intrinsic_load_primitive_id:
634             if (shader->info.stage == MESA_SHADER_TESS_EVAL)
635                uses_tes_patch_id = true;
636             break;
637          default:
638             break;
639          }
640       }
641    }
642 
643    cleanup_culling_shader_after_dce_done:
644 
645    if (shader->info.stage == MESA_SHADER_VERTEX) {
646       if (!uses_vs_vertex_id)
647          progress |= remove_compacted_arg(state, &b, 0);
648       if (!uses_vs_instance_id)
649          progress |= remove_compacted_arg(state, &b, 1);
650    } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
651       if (!uses_tes_u)
652          progress |= remove_compacted_arg(state, &b, 0);
653       if (!uses_tes_v)
654          progress |= remove_compacted_arg(state, &b, 1);
655       if (!uses_tes_rel_patch_id)
656          progress |= remove_compacted_arg(state, &b, 2);
657       if (!uses_tes_patch_id)
658          progress |= remove_compacted_arg(state, &b, 3);
659    }
660 
661    return progress;
662 }
663 
664 /**
665  * Perform vertex compaction after culling.
666  *
667  * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
668  * 2. Surviving ES vertex invocations store their data to LDS
669  * 3. Emit GS_ALLOC_REQ
670  * 4. Repacked invocations load the vertex data from LDS
671  * 5. GS threads update their vertex indices
672  */
673 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * nogs_state,nir_variable ** repacked_arg_vars,nir_variable ** gs_vtxaddr_vars,nir_ssa_def * invocation_index,nir_ssa_def * es_vertex_lds_addr,nir_ssa_def * es_exporter_tid,nir_ssa_def * num_live_vertices_in_workgroup,nir_ssa_def * fully_culled,unsigned ngg_scratch_lds_base_addr,unsigned pervertex_lds_bytes,unsigned max_exported_args)674 compact_vertices_after_culling(nir_builder *b,
675                                lower_ngg_nogs_state *nogs_state,
676                                nir_variable **repacked_arg_vars,
677                                nir_variable **gs_vtxaddr_vars,
678                                nir_ssa_def *invocation_index,
679                                nir_ssa_def *es_vertex_lds_addr,
680                                nir_ssa_def *es_exporter_tid,
681                                nir_ssa_def *num_live_vertices_in_workgroup,
682                                nir_ssa_def *fully_culled,
683                                unsigned ngg_scratch_lds_base_addr,
684                                unsigned pervertex_lds_bytes,
685                                unsigned max_exported_args)
686 {
687    nir_variable *es_accepted_var = nogs_state->es_accepted_var;
688    nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
689    nir_variable *position_value_var = nogs_state->position_value_var;
690    nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
691 
692    nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
693    {
694       nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
695 
696       /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
697       nir_build_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid, .align_mul = 1u, .write_mask = 0x1u);
698 
699       /* Store the current thread's position output to the exporter thread's LDS space */
700       nir_ssa_def *pos = nir_load_var(b, position_value_var);
701       nir_build_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x, .align_mul = 4u, .write_mask = 0xfu);
702 
703       /* Store the current thread's repackable arguments to the exporter thread's LDS space */
704       for (unsigned i = 0; i < max_exported_args; ++i) {
705          nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);
706          nir_intrinsic_instr *store = nir_build_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u, .write_mask = 0x1u);
707 
708          nogs_state->compact_arg_stores[i] = &store->instr;
709       }
710    }
711    nir_pop_if(b, if_es_accepted);
712 
713    /* TODO: Consider adding a shortcut exit.
714     * Waves that have no vertices and primitives left can s_endpgm right here.
715     */
716 
717    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
718                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
719 
720    nir_ssa_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
721    nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
722    {
723       /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
724       nir_ssa_def *exported_pos = nir_build_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x, .align_mul = 4u);
725       nir_store_var(b, position_value_var, exported_pos, 0xfu);
726 
727       /* Read the repacked arguments */
728       for (unsigned i = 0; i < max_exported_args; ++i) {
729          nir_ssa_def *arg_val = nir_build_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u);
730          nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);
731       }
732    }
733    nir_push_else(b, if_packed_es_thread);
734    {
735       nir_store_var(b, position_value_var, nir_ssa_undef(b, 4, 32), 0xfu);
736       for (unsigned i = 0; i < max_exported_args; ++i)
737          nir_store_var(b, repacked_arg_vars[i], nir_ssa_undef(b, 1, 32), 0x1u);
738    }
739    nir_pop_if(b, if_packed_es_thread);
740 
741    nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
742    {
743       nir_ssa_def *exporter_vtx_indices[3] = {0};
744 
745       /* Load the index of the ES threads that will export the current GS thread's vertices */
746       for (unsigned v = 0; v < 3; ++v) {
747          nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
748          nir_ssa_def *exporter_vtx_idx = nir_build_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid, .align_mul = 1u);
749          exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
750       }
751 
752       nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL, nogs_state->use_edgeflags);
753       nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
754    }
755    nir_pop_if(b, if_gs_accepted);
756 
757    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
758    nir_store_var(b, gs_accepted_var, nir_bcsel(b, fully_culled, nir_imm_false(b), nir_build_has_input_primitive_amd(b)), 0x1u);
759 }
760 
761 static void
analyze_shader_before_culling_walk(nir_ssa_def * ssa,uint8_t flag,lower_ngg_nogs_state * nogs_state)762 analyze_shader_before_culling_walk(nir_ssa_def *ssa,
763                                    uint8_t flag,
764                                    lower_ngg_nogs_state *nogs_state)
765 {
766    nir_instr *instr = ssa->parent_instr;
767    uint8_t old_pass_flags = instr->pass_flags;
768    instr->pass_flags |= flag;
769 
770    if (instr->pass_flags == old_pass_flags)
771       return; /* Already visited. */
772 
773    switch (instr->type) {
774    case nir_instr_type_intrinsic: {
775       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
776 
777       /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
778       switch (intrin->intrinsic) {
779       case nir_intrinsic_load_input: {
780          nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
781          uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
782          if (instr->pass_flags & nggc_passflag_used_by_pos)
783             nogs_state->inputs_needed_by_pos |= in_mask;
784          else if (instr->pass_flags & nggc_passflag_used_by_other)
785             nogs_state->inputs_needed_by_others |= in_mask;
786          break;
787       }
788       default:
789          break;
790       }
791 
792       break;
793    }
794    case nir_instr_type_alu: {
795       nir_alu_instr *alu = nir_instr_as_alu(instr);
796       unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
797 
798       for (unsigned i = 0; i < num_srcs; ++i) {
799          analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state);
800       }
801 
802       break;
803    }
804    case nir_instr_type_phi: {
805       nir_phi_instr *phi = nir_instr_as_phi(instr);
806       nir_foreach_phi_src_safe(phi_src, phi) {
807          analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state);
808       }
809 
810       break;
811    }
812    default:
813       break;
814    }
815 }
816 
817 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * nogs_state)818 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
819 {
820    nir_foreach_function(func, shader) {
821       nir_foreach_block(block, func->impl) {
822          nir_foreach_instr(instr, block) {
823             instr->pass_flags = 0;
824 
825             if (instr->type != nir_instr_type_intrinsic)
826                continue;
827 
828             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
829             if (intrin->intrinsic != nir_intrinsic_store_output)
830                continue;
831 
832             nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
833             nir_ssa_def *store_val = intrin->src[0].ssa;
834             uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
835             analyze_shader_before_culling_walk(store_val, flag, nogs_state);
836          }
837       }
838    }
839 }
840 
841 /**
842  * Save the reusable SSA definitions to variables so that the
843  * bottom shader part can reuse them from the top part.
844  *
845  * 1. We create a new function temporary variable for reusables,
846  *    and insert a store+load.
847  * 2. The shader is cloned (the top part is created), then the
848  *    control flow is reinserted (for the bottom part.)
849  * 3. For reusables, we delete the variable stores from the
850  *    bottom part. This will make them use the variables from
851  *    the top part and DCE the redundant instructions.
852  */
853 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)854 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
855 {
856    ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, 4, sizeof(saved_uniform));
857    assert(vec_ok);
858 
859    nir_block *block = nir_start_block(b->impl);
860    while (block) {
861       /* Process the instructions in the current block. */
862       nir_foreach_instr_safe(instr, block) {
863          /* Find instructions whose SSA definitions are used by both
864           * the top and bottom parts of the shader (before and after culling).
865           * Only in this case, it makes sense for the bottom part
866           * to try to reuse these from the top part.
867           */
868          if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
869             continue;
870 
871          /* Determine if we can reuse the current SSA value.
872           * When vertex compaction is used, it is possible that the same shader invocation
873           * processes a different vertex in the top and bottom part of the shader.
874           * Therefore, we only reuse uniform values.
875           */
876          nir_ssa_def *ssa = NULL;
877          switch (instr->type) {
878          case nir_instr_type_alu: {
879             nir_alu_instr *alu = nir_instr_as_alu(instr);
880             if (alu->dest.dest.ssa.divergent)
881                continue;
882             /* Ignore uniform floats because they regress VGPR usage too much */
883             if (nir_op_infos[alu->op].output_type & nir_type_float)
884                continue;
885             ssa = &alu->dest.dest.ssa;
886             break;
887          }
888          case nir_instr_type_intrinsic: {
889             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
890             if (!nir_intrinsic_can_reorder(intrin) ||
891                 !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
892                 intrin->dest.ssa.divergent)
893                continue;
894             ssa = &intrin->dest.ssa;
895             break;
896          }
897          case nir_instr_type_phi: {
898             nir_phi_instr *phi = nir_instr_as_phi(instr);
899             if (phi->dest.ssa.divergent)
900                continue;
901             ssa = &phi->dest.ssa;
902             break;
903          }
904          default:
905             continue;
906          }
907 
908          assert(ssa);
909 
910          /* Determine a suitable type for the SSA value. */
911          enum glsl_base_type base_type = GLSL_TYPE_UINT;
912          switch (ssa->bit_size) {
913          case 8: base_type = GLSL_TYPE_UINT8; break;
914          case 16: base_type = GLSL_TYPE_UINT16; break;
915          case 32: base_type = GLSL_TYPE_UINT; break;
916          case 64: base_type = GLSL_TYPE_UINT64; break;
917          default: continue;
918          }
919 
920          const struct glsl_type *t = ssa->num_components == 1
921                                      ? glsl_scalar_type(base_type)
922                                      : glsl_vector_type(base_type, ssa->num_components);
923 
924          saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms);
925          assert(saved);
926 
927          /* Create a new NIR variable where we store the reusable value.
928           * Then, we reload the variable and replace the uses of the value
929           * with the reloaded variable.
930           */
931          saved->var = nir_local_variable_create(b->impl, t, NULL);
932          saved->ssa = ssa;
933 
934          b->cursor = instr->type == nir_instr_type_phi
935                      ? nir_after_instr_and_phis(instr)
936                      : nir_after_instr(instr);
937          nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
938          nir_ssa_def *reloaded = nir_load_var(b, saved->var);
939          nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
940       }
941 
942       /* Look at the next CF node. */
943       nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
944       if (next_cf_node) {
945          /* It makes no sense to try to reuse things from within loops. */
946          bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
947 
948          /* Don't reuse if we're in divergent control flow.
949           *
950           * Thanks to vertex repacking, the same shader invocation may process a different vertex
951           * in the top and bottom part, and it's even possible that this different vertex was initially
952           * processed in a different wave. So the two parts may take a different divergent code path.
953           * Therefore, these variables in divergent control flow may stay undefined.
954           *
955           * Note that this problem doesn't exist if vertices are not repacked or if the
956           * workgroup only has a single wave.
957           */
958          bool next_is_divergent_if =
959             next_cf_node->type == nir_cf_node_if &&
960             nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent;
961 
962          if (next_is_loop || next_is_divergent_if) {
963             block = nir_cf_node_cf_tree_next(next_cf_node);
964             continue;
965          }
966       }
967 
968       /* Go to the next block. */
969       block = nir_block_cf_tree_next(block);
970    }
971 }
972 
973 /**
974  * Reuses suitable variables from the top part of the shader,
975  * by deleting their stores from the bottom part.
976  */
977 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)978 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
979 {
980    if (!u_vector_length(&nogs_state->saved_uniforms)) {
981       u_vector_finish(&nogs_state->saved_uniforms);
982       return;
983    }
984 
985    nir_foreach_block_reverse_safe(block, b->impl) {
986       nir_foreach_instr_reverse_safe(instr, block) {
987          if (instr->type != nir_instr_type_intrinsic)
988             continue;
989          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
990 
991          /* When we found any of these intrinsics, it means
992           * we reached the top part and we must stop.
993           */
994          if (intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd)
995             goto done;
996 
997          if (intrin->intrinsic != nir_intrinsic_store_deref)
998             continue;
999          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1000          if (deref->deref_type != nir_deref_type_var)
1001             continue;
1002 
1003          saved_uniform *saved;
1004          u_vector_foreach(saved, &nogs_state->saved_uniforms) {
1005             if (saved->var == deref->var) {
1006                nir_instr_remove(instr);
1007             }
1008          }
1009       }
1010    }
1011 
1012    done:
1013    u_vector_finish(&nogs_state->saved_uniforms);
1014 }
1015 
1016 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * nogs_state)1017 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)
1018 {
1019    assert(b->shader->info.outputs_written & (1 << VARYING_SLOT_POS));
1020 
1021    bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1022    bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1023 
1024    unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;
1025    if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)
1026       max_exported_args--;
1027    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)
1028       max_exported_args--;
1029 
1030    unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;
1031    unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;
1032    unsigned max_num_waves = nogs_state->max_num_waves;
1033    unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);
1034    unsigned ngg_scratch_lds_bytes = DIV_ROUND_UP(max_num_waves, 4u);
1035    nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;
1036 
1037    nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1038 
1039    /* Create some helper variables. */
1040    nir_variable *position_value_var = nogs_state->position_value_var;
1041    nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
1042    nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
1043    nir_variable *es_accepted_var = nogs_state->es_accepted_var;
1044    nir_variable *gs_vtxaddr_vars[3] = {
1045       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1046       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1047       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1048    };
1049    nir_variable *repacked_arg_vars[4] = {
1050       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),
1051       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),
1052       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),
1053       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
1054    };
1055 
1056    /* Top part of the culling shader (aka. position shader part)
1057     *
1058     * We clone the full ES shader and emit it here, but we only really care
1059     * about its position output, so we delete every other output from this part.
1060     * The position output is stored into a temporary variable, and reloaded later.
1061     */
1062 
1063    b->cursor = nir_before_cf_list(&impl->body);
1064 
1065    nir_ssa_def *es_thread = nir_build_has_input_vertex_amd(b);
1066    nir_if *if_es_thread = nir_push_if(b, es_thread);
1067    {
1068       /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1069        * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1070        */
1071       nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1072 
1073       /* Now reinsert a clone of the shader code */
1074       struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1075       nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1076       _mesa_hash_table_destroy(remap_table, NULL);
1077       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1078 
1079       /* Remember the current thread's shader arguments */
1080       if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1081          nir_store_var(b, repacked_arg_vars[0], nir_build_load_vertex_id_zero_base(b), 0x1u);
1082          if (uses_instance_id)
1083             nir_store_var(b, repacked_arg_vars[1], nir_build_load_instance_id(b), 0x1u);
1084       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1085          nir_ssa_def *tess_coord = nir_build_load_tess_coord(b);
1086          nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);
1087          nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);
1088          nir_store_var(b, repacked_arg_vars[2], nir_build_load_tess_rel_patch_id_amd(b), 0x1u);
1089          if (uses_tess_primitive_id)
1090             nir_store_var(b, repacked_arg_vars[3], nir_build_load_primitive_id(b), 0x1u);
1091       } else {
1092          unreachable("Should be VS or TES.");
1093       }
1094    }
1095    nir_pop_if(b, if_es_thread);
1096 
1097    nir_store_var(b, es_accepted_var, es_thread, 0x1u);
1098    nir_store_var(b, gs_accepted_var, nir_build_has_input_primitive_amd(b), 0x1u);
1099 
1100    /* Remove all non-position outputs, and put the position output into the variable. */
1101    nir_metadata_preserve(impl, nir_metadata_none);
1102    remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);
1103    b->cursor = nir_after_cf_list(&impl->body);
1104 
1105    /* Run culling algorithms if culling is enabled.
1106     *
1107     * NGG culling can be enabled or disabled in runtime.
1108     * This is determined by a SGPR shader argument which is acccessed
1109     * by the following NIR intrinsic.
1110     */
1111 
1112    nir_if *if_cull_en = nir_push_if(b, nir_build_load_cull_any_enabled_amd(b));
1113    {
1114       nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);
1115       nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1116 
1117       /* ES invocations store their vertex data to LDS for GS threads to read. */
1118       if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1119       {
1120          /* Store position components that are relevant to culling in LDS */
1121          nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);
1122          nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1123          nir_build_store_shared(b, pre_cull_w, es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_pos_w);
1124          nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1125          nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1126          nir_build_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .write_mask = 0x3u, .align_mul = 4, .base = lds_es_pos_x);
1127 
1128          /* Clear out the ES accepted flag in LDS */
1129          nir_build_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_vertex_accepted);
1130       }
1131       nir_pop_if(b, if_es_thread);
1132 
1133       nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1134                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1135 
1136       nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);
1137       nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1 << 31), 0x1u);
1138 
1139       /* GS invocations load the vertex data and perform the culling. */
1140       nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
1141       {
1142          /* Load vertex indices from input VGPRs */
1143          nir_ssa_def *vtx_idx[3] = {0};
1144          for (unsigned vertex = 0; vertex < 3; ++vertex)
1145             vtx_idx[vertex] = ngg_input_primitive_vertex_index(b, vertex);
1146 
1147          nir_ssa_def *vtx_addr[3] = {0};
1148          nir_ssa_def *pos[3][4] = {0};
1149 
1150          /* Load W positions of vertices first because the culling code will use these first */
1151          for (unsigned vtx = 0; vtx < 3; ++vtx) {
1152             vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1153             pos[vtx][3] = nir_build_load_shared(b, 1, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_w);
1154             nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);
1155          }
1156 
1157          /* Load the X/W, Y/W positions of vertices */
1158          for (unsigned vtx = 0; vtx < 3; ++vtx) {
1159             nir_ssa_def *xy = nir_build_load_shared(b, 2, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_x);
1160             pos[vtx][0] = nir_channel(b, xy, 0);
1161             pos[vtx][1] = nir_channel(b, xy, 1);
1162          }
1163 
1164          /* See if the current primitive is accepted */
1165          nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);
1166          nir_store_var(b, gs_accepted_var, accepted, 0x1u);
1167 
1168          nir_if *if_gs_accepted = nir_push_if(b, accepted);
1169          {
1170             /* Store the accepted state to LDS for ES threads */
1171             for (unsigned vtx = 0; vtx < 3; ++vtx)
1172                nir_build_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u, .write_mask = 0x1u);
1173          }
1174          nir_pop_if(b, if_gs_accepted);
1175       }
1176       nir_pop_if(b, if_gs_thread);
1177 
1178       nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1179                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1180 
1181       nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
1182 
1183       /* ES invocations load their accepted flag from LDS. */
1184       if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1185       {
1186          nir_ssa_def *accepted = nir_build_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1187          nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));
1188          nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);
1189       }
1190       nir_pop_if(b, if_es_thread);
1191 
1192       nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);
1193 
1194       /* Repack the vertices that survived the culling. */
1195       wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,
1196                                                             nogs_state->max_num_waves, nogs_state->wave_size);
1197       nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
1198       nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
1199 
1200       /* If all vertices are culled, set primitive count to 0 as well. */
1201       nir_ssa_def *num_exported_prims = nir_build_load_workgroup_num_input_primitives_amd(b);
1202       nir_ssa_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1203       num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims);
1204 
1205       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1206       {
1207          /* Tell the final vertex and primitive count to the HW. */
1208          nir_build_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
1209       }
1210       nir_pop_if(b, if_wave_0);
1211 
1212       /* Vertex compaction. */
1213       compact_vertices_after_culling(b, nogs_state,
1214                                      repacked_arg_vars, gs_vtxaddr_vars,
1215                                      invocation_index, es_vertex_lds_addr,
1216                                      es_exporter_tid, num_live_vertices_in_workgroup, fully_culled,
1217                                      ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);
1218    }
1219    nir_push_else(b, if_cull_en);
1220    {
1221       /* When culling is disabled, we do the same as we would without culling. */
1222       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1223       {
1224          nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1225          nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
1226          nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1227       }
1228       nir_pop_if(b, if_wave_0);
1229       nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
1230    }
1231    nir_pop_if(b, if_cull_en);
1232 
1233    /* Update shader arguments.
1234     *
1235     * The registers which hold information about the subgroup's
1236     * vertices and primitives are updated here, so the rest of the shader
1237     * doesn't need to worry about the culling.
1238     *
1239     * These "overwrite" intrinsics must be at top level control flow,
1240     * otherwise they can mess up the backend (eg. ACO's SSA).
1241     *
1242     * TODO:
1243     * A cleaner solution would be to simply replace all usages of these args
1244     * with the load of the variables.
1245     * However, this wouldn't work right now because the backend uses the arguments
1246     * for purposes not expressed in NIR, eg. VS input loads, etc.
1247     * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1248     */
1249 
1250    if (b->shader->info.stage == MESA_SHADER_VERTEX)
1251       nogs_state->overwrite_args =
1252          nir_build_overwrite_vs_arguments_amd(b,
1253             nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));
1254    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1255       nogs_state->overwrite_args =
1256          nir_build_overwrite_tes_arguments_amd(b,
1257             nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),
1258             nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));
1259    else
1260       unreachable("Should be VS or TES.");
1261 }
1262 
1263 void
ac_nir_lower_ngg_nogs(nir_shader * shader,unsigned max_num_es_vertices,unsigned num_vertices_per_primitives,unsigned max_workgroup_size,unsigned wave_size,bool can_cull,bool early_prim_export,bool passthrough,bool export_prim_id,bool provoking_vtx_last,bool use_edgeflags,uint32_t instance_rate_inputs)1264 ac_nir_lower_ngg_nogs(nir_shader *shader,
1265                       unsigned max_num_es_vertices,
1266                       unsigned num_vertices_per_primitives,
1267                       unsigned max_workgroup_size,
1268                       unsigned wave_size,
1269                       bool can_cull,
1270                       bool early_prim_export,
1271                       bool passthrough,
1272                       bool export_prim_id,
1273                       bool provoking_vtx_last,
1274                       bool use_edgeflags,
1275                       uint32_t instance_rate_inputs)
1276 {
1277    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1278    assert(impl);
1279    assert(max_num_es_vertices && max_workgroup_size && wave_size);
1280    assert(!(can_cull && passthrough));
1281 
1282    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
1283    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
1284    nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
1285    nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
1286 
1287    lower_ngg_nogs_state state = {
1288       .passthrough = passthrough,
1289       .export_prim_id = export_prim_id,
1290       .early_prim_export = early_prim_export,
1291       .use_edgeflags = use_edgeflags,
1292       .num_vertices_per_primitives = num_vertices_per_primitives,
1293       .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
1294       .position_value_var = position_value_var,
1295       .prim_exp_arg_var = prim_exp_arg_var,
1296       .es_accepted_var = es_accepted_var,
1297       .gs_accepted_var = gs_accepted_var,
1298       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1299       .max_es_num_vertices = max_num_es_vertices,
1300       .wave_size = wave_size,
1301       .instance_rate_inputs = instance_rate_inputs,
1302    };
1303 
1304    /* We need LDS space when VS needs to export the primitive ID. */
1305    if (shader->info.stage == MESA_SHADER_VERTEX && export_prim_id)
1306       state.total_lds_bytes = max_num_es_vertices * 4u;
1307 
1308    nir_builder builder;
1309    nir_builder *b = &builder; /* This is to avoid the & */
1310    nir_builder_init(b, impl);
1311 
1312    if (can_cull) {
1313       /* We need divergence info for culling shaders. */
1314       nir_divergence_analysis(shader);
1315       analyze_shader_before_culling(shader, &state);
1316       save_reusable_variables(b, &state);
1317    }
1318 
1319    nir_cf_list extracted;
1320    nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1321    b->cursor = nir_before_cf_list(&impl->body);
1322 
1323    if (!can_cull) {
1324       /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
1325       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1326       {
1327          nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1328          nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
1329          nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1330       }
1331       nir_pop_if(b, if_wave_0);
1332 
1333       /* Take care of early primitive export, otherwise just pack the primitive export argument */
1334       if (state.early_prim_export)
1335          emit_ngg_nogs_prim_export(b, &state, NULL);
1336       else
1337          nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
1338    } else {
1339       add_deferred_attribute_culling(b, &extracted, &state);
1340       b->cursor = nir_after_cf_list(&impl->body);
1341 
1342       if (state.early_prim_export)
1343          emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
1344    }
1345 
1346    nir_intrinsic_instr *export_vertex_instr;
1347    nir_ssa_def *es_thread = can_cull ? nir_load_var(b, es_accepted_var) : nir_build_has_input_vertex_amd(b);
1348 
1349    nir_if *if_es_thread = nir_push_if(b, es_thread);
1350    {
1351       /* Run the actual shader */
1352       nir_cf_reinsert(&extracted, b->cursor);
1353       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1354 
1355       /* Export all vertex attributes (except primitive ID) */
1356       export_vertex_instr = nir_build_export_vertex_amd(b);
1357 
1358       /* Export primitive ID (in case of early primitive export or TES) */
1359       if (state.export_prim_id && (state.early_prim_export || shader->info.stage != MESA_SHADER_VERTEX))
1360          emit_store_ngg_nogs_es_primitive_id(b);
1361    }
1362    nir_pop_if(b, if_es_thread);
1363 
1364    /* Take care of late primitive export */
1365    if (!state.early_prim_export) {
1366       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
1367       if (state.export_prim_id && shader->info.stage == MESA_SHADER_VERTEX) {
1368          if_es_thread = nir_push_if(b, can_cull ? es_thread : nir_build_has_input_vertex_amd(b));
1369          emit_store_ngg_nogs_es_primitive_id(b);
1370          nir_pop_if(b, if_es_thread);
1371       }
1372    }
1373 
1374    if (can_cull) {
1375       /* Replace uniforms. */
1376       apply_reusable_variables(b, &state);
1377 
1378       /* Remove the redundant position output. */
1379       remove_extra_pos_outputs(shader, &state);
1380 
1381       /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
1382        * it seems that it's best to put the position export always at the end, and
1383        * then let ACO schedule it up (slightly) only when early prim export is used.
1384        */
1385       b->cursor = nir_before_instr(&export_vertex_instr->instr);
1386 
1387       nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
1388       nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
1389       nir_build_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem, .write_mask = 0xfu);
1390    }
1391 
1392    nir_metadata_preserve(impl, nir_metadata_none);
1393    nir_validate_shader(shader, "after emitting NGG VS/TES");
1394 
1395    /* Cleanup */
1396    nir_opt_dead_write_vars(shader);
1397    nir_lower_vars_to_ssa(shader);
1398    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1399    nir_lower_alu_to_scalar(shader, NULL, NULL);
1400    nir_lower_phis_to_scalar(shader, true);
1401 
1402    if (can_cull) {
1403       /* It's beneficial to redo these opts after splitting the shader. */
1404       nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
1405       nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
1406    }
1407 
1408    bool progress;
1409    do {
1410       progress = false;
1411       NIR_PASS(progress, shader, nir_opt_undef);
1412       NIR_PASS(progress, shader, nir_opt_dce);
1413       NIR_PASS(progress, shader, nir_opt_dead_cf);
1414 
1415       if (can_cull)
1416          progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
1417    } while (progress);
1418 
1419    shader->info.shared_size = state.total_lds_bytes;
1420 }
1421 
1422 static nir_ssa_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_ssa_def * out_vtx_idx,lower_ngg_gs_state * s)1423 ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s)
1424 {
1425    unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
1426 
1427    /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
1428    if (write_stride_2exp) {
1429       nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
1430       nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
1431       out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
1432    }
1433 
1434    nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
1435    return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
1436 }
1437 
1438 static nir_ssa_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_ssa_def * gs_vtx_idx,lower_ngg_gs_state * s)1439 ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s)
1440 {
1441    nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);
1442    nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
1443    nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
1444 
1445    return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
1446 }
1447 
1448 static void
ngg_gs_clear_primflags(nir_builder * b,nir_ssa_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)1449 ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
1450 {
1451    nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8);
1452    nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u);
1453 
1454    nir_loop *loop = nir_push_loop(b);
1455    {
1456       nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var);
1457       nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out)));
1458       {
1459          nir_jump(b, nir_jump_break);
1460       }
1461       nir_push_else(b, if_break);
1462       {
1463          nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s);
1464          nir_build_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 1, .write_mask = 0x1u);
1465          nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u);
1466       }
1467       nir_pop_if(b, if_break);
1468    }
1469    nir_pop_loop(b, loop);
1470 }
1471 
1472 static void
ngg_gs_shader_query(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1473 ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1474 {
1475    nir_if *if_shader_query = nir_push_if(b, nir_build_load_shader_query_enabled_amd(b));
1476    nir_ssa_def *num_prims_in_wave = NULL;
1477 
1478    /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
1479     * GS emits points, line strips or triangle strips.
1480     * Real primitives are points, lines or triangles.
1481     */
1482    if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
1483       unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
1484       unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
1485       unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
1486       nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
1487       num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
1488    } else {
1489       nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
1490       nir_ssa_def *prm_cnt = intrin->src[1].ssa;
1491       if (s->num_vertices_per_primitive > 1)
1492          prm_cnt = nir_iadd_nuw(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
1493       num_prims_in_wave = nir_build_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
1494    }
1495 
1496    /* Store the query result to GDS using an atomic add. */
1497    nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
1498    nir_build_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100));
1499    nir_pop_if(b, if_first_lane);
1500 
1501    nir_pop_if(b, if_shader_query);
1502 }
1503 
1504 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1505 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1506 {
1507    assert(nir_src_is_const(intrin->src[1]));
1508    b->cursor = nir_before_instr(&intrin->instr);
1509 
1510    unsigned writemask = nir_intrinsic_write_mask(intrin);
1511    unsigned base = nir_intrinsic_base(intrin);
1512    unsigned component_offset = nir_intrinsic_component(intrin);
1513    unsigned base_offset = nir_src_as_uint(intrin->src[1]);
1514    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1515 
1516    assert((base + base_offset) < VARYING_SLOT_MAX);
1517 
1518    nir_ssa_def *store_val = intrin->src[0].ssa;
1519 
1520    for (unsigned comp = 0; comp < 4; ++comp) {
1521       if (!(writemask & (1 << comp)))
1522          continue;
1523       unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;
1524       if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))
1525          continue;
1526 
1527       /* Small bitsize components consume the same amount of space as 32-bit components,
1528        * but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)
1529        */
1530       unsigned num_consumed_components = MIN2(1, DIV_ROUND_UP(store_val->bit_size, 32));
1531       nir_ssa_def *element = nir_channel(b, store_val, comp);
1532       if (num_consumed_components > 1)
1533          element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32);
1534 
1535       for (unsigned c = 0; c < num_consumed_components; ++c) {
1536          unsigned component_index =  (comp * num_consumed_components) + c + component_offset;
1537          unsigned base_index = base + base_offset + component_index / 4;
1538          component_index %= 4;
1539 
1540          /* Save output usage info */
1541          gs_output_component_info *info = &s->output_component_info[base_index][component_index];
1542          info->bit_size = MAX2(info->bit_size, MIN2(store_val->bit_size, 32));
1543          info->stream = stream;
1544 
1545          /* Store the current component element */
1546          nir_ssa_def *component_element = element;
1547          if (num_consumed_components > 1)
1548             component_element = nir_channel(b, component_element, c);
1549          if (component_element->bit_size != 32)
1550             component_element = nir_u2u32(b, component_element);
1551 
1552          nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u);
1553       }
1554    }
1555 
1556    nir_instr_remove(&intrin->instr);
1557    return true;
1558 }
1559 
1560 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1561 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1562 {
1563    b->cursor = nir_before_instr(&intrin->instr);
1564 
1565    unsigned stream = nir_intrinsic_stream_id(intrin);
1566    if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1567       nir_instr_remove(&intrin->instr);
1568       return true;
1569    }
1570 
1571    nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa;
1572    nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa;
1573    nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
1574 
1575    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1576       unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1577 
1578       for (unsigned comp = 0; comp < 4; ++comp) {
1579          gs_output_component_info *info = &s->output_component_info[slot][comp];
1580          if (info->stream != stream || !info->bit_size)
1581             continue;
1582 
1583          /* Store the output to LDS */
1584          nir_ssa_def *out_val = nir_load_var(b, s->output_vars[slot][comp]);
1585          if (info->bit_size != 32)
1586             out_val = nir_u2u(b, out_val, info->bit_size);
1587 
1588          nir_build_store_shared(b, out_val, gs_emit_vtx_addr, .base = packed_location * 16 + comp * 4, .align_mul = 4, .write_mask = 0x1u);
1589 
1590          /* Clear the variable that holds the output */
1591          nir_store_var(b, s->output_vars[slot][comp], nir_ssa_undef(b, 1, 32), 0x1u);
1592       }
1593    }
1594 
1595    /* Calculate and store per-vertex primitive flags based on vertex counts:
1596     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
1597     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
1598     * - bit 2: always 1 (so that we can use it for determining vertex liveness)
1599     */
1600 
1601    nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
1602    nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
1603 
1604    if (s->num_vertices_per_primitive == 3) {
1605       nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
1606       prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
1607    }
1608 
1609    nir_build_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u, .write_mask = 0x1u);
1610    nir_instr_remove(&intrin->instr);
1611    return true;
1612 }
1613 
1614 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)1615 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
1616 {
1617    b->cursor = nir_before_instr(&intrin->instr);
1618 
1619    /* These are not needed, we can simply remove them */
1620    nir_instr_remove(&intrin->instr);
1621    return true;
1622 }
1623 
1624 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1625 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1626 {
1627    b->cursor = nir_before_instr(&intrin->instr);
1628 
1629    unsigned stream = nir_intrinsic_stream_id(intrin);
1630    if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1631       nir_instr_remove(&intrin->instr);
1632       return true;
1633    }
1634 
1635    s->found_out_vtxcnt[stream] = true;
1636 
1637    /* Clear the primitive flags of non-emitted vertices */
1638    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
1639       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
1640 
1641    ngg_gs_shader_query(b, intrin, s);
1642    nir_instr_remove(&intrin->instr);
1643    return true;
1644 }
1645 
1646 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)1647 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
1648 {
1649    lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
1650 
1651    if (instr->type != nir_instr_type_intrinsic)
1652       return false;
1653 
1654    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1655 
1656    if (intrin->intrinsic == nir_intrinsic_store_output)
1657       return lower_ngg_gs_store_output(b, intrin, s);
1658    else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
1659       return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
1660    else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
1661       return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
1662    else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
1663       return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
1664 
1665    return false;
1666 }
1667 
1668 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)1669 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
1670 {
1671    nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
1672 }
1673 
1674 static void
ngg_gs_export_primitives(nir_builder * b,nir_ssa_def * max_num_out_prims,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,nir_ssa_def * primflag_0,lower_ngg_gs_state * s)1675 ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg,
1676                          nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0,
1677                          lower_ngg_gs_state *s)
1678 {
1679    nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
1680 
1681    /* Only bit 0 matters here - set it to 1 when the primitive should be null */
1682    nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
1683 
1684    nir_ssa_def *vtx_indices[3] = {0};
1685    vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
1686    if (s->num_vertices_per_primitive >= 2)
1687       vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1));
1688    if (s->num_vertices_per_primitive == 3)
1689       vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2));
1690 
1691    if (s->num_vertices_per_primitive == 3) {
1692       /* API GS outputs triangle strips, but NGG HW understands triangles.
1693        * We already know the triangles due to how we set the primitive flags, but we need to
1694        * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
1695        */
1696 
1697       nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));
1698       if (!s->provoking_vertex_last) {
1699          vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);
1700          vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);
1701       } else {
1702          vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);
1703          vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);
1704       }
1705    }
1706 
1707    nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, false);
1708    nir_build_export_primitive_amd(b, arg);
1709    nir_pop_if(b, if_prim_export_thread);
1710 }
1711 
1712 static void
ngg_gs_export_vertices(nir_builder * b,nir_ssa_def * max_num_out_vtx,nir_ssa_def * tid_in_tg,nir_ssa_def * out_vtx_lds_addr,lower_ngg_gs_state * s)1713 ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg,
1714                        nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
1715 {
1716    nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1717    nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
1718 
1719    if (!s->output_compile_time_known) {
1720       /* Vertex compaction.
1721        * The current thread will export a vertex that was live in another invocation.
1722        * Load the index of the vertex that the current thread will have to export.
1723        */
1724       nir_ssa_def *exported_vtx_idx = nir_build_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u);
1725       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
1726    }
1727 
1728    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1729       if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))
1730          continue;
1731 
1732       unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1733       nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
1734 
1735       for (unsigned comp = 0; comp < 4; ++comp) {
1736          gs_output_component_info *info = &s->output_component_info[slot][comp];
1737          if (info->stream != 0 || info->bit_size == 0)
1738             continue;
1739 
1740          nir_ssa_def *load = nir_build_load_shared(b, 1, info->bit_size, exported_out_vtx_lds_addr, .base = packed_location * 16u + comp * 4u, .align_mul = 4u);
1741          nir_build_store_output(b, load, nir_imm_int(b, 0), .write_mask = 0x1u, .base = slot, .component = comp, .io_semantics = io_sem);
1742       }
1743    }
1744 
1745    nir_build_export_vertex_amd(b);
1746    nir_pop_if(b, if_vtx_export_thread);
1747 }
1748 
1749 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_ssa_def * vertex_live,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,lower_ngg_gs_state * s)1750 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg,
1751                                nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
1752 {
1753    assert(vertex_live->bit_size == 1);
1754    nir_if *if_vertex_live = nir_push_if(b, vertex_live);
1755    {
1756       /* Setup the vertex compaction.
1757        * Save the current thread's id for the thread which will export the current vertex.
1758        * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
1759        */
1760 
1761       nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
1762       nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
1763       nir_build_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u, .write_mask = 0x1u);
1764    }
1765    nir_pop_if(b, if_vertex_live);
1766 }
1767 
1768 static nir_ssa_def *
ngg_gs_load_out_vtx_primflag_0(nir_builder * b,nir_ssa_def * tid_in_tg,nir_ssa_def * vtx_lds_addr,nir_ssa_def * max_num_out_vtx,lower_ngg_gs_state * s)1769 ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr,
1770                                nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s)
1771 {
1772    nir_ssa_def *zero = nir_imm_int(b, 0);
1773 
1774    nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1775    nir_ssa_def *primflag_0 = nir_build_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u);
1776    primflag_0 = nir_u2u32(b, primflag_0);
1777    nir_pop_if(b, if_outvtx_thread);
1778 
1779    return nir_if_phi(b, primflag_0, zero);
1780 }
1781 
1782 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)1783 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
1784 {
1785    nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);
1786    nir_ssa_def *max_vtxcnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1787    nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
1788    nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
1789 
1790    if (s->output_compile_time_known) {
1791       /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
1792        * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
1793        */
1794       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1795       nir_build_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt);
1796       nir_pop_if(b, if_wave_0);
1797    }
1798 
1799    /* Workgroup barrier: wait for all GS threads to finish */
1800    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1801                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1802 
1803    nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
1804 
1805    if (s->output_compile_time_known) {
1806       ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
1807       ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
1808       return;
1809    }
1810 
1811    /* When the output vertex count is not known at compile time:
1812     * There may be gaps between invocations that have live vertices, but NGG hardware
1813     * requires that the invocations that export vertices are packed (ie. compact).
1814     * To ensure this, we need to repack invocations that have a live vertex.
1815     */
1816    nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
1817    wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
1818 
1819    nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
1820    nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
1821 
1822    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
1823    nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));
1824    max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
1825 
1826    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
1827    nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1828    nir_build_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
1829    nir_pop_if(b, if_wave_0);
1830 
1831    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
1832    ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
1833 
1834    /* Workgroup barrier: wait for all LDS stores to finish. */
1835    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1836                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1837 
1838    ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
1839    ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
1840 }
1841 
1842 void
ac_nir_lower_ngg_gs(nir_shader * shader,unsigned wave_size,unsigned max_workgroup_size,unsigned esgs_ring_lds_bytes,unsigned gs_out_vtx_bytes,unsigned gs_total_out_vtx_bytes,bool provoking_vertex_last)1843 ac_nir_lower_ngg_gs(nir_shader *shader,
1844                     unsigned wave_size,
1845                     unsigned max_workgroup_size,
1846                     unsigned esgs_ring_lds_bytes,
1847                     unsigned gs_out_vtx_bytes,
1848                     unsigned gs_total_out_vtx_bytes,
1849                     bool provoking_vertex_last)
1850 {
1851    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1852    assert(impl);
1853 
1854    lower_ngg_gs_state state = {
1855       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1856       .wave_size = wave_size,
1857       .lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
1858       .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
1859       .lds_offs_primflags = gs_out_vtx_bytes,
1860       .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
1861       .provoking_vertex_last = provoking_vertex_last,
1862    };
1863 
1864    unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
1865    unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
1866    shader->info.shared_size = total_lds_bytes;
1867 
1868    nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
1869    state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
1870                                      state.const_out_prmcnt[0] != -1;
1871 
1872    if (!state.output_compile_time_known)
1873       state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
1874 
1875    if (shader->info.gs.output_primitive == GL_POINTS)
1876       state.num_vertices_per_primitive = 1;
1877    else if (shader->info.gs.output_primitive == GL_LINE_STRIP)
1878       state.num_vertices_per_primitive = 2;
1879    else if (shader->info.gs.output_primitive == GL_TRIANGLE_STRIP)
1880       state.num_vertices_per_primitive = 3;
1881    else
1882       unreachable("Invalid GS output primitive.");
1883 
1884    /* Extract the full control flow. It is going to be wrapped in an if statement. */
1885    nir_cf_list extracted;
1886    nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1887 
1888    nir_builder builder;
1889    nir_builder *b = &builder; /* This is to avoid the & */
1890    nir_builder_init(b, impl);
1891    b->cursor = nir_before_cf_list(&impl->body);
1892 
1893    /* Workgroup barrier: wait for ES threads */
1894    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1895                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1896 
1897    /* Wrap the GS control flow. */
1898    nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
1899 
1900    /* Create and initialize output variables */
1901    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1902       for (unsigned comp = 0; comp < 4; ++comp) {
1903          state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output");
1904       }
1905    }
1906 
1907    nir_cf_reinsert(&extracted, b->cursor);
1908    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
1909    nir_pop_if(b, if_gs_thread);
1910 
1911    /* Lower the GS intrinsics */
1912    lower_ngg_gs_intrinsics(shader, &state);
1913    b->cursor = nir_after_cf_list(&impl->body);
1914 
1915    if (!state.found_out_vtxcnt[0]) {
1916       fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
1917       abort();
1918    }
1919 
1920    /* Emit the finale sequence */
1921    ngg_gs_finale(b, &state);
1922    nir_validate_shader(shader, "after emitting NGG GS");
1923 
1924    /* Cleanup */
1925    nir_lower_vars_to_ssa(shader);
1926    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1927    nir_metadata_preserve(impl, nir_metadata_none);
1928 }
1929