1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "dxil_nir.h"
25 
26 #include "nir_builder.h"
27 #include "nir_deref.h"
28 #include "nir_to_dxil.h"
29 #include "util/u_math.h"
30 
31 static void
cl_type_size_align(const struct glsl_type * type,unsigned * size,unsigned * align)32 cl_type_size_align(const struct glsl_type *type, unsigned *size,
33                    unsigned *align)
34 {
35    *size = glsl_get_cl_size(type);
36    *align = glsl_get_cl_alignment(type);
37 }
38 
39 static void
extract_comps_from_vec32(nir_builder * b,nir_ssa_def * vec32,unsigned dst_bit_size,nir_ssa_def ** dst_comps,unsigned num_dst_comps)40 extract_comps_from_vec32(nir_builder *b, nir_ssa_def *vec32,
41                          unsigned dst_bit_size,
42                          nir_ssa_def **dst_comps,
43                          unsigned num_dst_comps)
44 {
45    unsigned step = DIV_ROUND_UP(dst_bit_size, 32);
46    unsigned comps_per32b = 32 / dst_bit_size;
47    nir_ssa_def *tmp;
48 
49    for (unsigned i = 0; i < vec32->num_components; i += step) {
50       switch (dst_bit_size) {
51       case 64:
52          tmp = nir_pack_64_2x32_split(b, nir_channel(b, vec32, i),
53                                          nir_channel(b, vec32, i + 1));
54          dst_comps[i / 2] = tmp;
55          break;
56       case 32:
57          dst_comps[i] = nir_channel(b, vec32, i);
58          break;
59       case 16:
60       case 8: {
61          unsigned dst_offs = i * comps_per32b;
62 
63          tmp = nir_unpack_bits(b, nir_channel(b, vec32, i), dst_bit_size);
64          for (unsigned j = 0; j < comps_per32b && dst_offs + j < num_dst_comps; j++)
65             dst_comps[dst_offs + j] = nir_channel(b, tmp, j);
66          }
67 
68          break;
69       }
70    }
71 }
72 
73 static nir_ssa_def *
load_comps_to_vec32(nir_builder * b,unsigned src_bit_size,nir_ssa_def ** src_comps,unsigned num_src_comps)74 load_comps_to_vec32(nir_builder *b, unsigned src_bit_size,
75                     nir_ssa_def **src_comps, unsigned num_src_comps)
76 {
77    unsigned num_vec32comps = DIV_ROUND_UP(num_src_comps * src_bit_size, 32);
78    unsigned step = DIV_ROUND_UP(src_bit_size, 32);
79    unsigned comps_per32b = 32 / src_bit_size;
80    nir_ssa_def *vec32comps[4];
81 
82    for (unsigned i = 0; i < num_vec32comps; i += step) {
83       switch (src_bit_size) {
84       case 64:
85          vec32comps[i] = nir_unpack_64_2x32_split_x(b, src_comps[i / 2]);
86          vec32comps[i + 1] = nir_unpack_64_2x32_split_y(b, src_comps[i / 2]);
87          break;
88       case 32:
89          vec32comps[i] = src_comps[i];
90          break;
91       case 16:
92       case 8: {
93          unsigned src_offs = i * comps_per32b;
94 
95          vec32comps[i] = nir_u2u32(b, src_comps[src_offs]);
96          for (unsigned j = 1; j < comps_per32b && src_offs + j < num_src_comps; j++) {
97             nir_ssa_def *tmp = nir_ishl(b, nir_u2u32(b, src_comps[src_offs + j]),
98                                            nir_imm_int(b, j * src_bit_size));
99             vec32comps[i] = nir_ior(b, vec32comps[i], tmp);
100          }
101          break;
102       }
103       }
104    }
105 
106    return nir_vec(b, vec32comps, num_vec32comps);
107 }
108 
109 static nir_ssa_def *
build_load_ptr_dxil(nir_builder * b,nir_deref_instr * deref,nir_ssa_def * idx)110 build_load_ptr_dxil(nir_builder *b, nir_deref_instr *deref, nir_ssa_def *idx)
111 {
112    return nir_load_ptr_dxil(b, 1, 32, &deref->dest.ssa, idx);
113 }
114 
115 static bool
lower_load_deref(nir_builder * b,nir_intrinsic_instr * intr)116 lower_load_deref(nir_builder *b, nir_intrinsic_instr *intr)
117 {
118    assert(intr->dest.is_ssa);
119 
120    b->cursor = nir_before_instr(&intr->instr);
121 
122    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
123    if (!nir_deref_mode_is(deref, nir_var_shader_temp))
124       return false;
125    nir_ssa_def *ptr = nir_u2u32(b, nir_build_deref_offset(b, deref, cl_type_size_align));
126    nir_ssa_def *offset = nir_iand(b, ptr, nir_inot(b, nir_imm_int(b, 3)));
127 
128    assert(intr->dest.is_ssa);
129    unsigned num_components = nir_dest_num_components(intr->dest);
130    unsigned bit_size = nir_dest_bit_size(intr->dest);
131    unsigned load_size = MAX2(32, bit_size);
132    unsigned num_bits = num_components * bit_size;
133    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
134    unsigned comp_idx = 0;
135 
136    nir_deref_path path;
137    nir_deref_path_init(&path, deref, NULL);
138    nir_ssa_def *base_idx = nir_ishr(b, offset, nir_imm_int(b, 2 /* log2(32 / 8) */));
139 
140    /* Split loads into 32-bit chunks */
141    for (unsigned i = 0; i < num_bits; i += load_size) {
142       unsigned subload_num_bits = MIN2(num_bits - i, load_size);
143       nir_ssa_def *idx = nir_iadd(b, base_idx, nir_imm_int(b, i / 32));
144       nir_ssa_def *vec32 = build_load_ptr_dxil(b, path.path[0], idx);
145 
146       if (load_size == 64) {
147          idx = nir_iadd(b, idx, nir_imm_int(b, 1));
148          vec32 = nir_vec2(b, vec32,
149                              build_load_ptr_dxil(b, path.path[0], idx));
150       }
151 
152       /* If we have 2 bytes or less to load we need to adjust the u32 value so
153        * we can always extract the LSB.
154        */
155       if (subload_num_bits <= 16) {
156          nir_ssa_def *shift = nir_imul(b, nir_iand(b, ptr, nir_imm_int(b, 3)),
157                                           nir_imm_int(b, 8));
158          vec32 = nir_ushr(b, vec32, shift);
159       }
160 
161       /* And now comes the pack/unpack step to match the original type. */
162       extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
163                                subload_num_bits / bit_size);
164       comp_idx += subload_num_bits / bit_size;
165    }
166 
167    nir_deref_path_finish(&path);
168    assert(comp_idx == num_components);
169    nir_ssa_def *result = nir_vec(b, comps, num_components);
170    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
171    nir_instr_remove(&intr->instr);
172    return true;
173 }
174 
175 static nir_ssa_def *
ubo_load_select_32b_comps(nir_builder * b,nir_ssa_def * vec32,nir_ssa_def * offset,unsigned num_bytes)176 ubo_load_select_32b_comps(nir_builder *b, nir_ssa_def *vec32,
177                           nir_ssa_def *offset, unsigned num_bytes)
178 {
179    assert(num_bytes == 16 || num_bytes == 12 || num_bytes == 8 ||
180           num_bytes == 4 || num_bytes == 3 || num_bytes == 2 ||
181           num_bytes == 1);
182    assert(vec32->num_components == 4);
183 
184    /* 16 and 12 byte types are always aligned on 16 bytes. */
185    if (num_bytes > 8)
186       return vec32;
187 
188    nir_ssa_def *comps[4];
189    nir_ssa_def *cond;
190 
191    for (unsigned i = 0; i < 4; i++)
192       comps[i] = nir_channel(b, vec32, i);
193 
194    /* If we have 8bytes or less to load, select which half the vec4 should
195     * be used.
196     */
197    cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x8)),
198                                  nir_imm_int(b, 0));
199 
200    comps[0] = nir_bcsel(b, cond, comps[2], comps[0]);
201    comps[1] = nir_bcsel(b, cond, comps[3], comps[1]);
202 
203    /* Thanks to the CL alignment constraints, if we want 8 bytes we're done. */
204    if (num_bytes == 8)
205       return nir_vec(b, comps, 2);
206 
207    /* 4 bytes or less needed, select which of the 32bit component should be
208     * used and return it. The sub-32bit split is handled in
209     * extract_comps_from_vec32().
210     */
211    cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x4)),
212                                  nir_imm_int(b, 0));
213    return nir_bcsel(b, cond, comps[1], comps[0]);
214 }
215 
216 nir_ssa_def *
build_load_ubo_dxil(nir_builder * b,nir_ssa_def * buffer,nir_ssa_def * offset,unsigned num_components,unsigned bit_size)217 build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer,
218                     nir_ssa_def *offset, unsigned num_components,
219                     unsigned bit_size)
220 {
221    nir_ssa_def *idx = nir_ushr(b, offset, nir_imm_int(b, 4));
222    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
223    unsigned num_bits = num_components * bit_size;
224    unsigned comp_idx = 0;
225 
226    /* We need to split loads in 16byte chunks because that's the
227     * granularity of cBufferLoadLegacy().
228     */
229    for (unsigned i = 0; i < num_bits; i += (16 * 8)) {
230       /* For each 16byte chunk (or smaller) we generate a 32bit ubo vec
231        * load.
232        */
233       unsigned subload_num_bits = MIN2(num_bits - i, 16 * 8);
234       nir_ssa_def *vec32 =
235          nir_load_ubo_dxil(b, 4, 32, buffer, nir_iadd(b, idx, nir_imm_int(b, i / (16 * 8))));
236 
237       /* First re-arrange the vec32 to account for intra 16-byte offset. */
238       vec32 = ubo_load_select_32b_comps(b, vec32, offset, subload_num_bits / 8);
239 
240       /* If we have 2 bytes or less to load we need to adjust the u32 value so
241        * we can always extract the LSB.
242        */
243       if (subload_num_bits <= 16) {
244          nir_ssa_def *shift = nir_imul(b, nir_iand(b, offset,
245                                                       nir_imm_int(b, 3)),
246                                           nir_imm_int(b, 8));
247          vec32 = nir_ushr(b, vec32, shift);
248       }
249 
250       /* And now comes the pack/unpack step to match the original type. */
251       extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
252                                subload_num_bits / bit_size);
253       comp_idx += subload_num_bits / bit_size;
254    }
255 
256    assert(comp_idx == num_components);
257    return nir_vec(b, comps, num_components);
258 }
259 
260 static bool
lower_load_ssbo(nir_builder * b,nir_intrinsic_instr * intr)261 lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
262 {
263    assert(intr->dest.is_ssa);
264    assert(intr->src[0].is_ssa);
265    assert(intr->src[1].is_ssa);
266 
267    b->cursor = nir_before_instr(&intr->instr);
268 
269    nir_ssa_def *buffer = intr->src[0].ssa;
270    nir_ssa_def *offset = nir_iand(b, intr->src[1].ssa, nir_imm_int(b, ~3));
271    enum gl_access_qualifier access = nir_intrinsic_access(intr);
272    unsigned bit_size = nir_dest_bit_size(intr->dest);
273    unsigned num_components = nir_dest_num_components(intr->dest);
274    unsigned num_bits = num_components * bit_size;
275 
276    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
277    unsigned comp_idx = 0;
278 
279    /* We need to split loads in 16byte chunks because that's the optimal
280     * granularity of bufferLoad(). Minimum alignment is 4byte, which saves
281     * from us from extra complexity to extract >= 32 bit components.
282     */
283    for (unsigned i = 0; i < num_bits; i += 4 * 32) {
284       /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
285        * load.
286        */
287       unsigned subload_num_bits = MIN2(num_bits - i, 4 * 32);
288 
289       /* The number of components to store depends on the number of bytes. */
290       nir_ssa_def *vec32 =
291          nir_load_ssbo(b, DIV_ROUND_UP(subload_num_bits, 32), 32,
292                        buffer, nir_iadd(b, offset, nir_imm_int(b, i / 8)),
293                        .align_mul = 4,
294                        .align_offset = 0,
295                        .access = access);
296 
297       /* If we have 2 bytes or less to load we need to adjust the u32 value so
298        * we can always extract the LSB.
299        */
300       if (subload_num_bits <= 16) {
301          nir_ssa_def *shift = nir_imul(b, nir_iand(b, intr->src[1].ssa, nir_imm_int(b, 3)),
302                                           nir_imm_int(b, 8));
303          vec32 = nir_ushr(b, vec32, shift);
304       }
305 
306       /* And now comes the pack/unpack step to match the original type. */
307       extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
308                                subload_num_bits / bit_size);
309       comp_idx += subload_num_bits / bit_size;
310    }
311 
312    assert(comp_idx == num_components);
313    nir_ssa_def *result = nir_vec(b, comps, num_components);
314    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
315    nir_instr_remove(&intr->instr);
316    return true;
317 }
318 
319 static bool
lower_store_ssbo(nir_builder * b,nir_intrinsic_instr * intr)320 lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
321 {
322    b->cursor = nir_before_instr(&intr->instr);
323 
324    assert(intr->src[0].is_ssa);
325    assert(intr->src[1].is_ssa);
326    assert(intr->src[2].is_ssa);
327 
328    nir_ssa_def *val = intr->src[0].ssa;
329    nir_ssa_def *buffer = intr->src[1].ssa;
330    nir_ssa_def *offset = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, ~3));
331 
332    unsigned bit_size = val->bit_size;
333    unsigned num_components = val->num_components;
334    unsigned num_bits = num_components * bit_size;
335 
336    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
337    unsigned comp_idx = 0;
338 
339    for (unsigned i = 0; i < num_components; i++)
340       comps[i] = nir_channel(b, val, i);
341 
342    /* We split stores in 16byte chunks because that's the optimal granularity
343     * of bufferStore(). Minimum alignment is 4byte, which saves from us from
344     * extra complexity to store >= 32 bit components.
345     */
346    for (unsigned i = 0; i < num_bits; i += 4 * 32) {
347       /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
348        * store.
349        */
350       unsigned substore_num_bits = MIN2(num_bits - i, 4 * 32);
351       nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
352       nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
353                                                substore_num_bits / bit_size);
354       nir_intrinsic_instr *store;
355 
356       if (substore_num_bits < 32) {
357          nir_ssa_def *mask = nir_imm_int(b, (1 << substore_num_bits) - 1);
358 
359         /* If we have 16 bits or less to store we need to place them
360          * correctly in the u32 component. Anything greater than 16 bits
361          * (including uchar3) is naturally aligned on 32bits.
362          */
363          if (substore_num_bits <= 16) {
364             nir_ssa_def *pos = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, 3));
365             nir_ssa_def *shift = nir_imul_imm(b, pos, 8);
366 
367             vec32 = nir_ishl(b, vec32, shift);
368             mask = nir_ishl(b, mask, shift);
369          }
370 
371          store = nir_intrinsic_instr_create(b->shader,
372                                             nir_intrinsic_store_ssbo_masked_dxil);
373          store->src[0] = nir_src_for_ssa(vec32);
374          store->src[1] = nir_src_for_ssa(nir_inot(b, mask));
375          store->src[2] = nir_src_for_ssa(buffer);
376          store->src[3] = nir_src_for_ssa(local_offset);
377       } else {
378          store = nir_intrinsic_instr_create(b->shader,
379                                             nir_intrinsic_store_ssbo);
380          store->src[0] = nir_src_for_ssa(vec32);
381          store->src[1] = nir_src_for_ssa(buffer);
382          store->src[2] = nir_src_for_ssa(local_offset);
383 
384          nir_intrinsic_set_align(store, 4, 0);
385       }
386 
387       /* The number of components to store depends on the number of bits. */
388       store->num_components = DIV_ROUND_UP(substore_num_bits, 32);
389       nir_builder_instr_insert(b, &store->instr);
390       comp_idx += substore_num_bits / bit_size;
391    }
392 
393    nir_instr_remove(&intr->instr);
394    return true;
395 }
396 
397 static void
lower_load_vec32(nir_builder * b,nir_ssa_def * index,unsigned num_comps,nir_ssa_def ** comps,nir_intrinsic_op op)398 lower_load_vec32(nir_builder *b, nir_ssa_def *index, unsigned num_comps, nir_ssa_def **comps, nir_intrinsic_op op)
399 {
400    for (unsigned i = 0; i < num_comps; i++) {
401       nir_intrinsic_instr *load =
402          nir_intrinsic_instr_create(b->shader, op);
403 
404       load->num_components = 1;
405       load->src[0] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
406       nir_ssa_dest_init(&load->instr, &load->dest, 1, 32, NULL);
407       nir_builder_instr_insert(b, &load->instr);
408       comps[i] = &load->dest.ssa;
409    }
410 }
411 
412 static bool
lower_32b_offset_load(nir_builder * b,nir_intrinsic_instr * intr)413 lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
414 {
415    assert(intr->dest.is_ssa);
416    unsigned bit_size = nir_dest_bit_size(intr->dest);
417    unsigned num_components = nir_dest_num_components(intr->dest);
418    unsigned num_bits = num_components * bit_size;
419 
420    b->cursor = nir_before_instr(&intr->instr);
421    nir_intrinsic_op op = intr->intrinsic;
422 
423    assert(intr->src[0].is_ssa);
424    nir_ssa_def *offset = intr->src[0].ssa;
425    if (op == nir_intrinsic_load_shared) {
426       offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
427       op = nir_intrinsic_load_shared_dxil;
428    } else {
429       offset = nir_u2u32(b, offset);
430       op = nir_intrinsic_load_scratch_dxil;
431    }
432    nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
433    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
434    nir_ssa_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
435 
436    /* We need to split loads in 32-bit accesses because the buffer
437     * is an i32 array and DXIL does not support type casts.
438     */
439    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
440    lower_load_vec32(b, index, num_32bit_comps, comps_32bit, op);
441    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
442 
443    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
444       unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
445       unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
446       nir_ssa_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
447 
448       /* If we have 16 bits or less to load we need to adjust the u32 value so
449        * we can always extract the LSB.
450        */
451       if (num_bits <= 16) {
452          nir_ssa_def *shift =
453             nir_imul(b, nir_iand(b, offset, nir_imm_int(b, 3)),
454                         nir_imm_int(b, 8));
455          vec32 = nir_ushr(b, vec32, shift);
456       }
457 
458       /* And now comes the pack/unpack step to match the original type. */
459       unsigned dest_index = i * 32 / bit_size;
460       extract_comps_from_vec32(b, vec32, bit_size, &comps[dest_index], num_dest_comps);
461    }
462 
463    nir_ssa_def *result = nir_vec(b, comps, num_components);
464    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
465    nir_instr_remove(&intr->instr);
466 
467    return true;
468 }
469 
470 static void
lower_store_vec32(nir_builder * b,nir_ssa_def * index,nir_ssa_def * vec32,nir_intrinsic_op op)471 lower_store_vec32(nir_builder *b, nir_ssa_def *index, nir_ssa_def *vec32, nir_intrinsic_op op)
472 {
473 
474    for (unsigned i = 0; i < vec32->num_components; i++) {
475       nir_intrinsic_instr *store =
476          nir_intrinsic_instr_create(b->shader, op);
477 
478       store->src[0] = nir_src_for_ssa(nir_channel(b, vec32, i));
479       store->src[1] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
480       store->num_components = 1;
481       nir_builder_instr_insert(b, &store->instr);
482    }
483 }
484 
485 static void
lower_masked_store_vec32(nir_builder * b,nir_ssa_def * offset,nir_ssa_def * index,nir_ssa_def * vec32,unsigned num_bits,nir_intrinsic_op op)486 lower_masked_store_vec32(nir_builder *b, nir_ssa_def *offset, nir_ssa_def *index,
487                          nir_ssa_def *vec32, unsigned num_bits, nir_intrinsic_op op)
488 {
489    nir_ssa_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
490 
491    /* If we have 16 bits or less to store we need to place them correctly in
492     * the u32 component. Anything greater than 16 bits (including uchar3) is
493     * naturally aligned on 32bits.
494     */
495    if (num_bits <= 16) {
496       nir_ssa_def *shift =
497          nir_imul_imm(b, nir_iand(b, offset, nir_imm_int(b, 3)), 8);
498 
499       vec32 = nir_ishl(b, vec32, shift);
500       mask = nir_ishl(b, mask, shift);
501    }
502 
503    if (op == nir_intrinsic_store_shared_dxil) {
504       /* Use the dedicated masked intrinsic */
505       nir_store_shared_masked_dxil(b, vec32, nir_inot(b, mask), index);
506    } else {
507       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
508       nir_ssa_def *load = nir_load_scratch_dxil(b, 1, 32, index);
509 
510       nir_ssa_def *new_val = nir_ior(b, vec32,
511                                      nir_iand(b,
512                                               nir_inot(b, mask),
513                                               load));
514 
515       lower_store_vec32(b, index, new_val, op);
516    }
517 }
518 
519 static bool
lower_32b_offset_store(nir_builder * b,nir_intrinsic_instr * intr)520 lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
521 {
522    assert(intr->src[0].is_ssa);
523    unsigned num_components = nir_src_num_components(intr->src[0]);
524    unsigned bit_size = nir_src_bit_size(intr->src[0]);
525    unsigned num_bits = num_components * bit_size;
526 
527    b->cursor = nir_before_instr(&intr->instr);
528    nir_intrinsic_op op = intr->intrinsic;
529 
530    nir_ssa_def *offset = intr->src[1].ssa;
531    if (op == nir_intrinsic_store_shared) {
532       offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
533       op = nir_intrinsic_store_shared_dxil;
534    } else {
535       offset = nir_u2u32(b, offset);
536       op = nir_intrinsic_store_scratch_dxil;
537    }
538    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
539 
540    unsigned comp_idx = 0;
541    for (unsigned i = 0; i < num_components; i++)
542       comps[i] = nir_channel(b, intr->src[0].ssa, i);
543 
544    for (unsigned i = 0; i < num_bits; i += 4 * 32) {
545       /* For each 4byte chunk (or smaller) we generate a 32bit scalar store.
546        */
547       unsigned substore_num_bits = MIN2(num_bits - i, 4 * 32);
548       nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
549       nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
550                                                substore_num_bits / bit_size);
551       nir_ssa_def *index = nir_ushr(b, local_offset, nir_imm_int(b, 2));
552 
553       /* For anything less than 32bits we need to use the masked version of the
554        * intrinsic to preserve data living in the same 32bit slot.
555        */
556       if (num_bits < 32) {
557          lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, op);
558       } else {
559          lower_store_vec32(b, index, vec32, op);
560       }
561 
562       comp_idx += substore_num_bits / bit_size;
563    }
564 
565    nir_instr_remove(&intr->instr);
566 
567    return true;
568 }
569 
570 static void
ubo_to_temp_patch_deref_mode(nir_deref_instr * deref)571 ubo_to_temp_patch_deref_mode(nir_deref_instr *deref)
572 {
573    deref->modes = nir_var_shader_temp;
574    nir_foreach_use(use_src, &deref->dest.ssa) {
575       if (use_src->parent_instr->type != nir_instr_type_deref)
576 	 continue;
577 
578       nir_deref_instr *parent = nir_instr_as_deref(use_src->parent_instr);
579       ubo_to_temp_patch_deref_mode(parent);
580    }
581 }
582 
583 static void
ubo_to_temp_update_entry(nir_deref_instr * deref,struct hash_entry * he)584 ubo_to_temp_update_entry(nir_deref_instr *deref, struct hash_entry *he)
585 {
586    assert(nir_deref_mode_is(deref, nir_var_mem_constant));
587    assert(deref->dest.is_ssa);
588    assert(he->data);
589 
590    nir_foreach_use(use_src, &deref->dest.ssa) {
591       if (use_src->parent_instr->type == nir_instr_type_deref) {
592          ubo_to_temp_update_entry(nir_instr_as_deref(use_src->parent_instr), he);
593       } else if (use_src->parent_instr->type == nir_instr_type_intrinsic) {
594          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(use_src->parent_instr);
595          if (intr->intrinsic != nir_intrinsic_load_deref)
596             he->data = NULL;
597       } else {
598          he->data = NULL;
599       }
600 
601       if (!he->data)
602          break;
603    }
604 }
605 
606 bool
dxil_nir_lower_ubo_to_temp(nir_shader * nir)607 dxil_nir_lower_ubo_to_temp(nir_shader *nir)
608 {
609    struct hash_table *ubo_to_temp = _mesa_pointer_hash_table_create(NULL);
610    bool progress = false;
611 
612    /* First pass: collect all UBO accesses that could be turned into
613     * shader temp accesses.
614     */
615    foreach_list_typed(nir_function, func, node, &nir->functions) {
616       if (!func->is_entrypoint)
617          continue;
618       assert(func->impl);
619 
620       nir_foreach_block(block, func->impl) {
621          nir_foreach_instr_safe(instr, block) {
622             if (instr->type != nir_instr_type_deref)
623                continue;
624 
625             nir_deref_instr *deref = nir_instr_as_deref(instr);
626             if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
627                 deref->deref_type != nir_deref_type_var)
628                   continue;
629 
630             struct hash_entry *he =
631                _mesa_hash_table_search(ubo_to_temp, deref->var);
632 
633             if (!he)
634                he = _mesa_hash_table_insert(ubo_to_temp, deref->var, deref->var);
635 
636             if (!he->data)
637                continue;
638 
639             ubo_to_temp_update_entry(deref, he);
640          }
641       }
642    }
643 
644    hash_table_foreach(ubo_to_temp, he) {
645       nir_variable *var = he->data;
646 
647       if (!var)
648          continue;
649 
650       /* Change the variable mode. */
651       var->data.mode = nir_var_shader_temp;
652 
653       /* Make sure the variable has a name.
654        * DXIL variables must have names.
655        */
656       if (!var->name)
657          var->name = ralloc_asprintf(nir, "global_%d", exec_list_length(&nir->variables));
658 
659       progress = true;
660    }
661    _mesa_hash_table_destroy(ubo_to_temp, NULL);
662 
663    /* Second pass: patch all derefs that were accessing the converted UBOs
664     * variables.
665     */
666    foreach_list_typed(nir_function, func, node, &nir->functions) {
667       if (!func->is_entrypoint)
668          continue;
669       assert(func->impl);
670 
671       nir_foreach_block(block, func->impl) {
672          nir_foreach_instr_safe(instr, block) {
673             if (instr->type != nir_instr_type_deref)
674                continue;
675 
676             nir_deref_instr *deref = nir_instr_as_deref(instr);
677             if (nir_deref_mode_is(deref, nir_var_mem_constant) &&
678                 deref->deref_type == nir_deref_type_var &&
679                 deref->var->data.mode == nir_var_shader_temp)
680                ubo_to_temp_patch_deref_mode(deref);
681          }
682       }
683    }
684 
685    return progress;
686 }
687 
688 static bool
lower_load_ubo(nir_builder * b,nir_intrinsic_instr * intr)689 lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intr)
690 {
691    assert(intr->dest.is_ssa);
692    assert(intr->src[0].is_ssa);
693    assert(intr->src[1].is_ssa);
694 
695    b->cursor = nir_before_instr(&intr->instr);
696 
697    nir_ssa_def *result =
698       build_load_ubo_dxil(b, intr->src[0].ssa, intr->src[1].ssa,
699                              nir_dest_num_components(intr->dest),
700                              nir_dest_bit_size(intr->dest));
701 
702    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
703    nir_instr_remove(&intr->instr);
704    return true;
705 }
706 
707 bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader * nir)708 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
709 {
710    bool progress = false;
711 
712    foreach_list_typed(nir_function, func, node, &nir->functions) {
713       if (!func->is_entrypoint)
714          continue;
715       assert(func->impl);
716 
717       nir_builder b;
718       nir_builder_init(&b, func->impl);
719 
720       nir_foreach_block(block, func->impl) {
721          nir_foreach_instr_safe(instr, block) {
722             if (instr->type != nir_instr_type_intrinsic)
723                continue;
724             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
725 
726             switch (intr->intrinsic) {
727             case nir_intrinsic_load_deref:
728                progress |= lower_load_deref(&b, intr);
729                break;
730             case nir_intrinsic_load_shared:
731             case nir_intrinsic_load_scratch:
732                progress |= lower_32b_offset_load(&b, intr);
733                break;
734             case nir_intrinsic_load_ssbo:
735                progress |= lower_load_ssbo(&b, intr);
736                break;
737             case nir_intrinsic_load_ubo:
738                progress |= lower_load_ubo(&b, intr);
739                break;
740             case nir_intrinsic_store_shared:
741             case nir_intrinsic_store_scratch:
742                progress |= lower_32b_offset_store(&b, intr);
743                break;
744             case nir_intrinsic_store_ssbo:
745                progress |= lower_store_ssbo(&b, intr);
746                break;
747             default:
748                break;
749             }
750          }
751       }
752    }
753 
754    return progress;
755 }
756 
757 static bool
lower_shared_atomic(nir_builder * b,nir_intrinsic_instr * intr,nir_intrinsic_op dxil_op)758 lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr,
759                     nir_intrinsic_op dxil_op)
760 {
761    b->cursor = nir_before_instr(&intr->instr);
762 
763    assert(intr->src[0].is_ssa);
764    nir_ssa_def *offset =
765       nir_iadd(b, intr->src[0].ssa, nir_imm_int(b, nir_intrinsic_base(intr)));
766    nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
767 
768    nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b->shader, dxil_op);
769    atomic->src[0] = nir_src_for_ssa(index);
770    assert(intr->src[1].is_ssa);
771    atomic->src[1] = nir_src_for_ssa(intr->src[1].ssa);
772    if (dxil_op == nir_intrinsic_shared_atomic_comp_swap_dxil) {
773       assert(intr->src[2].is_ssa);
774       atomic->src[2] = nir_src_for_ssa(intr->src[2].ssa);
775    }
776    atomic->num_components = 0;
777    nir_ssa_dest_init(&atomic->instr, &atomic->dest, 1, 32, NULL);
778 
779    nir_builder_instr_insert(b, &atomic->instr);
780    nir_ssa_def_rewrite_uses(&intr->dest.ssa, &atomic->dest.ssa);
781    nir_instr_remove(&intr->instr);
782    return true;
783 }
784 
785 bool
dxil_nir_lower_atomics_to_dxil(nir_shader * nir)786 dxil_nir_lower_atomics_to_dxil(nir_shader *nir)
787 {
788    bool progress = false;
789 
790    foreach_list_typed(nir_function, func, node, &nir->functions) {
791       if (!func->is_entrypoint)
792          continue;
793       assert(func->impl);
794 
795       nir_builder b;
796       nir_builder_init(&b, func->impl);
797 
798       nir_foreach_block(block, func->impl) {
799          nir_foreach_instr_safe(instr, block) {
800             if (instr->type != nir_instr_type_intrinsic)
801                continue;
802             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
803 
804             switch (intr->intrinsic) {
805 
806 #define ATOMIC(op)                                                            \
807   case nir_intrinsic_shared_atomic_##op:                                     \
808      progress |= lower_shared_atomic(&b, intr,                                \
809                                      nir_intrinsic_shared_atomic_##op##_dxil); \
810      break
811 
812             ATOMIC(add);
813             ATOMIC(imin);
814             ATOMIC(umin);
815             ATOMIC(imax);
816             ATOMIC(umax);
817             ATOMIC(and);
818             ATOMIC(or);
819             ATOMIC(xor);
820             ATOMIC(exchange);
821             ATOMIC(comp_swap);
822 
823 #undef ATOMIC
824             default:
825                break;
826             }
827          }
828       }
829    }
830 
831    return progress;
832 }
833 
834 static bool
lower_deref_ssbo(nir_builder * b,nir_deref_instr * deref)835 lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
836 {
837    assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
838    assert(deref->deref_type == nir_deref_type_var ||
839           deref->deref_type == nir_deref_type_cast);
840    nir_variable *var = deref->var;
841 
842    b->cursor = nir_before_instr(&deref->instr);
843 
844    if (deref->deref_type == nir_deref_type_var) {
845       /* We turn all deref_var into deref_cast and build a pointer value based on
846        * the var binding which encodes the UAV id.
847        */
848       nir_ssa_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
849       nir_deref_instr *deref_cast =
850          nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
851                               glsl_get_explicit_stride(var->type));
852       nir_ssa_def_rewrite_uses(&deref->dest.ssa,
853                                &deref_cast->dest.ssa);
854       nir_instr_remove(&deref->instr);
855 
856       deref = deref_cast;
857       return true;
858    }
859    return false;
860 }
861 
862 bool
dxil_nir_lower_deref_ssbo(nir_shader * nir)863 dxil_nir_lower_deref_ssbo(nir_shader *nir)
864 {
865    bool progress = false;
866 
867    foreach_list_typed(nir_function, func, node, &nir->functions) {
868       if (!func->is_entrypoint)
869          continue;
870       assert(func->impl);
871 
872       nir_builder b;
873       nir_builder_init(&b, func->impl);
874 
875       nir_foreach_block(block, func->impl) {
876          nir_foreach_instr_safe(instr, block) {
877             if (instr->type != nir_instr_type_deref)
878                continue;
879 
880             nir_deref_instr *deref = nir_instr_as_deref(instr);
881 
882             if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
883                 (deref->deref_type != nir_deref_type_var &&
884                  deref->deref_type != nir_deref_type_cast))
885                continue;
886 
887             progress |= lower_deref_ssbo(&b, deref);
888          }
889       }
890    }
891 
892    return progress;
893 }
894 
895 static bool
lower_alu_deref_srcs(nir_builder * b,nir_alu_instr * alu)896 lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
897 {
898    const nir_op_info *info = &nir_op_infos[alu->op];
899    bool progress = false;
900 
901    b->cursor = nir_before_instr(&alu->instr);
902 
903    for (unsigned i = 0; i < info->num_inputs; i++) {
904       nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
905 
906       if (!deref)
907          continue;
908 
909       nir_deref_path path;
910       nir_deref_path_init(&path, deref, NULL);
911       nir_deref_instr *root_deref = path.path[0];
912       nir_deref_path_finish(&path);
913 
914       if (root_deref->deref_type != nir_deref_type_cast)
915          continue;
916 
917       nir_ssa_def *ptr =
918          nir_iadd(b, root_deref->parent.ssa,
919                      nir_build_deref_offset(b, deref, cl_type_size_align));
920       nir_instr_rewrite_src(&alu->instr, &alu->src[i].src, nir_src_for_ssa(ptr));
921       progress = true;
922    }
923 
924    return progress;
925 }
926 
927 bool
dxil_nir_opt_alu_deref_srcs(nir_shader * nir)928 dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
929 {
930    bool progress = false;
931 
932    foreach_list_typed(nir_function, func, node, &nir->functions) {
933       if (!func->is_entrypoint)
934          continue;
935       assert(func->impl);
936 
937       bool progress = false;
938       nir_builder b;
939       nir_builder_init(&b, func->impl);
940 
941       nir_foreach_block(block, func->impl) {
942          nir_foreach_instr_safe(instr, block) {
943             if (instr->type != nir_instr_type_alu)
944                continue;
945 
946             nir_alu_instr *alu = nir_instr_as_alu(instr);
947             progress |= lower_alu_deref_srcs(&b, alu);
948          }
949       }
950    }
951 
952    return progress;
953 }
954 
955 static nir_ssa_def *
memcpy_load_deref_elem(nir_builder * b,nir_deref_instr * parent,nir_ssa_def * index)956 memcpy_load_deref_elem(nir_builder *b, nir_deref_instr *parent,
957                        nir_ssa_def *index)
958 {
959    nir_deref_instr *deref;
960 
961    index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
962    assert(parent->deref_type == nir_deref_type_cast);
963    deref = nir_build_deref_ptr_as_array(b, parent, index);
964 
965    return nir_load_deref(b, deref);
966 }
967 
968 static void
memcpy_store_deref_elem(nir_builder * b,nir_deref_instr * parent,nir_ssa_def * index,nir_ssa_def * value)969 memcpy_store_deref_elem(nir_builder *b, nir_deref_instr *parent,
970                         nir_ssa_def *index, nir_ssa_def *value)
971 {
972    nir_deref_instr *deref;
973 
974    index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
975    assert(parent->deref_type == nir_deref_type_cast);
976    deref = nir_build_deref_ptr_as_array(b, parent, index);
977    nir_store_deref(b, deref, value, 1);
978 }
979 
980 static bool
lower_memcpy_deref(nir_builder * b,nir_intrinsic_instr * intr)981 lower_memcpy_deref(nir_builder *b, nir_intrinsic_instr *intr)
982 {
983    nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
984    nir_deref_instr *src_deref = nir_src_as_deref(intr->src[1]);
985    assert(intr->src[2].is_ssa);
986    nir_ssa_def *num_bytes = intr->src[2].ssa;
987 
988    assert(dst_deref && src_deref);
989 
990    b->cursor = nir_after_instr(&intr->instr);
991 
992    dst_deref = nir_build_deref_cast(b, &dst_deref->dest.ssa, dst_deref->modes,
993                                        glsl_uint8_t_type(), 1);
994    src_deref = nir_build_deref_cast(b, &src_deref->dest.ssa, src_deref->modes,
995                                        glsl_uint8_t_type(), 1);
996 
997    /*
998     * We want to avoid 64b instructions, so let's assume we'll always be
999     * passed a value that fits in a 32b type and truncate the 64b value.
1000     */
1001    num_bytes = nir_u2u32(b, num_bytes);
1002 
1003    nir_variable *loop_index_var =
1004      nir_local_variable_create(b->impl, glsl_uint_type(), "loop_index");
1005    nir_deref_instr *loop_index_deref = nir_build_deref_var(b, loop_index_var);
1006    nir_store_deref(b, loop_index_deref, nir_imm_int(b, 0), 1);
1007 
1008    nir_loop *loop = nir_push_loop(b);
1009    nir_ssa_def *loop_index = nir_load_deref(b, loop_index_deref);
1010    nir_ssa_def *cmp = nir_ige(b, loop_index, num_bytes);
1011    nir_if *loop_check = nir_push_if(b, cmp);
1012    nir_jump(b, nir_jump_break);
1013    nir_pop_if(b, loop_check);
1014    nir_ssa_def *val = memcpy_load_deref_elem(b, src_deref, loop_index);
1015    memcpy_store_deref_elem(b, dst_deref, loop_index, val);
1016    nir_store_deref(b, loop_index_deref, nir_iadd_imm(b, loop_index, 1), 1);
1017    nir_pop_loop(b, loop);
1018    nir_instr_remove(&intr->instr);
1019    return true;
1020 }
1021 
1022 bool
dxil_nir_lower_memcpy_deref(nir_shader * nir)1023 dxil_nir_lower_memcpy_deref(nir_shader *nir)
1024 {
1025    bool progress = false;
1026 
1027    foreach_list_typed(nir_function, func, node, &nir->functions) {
1028       if (!func->is_entrypoint)
1029          continue;
1030       assert(func->impl);
1031 
1032       nir_builder b;
1033       nir_builder_init(&b, func->impl);
1034 
1035       nir_foreach_block(block, func->impl) {
1036          nir_foreach_instr_safe(instr, block) {
1037             if (instr->type != nir_instr_type_intrinsic)
1038                continue;
1039 
1040             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1041 
1042             if (intr->intrinsic == nir_intrinsic_memcpy_deref)
1043                progress |= lower_memcpy_deref(&b, intr);
1044          }
1045       }
1046    }
1047 
1048    return progress;
1049 }
1050 
1051 static void
cast_phi(nir_builder * b,nir_phi_instr * phi,unsigned new_bit_size)1052 cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
1053 {
1054    nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
1055    int num_components = 0;
1056    int old_bit_size = phi->dest.ssa.bit_size;
1057 
1058    nir_op upcast_op = nir_type_conversion_op(nir_type_uint | old_bit_size,
1059                                              nir_type_uint | new_bit_size,
1060                                              nir_rounding_mode_undef);
1061    nir_op downcast_op = nir_type_conversion_op(nir_type_uint | new_bit_size,
1062                                                nir_type_uint | old_bit_size,
1063                                                nir_rounding_mode_undef);
1064 
1065    nir_foreach_phi_src(src, phi) {
1066       assert(num_components == 0 || num_components == src->src.ssa->num_components);
1067       num_components = src->src.ssa->num_components;
1068 
1069       b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
1070 
1071       nir_ssa_def *cast = nir_build_alu(b, upcast_op, src->src.ssa, NULL, NULL, NULL);
1072       nir_phi_instr_add_src(lowered, src->pred, nir_src_for_ssa(cast));
1073    }
1074 
1075    nir_ssa_dest_init(&lowered->instr, &lowered->dest,
1076                      num_components, new_bit_size, NULL);
1077 
1078    b->cursor = nir_before_instr(&phi->instr);
1079    nir_builder_instr_insert(b, &lowered->instr);
1080 
1081    b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
1082    nir_ssa_def *result = nir_build_alu(b, downcast_op, &lowered->dest.ssa, NULL, NULL, NULL);
1083 
1084    nir_ssa_def_rewrite_uses(&phi->dest.ssa, result);
1085    nir_instr_remove(&phi->instr);
1086 }
1087 
1088 static bool
upcast_phi_impl(nir_function_impl * impl,unsigned min_bit_size)1089 upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
1090 {
1091    nir_builder b;
1092    nir_builder_init(&b, impl);
1093    bool progress = false;
1094 
1095    nir_foreach_block_reverse(block, impl) {
1096       nir_foreach_instr_safe(instr, block) {
1097          if (instr->type != nir_instr_type_phi)
1098             continue;
1099 
1100          nir_phi_instr *phi = nir_instr_as_phi(instr);
1101          assert(phi->dest.is_ssa);
1102 
1103          if (phi->dest.ssa.bit_size == 1 ||
1104              phi->dest.ssa.bit_size >= min_bit_size)
1105             continue;
1106 
1107          cast_phi(&b, phi, min_bit_size);
1108          progress = true;
1109       }
1110    }
1111 
1112    if (progress) {
1113       nir_metadata_preserve(impl, nir_metadata_block_index |
1114                                   nir_metadata_dominance);
1115    } else {
1116       nir_metadata_preserve(impl, nir_metadata_all);
1117    }
1118 
1119    return progress;
1120 }
1121 
1122 bool
dxil_nir_lower_upcast_phis(nir_shader * shader,unsigned min_bit_size)1123 dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
1124 {
1125    bool progress = false;
1126 
1127    nir_foreach_function(function, shader) {
1128       if (function->impl)
1129          progress |= upcast_phi_impl(function->impl, min_bit_size);
1130    }
1131 
1132    return progress;
1133 }
1134 
1135 struct dxil_nir_split_clip_cull_distance_params {
1136    nir_variable *new_var;
1137    nir_shader *shader;
1138 };
1139 
1140 /* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
1141  * In DXIL, clip and cull distances are up to 2 float4s combined.
1142  * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
1143  * we can't, and have to accept a "compact" array of scalar floats.
1144  *
1145  * To help emitting a valid input signature for this case, split the variables so that they
1146  * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
1147  */
1148 static bool
dxil_nir_split_clip_cull_distance_instr(nir_builder * b,nir_instr * instr,void * cb_data)1149 dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
1150                                         nir_instr *instr,
1151                                         void *cb_data)
1152 {
1153    struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
1154    nir_variable *new_var = params->new_var;
1155 
1156    if (instr->type != nir_instr_type_deref)
1157       return false;
1158 
1159    nir_deref_instr *deref = nir_instr_as_deref(instr);
1160    nir_variable *var = nir_deref_instr_get_variable(deref);
1161    if (!var ||
1162        var->data.location < VARYING_SLOT_CLIP_DIST0 ||
1163        var->data.location > VARYING_SLOT_CULL_DIST1 ||
1164        !var->data.compact)
1165       return false;
1166 
1167    /* The location should only be inside clip distance, because clip
1168     * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
1169     */
1170    assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
1171           var->data.location == VARYING_SLOT_CLIP_DIST1);
1172 
1173    /* The deref chain to the clip/cull variables should be simple, just the
1174     * var and an array with a constant index, otherwise more lowering/optimization
1175     * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
1176     * split_var_copies, and/or lower_var_copies
1177     */
1178    assert(deref->deref_type == nir_deref_type_var ||
1179           deref->deref_type == nir_deref_type_array);
1180 
1181    b->cursor = nir_before_instr(instr);
1182    if (!new_var) {
1183       /* Update lengths for new and old vars */
1184       int old_length = glsl_array_size(var->type);
1185       int new_length = (old_length + var->data.location_frac) - 4;
1186       old_length -= new_length;
1187 
1188       /* The existing variable fits in the float4 */
1189       if (new_length <= 0)
1190          return false;
1191 
1192       new_var = nir_variable_clone(var, params->shader);
1193       nir_shader_add_variable(params->shader, new_var);
1194       assert(glsl_get_base_type(glsl_get_array_element(var->type)) == GLSL_TYPE_FLOAT);
1195       var->type = glsl_array_type(glsl_float_type(), old_length, 0);
1196       new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
1197       new_var->data.location++;
1198       new_var->data.location_frac = 0;
1199       params->new_var = new_var;
1200    }
1201 
1202    /* Update the type for derefs of the old var */
1203    if (deref->deref_type == nir_deref_type_var) {
1204       deref->type = var->type;
1205       return false;
1206    }
1207 
1208    nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1209    assert(index);
1210 
1211    /* Treat this array as a vector starting at the component index in location_frac,
1212     * so if location_frac is 1 and index is 0, then it's accessing the 'y' component
1213     * of the vector. If index + location_frac is >= 4, there's no component there,
1214     * so we need to add a new variable and adjust the index.
1215     */
1216    unsigned total_index = index->u32 + var->data.location_frac;
1217    if (total_index < 4)
1218       return false;
1219 
1220    nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1221    nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_var_deref, nir_imm_int(b, total_index % 4));
1222    nir_ssa_def_rewrite_uses(&deref->dest.ssa, &new_array_deref->dest.ssa);
1223    return true;
1224 }
1225 
1226 bool
dxil_nir_split_clip_cull_distance(nir_shader * shader)1227 dxil_nir_split_clip_cull_distance(nir_shader *shader)
1228 {
1229    struct dxil_nir_split_clip_cull_distance_params params = {
1230       .new_var = NULL,
1231       .shader = shader,
1232    };
1233    nir_shader_instructions_pass(shader,
1234                                 dxil_nir_split_clip_cull_distance_instr,
1235                                 nir_metadata_block_index |
1236                                 nir_metadata_dominance |
1237                                 nir_metadata_loop_analysis,
1238                                 &params);
1239    return params.new_var != NULL;
1240 }
1241 
1242 static bool
dxil_nir_lower_double_math_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)1243 dxil_nir_lower_double_math_instr(nir_builder *b,
1244                                  nir_instr *instr,
1245                                  UNUSED void *cb_data)
1246 {
1247    if (instr->type != nir_instr_type_alu)
1248       return false;
1249 
1250    nir_alu_instr *alu = nir_instr_as_alu(instr);
1251 
1252    /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1253     * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1254     * then try to bitcast to double (not expressible in HLSL, but it is in other
1255     * source languages), this would unpack the integer and repack as a double, when
1256     * we probably want to just send the bitcast through to the backend.
1257     */
1258 
1259    b->cursor = nir_before_instr(&alu->instr);
1260 
1261    bool progress = false;
1262    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1263       if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1264           alu->src[i].src.ssa->bit_size == 64) {
1265          nir_ssa_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[0]);
1266          nir_ssa_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1267          nir_ssa_def *repacked_double = nir_pack_double_2x32_dxil(b, unpacked_double);
1268          nir_instr_rewrite_src_ssa(instr, &alu->src[i].src, repacked_double);
1269          memset(alu->src[i].swizzle, 0, ARRAY_SIZE(alu->src[i].swizzle));
1270          progress = true;
1271       }
1272    }
1273 
1274    if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1275        alu->dest.dest.ssa.bit_size == 64) {
1276       b->cursor = nir_after_instr(&alu->instr);
1277       nir_ssa_def *packed_double = &alu->dest.dest.ssa;
1278       nir_ssa_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1279       nir_ssa_def *repacked_double = nir_pack_64_2x32(b, unpacked_double);
1280       nir_ssa_def_rewrite_uses_after(packed_double, repacked_double, unpacked_double->parent_instr);
1281       progress = true;
1282    }
1283 
1284    return progress;
1285 }
1286 
1287 bool
dxil_nir_lower_double_math(nir_shader * shader)1288 dxil_nir_lower_double_math(nir_shader *shader)
1289 {
1290    return nir_shader_instructions_pass(shader,
1291                                        dxil_nir_lower_double_math_instr,
1292                                        nir_metadata_block_index |
1293                                        nir_metadata_dominance |
1294                                        nir_metadata_loop_analysis,
1295                                        NULL);
1296 }
1297 
1298 typedef struct {
1299    gl_system_value *values;
1300    uint32_t count;
1301 } zero_system_values_state;
1302 
1303 static bool
lower_system_value_to_zero_filter(const nir_instr * instr,const void * cb_state)1304 lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1305 {
1306    if (instr->type != nir_instr_type_intrinsic) {
1307       return false;
1308    }
1309 
1310    nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1311 
1312    /* All the intrinsics we care about are loads */
1313    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1314       return false;
1315 
1316    assert(intrin->dest.is_ssa);
1317 
1318    zero_system_values_state* state = (zero_system_values_state*)cb_state;
1319    for (uint32_t i = 0; i < state->count; ++i) {
1320       gl_system_value value = state->values[i];
1321       nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1322 
1323       if (intrin->intrinsic == value_op) {
1324          return true;
1325       } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1326          nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1327          if (!nir_deref_mode_is(deref, nir_var_system_value))
1328             return false;
1329 
1330          nir_variable* var = deref->var;
1331          if (var->data.location == value) {
1332             return true;
1333          }
1334       }
1335    }
1336 
1337    return false;
1338 }
1339 
1340 static nir_ssa_def*
lower_system_value_to_zero_instr(nir_builder * b,nir_instr * instr,void * _state)1341 lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1342 {
1343    return nir_imm_int(b, 0);
1344 }
1345 
1346 bool
dxil_nir_lower_system_values_to_zero(nir_shader * shader,gl_system_value * system_values,uint32_t count)1347 dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1348                                      gl_system_value* system_values,
1349                                      uint32_t count)
1350 {
1351    zero_system_values_state state = { system_values, count };
1352    return nir_shader_lower_instructions(shader,
1353       lower_system_value_to_zero_filter,
1354       lower_system_value_to_zero_instr,
1355       &state);
1356 }
1357 
1358 static const struct glsl_type *
get_bare_samplers_for_type(const struct glsl_type * type)1359 get_bare_samplers_for_type(const struct glsl_type *type)
1360 {
1361    if (glsl_type_is_sampler(type)) {
1362       if (glsl_sampler_type_is_shadow(type))
1363          return glsl_bare_shadow_sampler_type();
1364       else
1365          return glsl_bare_sampler_type();
1366    } else if (glsl_type_is_array(type)) {
1367       return glsl_array_type(
1368          get_bare_samplers_for_type(glsl_get_array_element(type)),
1369          glsl_get_length(type),
1370          0 /*explicit size*/);
1371    }
1372    assert(!"Unexpected type");
1373    return NULL;
1374 }
1375 
1376 static bool
redirect_sampler_derefs(struct nir_builder * b,nir_instr * instr,void * data)1377 redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1378 {
1379    if (instr->type != nir_instr_type_tex)
1380       return false;
1381 
1382    nir_tex_instr *tex = nir_instr_as_tex(instr);
1383    if (!nir_tex_instr_need_sampler(tex))
1384       return false;
1385 
1386    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1387    if (sampler_idx == -1) {
1388       /* No derefs, must be using indices */
1389       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1390 
1391       /* Already have a bare sampler here */
1392       if (bare_sampler)
1393          return false;
1394 
1395       nir_variable *typed_sampler = NULL;
1396       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1397          if (var->data.binding <= tex->sampler_index &&
1398              var->data.binding + glsl_type_get_sampler_count(var->type) > tex->sampler_index) {
1399             /* Already have a bare sampler for this binding, add it to the table */
1400             if (glsl_get_sampler_result_type(glsl_without_array(var->type)) == GLSL_TYPE_VOID) {
1401                _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1402                return false;
1403             }
1404 
1405             typed_sampler = var;
1406          }
1407       }
1408 
1409       /* Clone the typed sampler to a bare sampler and we're done */
1410       assert(typed_sampler);
1411       bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1412       bare_sampler->type = get_bare_samplers_for_type(typed_sampler->type);
1413       nir_shader_add_variable(b->shader, bare_sampler);
1414       _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1415       return true;
1416    }
1417 
1418    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1419    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1420    nir_deref_path path;
1421    nir_deref_path_init(&path, final_deref, NULL);
1422 
1423    nir_deref_instr *old_tail = path.path[0];
1424    assert(old_tail->deref_type == nir_deref_type_var);
1425    nir_variable *old_var = old_tail->var;
1426    if (glsl_get_sampler_result_type(glsl_without_array(old_var->type)) == GLSL_TYPE_VOID) {
1427       nir_deref_path_finish(&path);
1428       return false;
1429    }
1430 
1431    nir_variable *new_var = _mesa_hash_table_u64_search(data, old_var->data.binding);
1432    if (!new_var) {
1433       new_var = nir_variable_clone(old_var, b->shader);
1434       new_var->type = get_bare_samplers_for_type(old_var->type);
1435       nir_shader_add_variable(b->shader, new_var);
1436       _mesa_hash_table_u64_insert(data, old_var->data.binding, new_var);
1437    }
1438 
1439    b->cursor = nir_after_instr(&old_tail->instr);
1440    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1441 
1442    for (unsigned i = 1; path.path[i]; ++i) {
1443       b->cursor = nir_after_instr(&path.path[i]->instr);
1444       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1445    }
1446 
1447    nir_deref_path_finish(&path);
1448    nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[sampler_idx].src, &new_tail->dest.ssa);
1449 
1450    return true;
1451 }
1452 
1453 bool
dxil_nir_create_bare_samplers(nir_shader * nir)1454 dxil_nir_create_bare_samplers(nir_shader *nir)
1455 {
1456    struct hash_table_u64 *sampler_to_bare = _mesa_hash_table_u64_create(NULL);
1457 
1458    bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1459       nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, sampler_to_bare);
1460 
1461    _mesa_hash_table_u64_destroy(sampler_to_bare);
1462    return progress;
1463 }
1464 
1465 
1466 static bool
lower_bool_input_filter(const nir_instr * instr,UNUSED const void * _options)1467 lower_bool_input_filter(const nir_instr *instr,
1468                         UNUSED const void *_options)
1469 {
1470    if (instr->type != nir_instr_type_intrinsic)
1471       return false;
1472 
1473    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1474    if (intr->intrinsic == nir_intrinsic_load_front_face)
1475       return true;
1476 
1477    if (intr->intrinsic == nir_intrinsic_load_deref) {
1478       nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1479       nir_variable *var = nir_deref_instr_get_variable(deref);
1480       return var->data.mode == nir_var_shader_in &&
1481              glsl_get_base_type(var->type) == GLSL_TYPE_BOOL;
1482    }
1483 
1484    return false;
1485 }
1486 
1487 static nir_ssa_def *
lower_bool_input_impl(nir_builder * b,nir_instr * instr,UNUSED void * _options)1488 lower_bool_input_impl(nir_builder *b, nir_instr *instr,
1489                       UNUSED void *_options)
1490 {
1491    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1492 
1493    if (intr->intrinsic == nir_intrinsic_load_deref) {
1494       nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1495       nir_variable *var = nir_deref_instr_get_variable(deref);
1496 
1497       /* rewrite var->type */
1498       var->type = glsl_vector_type(GLSL_TYPE_UINT,
1499                                    glsl_get_vector_elements(var->type));
1500       deref->type = var->type;
1501    }
1502 
1503    intr->dest.ssa.bit_size = 32;
1504    return nir_i2b1(b, &intr->dest.ssa);
1505 }
1506 
1507 bool
dxil_nir_lower_bool_input(struct nir_shader * s)1508 dxil_nir_lower_bool_input(struct nir_shader *s)
1509 {
1510    return nir_shader_lower_instructions(s, lower_bool_input_filter,
1511                                         lower_bool_input_impl, NULL);
1512 }
1513 
1514 /* Comparison function to sort io values so that first come normal varyings,
1515  * then system values, and then system generated values.
1516  */
1517 static int
variable_location_cmp(const nir_variable * a,const nir_variable * b)1518 variable_location_cmp(const nir_variable* a, const nir_variable* b)
1519 {
1520    // Sort by driver_location, location, then index
1521    return a->data.driver_location != b->data.driver_location ?
1522             a->data.driver_location - b->data.driver_location :
1523             a->data.location !=  b->data.location ?
1524                a->data.location - b->data.location :
1525                a->data.index - b->data.index;
1526 }
1527 
1528 /* Order varyings according to driver location */
1529 uint64_t
dxil_sort_by_driver_location(nir_shader * s,nir_variable_mode modes)1530 dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1531 {
1532    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1533 
1534    uint64_t result = 0;
1535    nir_foreach_variable_with_modes(var, s, modes) {
1536       result |= 1ull << var->data.location;
1537    }
1538    return result;
1539 }
1540 
1541 /* Sort PS outputs so that color outputs come first */
1542 void
dxil_sort_ps_outputs(nir_shader * s)1543 dxil_sort_ps_outputs(nir_shader* s)
1544 {
1545    nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1546       /* We use the driver_location here to avoid introducing a new
1547        * struct or member variable here. The true, updated driver location
1548        * will be written below, after sorting */
1549       switch (var->data.location) {
1550       case FRAG_RESULT_DEPTH:
1551          var->data.driver_location = 1;
1552          break;
1553       case FRAG_RESULT_STENCIL:
1554          var->data.driver_location = 2;
1555          break;
1556       case FRAG_RESULT_SAMPLE_MASK:
1557          var->data.driver_location = 3;
1558          break;
1559       default:
1560          var->data.driver_location = 0;
1561       }
1562    }
1563 
1564    nir_sort_variables_with_modes(s, variable_location_cmp,
1565                                  nir_var_shader_out);
1566 
1567    unsigned driver_loc = 0;
1568    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1569       var->data.driver_location = driver_loc++;
1570    }
1571 }
1572 
1573 /* Order between stage values so that normal varyings come first,
1574  * then sysvalues and then system generated values.
1575  */
1576 uint64_t
dxil_reassign_driver_locations(nir_shader * s,nir_variable_mode modes,uint64_t other_stage_mask)1577 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1578    uint64_t other_stage_mask)
1579 {
1580    nir_foreach_variable_with_modes_safe(var, s, modes) {
1581       /* We use the driver_location here to avoid introducing a new
1582        * struct or member variable here. The true, updated driver location
1583        * will be written below, after sorting */
1584       var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask);
1585    }
1586 
1587    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1588 
1589    uint64_t result = 0;
1590    unsigned driver_loc = 0;
1591    nir_foreach_variable_with_modes(var, s, modes) {
1592       result |= 1ull << var->data.location;
1593       var->data.driver_location = driver_loc++;
1594    }
1595    return result;
1596 }
1597