1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 
28 /*
29  * Lower NIR cross-stage I/O intrinsics into the memory accesses that actually happen on the HW.
30  *
31  * These HW stages are used only when a Geometry Shader is used.
32  * Export Shader (ES) runs the SW stage before GS, can be either VS or TES.
33  *
34  * * GFX6-8:
35  *   ES and GS are separate HW stages.
36  *   I/O is passed between them through VRAM.
37  * * GFX9+:
38  *   ES and GS are merged into a single HW stage.
39  *   I/O is passed between them through LDS.
40  *
41  */
42 
43 typedef struct {
44    /* Which hardware generation we're dealing with */
45    enum chip_class chip_class;
46 
47    /* Number of ES outputs for which memory should be reserved.
48     * When compacted, this should be the number of linked ES outputs.
49     */
50    unsigned num_reserved_es_outputs;
51 } lower_esgs_io_state;
52 
53 static nir_ssa_def *
emit_split_buffer_load(nir_builder * b,nir_ssa_def * desc,nir_ssa_def * v_off,nir_ssa_def * s_off,unsigned component_stride,unsigned num_components,unsigned bit_size)54 emit_split_buffer_load(nir_builder *b, nir_ssa_def *desc, nir_ssa_def *v_off, nir_ssa_def *s_off,
55                        unsigned component_stride, unsigned num_components, unsigned bit_size)
56 {
57    unsigned total_bytes = num_components * bit_size / 8u;
58    unsigned full_dwords = total_bytes / 4u;
59    unsigned remaining_bytes = total_bytes - full_dwords * 4u;
60 
61    /* Accomodate max number of split 64-bit loads */
62    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS * 2u];
63 
64    /* Assume that 1x32-bit load is better than 1x16-bit + 1x8-bit */
65    if (remaining_bytes == 3) {
66       remaining_bytes = 0;
67       full_dwords++;
68    }
69 
70    for (unsigned i = 0; i < full_dwords; ++i)
71       comps[i] = nir_build_load_buffer_amd(b, 1, 32, desc, v_off, s_off,
72                                            .base = component_stride * i, .memory_modes = nir_var_shader_in);
73 
74    if (remaining_bytes)
75       comps[full_dwords] = nir_build_load_buffer_amd(b, 1, remaining_bytes * 8, desc, v_off, s_off,
76                                                      .base = component_stride * full_dwords, .memory_modes = nir_var_shader_in);
77 
78    return nir_extract_bits(b, comps, full_dwords + !!remaining_bytes, 0, num_components, bit_size);
79 }
80 
81 static void
emit_split_buffer_store(nir_builder * b,nir_ssa_def * d,nir_ssa_def * desc,nir_ssa_def * v_off,nir_ssa_def * s_off,unsigned component_stride,unsigned num_components,unsigned bit_size,unsigned writemask,bool swizzled,bool slc)82 emit_split_buffer_store(nir_builder *b, nir_ssa_def *d, nir_ssa_def *desc, nir_ssa_def *v_off, nir_ssa_def *s_off,
83                         unsigned component_stride, unsigned num_components, unsigned bit_size,
84                         unsigned writemask, bool swizzled, bool slc)
85 {
86    while (writemask) {
87       int start, count;
88       u_bit_scan_consecutive_range(&writemask, &start, &count);
89       assert(start >= 0 && count >= 0);
90 
91       unsigned bytes = count * bit_size / 8u;
92       unsigned start_byte = start * bit_size / 8u;
93 
94       while (bytes) {
95          unsigned store_bytes = MIN2(bytes, 4u);
96          if ((start_byte % 4) == 1 || (start_byte % 4) == 3)
97             store_bytes = MIN2(store_bytes, 1);
98          else if ((start_byte % 4) == 2)
99             store_bytes = MIN2(store_bytes, 2);
100 
101          nir_ssa_def *store_val = nir_extract_bits(b, &d, 1, start_byte * 8u, 1, store_bytes * 8u);
102          nir_build_store_buffer_amd(b, store_val, desc, v_off, s_off, .is_swizzled = swizzled, .slc_amd = slc,
103                                     .base = start_byte, .write_mask = 1u, .memory_modes = nir_var_shader_out);
104 
105          start_byte += store_bytes;
106          bytes -= store_bytes;
107       }
108    }
109 }
110 
111 static bool
lower_es_output_store(nir_builder * b,nir_instr * instr,void * state)112 lower_es_output_store(nir_builder *b,
113                       nir_instr *instr,
114                       void *state)
115 {
116    if (instr->type != nir_instr_type_intrinsic)
117       return false;
118 
119    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
120 
121    if (intrin->intrinsic != nir_intrinsic_store_output)
122       return false;
123 
124    lower_esgs_io_state *st = (lower_esgs_io_state *) state;
125    unsigned write_mask = nir_intrinsic_write_mask(intrin);
126 
127    b->cursor = nir_before_instr(instr);
128    nir_ssa_def *io_off = nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u);
129 
130    if (st->chip_class <= GFX8) {
131       /* GFX6-8: ES is a separate HW stage, data is passed from ES to GS in VRAM. */
132       nir_ssa_def *ring = nir_build_load_ring_esgs_amd(b);
133       nir_ssa_def *es2gs_off = nir_build_load_ring_es2gs_offset_amd(b);
134       emit_split_buffer_store(b, intrin->src[0].ssa, ring, io_off, es2gs_off, 4u,
135                               intrin->src[0].ssa->num_components, intrin->src[0].ssa->bit_size,
136                               write_mask, true, true);
137    } else {
138       /* GFX9+: ES is merged into GS, data is passed through LDS. */
139       unsigned esgs_itemsize = st->num_reserved_es_outputs * 16u;
140       nir_ssa_def *vertex_idx = nir_build_load_local_invocation_index(b);
141       nir_ssa_def *off = nir_iadd(b, nir_imul_imm(b, vertex_idx, esgs_itemsize), io_off);
142       nir_build_store_shared(b, intrin->src[0].ssa, off, .write_mask = write_mask,
143                              .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
144    }
145 
146    nir_instr_remove(instr);
147    return true;
148 }
149 
150 static nir_ssa_def *
gs_per_vertex_input_vertex_offset_gfx6(nir_builder * b,nir_src * vertex_src)151 gs_per_vertex_input_vertex_offset_gfx6(nir_builder *b, nir_src *vertex_src)
152 {
153    if (nir_src_is_const(*vertex_src))
154       return nir_build_load_gs_vertex_offset_amd(b, .base = nir_src_as_uint(*vertex_src));
155 
156    nir_ssa_def *vertex_offset = nir_build_load_gs_vertex_offset_amd(b, .base = 0);
157 
158    for (unsigned i = 1; i < b->shader->info.gs.vertices_in; ++i) {
159       nir_ssa_def *cond = nir_ieq_imm(b, vertex_src->ssa, i);
160       nir_ssa_def *elem = nir_build_load_gs_vertex_offset_amd(b, .base = i);
161       vertex_offset = nir_bcsel(b, cond, elem, vertex_offset);
162    }
163 
164    return vertex_offset;
165 }
166 
167 static nir_ssa_def *
gs_per_vertex_input_vertex_offset_gfx9(nir_builder * b,nir_src * vertex_src)168 gs_per_vertex_input_vertex_offset_gfx9(nir_builder *b, nir_src *vertex_src)
169 {
170    if (nir_src_is_const(*vertex_src)) {
171       unsigned vertex = nir_src_as_uint(*vertex_src);
172       return nir_ubfe(b, nir_build_load_gs_vertex_offset_amd(b, .base = vertex / 2u),
173                       nir_imm_int(b, (vertex & 1u) * 16u), nir_imm_int(b, 16u));
174    }
175 
176    nir_ssa_def *vertex_offset = nir_build_load_gs_vertex_offset_amd(b, .base = 0);
177 
178    for (unsigned i = 1; i < b->shader->info.gs.vertices_in; i++) {
179       nir_ssa_def *cond = nir_ieq_imm(b, vertex_src->ssa, i);
180       nir_ssa_def *elem = nir_build_load_gs_vertex_offset_amd(b, .base = i / 2u * 2u);
181       if (i % 2u)
182          elem = nir_ishr_imm(b, elem, 16u);
183 
184       vertex_offset = nir_bcsel(b, cond, elem, vertex_offset);
185    }
186 
187    return nir_iand_imm(b, vertex_offset, 0xffffu);
188 }
189 
190 static nir_ssa_def *
gs_per_vertex_input_offset(nir_builder * b,lower_esgs_io_state * st,nir_intrinsic_instr * instr)191 gs_per_vertex_input_offset(nir_builder *b,
192                            lower_esgs_io_state *st,
193                            nir_intrinsic_instr *instr)
194 {
195    nir_src *vertex_src = nir_get_io_vertex_index_src(instr);
196    nir_ssa_def *vertex_offset = st->chip_class >= GFX9
197                                 ? gs_per_vertex_input_vertex_offset_gfx9(b, vertex_src)
198                                 : gs_per_vertex_input_vertex_offset_gfx6(b, vertex_src);
199 
200    unsigned base_stride = st->chip_class >= GFX9 ? 1 : 64 /* Wave size on GFX6-8 */;
201    nir_ssa_def *io_off = nir_build_calc_io_offset(b, instr, nir_imm_int(b, base_stride * 4u), base_stride);
202    nir_ssa_def *off = nir_iadd(b, io_off, vertex_offset);
203    return nir_imul_imm(b, off, 4u);
204 }
205 
206 static nir_ssa_def *
lower_gs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)207 lower_gs_per_vertex_input_load(nir_builder *b,
208                                nir_instr *instr,
209                                void *state)
210 {
211    lower_esgs_io_state *st = (lower_esgs_io_state *) state;
212    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
213    nir_ssa_def *off = gs_per_vertex_input_offset(b, st, intrin);
214 
215    if (st->chip_class >= GFX9)
216       return nir_build_load_shared(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, off,
217                                    .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
218 
219    unsigned wave_size = 64u; /* GFX6-8 only support wave64 */
220    nir_ssa_def *ring = nir_build_load_ring_esgs_amd(b);
221    return emit_split_buffer_load(b, ring, off, nir_imm_zero(b, 1, 32), 4u * wave_size,
222                                  intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size);
223 }
224 
225 static bool
filter_load_per_vertex_input(const nir_instr * instr,UNUSED const void * state)226 filter_load_per_vertex_input(const nir_instr *instr, UNUSED const void *state)
227 {
228    return instr->type == nir_instr_type_intrinsic && nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_per_vertex_input;
229 }
230 
231 void
ac_nir_lower_es_outputs_to_mem(nir_shader * shader,enum chip_class chip_class,unsigned num_reserved_es_outputs)232 ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
233                                enum chip_class chip_class,
234                                unsigned num_reserved_es_outputs)
235 {
236    lower_esgs_io_state state = {
237       .chip_class = chip_class,
238       .num_reserved_es_outputs = num_reserved_es_outputs,
239    };
240 
241    nir_shader_instructions_pass(shader,
242                                 lower_es_output_store,
243                                 nir_metadata_block_index | nir_metadata_dominance,
244                                 &state);
245 }
246 
247 void
ac_nir_lower_gs_inputs_to_mem(nir_shader * shader,enum chip_class chip_class,unsigned num_reserved_es_outputs)248 ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
249                               enum chip_class chip_class,
250                               unsigned num_reserved_es_outputs)
251 {
252    lower_esgs_io_state state = {
253       .chip_class = chip_class,
254       .num_reserved_es_outputs = num_reserved_es_outputs,
255    };
256 
257    nir_shader_lower_instructions(shader,
258                                  filter_load_per_vertex_input,
259                                  lower_gs_per_vertex_input_load,
260                                  &state);
261 }
262