1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 
28 /*
29  * These NIR passes are used to lower NIR cross-stage I/O intrinsics into the
30  * memory accesses that actually happen on the HW.
31  *
32  * Each input and output has a 16-byte (4 dwords) slot reserved for it, and
33  * can have up to 4 components. Each component is 32 bits.
34  *
35  * ## VS-TCS-TES I/O - Terminology:
36  *
37  * * patch - Group of vertices, used instead of primitives in tessellation
38  * * per-vertex - input or output which can be different for every vertex.
39  * * per-patch - input output which applies to a patch (a group of vertices)
40  *
41  * ## VS-TCS-TES I/O - How it works:
42  *
43  * ```
44  * SW model:    SW VS         SW TCS    tessellator    SW TES
45  *                ┊             ┊             ┊          ┊
46  *              ┌────┐        ┌────┐        ┌────┐    ┌─────┐
47  * HW pipeline: │ LS │─╮   ╭─>│ HS │─╮   ╭─>│ FF │ ╭─>│VS/ES48  *              └────┘ │   │  └────┘ │   │  └────┘ │  └─────┘
49  * Memory:             ╰─>LDS<──╯    ╰─>VRAM───────╯
50  * ```
51  *
52  * * SW VS runs as a HW LS (Local Shader, merged into HS on GFX9+),
53  *   and SW TCS runs as HW HS (Hull Shader).
54  *   SW TES runs as either HW VS or HW ES (Export Shader).
55  * * LS and HS share the same LDS space.
56  * * LS (SW VS) stores outputs to LDS to be read by HS (SW TCS).
57  * * HS (SW TCS) stores outputs in LDS if the HS (SW TCS) reads them.
58  * * HS (SW TCS) stores outputs in VRAM if the next stage (SW TES) reads them.
59  *
60  * Side note: some old HW supports having TES read from the same LDS space where LS/HS write, but
61  * Mesa always stores HS outputs to VRAM to avoid forcing TES waves to run on the same CU as the LS/HS waves.
62  *
63  * ### Passing VS-TCS I/O in registers
64  *
65  * On GPUs that run SW VS and  SW TCS on the same HW stage (HS on GFX9+),
66  * IO can be passed through registers instead of LDS when the following conditions are met:
67  *
68  * 1. TCS input and output patch size match
69  * 2. Floating point execution modes in SW VS and SW TCS match
70  * 3. The SW VS output is not written indirectly, and the corresponding SW TCS input is not read indirectly
71  *
72  * Some HS outputs could be passed through registers to, but this is a TODO.
73  *
74  * ### LDS layout used by VS-TCS:
75  *
76  * ```
77  * TCS per-vertex inputs for patch 0  <─── 0
78  * TCS per-vertex inputs for patch 1
79  * TCS per-vertex inputs for patch 2  <─── hs_per_vertex_input_lds_offset (rel_patch_id = 2)
80  * ...
81  * TCS per-vertex outputs for patch 0 <─── output_patch0_offset
82  * TCS per-patch outputs for patch 0  <─── output_patch0_patch_data_offset
83  * TCS per-vertex outputs for patch 1
84  * TCS per-patch outputs for patch 1
85  * TCS per-vertex outputs for patch 2 <─── hs_output_lds_offset (rel_patch_id = 2, per-vertex)
86  * TCS per-patch outputs for patch 2  <─── hs_output_lds_offset (rel_patch_id = 2, per-patch)
87  * ...
88  * ```
89  *
90  * ### VRAM layout used by TCS-TES I/O:
91  *
92  * ```
93  * attr 0 of patch 0 vertex 0   <─── "off-chip LDS" offset
94  * attr 0 of patch 0 vertex 1
95  * attr 0 of patch 0 vertex 2
96  * ...
97  * attr 0 of patch 1 vertex 0
98  * attr 0 of patch 1 vertex 1
99  * attr 0 of patch 1 vertex 2   <─── hs_per_vertex_output_vmem_offset (attribute slot = 0, rel_patch_id = 1, vertex index = 1)
100  * ...
101  * attr 0 of patch 2 vertex 0
102  * attr 0 of patch 2 vertex 1
103  * attr 0 of patch 2 vertex 2
104  * ...
105  * attr 1 of patch 0 vertex 0
106  * attr 1 of patch 0 vertex 1
107  * attr 1 of patch 0 vertex 2
108  * ...
109  * ...
110  * per-patch attr 0 of patch 0
111  * per-patch attr 0 of patch 1
112  * per-patch attr 0 of patch 2  <─── hs_per_patch_output_vmem_offset (attribute slot = 0, rel_patch_id = 2)
113  * ...
114  * per-patch attr 1 of patch 0
115  * per-patch attr 1 of patch 1
116  * per-patch attr 1 of patch 2
117  * ...
118  * ```
119  *
120  */
121 
122 typedef struct {
123    /* Which hardware generation we're dealing with */
124    enum chip_class chip_class;
125 
126    /* True if merged VS+TCS (on GFX9+) has the same number
127     * of input and output patch size.
128     */
129    bool tcs_in_out_eq;
130 
131    /* Bit mask of TCS per-vertex inputs (VS outputs) which
132     * are passed between the two stages only in temporaries (registers).
133     */
134    uint64_t tcs_temp_only_inputs;
135 
136    /* Bit mask of TCS outputs read by TES. */
137    uint64_t tes_inputs_read;
138    uint64_t tes_patch_inputs_read;
139 
140    /* Whether TES reads the tess factors. */
141    bool tes_reads_tessfactors;
142 
143    /* Number of inputs for which memory should be reserved.
144     * When compacted, this should be the number of linked inputs.
145     */
146    unsigned tcs_num_reserved_inputs;
147    unsigned tcs_num_reserved_outputs;
148    unsigned tcs_num_reserved_patch_outputs;
149 
150    /* Location (slot) where tessellation levels are stored. */
151    unsigned tcs_tess_lvl_in_loc;
152    unsigned tcs_tess_lvl_out_loc;
153 
154 } lower_tess_io_state;
155 
156 static bool
match_mask(gl_shader_stage stage,nir_intrinsic_instr * intrin,uint64_t mask,bool match_indirect)157 match_mask(gl_shader_stage stage,
158            nir_intrinsic_instr *intrin,
159            uint64_t mask,
160            bool match_indirect)
161 {
162    bool indirect = !nir_src_is_const(*nir_get_io_offset_src(intrin));
163    if (indirect)
164       return match_indirect;
165 
166    uint64_t slot = nir_intrinsic_io_semantics(intrin).location;
167    if (stage == MESA_SHADER_TESS_CTRL &&
168        intrin->intrinsic != nir_intrinsic_load_per_vertex_input &&
169        intrin->intrinsic != nir_intrinsic_store_per_vertex_output)
170       slot -= VARYING_SLOT_PATCH0;
171 
172    return (UINT64_C(1) << slot) & mask;
173 }
174 
175 static bool
tcs_output_needs_vmem(nir_intrinsic_instr * intrin,lower_tess_io_state * st)176 tcs_output_needs_vmem(nir_intrinsic_instr *intrin,
177                       lower_tess_io_state *st)
178 {
179    uint64_t mask = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
180                    ? st->tes_inputs_read
181                    : st->tes_patch_inputs_read;
182 
183    return match_mask(MESA_SHADER_TESS_CTRL, intrin, mask, true);
184 }
185 
186 static bool
tcs_output_needs_lds(nir_intrinsic_instr * intrin,nir_shader * shader)187 tcs_output_needs_lds(nir_intrinsic_instr *intrin,
188                      nir_shader *shader)
189 {
190    uint64_t mask = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
191                    ? shader->info.outputs_read
192                    : shader->info.patch_outputs_read;
193 
194    return match_mask(MESA_SHADER_TESS_CTRL, intrin, mask, true);
195 }
196 
197 static bool
lower_ls_output_store(nir_builder * b,nir_instr * instr,void * state)198 lower_ls_output_store(nir_builder *b,
199                       nir_instr *instr,
200                       void *state)
201 {
202    if (instr->type != nir_instr_type_intrinsic)
203       return false;
204 
205    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
206 
207    if (intrin->intrinsic != nir_intrinsic_store_output)
208       return false;
209 
210    lower_tess_io_state *st = (lower_tess_io_state *) state;
211 
212    /* If this is a temp-only TCS input, we don't need to use shared memory at all. */
213    if (match_mask(MESA_SHADER_VERTEX, intrin, st->tcs_temp_only_inputs, false))
214       return false;
215 
216    b->cursor = nir_before_instr(instr);
217 
218    nir_ssa_def *vertex_idx = nir_build_load_local_invocation_index(b);
219    nir_ssa_def *base_off_var = nir_imul_imm(b, vertex_idx, st->tcs_num_reserved_inputs * 16u);
220 
221    nir_ssa_def *io_off = nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u);
222    unsigned write_mask = nir_intrinsic_write_mask(intrin);
223 
224    nir_ssa_def *off = nir_iadd_nuw(b, base_off_var, io_off);
225    nir_build_store_shared(b, intrin->src[0].ssa, off, .write_mask = write_mask,
226                           .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
227 
228    /* NOTE: don't remove the store_output intrinsic on GFX9+ when tcs_in_out_eq,
229     * it will be used by same-invocation TCS input loads.
230     */
231    if (!st->tcs_in_out_eq)
232       nir_instr_remove(instr);
233 
234    return true;
235 }
236 
237 static bool
filter_load_tcs_per_vertex_input(const nir_instr * instr,UNUSED const void * state)238 filter_load_tcs_per_vertex_input(const nir_instr *instr,
239                                  UNUSED const void *state)
240 {
241    if (instr->type != nir_instr_type_intrinsic)
242       return false;
243 
244    lower_tess_io_state *st = (lower_tess_io_state *) state;
245    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
246 
247    if (intrin->intrinsic != nir_intrinsic_load_per_vertex_input)
248       return false;
249    if (!st->tcs_in_out_eq)
250       return true;
251 
252    /* tcs_in_out_eq: a same-invocation input load, without indirect offset,
253     * can use temporaries, no need to use shared memory.
254     */
255    nir_src *off_src = nir_get_io_offset_src(intrin);
256    nir_src *vertex_index_src = nir_get_io_vertex_index_src(intrin);
257    nir_instr *vertex_index_instr = vertex_index_src->ssa->parent_instr;
258 
259    bool can_use_temps = nir_src_is_const(*off_src) &&
260                         vertex_index_instr->type == nir_instr_type_intrinsic &&
261                         nir_instr_as_intrinsic(vertex_index_instr)->intrinsic == nir_intrinsic_load_invocation_id;
262 
263    return !can_use_temps;
264 }
265 
266 static nir_ssa_def *
hs_per_vertex_input_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * instr)267 hs_per_vertex_input_lds_offset(nir_builder *b,
268                                lower_tess_io_state *st,
269                                nir_intrinsic_instr *instr)
270 {
271    unsigned tcs_in_vertex_stride = st->tcs_num_reserved_inputs * 16u;
272    nir_ssa_def *tcs_in_vtxcnt = nir_build_load_patch_vertices_in(b);
273    nir_ssa_def *rel_patch_id = nir_build_load_tess_rel_patch_id_amd(b);
274 
275    nir_ssa_def *tcs_in_patch_stride = nir_imul_imm(b, tcs_in_vtxcnt, tcs_in_vertex_stride);
276    nir_ssa_def *tcs_in_current_patch_offset = nir_imul(b, rel_patch_id, tcs_in_patch_stride);
277 
278    nir_ssa_def *vertex_index = nir_get_io_vertex_index_src(instr)->ssa;
279    nir_ssa_def *vertex_index_off = nir_imul_imm(b, vertex_index, tcs_in_vertex_stride);
280 
281    nir_ssa_def *io_offset = nir_build_calc_io_offset(b, instr, nir_imm_int(b, 16u), 4u);
282 
283    return nir_iadd_nuw(b, nir_iadd_nuw(b, tcs_in_current_patch_offset, vertex_index_off), io_offset);
284 }
285 
286 static nir_ssa_def *
hs_output_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)287 hs_output_lds_offset(nir_builder *b,
288                      lower_tess_io_state *st,
289                      nir_intrinsic_instr *intrin)
290 {
291    bool per_vertex = intrin &&
292                      (intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
293                       intrin->intrinsic == nir_intrinsic_load_per_vertex_output);
294 
295    unsigned output_vertex_size = st->tcs_num_reserved_outputs * 16u;
296    unsigned pervertex_output_patch_size = b->shader->info.tess.tcs_vertices_out * output_vertex_size;
297    unsigned output_patch_stride = pervertex_output_patch_size + st->tcs_num_reserved_patch_outputs * 16u;
298 
299    nir_ssa_def *tcs_in_vtxcnt = nir_build_load_patch_vertices_in(b);
300    nir_ssa_def *tcs_num_patches = nir_build_load_tcs_num_patches_amd(b);
301    nir_ssa_def *input_patch_size = nir_imul_imm(b, tcs_in_vtxcnt, st->tcs_num_reserved_inputs * 16u);
302    nir_ssa_def *output_patch0_offset = nir_imul(b, input_patch_size, tcs_num_patches);
303 
304    nir_ssa_def *off = intrin
305                     ? nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u)
306                     : nir_imm_int(b, 0);
307 
308    nir_ssa_def *rel_patch_id = nir_build_load_tess_rel_patch_id_amd(b);
309    nir_ssa_def *patch_offset = nir_imul_imm(b, rel_patch_id, output_patch_stride);
310    nir_ssa_def *output_patch_offset = nir_iadd_nuw(b, patch_offset, output_patch0_offset);
311 
312    if (per_vertex) {
313       nir_ssa_def *vertex_index = nir_ssa_for_src(b, *nir_get_io_vertex_index_src(intrin), 1);
314       nir_ssa_def *vertex_index_off = nir_imul_imm(b, vertex_index, output_vertex_size);
315 
316       off = nir_iadd_nuw(b, off, vertex_index_off);
317       return nir_iadd_nuw(b, off, output_patch_offset);
318    } else {
319       off = nir_iadd_imm_nuw(b, off, pervertex_output_patch_size);
320       return nir_iadd_nuw(b, off, output_patch_offset);
321    }
322 }
323 
324 static nir_ssa_def *
hs_per_vertex_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)325 hs_per_vertex_output_vmem_offset(nir_builder *b,
326                                  lower_tess_io_state *st,
327                                  nir_intrinsic_instr *intrin)
328 {
329    nir_ssa_def *out_vertices_per_patch = b->shader->info.stage == MESA_SHADER_TESS_CTRL
330                                          ? nir_imm_int(b, b->shader->info.tess.tcs_vertices_out)
331                                          : nir_build_load_patch_vertices_in(b);
332 
333    nir_ssa_def *tcs_num_patches = nir_build_load_tcs_num_patches_amd(b);
334    nir_ssa_def *attr_stride = nir_imul(b, tcs_num_patches, nir_imul_imm(b, out_vertices_per_patch, 16u));
335    nir_ssa_def *io_offset = nir_build_calc_io_offset(b, intrin, attr_stride, 4u);
336 
337    nir_ssa_def *rel_patch_id = nir_build_load_tess_rel_patch_id_amd(b);
338    nir_ssa_def *patch_offset = nir_imul(b, rel_patch_id, nir_imul_imm(b, out_vertices_per_patch, 16u));
339 
340    nir_ssa_def *vertex_index = nir_ssa_for_src(b, *nir_get_io_vertex_index_src(intrin), 1);
341    nir_ssa_def *vertex_index_off = nir_imul_imm(b, vertex_index, 16u);
342 
343    return nir_iadd_nuw(b, nir_iadd_nuw(b, patch_offset, vertex_index_off), io_offset);
344 }
345 
346 static nir_ssa_def *
hs_per_patch_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin,unsigned const_base_offset)347 hs_per_patch_output_vmem_offset(nir_builder *b,
348                                 lower_tess_io_state *st,
349                                 nir_intrinsic_instr *intrin,
350                                 unsigned const_base_offset)
351 {
352    nir_ssa_def *out_vertices_per_patch = b->shader->info.stage == MESA_SHADER_TESS_CTRL
353                                          ? nir_imm_int(b, b->shader->info.tess.tcs_vertices_out)
354                                          : nir_build_load_patch_vertices_in(b);
355 
356    nir_ssa_def *tcs_num_patches = nir_build_load_tcs_num_patches_amd(b);
357    nir_ssa_def *per_vertex_output_patch_size = nir_imul_imm(b, out_vertices_per_patch, st->tcs_num_reserved_outputs * 16u);
358    nir_ssa_def *per_patch_data_offset = nir_imul(b, tcs_num_patches, per_vertex_output_patch_size);
359 
360    nir_ssa_def * off = intrin
361                     ? nir_build_calc_io_offset(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u)
362                     : nir_imm_int(b, 0);
363 
364    if (const_base_offset)
365       off = nir_iadd_nuw(b, off, nir_imul_imm(b, tcs_num_patches, const_base_offset));
366 
367    nir_ssa_def *rel_patch_id = nir_build_load_tess_rel_patch_id_amd(b);
368    nir_ssa_def *patch_offset = nir_imul_imm(b, rel_patch_id, 16u);
369    off = nir_iadd_nuw(b, off, per_patch_data_offset);
370    return nir_iadd_nuw(b, off, patch_offset);
371 }
372 
373 static nir_ssa_def *
lower_hs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)374 lower_hs_per_vertex_input_load(nir_builder *b,
375                                nir_instr *instr,
376                                void *state)
377 {
378    lower_tess_io_state *st = (lower_tess_io_state *) state;
379    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
380 
381    nir_ssa_def *off = hs_per_vertex_input_lds_offset(b, st, intrin);
382    return nir_build_load_shared(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, off,
383                                 .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
384 }
385 
386 static void
lower_hs_output_store(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)387 lower_hs_output_store(nir_builder *b,
388                       nir_intrinsic_instr *intrin,
389                       lower_tess_io_state *st)
390 {
391    assert(intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
392           intrin->intrinsic == nir_intrinsic_store_output);
393 
394    nir_io_semantics semantics = nir_intrinsic_io_semantics(intrin);
395    nir_ssa_def *store_val = intrin->src[0].ssa;
396    unsigned write_mask = nir_intrinsic_write_mask(intrin);
397    bool is_tess_factor = semantics.location == VARYING_SLOT_TESS_LEVEL_INNER ||
398                          semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER;
399    bool write_to_vmem = !is_tess_factor && tcs_output_needs_vmem(intrin, st);
400    bool write_to_lds = is_tess_factor || tcs_output_needs_lds(intrin, b->shader);
401 
402    if (write_to_vmem) {
403       nir_ssa_def *vmem_off = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
404                             ? hs_per_vertex_output_vmem_offset(b, st, intrin)
405                             : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
406 
407       nir_ssa_def *hs_ring_tess_offchip = nir_build_load_ring_tess_offchip_amd(b);
408       nir_ssa_def *offchip_offset = nir_build_load_ring_tess_offchip_offset_amd(b);
409       nir_build_store_buffer_amd(b, store_val, hs_ring_tess_offchip, vmem_off, offchip_offset, .write_mask = write_mask, .memory_modes = nir_var_shader_out);
410    }
411 
412    if (write_to_lds) {
413       /* Remember driver location of tess factors, so we can read them later */
414       if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER)
415          st->tcs_tess_lvl_in_loc = nir_intrinsic_base(intrin) * 16u;
416       else if (semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER)
417          st->tcs_tess_lvl_out_loc = nir_intrinsic_base(intrin) * 16u;
418 
419       nir_ssa_def *lds_off = hs_output_lds_offset(b, st, intrin);
420       nir_build_store_shared(b, store_val, lds_off, .write_mask = write_mask,
421                              .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
422    }
423 }
424 
425 static nir_ssa_def *
lower_hs_output_load(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)426 lower_hs_output_load(nir_builder *b,
427                      nir_intrinsic_instr *intrin,
428                      lower_tess_io_state *st)
429 {
430    nir_ssa_def *off = hs_output_lds_offset(b, st, intrin);
431    return nir_build_load_shared(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, off,
432                                 .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
433 }
434 
435 static void
update_hs_scoped_barrier(nir_intrinsic_instr * intrin)436 update_hs_scoped_barrier(nir_intrinsic_instr *intrin)
437 {
438    /* Output loads and stores are lowered to shared memory access,
439     * so we have to update the barriers to also reflect this.
440     */
441    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
442    if (mem_modes & nir_var_shader_out)
443       mem_modes |= nir_var_mem_shared;
444    nir_intrinsic_set_memory_modes(intrin, mem_modes);
445 }
446 
447 static nir_ssa_def *
lower_hs_output_access(nir_builder * b,nir_instr * instr,void * state)448 lower_hs_output_access(nir_builder *b,
449                        nir_instr *instr,
450                        void *state)
451 {
452    lower_tess_io_state *st = (lower_tess_io_state *) state;
453    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
454 
455    if (intrin->intrinsic == nir_intrinsic_store_output ||
456        intrin->intrinsic == nir_intrinsic_store_per_vertex_output) {
457       lower_hs_output_store(b, intrin, st);
458       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
459    } else if (intrin->intrinsic == nir_intrinsic_load_output ||
460               intrin->intrinsic == nir_intrinsic_load_per_vertex_output) {
461       return lower_hs_output_load(b, intrin, st);
462    } else if (intrin->intrinsic == nir_intrinsic_scoped_barrier) {
463       update_hs_scoped_barrier(intrin);
464       return NIR_LOWER_INSTR_PROGRESS;
465    } else {
466       unreachable("intrinsic not supported by lower_hs_output_access");
467    }
468 }
469 
470 static void
hs_emit_write_tess_factors(nir_shader * shader,lower_tess_io_state * st)471 hs_emit_write_tess_factors(nir_shader *shader,
472                            lower_tess_io_state *st)
473 {
474    unsigned outer_comps;
475    unsigned inner_comps;
476 
477    switch (shader->info.tess.primitive_mode) {
478    case GL_ISOLINES:
479       outer_comps = 2;
480       inner_comps = 0;
481       break;
482    case GL_TRIANGLES:
483       outer_comps = 3;
484       inner_comps = 1;
485       break;
486    case GL_QUADS:
487       outer_comps = 4;
488       inner_comps = 2;
489       break;
490    default:
491       unreachable("invalid primitive mode");
492       return;
493    }
494 
495    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
496    assert(impl);
497    nir_block *last_block = nir_impl_last_block(impl);
498    assert(last_block);
499 
500    /* We assume there is always a single end block in the shader. */
501 
502    nir_builder builder;
503    nir_builder *b = &builder; /* This is to avoid the & */
504    nir_builder_init(b, impl);
505    b->cursor = nir_after_block(last_block);
506 
507    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
508                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
509 
510    nir_ssa_def *invocation_id = nir_build_load_invocation_id(b);
511 
512    /* Only the 1st invocation of each patch needs to do this. */
513    nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0));
514 
515    /* The descriptor where tess factors have to be stored by the shader. */
516    nir_ssa_def *tessfactor_ring = nir_build_load_ring_tess_factors_amd(b);
517 
518    /* Base LDS address of per-patch outputs in the current patch. */
519    nir_ssa_def *lds_base = hs_output_lds_offset(b, st, NULL);
520 
521    /* Load all tessellation factors (aka. tess levels) from LDS. */
522    nir_ssa_def *tessfactors_outer = nir_build_load_shared(b, outer_comps, 32, lds_base, .base = st->tcs_tess_lvl_out_loc,
523                                                           .align_mul = 16u, .align_offset = st->tcs_tess_lvl_out_loc % 16u);
524    nir_ssa_def *tessfactors_inner = inner_comps
525                                     ? nir_build_load_shared(b, inner_comps, 32, lds_base, .base = st->tcs_tess_lvl_in_loc,
526                                                             .align_mul = 16u, .align_offset = st->tcs_tess_lvl_in_loc % 16u)
527                                     : NULL;
528 
529    nir_ssa_def *rel_patch_id = nir_build_load_tess_rel_patch_id_amd(b);
530    nir_ssa_def *tess_factors_base = nir_build_load_ring_tess_factors_offset_amd(b);
531    nir_ssa_def *tess_factors_offset = nir_imul_imm(b, rel_patch_id, (inner_comps + outer_comps) * 4u);
532    unsigned tess_factors_const_offset = 0;
533 
534    if (st->chip_class <= GFX8) {
535       /* Store the dynamic HS control word. */
536       nir_if *rel_patch_id_zero = nir_push_if(b, nir_ieq_imm(b, rel_patch_id, 0));
537       nir_ssa_def *ctrlw = nir_imm_int(b, 0x80000000u);
538       nir_build_store_buffer_amd(b, ctrlw, tessfactor_ring, nir_imm_zero(b, 1, 32), tess_factors_base, .write_mask = 0x1u);
539       tess_factors_const_offset += 4;
540       nir_pop_if(b, rel_patch_id_zero);
541    }
542 
543    /* Store tess factors for the tessellator */
544    if (shader->info.tess.primitive_mode == GL_ISOLINES) {
545       /* LINES reversal */
546       nir_ssa_def *t = nir_vec2(b, nir_channel(b, tessfactors_outer, 1), nir_channel(b, tessfactors_outer, 0));
547       nir_build_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset, .write_mask = 0xfu);
548    } else if (shader->info.tess.primitive_mode == GL_TRIANGLES) {
549       nir_ssa_def *t = nir_vec4(b, nir_channel(b, tessfactors_outer, 0), nir_channel(b, tessfactors_outer, 1),
550                                 nir_channel(b, tessfactors_outer, 2), nir_channel(b, tessfactors_inner, 0));
551       nir_build_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset, .write_mask = 0xfu);
552    } else {
553       nir_build_store_buffer_amd(b, tessfactors_outer, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset, .write_mask = 0xfu);
554       nir_build_store_buffer_amd(b, tessfactors_inner, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset + 4u * outer_comps, .write_mask = 0xfu);
555    }
556 
557    if (st->tes_reads_tessfactors) {
558       /* Store to offchip for TES to read - only if TES actually reads them */
559       nir_ssa_def *hs_ring_tess_offchip = nir_build_load_ring_tess_offchip_amd(b);
560       nir_ssa_def *offchip_offset = nir_build_load_ring_tess_offchip_offset_amd(b);
561 
562       nir_ssa_def *vmem_off_outer = hs_per_patch_output_vmem_offset(b, st, NULL, st->tcs_tess_lvl_out_loc);
563       nir_build_store_buffer_amd(b, tessfactors_outer, hs_ring_tess_offchip, vmem_off_outer, offchip_offset, .write_mask = 0xfu, .memory_modes = nir_var_shader_out);
564 
565       if (inner_comps) {
566          nir_ssa_def *vmem_off_inner = hs_per_patch_output_vmem_offset(b, st, NULL, st->tcs_tess_lvl_in_loc);
567          nir_build_store_buffer_amd(b, tessfactors_inner, hs_ring_tess_offchip, vmem_off_inner, offchip_offset, .write_mask = 0xfu, .memory_modes = nir_var_shader_out);
568       }
569    }
570 
571    nir_pop_if(b, invocation_id_zero);
572 
573    nir_metadata_preserve(impl, nir_metadata_none);
574 }
575 
576 static nir_ssa_def *
lower_tes_input_load(nir_builder * b,nir_instr * instr,void * state)577 lower_tes_input_load(nir_builder *b,
578                      nir_instr *instr,
579                      void *state)
580 {
581    lower_tess_io_state *st = (lower_tess_io_state *) state;
582    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
583 
584    nir_ssa_def *offchip_ring = nir_build_load_ring_tess_offchip_amd(b);
585    nir_ssa_def *offchip_offset = nir_build_load_ring_tess_offchip_offset_amd(b);
586    nir_ssa_def *off = intrin->intrinsic == nir_intrinsic_load_per_vertex_input
587                     ? hs_per_vertex_output_vmem_offset(b, st, intrin)
588                     : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
589 
590    return nir_build_load_buffer_amd(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, offchip_ring, off, offchip_offset);
591 }
592 
593 static bool
filter_hs_output_access(const nir_instr * instr,UNUSED const void * st)594 filter_hs_output_access(const nir_instr *instr,
595                          UNUSED const void *st)
596 {
597    if (instr->type != nir_instr_type_intrinsic)
598       return false;
599 
600    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
601    return intrin->intrinsic == nir_intrinsic_store_output ||
602           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
603           intrin->intrinsic == nir_intrinsic_load_output ||
604           intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
605           intrin->intrinsic == nir_intrinsic_scoped_barrier;
606 }
607 
608 static bool
filter_any_input_access(const nir_instr * instr,UNUSED const void * st)609 filter_any_input_access(const nir_instr *instr,
610                         UNUSED const void *st)
611 {
612    if (instr->type != nir_instr_type_intrinsic)
613       return false;
614 
615    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
616    return intrin->intrinsic == nir_intrinsic_load_input ||
617           intrin->intrinsic == nir_intrinsic_load_per_vertex_input;
618 }
619 
620 void
ac_nir_lower_ls_outputs_to_mem(nir_shader * shader,bool tcs_in_out_eq,uint64_t tcs_temp_only_inputs,unsigned num_reserved_ls_outputs)621 ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
622                                bool tcs_in_out_eq,
623                                uint64_t tcs_temp_only_inputs,
624                                unsigned num_reserved_ls_outputs)
625 {
626    assert(shader->info.stage == MESA_SHADER_VERTEX);
627 
628    lower_tess_io_state state = {
629       .tcs_num_reserved_inputs = num_reserved_ls_outputs,
630       .tcs_in_out_eq = tcs_in_out_eq,
631       .tcs_temp_only_inputs = tcs_in_out_eq ? tcs_temp_only_inputs : 0,
632    };
633 
634    nir_shader_instructions_pass(shader,
635                                 lower_ls_output_store,
636                                 nir_metadata_block_index | nir_metadata_dominance,
637                                 &state);
638 }
639 
640 void
ac_nir_lower_hs_inputs_to_mem(nir_shader * shader,bool tcs_in_out_eq,unsigned num_reserved_tcs_inputs)641 ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
642                               bool tcs_in_out_eq,
643                               unsigned num_reserved_tcs_inputs)
644 {
645    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
646 
647    lower_tess_io_state state = {
648       .tcs_in_out_eq = tcs_in_out_eq,
649       .tcs_num_reserved_inputs = num_reserved_tcs_inputs,
650    };
651 
652    nir_shader_lower_instructions(shader,
653                                  filter_load_tcs_per_vertex_input,
654                                  lower_hs_per_vertex_input_load,
655                                  &state);
656 }
657 
658 void
ac_nir_lower_hs_outputs_to_mem(nir_shader * shader,enum chip_class chip_class,bool tes_reads_tessfactors,uint64_t tes_inputs_read,uint64_t tes_patch_inputs_read,unsigned num_reserved_tcs_inputs,unsigned num_reserved_tcs_outputs,unsigned num_reserved_tcs_patch_outputs,bool emit_tess_factor_write)659 ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
660                                enum chip_class chip_class,
661                                bool tes_reads_tessfactors,
662                                uint64_t tes_inputs_read,
663                                uint64_t tes_patch_inputs_read,
664                                unsigned num_reserved_tcs_inputs,
665                                unsigned num_reserved_tcs_outputs,
666                                unsigned num_reserved_tcs_patch_outputs,
667                                bool emit_tess_factor_write)
668 {
669    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
670 
671    lower_tess_io_state state = {
672       .chip_class = chip_class,
673       .tes_reads_tessfactors = tes_reads_tessfactors,
674       .tes_inputs_read = tes_inputs_read,
675       .tes_patch_inputs_read = tes_patch_inputs_read,
676       .tcs_num_reserved_inputs = num_reserved_tcs_inputs,
677       .tcs_num_reserved_outputs = num_reserved_tcs_outputs,
678       .tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
679    };
680 
681    nir_shader_lower_instructions(shader,
682                                  filter_hs_output_access,
683                                  lower_hs_output_access,
684                                  &state);
685 
686    if (emit_tess_factor_write)
687       hs_emit_write_tess_factors(shader, &state);
688 }
689 
690 void
ac_nir_lower_tes_inputs_to_mem(nir_shader * shader,unsigned num_reserved_tcs_outputs,unsigned num_reserved_tcs_patch_outputs)691 ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
692                                unsigned num_reserved_tcs_outputs,
693                                unsigned num_reserved_tcs_patch_outputs)
694 {
695    assert(shader->info.stage == MESA_SHADER_TESS_EVAL);
696 
697    lower_tess_io_state state = {
698       .tcs_num_reserved_outputs = num_reserved_tcs_outputs,
699       .tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
700    };
701 
702    nir_shader_lower_instructions(shader,
703                                  filter_any_input_access,
704                                  lower_tes_input_load,
705                                  &state);
706 }
707 
708 typedef struct
709 {
710    unsigned patch_vtx_in;
711    unsigned tcs_num_patches;
712    unsigned options;
713 } lower_tess_to_const_state;
714 
715 static bool
filter_const_lowerable_tess_intrinsics(const nir_instr * instr,const void * state)716 filter_const_lowerable_tess_intrinsics(const nir_instr *instr,
717                                        const void *state)
718 {
719    if (instr->type != nir_instr_type_intrinsic)
720       return false;
721 
722    lower_tess_to_const_state *st = (lower_tess_to_const_state *) state;
723    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
724    return ((st->options & ac_nir_lower_patch_vtx_in) && intrin->intrinsic == nir_intrinsic_load_patch_vertices_in) ||
725           ((st->options & ac_nir_lower_num_patches) && intrin->intrinsic == nir_intrinsic_load_tcs_num_patches_amd);
726 }
727 
728 static nir_ssa_def *
lower_tess_intrinsics_to_const(nir_builder * b,nir_instr * instr,void * state)729 lower_tess_intrinsics_to_const(nir_builder *b,
730                                nir_instr *instr,
731                                void *state)
732 {
733    lower_tess_to_const_state *st = (lower_tess_to_const_state *) state;
734    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
735 
736    switch (intrin->intrinsic) {
737    case nir_intrinsic_load_patch_vertices_in:
738       return nir_imm_int(b, st->patch_vtx_in);
739    case nir_intrinsic_load_tcs_num_patches_amd:
740       return nir_imm_int(b, st->tcs_num_patches);
741    default:
742       unreachable("Unsupported tess intrinsic.");
743    }
744 }
745 
746 void
ac_nir_lower_tess_to_const(nir_shader * shader,unsigned patch_vtx_in,unsigned tcs_num_patches,unsigned options)747 ac_nir_lower_tess_to_const(nir_shader *shader,
748                            unsigned patch_vtx_in,
749                            unsigned tcs_num_patches,
750                            unsigned options)
751 {
752    lower_tess_to_const_state st = {
753       .patch_vtx_in = patch_vtx_in,
754       .tcs_num_patches = tcs_num_patches,
755       .options = options,
756    };
757 
758    nir_shader_lower_instructions(shader,
759                                  filter_const_lowerable_tess_intrinsics,
760                                  lower_tess_intrinsics_to_const,
761                                  &st);
762 }
763