1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 #include "u_math.h"
28 #include "u_vector.h"
29
30 enum {
31 nggc_passflag_used_by_pos = 1,
32 nggc_passflag_used_by_other = 2,
33 nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
34 };
35
36 typedef struct
37 {
38 nir_ssa_def *ssa;
39 nir_variable *var;
40 } saved_uniform;
41
42 typedef struct
43 {
44 nir_variable *position_value_var;
45 nir_variable *prim_exp_arg_var;
46 nir_variable *es_accepted_var;
47 nir_variable *gs_accepted_var;
48
49 struct u_vector saved_uniforms;
50
51 bool passthrough;
52 bool export_prim_id;
53 bool early_prim_export;
54 bool use_edgeflags;
55 unsigned wave_size;
56 unsigned max_num_waves;
57 unsigned num_vertices_per_primitives;
58 unsigned provoking_vtx_idx;
59 unsigned max_es_num_vertices;
60 unsigned total_lds_bytes;
61
62 uint64_t inputs_needed_by_pos;
63 uint64_t inputs_needed_by_others;
64 uint32_t instance_rate_inputs;
65
66 nir_instr *compact_arg_stores[4];
67 nir_intrinsic_instr *overwrite_args;
68 } lower_ngg_nogs_state;
69
70 typedef struct
71 {
72 /* bitsize of this component (max 32), or 0 if it's never written at all */
73 uint8_t bit_size : 6;
74 /* output stream index */
75 uint8_t stream : 2;
76 } gs_output_component_info;
77
78 typedef struct
79 {
80 nir_variable *output_vars[VARYING_SLOT_MAX][4];
81 nir_variable *current_clear_primflag_idx_var;
82 int const_out_vtxcnt[4];
83 int const_out_prmcnt[4];
84 unsigned wave_size;
85 unsigned max_num_waves;
86 unsigned num_vertices_per_primitive;
87 unsigned lds_addr_gs_out_vtx;
88 unsigned lds_addr_gs_scratch;
89 unsigned lds_bytes_per_gs_out_vertex;
90 unsigned lds_offs_primflags;
91 bool found_out_vtxcnt[4];
92 bool output_compile_time_known;
93 bool provoking_vertex_last;
94 gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
95 } lower_ngg_gs_state;
96
97 typedef struct {
98 nir_variable *pre_cull_position_value_var;
99 } remove_culling_shader_outputs_state;
100
101 typedef struct {
102 nir_variable *pos_value_replacement;
103 } remove_extra_position_output_state;
104
105 /* Per-vertex LDS layout of culling shaders */
106 enum {
107 /* Position of the ES vertex (at the beginning for alignment reasons) */
108 lds_es_pos_x = 0,
109 lds_es_pos_y = 4,
110 lds_es_pos_z = 8,
111 lds_es_pos_w = 12,
112
113 /* 1 when the vertex is accepted, 0 if it should be culled */
114 lds_es_vertex_accepted = 16,
115 /* ID of the thread which will export the current thread's vertex */
116 lds_es_exporter_tid = 17,
117
118 /* Repacked arguments - also listed separately for VS and TES */
119 lds_es_arg_0 = 20,
120
121 /* VS arguments which need to be repacked */
122 lds_es_vs_vertex_id = 20,
123 lds_es_vs_instance_id = 24,
124
125 /* TES arguments which need to be repacked */
126 lds_es_tes_u = 20,
127 lds_es_tes_v = 24,
128 lds_es_tes_rel_patch_id = 28,
129 lds_es_tes_patch_id = 32,
130 };
131
132 typedef struct {
133 nir_ssa_def *num_repacked_invocations;
134 nir_ssa_def *repacked_invocation_index;
135 } wg_repack_result;
136
137 /**
138 * Computes a horizontal sum of 8-bit packed values loaded from LDS.
139 *
140 * Each lane N will sum packed bytes 0 to N-1.
141 * We only care about the results from up to wave_id+1 lanes.
142 * (Other lanes are not deactivated but their calculation is not used.)
143 */
144 static nir_ssa_def *
summarize_repack(nir_builder * b,nir_ssa_def * packed_counts,unsigned num_lds_dwords)145 summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords)
146 {
147 /* We'll use shift to filter out the bytes not needed by the current lane.
148 *
149 * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
150 * However, two shifts are needed because one can't go all the way,
151 * so the shift amount is half that (and in bits).
152 *
153 * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
154 * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
155 * therefore v_dot can get rid of the unneeded values.
156 * This sequence is preferable because it better hides the latency of the LDS.
157 *
158 * If the v_dot instruction can't be used, we left-shift the packed bytes.
159 * This will shift out the unneeded bytes and shift in zeroes instead,
160 * then we sum them using v_sad_u8.
161 */
162
163 nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
164 nir_ssa_def *shift = nir_iadd_imm_nuw(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
165 bool use_dot = b->shader->options->has_dot_4x8;
166
167 if (num_lds_dwords == 1) {
168 nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
169
170 /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
171 nir_ssa_def *packed = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
172
173 /* Horizontally add the packed bytes. */
174 if (use_dot) {
175 return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
176 } else {
177 nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
178 return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
179 }
180 } else if (num_lds_dwords == 2) {
181 nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
182
183 /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
184 nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
185 nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
186
187 /* Horizontally add the packed bytes. */
188 if (use_dot) {
189 nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
190 return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
191 } else {
192 nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
193 nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
194 return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
195 }
196 } else {
197 unreachable("Unimplemented NGG wave count");
198 }
199 }
200
201 /**
202 * Repacks invocations in the current workgroup to eliminate gaps between them.
203 *
204 * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
205 * Assumes that all invocations in the workgroup are active (exec = -1).
206 */
207 static wg_repack_result
repack_invocations_in_workgroup(nir_builder * b,nir_ssa_def * input_bool,unsigned lds_addr_base,unsigned max_num_waves,unsigned wave_size)208 repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
209 unsigned lds_addr_base, unsigned max_num_waves,
210 unsigned wave_size)
211 {
212 /* Input boolean: 1 if the current invocation should survive the repack. */
213 assert(input_bool->bit_size == 1);
214
215 /* STEP 1. Count surviving invocations in the current wave.
216 *
217 * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
218 */
219
220 nir_ssa_def *input_mask = nir_build_ballot(b, 1, wave_size, input_bool);
221 nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
222
223 /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
224 if (max_num_waves == 1) {
225 wg_repack_result r = {
226 .num_repacked_invocations = surviving_invocations_in_current_wave,
227 .repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
228 };
229 return r;
230 }
231
232 /* STEP 2. Waves tell each other their number of surviving invocations.
233 *
234 * Each wave activates only its first lane (exec = 1), which stores the number of surviving
235 * invocations in that wave into the LDS, then reads the numbers from every wave.
236 *
237 * The workgroup size of NGG shaders is at most 256, which means
238 * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
239 * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
240 */
241
242 const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
243 assert(num_lds_dwords <= 2);
244
245 nir_ssa_def *wave_id = nir_build_load_subgroup_id(b);
246 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);
247 nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
248
249 nir_build_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base, .align_mul = 1u, .write_mask = 0x1u);
250
251 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
252 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
253
254 nir_ssa_def *packed_counts = nir_build_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);
255
256 nir_pop_if(b, if_first_lane);
257
258 packed_counts = nir_if_phi(b, packed_counts, dont_care);
259
260 /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
261 *
262 * By now, every wave knows the number of surviving invocations in all waves.
263 * Each number is 1 byte, and they are packed into up to 2 dwords.
264 *
265 * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
266 * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
267 * (Other lanes are not deactivated but their calculation is not used.)
268 *
269 * - We read the sum from the lane whose id is the current wave's id.
270 * Add the masked bitcount to this, and we get the repacked invocation index.
271 * - We read the sum from the lane whose id is the number of waves in the workgroup.
272 * This is the total number of surviving invocations in the workgroup.
273 */
274
275 nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
276 nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
277
278 nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
279 nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
280 nir_ssa_def *wg_repacked_index = nir_build_mbcnt_amd(b, input_mask, wg_repacked_index_base);
281
282 wg_repack_result r = {
283 .num_repacked_invocations = wg_num_repacked_invocations,
284 .repacked_invocation_index = wg_repacked_index,
285 };
286
287 return r;
288 }
289
290 static nir_ssa_def *
pervertex_lds_addr(nir_builder * b,nir_ssa_def * vertex_idx,unsigned per_vtx_bytes)291 pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes)
292 {
293 return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
294 }
295
296 static nir_ssa_def *
emit_pack_ngg_prim_exp_arg(nir_builder * b,unsigned num_vertices_per_primitives,nir_ssa_def * vertex_indices[3],nir_ssa_def * is_null_prim,bool use_edgeflags)297 emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
298 nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim,
299 bool use_edgeflags)
300 {
301 nir_ssa_def *arg = use_edgeflags
302 ? nir_build_load_initial_edgeflags_amd(b)
303 : nir_imm_int(b, 0);
304
305 for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
306 assert(vertex_indices[i]);
307 arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i)));
308 }
309
310 if (is_null_prim) {
311 if (is_null_prim->bit_size == 1)
312 is_null_prim = nir_b2i32(b, is_null_prim);
313 assert(is_null_prim->bit_size == 32);
314 arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u)));
315 }
316
317 return arg;
318 }
319
320 static nir_ssa_def *
ngg_input_primitive_vertex_index(nir_builder * b,unsigned vertex)321 ngg_input_primitive_vertex_index(nir_builder *b, unsigned vertex)
322 {
323 return nir_ubfe(b, nir_build_load_gs_vertex_offset_amd(b, .base = vertex / 2u),
324 nir_imm_int(b, (vertex & 1u) * 16u), nir_imm_int(b, 16u));
325 }
326
327 static nir_ssa_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * st)328 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
329 {
330 if (st->passthrough) {
331 assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX);
332 return nir_build_load_packed_passthrough_primitive_amd(b);
333 } else {
334 nir_ssa_def *vtx_idx[3] = {0};
335
336 vtx_idx[0] = ngg_input_primitive_vertex_index(b, 0);
337 vtx_idx[1] = st->num_vertices_per_primitives >= 2
338 ? ngg_input_primitive_vertex_index(b, 1)
339 : nir_imm_zero(b, 1, 32);
340 vtx_idx[2] = st->num_vertices_per_primitives >= 3
341 ? ngg_input_primitive_vertex_index(b, 2)
342 : nir_imm_zero(b, 1, 32);
343
344 return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL, st->use_edgeflags);
345 }
346 }
347
348 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * st,nir_ssa_def * arg)349 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)
350 {
351 nir_ssa_def *gs_thread = st->gs_accepted_var
352 ? nir_load_var(b, st->gs_accepted_var)
353 : nir_build_has_input_primitive_amd(b);
354
355 nir_if *if_gs_thread = nir_push_if(b, gs_thread);
356 {
357 if (!arg)
358 arg = emit_ngg_nogs_prim_exp_arg(b, st);
359
360 if (st->export_prim_id && b->shader->info.stage == MESA_SHADER_VERTEX) {
361 /* Copy Primitive IDs from GS threads to the LDS address corresponding to the ES thread of the provoking vertex. */
362 nir_ssa_def *prim_id = nir_build_load_primitive_id(b);
363 nir_ssa_def *provoking_vtx_idx = ngg_input_primitive_vertex_index(b, st->provoking_vtx_idx);
364 nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u);
365
366 nir_build_store_shared(b, prim_id, addr, .write_mask = 1u, .align_mul = 4u);
367 }
368
369 nir_build_export_primitive_amd(b, arg);
370 }
371 nir_pop_if(b, if_gs_thread);
372 }
373
374 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b)375 emit_store_ngg_nogs_es_primitive_id(nir_builder *b)
376 {
377 nir_ssa_def *prim_id = NULL;
378
379 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
380 /* Workgroup barrier - wait for GS threads to store primitive ID in LDS. */
381 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP,
382 .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
383
384 /* LDS address where the primitive ID is stored */
385 nir_ssa_def *thread_id_in_threadgroup = nir_build_load_local_invocation_index(b);
386 nir_ssa_def *addr = pervertex_lds_addr(b, thread_id_in_threadgroup, 4u);
387
388 /* Load primitive ID from LDS */
389 prim_id = nir_build_load_shared(b, 1, 32, addr, .align_mul = 4u);
390 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
391 /* Just use tess eval primitive ID, which is the same as the patch ID. */
392 prim_id = nir_build_load_primitive_id(b);
393 }
394
395 nir_io_semantics io_sem = {
396 .location = VARYING_SLOT_PRIMITIVE_ID,
397 .num_slots = 1,
398 };
399
400 nir_build_store_output(b, prim_id, nir_imm_zero(b, 1, 32),
401 .base = io_sem.location,
402 .write_mask = 1u, .src_type = nir_type_uint32, .io_semantics = io_sem);
403 }
404
405 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)406 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
407 {
408 remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;
409
410 if (instr->type != nir_instr_type_intrinsic)
411 return false;
412
413 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
414
415 /* These are not allowed in VS / TES */
416 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
417 intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
418
419 /* We are only interested in output stores now */
420 if (intrin->intrinsic != nir_intrinsic_store_output)
421 return false;
422
423 b->cursor = nir_before_instr(instr);
424
425 /* Position output - store the value to a variable, remove output store */
426 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
427 if (io_sem.location == VARYING_SLOT_POS) {
428 /* TODO: check if it's indirect, etc? */
429 unsigned writemask = nir_intrinsic_write_mask(intrin);
430 nir_ssa_def *store_val = intrin->src[0].ssa;
431 nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);
432 }
433
434 /* Remove all output stores */
435 nir_instr_remove(instr);
436 return true;
437 }
438
439 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * nogs_state,nir_variable * pre_cull_position_value_var)440 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)
441 {
442 remove_culling_shader_outputs_state s = {
443 .pre_cull_position_value_var = pre_cull_position_value_var,
444 };
445
446 nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
447 nir_metadata_block_index | nir_metadata_dominance, &s);
448
449 /* Remove dead code resulting from the deleted outputs. */
450 bool progress;
451 do {
452 progress = false;
453 NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
454 NIR_PASS(progress, culling_shader, nir_opt_dce);
455 NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
456 } while (progress);
457 }
458
459 static void
rewrite_uses_to_var(nir_builder * b,nir_ssa_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)460 rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
461 {
462 if (old_def->parent_instr->type == nir_instr_type_load_const)
463 return;
464
465 b->cursor = nir_after_instr(old_def->parent_instr);
466 if (b->cursor.instr->type == nir_instr_type_phi)
467 b->cursor = nir_after_phis(old_def->parent_instr->block);
468
469 nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var);
470 nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
471
472 if (old_def->num_components > 1) {
473 /* old_def uses a swizzled vector component.
474 * There is no way to replace the uses of just a single vector component,
475 * so instead create a new vector and replace all uses of the old vector.
476 */
477 nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
478 for (unsigned j = 0; j < old_def->num_components; ++j)
479 old_def_elements[j] = nir_channel(b, old_def, j);
480 replacement = nir_vec(b, old_def_elements, old_def->num_components);
481 }
482
483 nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
484 }
485
486 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)487 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
488 {
489 remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;
490
491 if (instr->type != nir_instr_type_intrinsic)
492 return false;
493
494 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
495
496 /* These are not allowed in VS / TES */
497 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
498 intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
499
500 /* We are only interested in output stores now */
501 if (intrin->intrinsic != nir_intrinsic_store_output)
502 return false;
503
504 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
505 if (io_sem.location != VARYING_SLOT_POS)
506 return false;
507
508 b->cursor = nir_before_instr(instr);
509
510 /* In case other outputs use what we calculated for pos,
511 * try to avoid calculating it again by rewriting the usages
512 * of the store components here.
513 */
514 nir_ssa_def *store_val = intrin->src[0].ssa;
515 unsigned store_pos_component = nir_intrinsic_component(intrin);
516
517 nir_instr_remove(instr);
518
519 if (store_val->parent_instr->type == nir_instr_type_alu) {
520 nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
521 if (nir_op_is_vec(alu->op)) {
522 /* Output store uses a vector, we can easily rewrite uses of each vector element. */
523
524 unsigned num_vec_src = 0;
525 if (alu->op == nir_op_mov)
526 num_vec_src = 1;
527 else if (alu->op == nir_op_vec2)
528 num_vec_src = 2;
529 else if (alu->op == nir_op_vec3)
530 num_vec_src = 3;
531 else if (alu->op == nir_op_vec4)
532 num_vec_src = 4;
533 assert(num_vec_src);
534
535 /* Remember the current components whose uses we wish to replace.
536 * This is needed because rewriting one source can affect the others too.
537 */
538 nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
539 for (unsigned i = 0; i < num_vec_src; i++)
540 vec_comps[i] = alu->src[i].src.ssa;
541
542 for (unsigned i = 0; i < num_vec_src; i++)
543 rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i);
544 } else {
545 rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
546 }
547 } else {
548 rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
549 }
550
551 return true;
552 }
553
554 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * nogs_state)555 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
556 {
557 remove_extra_position_output_state s = {
558 .pos_value_replacement = nogs_state->position_value_var,
559 };
560
561 nir_shader_instructions_pass(shader, remove_extra_pos_output,
562 nir_metadata_block_index | nir_metadata_dominance, &s);
563 }
564
565 static bool
remove_compacted_arg(lower_ngg_nogs_state * state,nir_builder * b,unsigned idx)566 remove_compacted_arg(lower_ngg_nogs_state *state, nir_builder *b, unsigned idx)
567 {
568 nir_instr *store_instr = state->compact_arg_stores[idx];
569 if (!store_instr)
570 return false;
571
572 /* Simply remove the store. */
573 nir_instr_remove(store_instr);
574
575 /* Find the intrinsic that overwrites the shader arguments,
576 * and change its corresponding source.
577 * This will cause NIR's DCE to recognize the load and its phis as dead.
578 */
579 b->cursor = nir_before_instr(&state->overwrite_args->instr);
580 nir_ssa_def *undef_arg = nir_ssa_undef(b, 1, 32);
581 nir_ssa_def_rewrite_uses(state->overwrite_args->src[idx].ssa, undef_arg);
582
583 state->compact_arg_stores[idx] = NULL;
584 return true;
585 }
586
587 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * state)588 cleanup_culling_shader_after_dce(nir_shader *shader,
589 nir_function_impl *function_impl,
590 lower_ngg_nogs_state *state)
591 {
592 bool uses_vs_vertex_id = false;
593 bool uses_vs_instance_id = false;
594 bool uses_tes_u = false;
595 bool uses_tes_v = false;
596 bool uses_tes_rel_patch_id = false;
597 bool uses_tes_patch_id = false;
598
599 bool progress = false;
600 nir_builder b;
601 nir_builder_init(&b, function_impl);
602
603 nir_foreach_block_reverse_safe(block, function_impl) {
604 nir_foreach_instr_reverse_safe(instr, block) {
605 if (instr->type != nir_instr_type_intrinsic)
606 continue;
607
608 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
609
610 switch (intrin->intrinsic) {
611 case nir_intrinsic_alloc_vertices_and_primitives_amd:
612 goto cleanup_culling_shader_after_dce_done;
613 case nir_intrinsic_load_vertex_id:
614 case nir_intrinsic_load_vertex_id_zero_base:
615 uses_vs_vertex_id = true;
616 break;
617 case nir_intrinsic_load_instance_id:
618 uses_vs_instance_id = true;
619 break;
620 case nir_intrinsic_load_input:
621 if (state->instance_rate_inputs &
622 (1 << (nir_intrinsic_base(intrin) - VERT_ATTRIB_GENERIC0)))
623 uses_vs_instance_id = true;
624 else
625 uses_vs_vertex_id = true;
626 break;
627 case nir_intrinsic_load_tess_coord:
628 uses_tes_u = uses_tes_v = true;
629 break;
630 case nir_intrinsic_load_tess_rel_patch_id_amd:
631 uses_tes_rel_patch_id = true;
632 break;
633 case nir_intrinsic_load_primitive_id:
634 if (shader->info.stage == MESA_SHADER_TESS_EVAL)
635 uses_tes_patch_id = true;
636 break;
637 default:
638 break;
639 }
640 }
641 }
642
643 cleanup_culling_shader_after_dce_done:
644
645 if (shader->info.stage == MESA_SHADER_VERTEX) {
646 if (!uses_vs_vertex_id)
647 progress |= remove_compacted_arg(state, &b, 0);
648 if (!uses_vs_instance_id)
649 progress |= remove_compacted_arg(state, &b, 1);
650 } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
651 if (!uses_tes_u)
652 progress |= remove_compacted_arg(state, &b, 0);
653 if (!uses_tes_v)
654 progress |= remove_compacted_arg(state, &b, 1);
655 if (!uses_tes_rel_patch_id)
656 progress |= remove_compacted_arg(state, &b, 2);
657 if (!uses_tes_patch_id)
658 progress |= remove_compacted_arg(state, &b, 3);
659 }
660
661 return progress;
662 }
663
664 /**
665 * Perform vertex compaction after culling.
666 *
667 * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
668 * 2. Surviving ES vertex invocations store their data to LDS
669 * 3. Emit GS_ALLOC_REQ
670 * 4. Repacked invocations load the vertex data from LDS
671 * 5. GS threads update their vertex indices
672 */
673 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * nogs_state,nir_variable ** repacked_arg_vars,nir_variable ** gs_vtxaddr_vars,nir_ssa_def * invocation_index,nir_ssa_def * es_vertex_lds_addr,nir_ssa_def * es_exporter_tid,nir_ssa_def * num_live_vertices_in_workgroup,nir_ssa_def * fully_culled,unsigned ngg_scratch_lds_base_addr,unsigned pervertex_lds_bytes,unsigned max_exported_args)674 compact_vertices_after_culling(nir_builder *b,
675 lower_ngg_nogs_state *nogs_state,
676 nir_variable **repacked_arg_vars,
677 nir_variable **gs_vtxaddr_vars,
678 nir_ssa_def *invocation_index,
679 nir_ssa_def *es_vertex_lds_addr,
680 nir_ssa_def *es_exporter_tid,
681 nir_ssa_def *num_live_vertices_in_workgroup,
682 nir_ssa_def *fully_culled,
683 unsigned ngg_scratch_lds_base_addr,
684 unsigned pervertex_lds_bytes,
685 unsigned max_exported_args)
686 {
687 nir_variable *es_accepted_var = nogs_state->es_accepted_var;
688 nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
689 nir_variable *position_value_var = nogs_state->position_value_var;
690 nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
691
692 nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
693 {
694 nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
695
696 /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
697 nir_build_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid, .align_mul = 1u, .write_mask = 0x1u);
698
699 /* Store the current thread's position output to the exporter thread's LDS space */
700 nir_ssa_def *pos = nir_load_var(b, position_value_var);
701 nir_build_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x, .align_mul = 4u, .write_mask = 0xfu);
702
703 /* Store the current thread's repackable arguments to the exporter thread's LDS space */
704 for (unsigned i = 0; i < max_exported_args; ++i) {
705 nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);
706 nir_intrinsic_instr *store = nir_build_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u, .write_mask = 0x1u);
707
708 nogs_state->compact_arg_stores[i] = &store->instr;
709 }
710 }
711 nir_pop_if(b, if_es_accepted);
712
713 /* TODO: Consider adding a shortcut exit.
714 * Waves that have no vertices and primitives left can s_endpgm right here.
715 */
716
717 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
718 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
719
720 nir_ssa_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
721 nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
722 {
723 /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
724 nir_ssa_def *exported_pos = nir_build_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x, .align_mul = 4u);
725 nir_store_var(b, position_value_var, exported_pos, 0xfu);
726
727 /* Read the repacked arguments */
728 for (unsigned i = 0; i < max_exported_args; ++i) {
729 nir_ssa_def *arg_val = nir_build_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u);
730 nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);
731 }
732 }
733 nir_push_else(b, if_packed_es_thread);
734 {
735 nir_store_var(b, position_value_var, nir_ssa_undef(b, 4, 32), 0xfu);
736 for (unsigned i = 0; i < max_exported_args; ++i)
737 nir_store_var(b, repacked_arg_vars[i], nir_ssa_undef(b, 1, 32), 0x1u);
738 }
739 nir_pop_if(b, if_packed_es_thread);
740
741 nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
742 {
743 nir_ssa_def *exporter_vtx_indices[3] = {0};
744
745 /* Load the index of the ES threads that will export the current GS thread's vertices */
746 for (unsigned v = 0; v < 3; ++v) {
747 nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
748 nir_ssa_def *exporter_vtx_idx = nir_build_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid, .align_mul = 1u);
749 exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
750 }
751
752 nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL, nogs_state->use_edgeflags);
753 nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
754 }
755 nir_pop_if(b, if_gs_accepted);
756
757 nir_store_var(b, es_accepted_var, es_survived, 0x1u);
758 nir_store_var(b, gs_accepted_var, nir_bcsel(b, fully_culled, nir_imm_false(b), nir_build_has_input_primitive_amd(b)), 0x1u);
759 }
760
761 static void
analyze_shader_before_culling_walk(nir_ssa_def * ssa,uint8_t flag,lower_ngg_nogs_state * nogs_state)762 analyze_shader_before_culling_walk(nir_ssa_def *ssa,
763 uint8_t flag,
764 lower_ngg_nogs_state *nogs_state)
765 {
766 nir_instr *instr = ssa->parent_instr;
767 uint8_t old_pass_flags = instr->pass_flags;
768 instr->pass_flags |= flag;
769
770 if (instr->pass_flags == old_pass_flags)
771 return; /* Already visited. */
772
773 switch (instr->type) {
774 case nir_instr_type_intrinsic: {
775 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
776
777 /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
778 switch (intrin->intrinsic) {
779 case nir_intrinsic_load_input: {
780 nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
781 uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
782 if (instr->pass_flags & nggc_passflag_used_by_pos)
783 nogs_state->inputs_needed_by_pos |= in_mask;
784 else if (instr->pass_flags & nggc_passflag_used_by_other)
785 nogs_state->inputs_needed_by_others |= in_mask;
786 break;
787 }
788 default:
789 break;
790 }
791
792 break;
793 }
794 case nir_instr_type_alu: {
795 nir_alu_instr *alu = nir_instr_as_alu(instr);
796 unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
797
798 for (unsigned i = 0; i < num_srcs; ++i) {
799 analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state);
800 }
801
802 break;
803 }
804 case nir_instr_type_phi: {
805 nir_phi_instr *phi = nir_instr_as_phi(instr);
806 nir_foreach_phi_src_safe(phi_src, phi) {
807 analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state);
808 }
809
810 break;
811 }
812 default:
813 break;
814 }
815 }
816
817 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * nogs_state)818 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
819 {
820 nir_foreach_function(func, shader) {
821 nir_foreach_block(block, func->impl) {
822 nir_foreach_instr(instr, block) {
823 instr->pass_flags = 0;
824
825 if (instr->type != nir_instr_type_intrinsic)
826 continue;
827
828 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
829 if (intrin->intrinsic != nir_intrinsic_store_output)
830 continue;
831
832 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
833 nir_ssa_def *store_val = intrin->src[0].ssa;
834 uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
835 analyze_shader_before_culling_walk(store_val, flag, nogs_state);
836 }
837 }
838 }
839 }
840
841 /**
842 * Save the reusable SSA definitions to variables so that the
843 * bottom shader part can reuse them from the top part.
844 *
845 * 1. We create a new function temporary variable for reusables,
846 * and insert a store+load.
847 * 2. The shader is cloned (the top part is created), then the
848 * control flow is reinserted (for the bottom part.)
849 * 3. For reusables, we delete the variable stores from the
850 * bottom part. This will make them use the variables from
851 * the top part and DCE the redundant instructions.
852 */
853 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)854 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
855 {
856 ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, 4, sizeof(saved_uniform));
857 assert(vec_ok);
858
859 nir_block *block = nir_start_block(b->impl);
860 while (block) {
861 /* Process the instructions in the current block. */
862 nir_foreach_instr_safe(instr, block) {
863 /* Find instructions whose SSA definitions are used by both
864 * the top and bottom parts of the shader (before and after culling).
865 * Only in this case, it makes sense for the bottom part
866 * to try to reuse these from the top part.
867 */
868 if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
869 continue;
870
871 /* Determine if we can reuse the current SSA value.
872 * When vertex compaction is used, it is possible that the same shader invocation
873 * processes a different vertex in the top and bottom part of the shader.
874 * Therefore, we only reuse uniform values.
875 */
876 nir_ssa_def *ssa = NULL;
877 switch (instr->type) {
878 case nir_instr_type_alu: {
879 nir_alu_instr *alu = nir_instr_as_alu(instr);
880 if (alu->dest.dest.ssa.divergent)
881 continue;
882 /* Ignore uniform floats because they regress VGPR usage too much */
883 if (nir_op_infos[alu->op].output_type & nir_type_float)
884 continue;
885 ssa = &alu->dest.dest.ssa;
886 break;
887 }
888 case nir_instr_type_intrinsic: {
889 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
890 if (!nir_intrinsic_can_reorder(intrin) ||
891 !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
892 intrin->dest.ssa.divergent)
893 continue;
894 ssa = &intrin->dest.ssa;
895 break;
896 }
897 case nir_instr_type_phi: {
898 nir_phi_instr *phi = nir_instr_as_phi(instr);
899 if (phi->dest.ssa.divergent)
900 continue;
901 ssa = &phi->dest.ssa;
902 break;
903 }
904 default:
905 continue;
906 }
907
908 assert(ssa);
909
910 /* Determine a suitable type for the SSA value. */
911 enum glsl_base_type base_type = GLSL_TYPE_UINT;
912 switch (ssa->bit_size) {
913 case 8: base_type = GLSL_TYPE_UINT8; break;
914 case 16: base_type = GLSL_TYPE_UINT16; break;
915 case 32: base_type = GLSL_TYPE_UINT; break;
916 case 64: base_type = GLSL_TYPE_UINT64; break;
917 default: continue;
918 }
919
920 const struct glsl_type *t = ssa->num_components == 1
921 ? glsl_scalar_type(base_type)
922 : glsl_vector_type(base_type, ssa->num_components);
923
924 saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms);
925 assert(saved);
926
927 /* Create a new NIR variable where we store the reusable value.
928 * Then, we reload the variable and replace the uses of the value
929 * with the reloaded variable.
930 */
931 saved->var = nir_local_variable_create(b->impl, t, NULL);
932 saved->ssa = ssa;
933
934 b->cursor = instr->type == nir_instr_type_phi
935 ? nir_after_instr_and_phis(instr)
936 : nir_after_instr(instr);
937 nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
938 nir_ssa_def *reloaded = nir_load_var(b, saved->var);
939 nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
940 }
941
942 /* Look at the next CF node. */
943 nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
944 if (next_cf_node) {
945 /* It makes no sense to try to reuse things from within loops. */
946 bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
947
948 /* Don't reuse if we're in divergent control flow.
949 *
950 * Thanks to vertex repacking, the same shader invocation may process a different vertex
951 * in the top and bottom part, and it's even possible that this different vertex was initially
952 * processed in a different wave. So the two parts may take a different divergent code path.
953 * Therefore, these variables in divergent control flow may stay undefined.
954 *
955 * Note that this problem doesn't exist if vertices are not repacked or if the
956 * workgroup only has a single wave.
957 */
958 bool next_is_divergent_if =
959 next_cf_node->type == nir_cf_node_if &&
960 nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent;
961
962 if (next_is_loop || next_is_divergent_if) {
963 block = nir_cf_node_cf_tree_next(next_cf_node);
964 continue;
965 }
966 }
967
968 /* Go to the next block. */
969 block = nir_block_cf_tree_next(block);
970 }
971 }
972
973 /**
974 * Reuses suitable variables from the top part of the shader,
975 * by deleting their stores from the bottom part.
976 */
977 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)978 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
979 {
980 if (!u_vector_length(&nogs_state->saved_uniforms)) {
981 u_vector_finish(&nogs_state->saved_uniforms);
982 return;
983 }
984
985 nir_foreach_block_reverse_safe(block, b->impl) {
986 nir_foreach_instr_reverse_safe(instr, block) {
987 if (instr->type != nir_instr_type_intrinsic)
988 continue;
989 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
990
991 /* When we found any of these intrinsics, it means
992 * we reached the top part and we must stop.
993 */
994 if (intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd)
995 goto done;
996
997 if (intrin->intrinsic != nir_intrinsic_store_deref)
998 continue;
999 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1000 if (deref->deref_type != nir_deref_type_var)
1001 continue;
1002
1003 saved_uniform *saved;
1004 u_vector_foreach(saved, &nogs_state->saved_uniforms) {
1005 if (saved->var == deref->var) {
1006 nir_instr_remove(instr);
1007 }
1008 }
1009 }
1010 }
1011
1012 done:
1013 u_vector_finish(&nogs_state->saved_uniforms);
1014 }
1015
1016 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * nogs_state)1017 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)
1018 {
1019 assert(b->shader->info.outputs_written & (1 << VARYING_SLOT_POS));
1020
1021 bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1022 bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1023
1024 unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;
1025 if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)
1026 max_exported_args--;
1027 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)
1028 max_exported_args--;
1029
1030 unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;
1031 unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;
1032 unsigned max_num_waves = nogs_state->max_num_waves;
1033 unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);
1034 unsigned ngg_scratch_lds_bytes = DIV_ROUND_UP(max_num_waves, 4u);
1035 nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;
1036
1037 nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1038
1039 /* Create some helper variables. */
1040 nir_variable *position_value_var = nogs_state->position_value_var;
1041 nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
1042 nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
1043 nir_variable *es_accepted_var = nogs_state->es_accepted_var;
1044 nir_variable *gs_vtxaddr_vars[3] = {
1045 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1046 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1047 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1048 };
1049 nir_variable *repacked_arg_vars[4] = {
1050 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),
1051 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),
1052 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),
1053 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
1054 };
1055
1056 /* Top part of the culling shader (aka. position shader part)
1057 *
1058 * We clone the full ES shader and emit it here, but we only really care
1059 * about its position output, so we delete every other output from this part.
1060 * The position output is stored into a temporary variable, and reloaded later.
1061 */
1062
1063 b->cursor = nir_before_cf_list(&impl->body);
1064
1065 nir_ssa_def *es_thread = nir_build_has_input_vertex_amd(b);
1066 nir_if *if_es_thread = nir_push_if(b, es_thread);
1067 {
1068 /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1069 * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1070 */
1071 nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1072
1073 /* Now reinsert a clone of the shader code */
1074 struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1075 nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1076 _mesa_hash_table_destroy(remap_table, NULL);
1077 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1078
1079 /* Remember the current thread's shader arguments */
1080 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1081 nir_store_var(b, repacked_arg_vars[0], nir_build_load_vertex_id_zero_base(b), 0x1u);
1082 if (uses_instance_id)
1083 nir_store_var(b, repacked_arg_vars[1], nir_build_load_instance_id(b), 0x1u);
1084 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1085 nir_ssa_def *tess_coord = nir_build_load_tess_coord(b);
1086 nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);
1087 nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);
1088 nir_store_var(b, repacked_arg_vars[2], nir_build_load_tess_rel_patch_id_amd(b), 0x1u);
1089 if (uses_tess_primitive_id)
1090 nir_store_var(b, repacked_arg_vars[3], nir_build_load_primitive_id(b), 0x1u);
1091 } else {
1092 unreachable("Should be VS or TES.");
1093 }
1094 }
1095 nir_pop_if(b, if_es_thread);
1096
1097 nir_store_var(b, es_accepted_var, es_thread, 0x1u);
1098 nir_store_var(b, gs_accepted_var, nir_build_has_input_primitive_amd(b), 0x1u);
1099
1100 /* Remove all non-position outputs, and put the position output into the variable. */
1101 nir_metadata_preserve(impl, nir_metadata_none);
1102 remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);
1103 b->cursor = nir_after_cf_list(&impl->body);
1104
1105 /* Run culling algorithms if culling is enabled.
1106 *
1107 * NGG culling can be enabled or disabled in runtime.
1108 * This is determined by a SGPR shader argument which is acccessed
1109 * by the following NIR intrinsic.
1110 */
1111
1112 nir_if *if_cull_en = nir_push_if(b, nir_build_load_cull_any_enabled_amd(b));
1113 {
1114 nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);
1115 nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1116
1117 /* ES invocations store their vertex data to LDS for GS threads to read. */
1118 if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1119 {
1120 /* Store position components that are relevant to culling in LDS */
1121 nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);
1122 nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1123 nir_build_store_shared(b, pre_cull_w, es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_pos_w);
1124 nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1125 nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1126 nir_build_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .write_mask = 0x3u, .align_mul = 4, .base = lds_es_pos_x);
1127
1128 /* Clear out the ES accepted flag in LDS */
1129 nir_build_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_vertex_accepted);
1130 }
1131 nir_pop_if(b, if_es_thread);
1132
1133 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1134 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1135
1136 nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);
1137 nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1 << 31), 0x1u);
1138
1139 /* GS invocations load the vertex data and perform the culling. */
1140 nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
1141 {
1142 /* Load vertex indices from input VGPRs */
1143 nir_ssa_def *vtx_idx[3] = {0};
1144 for (unsigned vertex = 0; vertex < 3; ++vertex)
1145 vtx_idx[vertex] = ngg_input_primitive_vertex_index(b, vertex);
1146
1147 nir_ssa_def *vtx_addr[3] = {0};
1148 nir_ssa_def *pos[3][4] = {0};
1149
1150 /* Load W positions of vertices first because the culling code will use these first */
1151 for (unsigned vtx = 0; vtx < 3; ++vtx) {
1152 vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1153 pos[vtx][3] = nir_build_load_shared(b, 1, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_w);
1154 nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);
1155 }
1156
1157 /* Load the X/W, Y/W positions of vertices */
1158 for (unsigned vtx = 0; vtx < 3; ++vtx) {
1159 nir_ssa_def *xy = nir_build_load_shared(b, 2, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_x);
1160 pos[vtx][0] = nir_channel(b, xy, 0);
1161 pos[vtx][1] = nir_channel(b, xy, 1);
1162 }
1163
1164 /* See if the current primitive is accepted */
1165 nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);
1166 nir_store_var(b, gs_accepted_var, accepted, 0x1u);
1167
1168 nir_if *if_gs_accepted = nir_push_if(b, accepted);
1169 {
1170 /* Store the accepted state to LDS for ES threads */
1171 for (unsigned vtx = 0; vtx < 3; ++vtx)
1172 nir_build_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u, .write_mask = 0x1u);
1173 }
1174 nir_pop_if(b, if_gs_accepted);
1175 }
1176 nir_pop_if(b, if_gs_thread);
1177
1178 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1179 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1180
1181 nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
1182
1183 /* ES invocations load their accepted flag from LDS. */
1184 if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1185 {
1186 nir_ssa_def *accepted = nir_build_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1187 nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));
1188 nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);
1189 }
1190 nir_pop_if(b, if_es_thread);
1191
1192 nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);
1193
1194 /* Repack the vertices that survived the culling. */
1195 wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,
1196 nogs_state->max_num_waves, nogs_state->wave_size);
1197 nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
1198 nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
1199
1200 /* If all vertices are culled, set primitive count to 0 as well. */
1201 nir_ssa_def *num_exported_prims = nir_build_load_workgroup_num_input_primitives_amd(b);
1202 nir_ssa_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1203 num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims);
1204
1205 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1206 {
1207 /* Tell the final vertex and primitive count to the HW. */
1208 nir_build_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
1209 }
1210 nir_pop_if(b, if_wave_0);
1211
1212 /* Vertex compaction. */
1213 compact_vertices_after_culling(b, nogs_state,
1214 repacked_arg_vars, gs_vtxaddr_vars,
1215 invocation_index, es_vertex_lds_addr,
1216 es_exporter_tid, num_live_vertices_in_workgroup, fully_culled,
1217 ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);
1218 }
1219 nir_push_else(b, if_cull_en);
1220 {
1221 /* When culling is disabled, we do the same as we would without culling. */
1222 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1223 {
1224 nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1225 nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
1226 nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1227 }
1228 nir_pop_if(b, if_wave_0);
1229 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
1230 }
1231 nir_pop_if(b, if_cull_en);
1232
1233 /* Update shader arguments.
1234 *
1235 * The registers which hold information about the subgroup's
1236 * vertices and primitives are updated here, so the rest of the shader
1237 * doesn't need to worry about the culling.
1238 *
1239 * These "overwrite" intrinsics must be at top level control flow,
1240 * otherwise they can mess up the backend (eg. ACO's SSA).
1241 *
1242 * TODO:
1243 * A cleaner solution would be to simply replace all usages of these args
1244 * with the load of the variables.
1245 * However, this wouldn't work right now because the backend uses the arguments
1246 * for purposes not expressed in NIR, eg. VS input loads, etc.
1247 * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1248 */
1249
1250 if (b->shader->info.stage == MESA_SHADER_VERTEX)
1251 nogs_state->overwrite_args =
1252 nir_build_overwrite_vs_arguments_amd(b,
1253 nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));
1254 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1255 nogs_state->overwrite_args =
1256 nir_build_overwrite_tes_arguments_amd(b,
1257 nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),
1258 nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));
1259 else
1260 unreachable("Should be VS or TES.");
1261 }
1262
1263 void
ac_nir_lower_ngg_nogs(nir_shader * shader,unsigned max_num_es_vertices,unsigned num_vertices_per_primitives,unsigned max_workgroup_size,unsigned wave_size,bool can_cull,bool early_prim_export,bool passthrough,bool export_prim_id,bool provoking_vtx_last,bool use_edgeflags,uint32_t instance_rate_inputs)1264 ac_nir_lower_ngg_nogs(nir_shader *shader,
1265 unsigned max_num_es_vertices,
1266 unsigned num_vertices_per_primitives,
1267 unsigned max_workgroup_size,
1268 unsigned wave_size,
1269 bool can_cull,
1270 bool early_prim_export,
1271 bool passthrough,
1272 bool export_prim_id,
1273 bool provoking_vtx_last,
1274 bool use_edgeflags,
1275 uint32_t instance_rate_inputs)
1276 {
1277 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1278 assert(impl);
1279 assert(max_num_es_vertices && max_workgroup_size && wave_size);
1280 assert(!(can_cull && passthrough));
1281
1282 nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
1283 nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
1284 nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
1285 nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
1286
1287 lower_ngg_nogs_state state = {
1288 .passthrough = passthrough,
1289 .export_prim_id = export_prim_id,
1290 .early_prim_export = early_prim_export,
1291 .use_edgeflags = use_edgeflags,
1292 .num_vertices_per_primitives = num_vertices_per_primitives,
1293 .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
1294 .position_value_var = position_value_var,
1295 .prim_exp_arg_var = prim_exp_arg_var,
1296 .es_accepted_var = es_accepted_var,
1297 .gs_accepted_var = gs_accepted_var,
1298 .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1299 .max_es_num_vertices = max_num_es_vertices,
1300 .wave_size = wave_size,
1301 .instance_rate_inputs = instance_rate_inputs,
1302 };
1303
1304 /* We need LDS space when VS needs to export the primitive ID. */
1305 if (shader->info.stage == MESA_SHADER_VERTEX && export_prim_id)
1306 state.total_lds_bytes = max_num_es_vertices * 4u;
1307
1308 nir_builder builder;
1309 nir_builder *b = &builder; /* This is to avoid the & */
1310 nir_builder_init(b, impl);
1311
1312 if (can_cull) {
1313 /* We need divergence info for culling shaders. */
1314 nir_divergence_analysis(shader);
1315 analyze_shader_before_culling(shader, &state);
1316 save_reusable_variables(b, &state);
1317 }
1318
1319 nir_cf_list extracted;
1320 nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1321 b->cursor = nir_before_cf_list(&impl->body);
1322
1323 if (!can_cull) {
1324 /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
1325 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1326 {
1327 nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1328 nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
1329 nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1330 }
1331 nir_pop_if(b, if_wave_0);
1332
1333 /* Take care of early primitive export, otherwise just pack the primitive export argument */
1334 if (state.early_prim_export)
1335 emit_ngg_nogs_prim_export(b, &state, NULL);
1336 else
1337 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
1338 } else {
1339 add_deferred_attribute_culling(b, &extracted, &state);
1340 b->cursor = nir_after_cf_list(&impl->body);
1341
1342 if (state.early_prim_export)
1343 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
1344 }
1345
1346 nir_intrinsic_instr *export_vertex_instr;
1347 nir_ssa_def *es_thread = can_cull ? nir_load_var(b, es_accepted_var) : nir_build_has_input_vertex_amd(b);
1348
1349 nir_if *if_es_thread = nir_push_if(b, es_thread);
1350 {
1351 /* Run the actual shader */
1352 nir_cf_reinsert(&extracted, b->cursor);
1353 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1354
1355 /* Export all vertex attributes (except primitive ID) */
1356 export_vertex_instr = nir_build_export_vertex_amd(b);
1357
1358 /* Export primitive ID (in case of early primitive export or TES) */
1359 if (state.export_prim_id && (state.early_prim_export || shader->info.stage != MESA_SHADER_VERTEX))
1360 emit_store_ngg_nogs_es_primitive_id(b);
1361 }
1362 nir_pop_if(b, if_es_thread);
1363
1364 /* Take care of late primitive export */
1365 if (!state.early_prim_export) {
1366 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
1367 if (state.export_prim_id && shader->info.stage == MESA_SHADER_VERTEX) {
1368 if_es_thread = nir_push_if(b, can_cull ? es_thread : nir_build_has_input_vertex_amd(b));
1369 emit_store_ngg_nogs_es_primitive_id(b);
1370 nir_pop_if(b, if_es_thread);
1371 }
1372 }
1373
1374 if (can_cull) {
1375 /* Replace uniforms. */
1376 apply_reusable_variables(b, &state);
1377
1378 /* Remove the redundant position output. */
1379 remove_extra_pos_outputs(shader, &state);
1380
1381 /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
1382 * it seems that it's best to put the position export always at the end, and
1383 * then let ACO schedule it up (slightly) only when early prim export is used.
1384 */
1385 b->cursor = nir_before_instr(&export_vertex_instr->instr);
1386
1387 nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
1388 nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
1389 nir_build_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem, .write_mask = 0xfu);
1390 }
1391
1392 nir_metadata_preserve(impl, nir_metadata_none);
1393 nir_validate_shader(shader, "after emitting NGG VS/TES");
1394
1395 /* Cleanup */
1396 nir_opt_dead_write_vars(shader);
1397 nir_lower_vars_to_ssa(shader);
1398 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1399 nir_lower_alu_to_scalar(shader, NULL, NULL);
1400 nir_lower_phis_to_scalar(shader, true);
1401
1402 if (can_cull) {
1403 /* It's beneficial to redo these opts after splitting the shader. */
1404 nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
1405 nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
1406 }
1407
1408 bool progress;
1409 do {
1410 progress = false;
1411 NIR_PASS(progress, shader, nir_opt_undef);
1412 NIR_PASS(progress, shader, nir_opt_dce);
1413 NIR_PASS(progress, shader, nir_opt_dead_cf);
1414
1415 if (can_cull)
1416 progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
1417 } while (progress);
1418
1419 shader->info.shared_size = state.total_lds_bytes;
1420 }
1421
1422 static nir_ssa_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_ssa_def * out_vtx_idx,lower_ngg_gs_state * s)1423 ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s)
1424 {
1425 unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
1426
1427 /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
1428 if (write_stride_2exp) {
1429 nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
1430 nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
1431 out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
1432 }
1433
1434 nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
1435 return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
1436 }
1437
1438 static nir_ssa_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_ssa_def * gs_vtx_idx,lower_ngg_gs_state * s)1439 ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s)
1440 {
1441 nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);
1442 nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
1443 nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
1444
1445 return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
1446 }
1447
1448 static void
ngg_gs_clear_primflags(nir_builder * b,nir_ssa_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)1449 ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
1450 {
1451 nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8);
1452 nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u);
1453
1454 nir_loop *loop = nir_push_loop(b);
1455 {
1456 nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var);
1457 nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out)));
1458 {
1459 nir_jump(b, nir_jump_break);
1460 }
1461 nir_push_else(b, if_break);
1462 {
1463 nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s);
1464 nir_build_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 1, .write_mask = 0x1u);
1465 nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u);
1466 }
1467 nir_pop_if(b, if_break);
1468 }
1469 nir_pop_loop(b, loop);
1470 }
1471
1472 static void
ngg_gs_shader_query(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1473 ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1474 {
1475 nir_if *if_shader_query = nir_push_if(b, nir_build_load_shader_query_enabled_amd(b));
1476 nir_ssa_def *num_prims_in_wave = NULL;
1477
1478 /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
1479 * GS emits points, line strips or triangle strips.
1480 * Real primitives are points, lines or triangles.
1481 */
1482 if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
1483 unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
1484 unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
1485 unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
1486 nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
1487 num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
1488 } else {
1489 nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
1490 nir_ssa_def *prm_cnt = intrin->src[1].ssa;
1491 if (s->num_vertices_per_primitive > 1)
1492 prm_cnt = nir_iadd_nuw(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
1493 num_prims_in_wave = nir_build_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
1494 }
1495
1496 /* Store the query result to GDS using an atomic add. */
1497 nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
1498 nir_build_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100));
1499 nir_pop_if(b, if_first_lane);
1500
1501 nir_pop_if(b, if_shader_query);
1502 }
1503
1504 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1505 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1506 {
1507 assert(nir_src_is_const(intrin->src[1]));
1508 b->cursor = nir_before_instr(&intrin->instr);
1509
1510 unsigned writemask = nir_intrinsic_write_mask(intrin);
1511 unsigned base = nir_intrinsic_base(intrin);
1512 unsigned component_offset = nir_intrinsic_component(intrin);
1513 unsigned base_offset = nir_src_as_uint(intrin->src[1]);
1514 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1515
1516 assert((base + base_offset) < VARYING_SLOT_MAX);
1517
1518 nir_ssa_def *store_val = intrin->src[0].ssa;
1519
1520 for (unsigned comp = 0; comp < 4; ++comp) {
1521 if (!(writemask & (1 << comp)))
1522 continue;
1523 unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;
1524 if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))
1525 continue;
1526
1527 /* Small bitsize components consume the same amount of space as 32-bit components,
1528 * but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)
1529 */
1530 unsigned num_consumed_components = MIN2(1, DIV_ROUND_UP(store_val->bit_size, 32));
1531 nir_ssa_def *element = nir_channel(b, store_val, comp);
1532 if (num_consumed_components > 1)
1533 element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32);
1534
1535 for (unsigned c = 0; c < num_consumed_components; ++c) {
1536 unsigned component_index = (comp * num_consumed_components) + c + component_offset;
1537 unsigned base_index = base + base_offset + component_index / 4;
1538 component_index %= 4;
1539
1540 /* Save output usage info */
1541 gs_output_component_info *info = &s->output_component_info[base_index][component_index];
1542 info->bit_size = MAX2(info->bit_size, MIN2(store_val->bit_size, 32));
1543 info->stream = stream;
1544
1545 /* Store the current component element */
1546 nir_ssa_def *component_element = element;
1547 if (num_consumed_components > 1)
1548 component_element = nir_channel(b, component_element, c);
1549 if (component_element->bit_size != 32)
1550 component_element = nir_u2u32(b, component_element);
1551
1552 nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u);
1553 }
1554 }
1555
1556 nir_instr_remove(&intrin->instr);
1557 return true;
1558 }
1559
1560 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1561 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1562 {
1563 b->cursor = nir_before_instr(&intrin->instr);
1564
1565 unsigned stream = nir_intrinsic_stream_id(intrin);
1566 if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1567 nir_instr_remove(&intrin->instr);
1568 return true;
1569 }
1570
1571 nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa;
1572 nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa;
1573 nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
1574
1575 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1576 unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1577
1578 for (unsigned comp = 0; comp < 4; ++comp) {
1579 gs_output_component_info *info = &s->output_component_info[slot][comp];
1580 if (info->stream != stream || !info->bit_size)
1581 continue;
1582
1583 /* Store the output to LDS */
1584 nir_ssa_def *out_val = nir_load_var(b, s->output_vars[slot][comp]);
1585 if (info->bit_size != 32)
1586 out_val = nir_u2u(b, out_val, info->bit_size);
1587
1588 nir_build_store_shared(b, out_val, gs_emit_vtx_addr, .base = packed_location * 16 + comp * 4, .align_mul = 4, .write_mask = 0x1u);
1589
1590 /* Clear the variable that holds the output */
1591 nir_store_var(b, s->output_vars[slot][comp], nir_ssa_undef(b, 1, 32), 0x1u);
1592 }
1593 }
1594
1595 /* Calculate and store per-vertex primitive flags based on vertex counts:
1596 * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
1597 * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
1598 * - bit 2: always 1 (so that we can use it for determining vertex liveness)
1599 */
1600
1601 nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
1602 nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
1603
1604 if (s->num_vertices_per_primitive == 3) {
1605 nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
1606 prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
1607 }
1608
1609 nir_build_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u, .write_mask = 0x1u);
1610 nir_instr_remove(&intrin->instr);
1611 return true;
1612 }
1613
1614 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)1615 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
1616 {
1617 b->cursor = nir_before_instr(&intrin->instr);
1618
1619 /* These are not needed, we can simply remove them */
1620 nir_instr_remove(&intrin->instr);
1621 return true;
1622 }
1623
1624 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1625 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1626 {
1627 b->cursor = nir_before_instr(&intrin->instr);
1628
1629 unsigned stream = nir_intrinsic_stream_id(intrin);
1630 if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1631 nir_instr_remove(&intrin->instr);
1632 return true;
1633 }
1634
1635 s->found_out_vtxcnt[stream] = true;
1636
1637 /* Clear the primitive flags of non-emitted vertices */
1638 if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
1639 ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
1640
1641 ngg_gs_shader_query(b, intrin, s);
1642 nir_instr_remove(&intrin->instr);
1643 return true;
1644 }
1645
1646 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)1647 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
1648 {
1649 lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
1650
1651 if (instr->type != nir_instr_type_intrinsic)
1652 return false;
1653
1654 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1655
1656 if (intrin->intrinsic == nir_intrinsic_store_output)
1657 return lower_ngg_gs_store_output(b, intrin, s);
1658 else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
1659 return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
1660 else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
1661 return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
1662 else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
1663 return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
1664
1665 return false;
1666 }
1667
1668 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)1669 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
1670 {
1671 nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
1672 }
1673
1674 static void
ngg_gs_export_primitives(nir_builder * b,nir_ssa_def * max_num_out_prims,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,nir_ssa_def * primflag_0,lower_ngg_gs_state * s)1675 ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg,
1676 nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0,
1677 lower_ngg_gs_state *s)
1678 {
1679 nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
1680
1681 /* Only bit 0 matters here - set it to 1 when the primitive should be null */
1682 nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
1683
1684 nir_ssa_def *vtx_indices[3] = {0};
1685 vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
1686 if (s->num_vertices_per_primitive >= 2)
1687 vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1));
1688 if (s->num_vertices_per_primitive == 3)
1689 vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2));
1690
1691 if (s->num_vertices_per_primitive == 3) {
1692 /* API GS outputs triangle strips, but NGG HW understands triangles.
1693 * We already know the triangles due to how we set the primitive flags, but we need to
1694 * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
1695 */
1696
1697 nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));
1698 if (!s->provoking_vertex_last) {
1699 vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);
1700 vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);
1701 } else {
1702 vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);
1703 vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);
1704 }
1705 }
1706
1707 nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, false);
1708 nir_build_export_primitive_amd(b, arg);
1709 nir_pop_if(b, if_prim_export_thread);
1710 }
1711
1712 static void
ngg_gs_export_vertices(nir_builder * b,nir_ssa_def * max_num_out_vtx,nir_ssa_def * tid_in_tg,nir_ssa_def * out_vtx_lds_addr,lower_ngg_gs_state * s)1713 ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg,
1714 nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
1715 {
1716 nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1717 nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
1718
1719 if (!s->output_compile_time_known) {
1720 /* Vertex compaction.
1721 * The current thread will export a vertex that was live in another invocation.
1722 * Load the index of the vertex that the current thread will have to export.
1723 */
1724 nir_ssa_def *exported_vtx_idx = nir_build_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u);
1725 exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
1726 }
1727
1728 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1729 if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))
1730 continue;
1731
1732 unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1733 nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
1734
1735 for (unsigned comp = 0; comp < 4; ++comp) {
1736 gs_output_component_info *info = &s->output_component_info[slot][comp];
1737 if (info->stream != 0 || info->bit_size == 0)
1738 continue;
1739
1740 nir_ssa_def *load = nir_build_load_shared(b, 1, info->bit_size, exported_out_vtx_lds_addr, .base = packed_location * 16u + comp * 4u, .align_mul = 4u);
1741 nir_build_store_output(b, load, nir_imm_int(b, 0), .write_mask = 0x1u, .base = slot, .component = comp, .io_semantics = io_sem);
1742 }
1743 }
1744
1745 nir_build_export_vertex_amd(b);
1746 nir_pop_if(b, if_vtx_export_thread);
1747 }
1748
1749 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_ssa_def * vertex_live,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,lower_ngg_gs_state * s)1750 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg,
1751 nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
1752 {
1753 assert(vertex_live->bit_size == 1);
1754 nir_if *if_vertex_live = nir_push_if(b, vertex_live);
1755 {
1756 /* Setup the vertex compaction.
1757 * Save the current thread's id for the thread which will export the current vertex.
1758 * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
1759 */
1760
1761 nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
1762 nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
1763 nir_build_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u, .write_mask = 0x1u);
1764 }
1765 nir_pop_if(b, if_vertex_live);
1766 }
1767
1768 static nir_ssa_def *
ngg_gs_load_out_vtx_primflag_0(nir_builder * b,nir_ssa_def * tid_in_tg,nir_ssa_def * vtx_lds_addr,nir_ssa_def * max_num_out_vtx,lower_ngg_gs_state * s)1769 ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr,
1770 nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s)
1771 {
1772 nir_ssa_def *zero = nir_imm_int(b, 0);
1773
1774 nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1775 nir_ssa_def *primflag_0 = nir_build_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u);
1776 primflag_0 = nir_u2u32(b, primflag_0);
1777 nir_pop_if(b, if_outvtx_thread);
1778
1779 return nir_if_phi(b, primflag_0, zero);
1780 }
1781
1782 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)1783 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
1784 {
1785 nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);
1786 nir_ssa_def *max_vtxcnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1787 nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
1788 nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
1789
1790 if (s->output_compile_time_known) {
1791 /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
1792 * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
1793 */
1794 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1795 nir_build_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt);
1796 nir_pop_if(b, if_wave_0);
1797 }
1798
1799 /* Workgroup barrier: wait for all GS threads to finish */
1800 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1801 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1802
1803 nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
1804
1805 if (s->output_compile_time_known) {
1806 ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
1807 ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
1808 return;
1809 }
1810
1811 /* When the output vertex count is not known at compile time:
1812 * There may be gaps between invocations that have live vertices, but NGG hardware
1813 * requires that the invocations that export vertices are packed (ie. compact).
1814 * To ensure this, we need to repack invocations that have a live vertex.
1815 */
1816 nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
1817 wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
1818
1819 nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
1820 nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
1821
1822 /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
1823 nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));
1824 max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
1825
1826 /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
1827 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1828 nir_build_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
1829 nir_pop_if(b, if_wave_0);
1830
1831 /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
1832 ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
1833
1834 /* Workgroup barrier: wait for all LDS stores to finish. */
1835 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1836 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1837
1838 ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
1839 ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
1840 }
1841
1842 void
ac_nir_lower_ngg_gs(nir_shader * shader,unsigned wave_size,unsigned max_workgroup_size,unsigned esgs_ring_lds_bytes,unsigned gs_out_vtx_bytes,unsigned gs_total_out_vtx_bytes,bool provoking_vertex_last)1843 ac_nir_lower_ngg_gs(nir_shader *shader,
1844 unsigned wave_size,
1845 unsigned max_workgroup_size,
1846 unsigned esgs_ring_lds_bytes,
1847 unsigned gs_out_vtx_bytes,
1848 unsigned gs_total_out_vtx_bytes,
1849 bool provoking_vertex_last)
1850 {
1851 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1852 assert(impl);
1853
1854 lower_ngg_gs_state state = {
1855 .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1856 .wave_size = wave_size,
1857 .lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
1858 .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
1859 .lds_offs_primflags = gs_out_vtx_bytes,
1860 .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
1861 .provoking_vertex_last = provoking_vertex_last,
1862 };
1863
1864 unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
1865 unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
1866 shader->info.shared_size = total_lds_bytes;
1867
1868 nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
1869 state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
1870 state.const_out_prmcnt[0] != -1;
1871
1872 if (!state.output_compile_time_known)
1873 state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
1874
1875 if (shader->info.gs.output_primitive == GL_POINTS)
1876 state.num_vertices_per_primitive = 1;
1877 else if (shader->info.gs.output_primitive == GL_LINE_STRIP)
1878 state.num_vertices_per_primitive = 2;
1879 else if (shader->info.gs.output_primitive == GL_TRIANGLE_STRIP)
1880 state.num_vertices_per_primitive = 3;
1881 else
1882 unreachable("Invalid GS output primitive.");
1883
1884 /* Extract the full control flow. It is going to be wrapped in an if statement. */
1885 nir_cf_list extracted;
1886 nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1887
1888 nir_builder builder;
1889 nir_builder *b = &builder; /* This is to avoid the & */
1890 nir_builder_init(b, impl);
1891 b->cursor = nir_before_cf_list(&impl->body);
1892
1893 /* Workgroup barrier: wait for ES threads */
1894 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1895 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1896
1897 /* Wrap the GS control flow. */
1898 nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
1899
1900 /* Create and initialize output variables */
1901 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1902 for (unsigned comp = 0; comp < 4; ++comp) {
1903 state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output");
1904 }
1905 }
1906
1907 nir_cf_reinsert(&extracted, b->cursor);
1908 b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
1909 nir_pop_if(b, if_gs_thread);
1910
1911 /* Lower the GS intrinsics */
1912 lower_ngg_gs_intrinsics(shader, &state);
1913 b->cursor = nir_after_cf_list(&impl->body);
1914
1915 if (!state.found_out_vtxcnt[0]) {
1916 fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
1917 abort();
1918 }
1919
1920 /* Emit the finale sequence */
1921 ngg_gs_finale(b, &state);
1922 nir_validate_shader(shader, "after emitting NGG GS");
1923
1924 /* Cleanup */
1925 nir_lower_vars_to_ssa(shader);
1926 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1927 nir_metadata_preserve(impl, nir_metadata_none);
1928 }
1929