1 /*
2  * Copyright © 2021 Google
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "radv_acceleration_structure.h"
25 #include "radv_debug.h"
26 #include "radv_private.h"
27 #include "radv_shader.h"
28 
29 #include "nir/nir.h"
30 #include "nir/nir_builder.h"
31 #include "nir/nir_builtin_builder.h"
32 
33 static VkRayTracingPipelineCreateInfoKHR
radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR * pCreateInfo)34 radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
35 {
36    VkRayTracingPipelineCreateInfoKHR local_create_info = *pCreateInfo;
37    uint32_t total_stages = pCreateInfo->stageCount;
38    uint32_t total_groups = pCreateInfo->groupCount;
39 
40    if (pCreateInfo->pLibraryInfo) {
41       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
42          RADV_FROM_HANDLE(radv_pipeline, library, pCreateInfo->pLibraryInfo->pLibraries[i]);
43          total_stages += library->library.stage_count;
44          total_groups += library->library.group_count;
45       }
46    }
47    VkPipelineShaderStageCreateInfo *stages = NULL;
48    VkRayTracingShaderGroupCreateInfoKHR *groups = NULL;
49    local_create_info.stageCount = total_stages;
50    local_create_info.groupCount = total_groups;
51    local_create_info.pStages = stages =
52       malloc(sizeof(VkPipelineShaderStageCreateInfo) * total_stages);
53    local_create_info.pGroups = groups =
54       malloc(sizeof(VkRayTracingShaderGroupCreateInfoKHR) * total_groups);
55    if (!local_create_info.pStages || !local_create_info.pGroups)
56       return local_create_info;
57 
58    total_stages = pCreateInfo->stageCount;
59    total_groups = pCreateInfo->groupCount;
60    for (unsigned j = 0; j < pCreateInfo->stageCount; ++j)
61       stages[j] = pCreateInfo->pStages[j];
62    for (unsigned j = 0; j < pCreateInfo->groupCount; ++j)
63       groups[j] = pCreateInfo->pGroups[j];
64 
65    if (pCreateInfo->pLibraryInfo) {
66       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
67          RADV_FROM_HANDLE(radv_pipeline, library, pCreateInfo->pLibraryInfo->pLibraries[i]);
68          for (unsigned j = 0; j < library->library.stage_count; ++j)
69             stages[total_stages + j] = library->library.stages[j];
70          for (unsigned j = 0; j < library->library.group_count; ++j) {
71             VkRayTracingShaderGroupCreateInfoKHR *dst = &groups[total_groups + j];
72             *dst = library->library.groups[j];
73             if (dst->generalShader != VK_SHADER_UNUSED_KHR)
74                dst->generalShader += total_stages;
75             if (dst->closestHitShader != VK_SHADER_UNUSED_KHR)
76                dst->closestHitShader += total_stages;
77             if (dst->anyHitShader != VK_SHADER_UNUSED_KHR)
78                dst->anyHitShader += total_stages;
79             if (dst->intersectionShader != VK_SHADER_UNUSED_KHR)
80                dst->intersectionShader += total_stages;
81          }
82          total_stages += library->library.stage_count;
83          total_groups += library->library.group_count;
84       }
85    }
86    return local_create_info;
87 }
88 
89 static VkResult
radv_rt_pipeline_library_create(VkDevice _device,VkPipelineCache _cache,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipeline)90 radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache,
91                                 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
92                                 const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
93 {
94    RADV_FROM_HANDLE(radv_device, device, _device);
95    struct radv_pipeline *pipeline;
96 
97    pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*pipeline), 8,
98                          VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
99    if (pipeline == NULL)
100       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
101 
102    vk_object_base_init(&device->vk, &pipeline->base, VK_OBJECT_TYPE_PIPELINE);
103    pipeline->type = RADV_PIPELINE_LIBRARY;
104 
105    VkRayTracingPipelineCreateInfoKHR local_create_info =
106       radv_create_merged_rt_create_info(pCreateInfo);
107    if (!local_create_info.pStages || !local_create_info.pGroups)
108       goto fail;
109 
110    if (local_create_info.stageCount) {
111       size_t size = sizeof(VkPipelineShaderStageCreateInfo) * local_create_info.stageCount;
112       pipeline->library.stage_count = local_create_info.stageCount;
113       pipeline->library.stages = malloc(size);
114       if (!pipeline->library.stages)
115          goto fail;
116       memcpy(pipeline->library.stages, local_create_info.pStages, size);
117    }
118 
119    if (local_create_info.groupCount) {
120       size_t size = sizeof(VkRayTracingShaderGroupCreateInfoKHR) * local_create_info.groupCount;
121       pipeline->library.group_count = local_create_info.groupCount;
122       pipeline->library.groups = malloc(size);
123       if (!pipeline->library.groups)
124          goto fail;
125       memcpy(pipeline->library.groups, local_create_info.pGroups, size);
126    }
127 
128    *pPipeline = radv_pipeline_to_handle(pipeline);
129 
130    free((void *)local_create_info.pGroups);
131    free((void *)local_create_info.pStages);
132    return VK_SUCCESS;
133 fail:
134    free(pipeline->library.groups);
135    free(pipeline->library.stages);
136    free((void *)local_create_info.pGroups);
137    free((void *)local_create_info.pStages);
138    return VK_ERROR_OUT_OF_HOST_MEMORY;
139 }
140 
141 /*
142  * Global variables for an RT pipeline
143  */
144 struct rt_variables {
145    /* idx of the next shader to run in the next iteration of the main loop */
146    nir_variable *idx;
147 
148    /* scratch offset of the argument area relative to stack_ptr */
149    nir_variable *arg;
150 
151    nir_variable *stack_ptr;
152 
153    /* global address of the SBT entry used for the shader */
154    nir_variable *shader_record_ptr;
155 
156    /* trace_ray arguments */
157    nir_variable *accel_struct;
158    nir_variable *flags;
159    nir_variable *cull_mask;
160    nir_variable *sbt_offset;
161    nir_variable *sbt_stride;
162    nir_variable *miss_index;
163    nir_variable *origin;
164    nir_variable *tmin;
165    nir_variable *direction;
166    nir_variable *tmax;
167 
168    /* from the BTAS instance currently being visited */
169    nir_variable *custom_instance_and_mask;
170 
171    /* Properties of the primitive currently being visited. */
172    nir_variable *primitive_id;
173    nir_variable *geometry_id_and_flags;
174    nir_variable *instance_id;
175    nir_variable *instance_addr;
176    nir_variable *hit_kind;
177    nir_variable *opaque;
178 
179    /* Safeguard to ensure we don't end up in an infinite loop of non-existing case. Should not be
180     * needed but is extra anti-hang safety during bring-up. */
181    nir_variable *main_loop_case_visited;
182 
183    /* Output variable for intersection & anyhit shaders. */
184    nir_variable *ahit_status;
185 
186    /* Array of stack size struct for recording the max stack size for each group. */
187    struct radv_pipeline_shader_stack_size *stack_sizes;
188    unsigned group_idx;
189 };
190 
191 static struct rt_variables
create_rt_variables(nir_shader * shader,struct radv_pipeline_shader_stack_size * stack_sizes)192 create_rt_variables(nir_shader *shader, struct radv_pipeline_shader_stack_size *stack_sizes)
193 {
194    struct rt_variables vars = {
195       NULL,
196    };
197    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
198    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
199    vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
200    vars.shader_record_ptr =
201       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
202 
203    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
204    vars.accel_struct =
205       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
206    vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags");
207    vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
208    vars.sbt_offset =
209       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
210    vars.sbt_stride =
211       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
212    vars.miss_index =
213       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
214    vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
215    vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
216    vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
217    vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
218 
219    vars.custom_instance_and_mask = nir_variable_create(
220       shader, nir_var_shader_temp, glsl_uint_type(), "custom_instance_and_mask");
221    vars.primitive_id =
222       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
223    vars.geometry_id_and_flags =
224       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
225    vars.instance_id =
226       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "instance_id");
227    vars.instance_addr =
228       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
229    vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
230    vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
231 
232    vars.main_loop_case_visited =
233       nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "main_loop_case_visited");
234    vars.ahit_status =
235       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ahit_status");
236 
237    vars.stack_sizes = stack_sizes;
238    return vars;
239 }
240 
241 /*
242  * Remap all the variables between the two rt_variables struct for inlining.
243  */
244 static void
map_rt_variables(struct hash_table * var_remap,struct rt_variables * src,const struct rt_variables * dst)245 map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
246                  const struct rt_variables *dst)
247 {
248    _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
249    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
250    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
251    _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
252 
253    _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
254    _mesa_hash_table_insert(var_remap, src->flags, dst->flags);
255    _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask);
256    _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
257    _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
258    _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
259    _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
260    _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
261    _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
262    _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
263 
264    _mesa_hash_table_insert(var_remap, src->custom_instance_and_mask, dst->custom_instance_and_mask);
265    _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
266    _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
267    _mesa_hash_table_insert(var_remap, src->instance_id, dst->instance_id);
268    _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
269    _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
270    _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
271    _mesa_hash_table_insert(var_remap, src->ahit_status, dst->ahit_status);
272 
273    src->stack_sizes = dst->stack_sizes;
274    src->group_idx = dst->group_idx;
275 }
276 
277 /*
278  * Create a copy of the global rt variables where the primitive/instance related variables are
279  * independent.This is needed as we need to keep the old values of the global variables around
280  * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
281  * to the outer variables once we commit to a better hit.
282  */
283 static struct rt_variables
create_inner_vars(nir_builder * b,const struct rt_variables * vars)284 create_inner_vars(nir_builder *b, const struct rt_variables *vars)
285 {
286    struct rt_variables inner_vars = *vars;
287    inner_vars.idx =
288       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
289    inner_vars.shader_record_ptr = nir_variable_create(
290       b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
291    inner_vars.primitive_id =
292       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
293    inner_vars.geometry_id_and_flags = nir_variable_create(
294       b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
295    inner_vars.tmax =
296       nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
297    inner_vars.instance_id =
298       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_instance_id");
299    inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp,
300                                                   glsl_uint64_t_type(), "inner_instance_addr");
301    inner_vars.hit_kind =
302       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
303    inner_vars.custom_instance_and_mask = nir_variable_create(
304       b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_custom_instance_and_mask");
305 
306    return inner_vars;
307 }
308 
309 /* The hit attributes are stored on the stack. This is the offset compared to the current stack
310  * pointer of where the hit attrib is stored. */
311 const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
312 
313 static void
insert_rt_return(nir_builder * b,const struct rt_variables * vars)314 insert_rt_return(nir_builder *b, const struct rt_variables *vars)
315 {
316    nir_store_var(b, vars->stack_ptr,
317                  nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, -16)), 1);
318    nir_store_var(b, vars->idx,
319                  nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1);
320 }
321 
322 enum sbt_type {
323    SBT_RAYGEN,
324    SBT_MISS,
325    SBT_HIT,
326    SBT_CALLABLE,
327 };
328 
329 static nir_ssa_def *
get_sbt_ptr(nir_builder * b,nir_ssa_def * idx,enum sbt_type binding)330 get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
331 {
332    nir_ssa_def *desc = nir_load_sbt_amd(b, 4, .binding = binding);
333    nir_ssa_def *base_addr = nir_pack_64_2x32(b, nir_channels(b, desc, 0x3));
334    nir_ssa_def *stride = nir_channel(b, desc, 2);
335 
336    nir_ssa_def *ret = nir_imul(b, idx, stride);
337    ret = nir_iadd(b, base_addr, nir_u2u64(b, ret));
338 
339    return ret;
340 }
341 
342 static void
load_sbt_entry(nir_builder * b,const struct rt_variables * vars,nir_ssa_def * idx,enum sbt_type binding,unsigned offset)343 load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx,
344                enum sbt_type binding, unsigned offset)
345 {
346    nir_ssa_def *addr = get_sbt_ptr(b, idx, binding);
347 
348    nir_ssa_def *load_addr = addr;
349    if (offset)
350       load_addr = nir_iadd(b, load_addr, nir_imm_int64(b, offset));
351    nir_ssa_def *v_idx =
352       nir_build_load_global(b, 1, 32, load_addr, .align_mul = 4, .align_offset = 0);
353 
354    nir_store_var(b, vars->idx, v_idx, 1);
355 
356    nir_ssa_def *record_addr = nir_iadd(b, addr, nir_imm_int64(b, RADV_RT_HANDLE_SIZE));
357    nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
358 }
359 
360 static nir_ssa_def *
nir_build_vec3_mat_mult(nir_builder * b,nir_ssa_def * vec,nir_ssa_def * matrix[],bool translation)361 nir_build_vec3_mat_mult(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[], bool translation)
362 {
363    nir_ssa_def *result_components[3] = {
364       nir_channel(b, matrix[0], 3),
365       nir_channel(b, matrix[1], 3),
366       nir_channel(b, matrix[2], 3),
367    };
368    for (unsigned i = 0; i < 3; ++i) {
369       for (unsigned j = 0; j < 3; ++j) {
370          nir_ssa_def *v =
371             nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j));
372          result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v;
373       }
374    }
375    return nir_vec(b, result_components, 3);
376 }
377 
378 static nir_ssa_def *
nir_build_vec3_mat_mult_pre(nir_builder * b,nir_ssa_def * vec,nir_ssa_def * matrix[])379 nir_build_vec3_mat_mult_pre(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[])
380 {
381    nir_ssa_def *result_components[3] = {
382       nir_channel(b, matrix[0], 3),
383       nir_channel(b, matrix[1], 3),
384       nir_channel(b, matrix[2], 3),
385    };
386    return nir_build_vec3_mat_mult(b, nir_fsub(b, vec, nir_vec(b, result_components, 3)), matrix,
387                                   false);
388 }
389 
390 static void
nir_build_wto_matrix_load(nir_builder * b,nir_ssa_def * instance_addr,nir_ssa_def ** out)391 nir_build_wto_matrix_load(nir_builder *b, nir_ssa_def *instance_addr, nir_ssa_def **out)
392 {
393    unsigned offset = offsetof(struct radv_bvh_instance_node, wto_matrix);
394    for (unsigned i = 0; i < 3; ++i) {
395       out[i] = nir_build_load_global(b, 4, 32,
396                                      nir_iadd(b, instance_addr, nir_imm_int64(b, offset + i * 16)),
397                                      .align_mul = 64, .align_offset = offset + i * 16);
398    }
399 }
400 
401 /* This lowers all the RT instructions that we do not want to pass on to the combined shader and
402  * that we can implement using the variables from the shader we are going to inline into. */
403 static void
lower_rt_instructions(nir_shader * shader,struct rt_variables * vars,unsigned call_idx_base)404 lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)
405 {
406    nir_builder b_shader;
407    nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader));
408 
409    nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
410       nir_foreach_instr_safe (instr, block) {
411          switch (instr->type) {
412          case nir_instr_type_intrinsic: {
413             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
414             switch (intr->intrinsic) {
415             case nir_intrinsic_rt_execute_callable: {
416                uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
417                uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
418                b_shader.cursor = nir_instr_remove(instr);
419 
420                nir_store_var(&b_shader, vars->stack_ptr,
421                              nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
422                                       nir_imm_int(&b_shader, size)),
423                              1);
424                nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret),
425                                  nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16,
426                                  .write_mask = 1);
427 
428                nir_store_var(&b_shader, vars->stack_ptr,
429                              nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
430                                       nir_imm_int(&b_shader, 16)),
431                              1);
432                load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0);
433 
434                nir_store_var(
435                   &b_shader, vars->arg,
436                   nir_isub(&b_shader, intr->src[1].ssa, nir_imm_int(&b_shader, size + 16)), 1);
437 
438                vars->stack_sizes[vars->group_idx].recursive_size =
439                   MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16);
440                break;
441             }
442             case nir_intrinsic_rt_trace_ray: {
443                uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
444                uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
445                b_shader.cursor = nir_instr_remove(instr);
446 
447                nir_store_var(&b_shader, vars->stack_ptr,
448                              nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
449                                       nir_imm_int(&b_shader, size)),
450                              1);
451                nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret),
452                                  nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16,
453                                  .write_mask = 1);
454 
455                nir_store_var(&b_shader, vars->stack_ptr,
456                              nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
457                                       nir_imm_int(&b_shader, 16)),
458                              1);
459 
460                nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1);
461                nir_store_var(
462                   &b_shader, vars->arg,
463                   nir_isub(&b_shader, intr->src[10].ssa, nir_imm_int(&b_shader, size + 16)), 1);
464 
465                vars->stack_sizes[vars->group_idx].recursive_size =
466                   MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16);
467 
468                /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
469                nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
470                nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1);
471                nir_store_var(&b_shader, vars->cull_mask,
472                              nir_iand(&b_shader, intr->src[2].ssa, nir_imm_int(&b_shader, 0xff)),
473                              0x1);
474                nir_store_var(&b_shader, vars->sbt_offset,
475                              nir_iand(&b_shader, intr->src[3].ssa, nir_imm_int(&b_shader, 0xf)),
476                              0x1);
477                nir_store_var(&b_shader, vars->sbt_stride,
478                              nir_iand(&b_shader, intr->src[4].ssa, nir_imm_int(&b_shader, 0xf)),
479                              0x1);
480                nir_store_var(&b_shader, vars->miss_index,
481                              nir_iand(&b_shader, intr->src[5].ssa, nir_imm_int(&b_shader, 0xffff)),
482                              0x1);
483                nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
484                nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
485                nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
486                nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
487                break;
488             }
489             case nir_intrinsic_rt_resume: {
490                uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
491                b_shader.cursor = nir_instr_remove(instr);
492 
493                nir_store_var(&b_shader, vars->stack_ptr,
494                              nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
495                                       nir_imm_int(&b_shader, -size)),
496                              1);
497                break;
498             }
499             case nir_intrinsic_rt_return_amd: {
500                b_shader.cursor = nir_instr_remove(instr);
501 
502                if (shader->info.stage == MESA_SHADER_RAYGEN) {
503                   nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1);
504                   break;
505                }
506                insert_rt_return(&b_shader, vars);
507                break;
508             }
509             case nir_intrinsic_load_scratch: {
510                b_shader.cursor = nir_before_instr(instr);
511                nir_instr_rewrite_src_ssa(
512                   instr, &intr->src[0],
513                   nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
514                break;
515             }
516             case nir_intrinsic_store_scratch: {
517                b_shader.cursor = nir_before_instr(instr);
518                nir_instr_rewrite_src_ssa(
519                   instr, &intr->src[1],
520                   nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
521                break;
522             }
523             case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
524                b_shader.cursor = nir_instr_remove(instr);
525                nir_ssa_def *ret = nir_load_var(&b_shader, vars->arg);
526                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
527                break;
528             }
529             case nir_intrinsic_load_shader_record_ptr: {
530                b_shader.cursor = nir_instr_remove(instr);
531                nir_ssa_def *ret = nir_load_var(&b_shader, vars->shader_record_ptr);
532                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
533                break;
534             }
535             case nir_intrinsic_load_ray_launch_id: {
536                b_shader.cursor = nir_instr_remove(instr);
537                nir_ssa_def *ret = nir_load_global_invocation_id(&b_shader, 32);
538                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
539                break;
540             }
541             case nir_intrinsic_load_ray_t_min: {
542                b_shader.cursor = nir_instr_remove(instr);
543                nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmin);
544                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
545                break;
546             }
547             case nir_intrinsic_load_ray_t_max: {
548                b_shader.cursor = nir_instr_remove(instr);
549                nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmax);
550                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
551                break;
552             }
553             case nir_intrinsic_load_ray_world_origin: {
554                b_shader.cursor = nir_instr_remove(instr);
555                nir_ssa_def *ret = nir_load_var(&b_shader, vars->origin);
556                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
557                break;
558             }
559             case nir_intrinsic_load_ray_world_direction: {
560                b_shader.cursor = nir_instr_remove(instr);
561                nir_ssa_def *ret = nir_load_var(&b_shader, vars->direction);
562                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
563                break;
564             }
565             case nir_intrinsic_load_ray_instance_custom_index: {
566                b_shader.cursor = nir_instr_remove(instr);
567                nir_ssa_def *ret = nir_load_var(&b_shader, vars->custom_instance_and_mask);
568                ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFF));
569                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
570                break;
571             }
572             case nir_intrinsic_load_primitive_id: {
573                b_shader.cursor = nir_instr_remove(instr);
574                nir_ssa_def *ret = nir_load_var(&b_shader, vars->primitive_id);
575                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
576                break;
577             }
578             case nir_intrinsic_load_ray_geometry_index: {
579                b_shader.cursor = nir_instr_remove(instr);
580                nir_ssa_def *ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
581                ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFFF));
582                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
583                break;
584             }
585             case nir_intrinsic_load_instance_id: {
586                b_shader.cursor = nir_instr_remove(instr);
587                nir_ssa_def *ret = nir_load_var(&b_shader, vars->instance_id);
588                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
589                break;
590             }
591             case nir_intrinsic_load_ray_flags: {
592                b_shader.cursor = nir_instr_remove(instr);
593                nir_ssa_def *ret = nir_load_var(&b_shader, vars->flags);
594                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
595                break;
596             }
597             case nir_intrinsic_load_ray_hit_kind: {
598                b_shader.cursor = nir_instr_remove(instr);
599                nir_ssa_def *ret = nir_load_var(&b_shader, vars->hit_kind);
600                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
601                break;
602             }
603             case nir_intrinsic_load_ray_world_to_object: {
604                unsigned c = nir_intrinsic_column(intr);
605                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
606                nir_ssa_def *wto_matrix[3];
607                nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
608 
609                nir_ssa_def *vals[3];
610                for (unsigned i = 0; i < 3; ++i)
611                   vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
612 
613                nir_ssa_def *val = nir_vec(&b_shader, vals, 3);
614                if (c == 3)
615                   val = nir_fneg(&b_shader,
616                                  nir_build_vec3_mat_mult(&b_shader, val, wto_matrix, false));
617                b_shader.cursor = nir_instr_remove(instr);
618                nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
619                break;
620             }
621             case nir_intrinsic_load_ray_object_to_world: {
622                unsigned c = nir_intrinsic_column(intr);
623                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
624                nir_ssa_def *val;
625                if (c == 3) {
626                   nir_ssa_def *wto_matrix[3];
627                   nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
628 
629                   nir_ssa_def *vals[3];
630                   for (unsigned i = 0; i < 3; ++i)
631                      vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
632 
633                   val = nir_vec(&b_shader, vals, 3);
634                } else {
635                   val = nir_build_load_global(
636                      &b_shader, 3, 32,
637                      nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 92 + c * 12)),
638                      .align_mul = 4, .align_offset = 0);
639                }
640                b_shader.cursor = nir_instr_remove(instr);
641                nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
642                break;
643             }
644             case nir_intrinsic_load_ray_object_origin: {
645                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
646                nir_ssa_def *wto_matrix[] = {
647                   nir_build_load_global(
648                      &b_shader, 4, 32,
649                      nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 16)),
650                      .align_mul = 64, .align_offset = 16),
651                   nir_build_load_global(
652                      &b_shader, 4, 32,
653                      nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 32)),
654                      .align_mul = 64, .align_offset = 32),
655                   nir_build_load_global(
656                      &b_shader, 4, 32,
657                      nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 48)),
658                      .align_mul = 64, .align_offset = 48)};
659                nir_ssa_def *val = nir_build_vec3_mat_mult_pre(
660                   &b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix);
661                b_shader.cursor = nir_instr_remove(instr);
662                nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
663                break;
664             }
665             case nir_intrinsic_load_ray_object_direction: {
666                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
667                nir_ssa_def *wto_matrix[3];
668                nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
669                nir_ssa_def *val = nir_build_vec3_mat_mult(
670                   &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false);
671                b_shader.cursor = nir_instr_remove(instr);
672                nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
673                break;
674             }
675             case nir_intrinsic_load_intersection_opaque_amd: {
676                b_shader.cursor = nir_instr_remove(instr);
677                nir_ssa_def *ret = nir_load_var(&b_shader, vars->opaque);
678                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
679                break;
680             }
681             case nir_intrinsic_ignore_ray_intersection: {
682                b_shader.cursor = nir_instr_remove(instr);
683                nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 1), 1);
684 
685                /* The if is a workaround to avoid having to fix up control flow manually */
686                nir_push_if(&b_shader, nir_imm_true(&b_shader));
687                nir_jump(&b_shader, nir_jump_return);
688                nir_pop_if(&b_shader, NULL);
689                break;
690             }
691             case nir_intrinsic_terminate_ray: {
692                b_shader.cursor = nir_instr_remove(instr);
693                nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 2), 1);
694 
695                /* The if is a workaround to avoid having to fix up control flow manually */
696                nir_push_if(&b_shader, nir_imm_true(&b_shader));
697                nir_jump(&b_shader, nir_jump_return);
698                nir_pop_if(&b_shader, NULL);
699                break;
700             }
701             case nir_intrinsic_report_ray_intersection: {
702                b_shader.cursor = nir_instr_remove(instr);
703                nir_push_if(
704                   &b_shader,
705                   nir_iand(
706                      &b_shader,
707                      nir_flt(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmax)),
708                      nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
709                {
710                   nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 0), 1);
711                   nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
712                   nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
713                }
714                nir_pop_if(&b_shader, NULL);
715                break;
716             }
717             default:
718                break;
719             }
720             break;
721          }
722          case nir_instr_type_jump: {
723             nir_jump_instr *jump = nir_instr_as_jump(instr);
724             if (jump->type == nir_jump_halt) {
725                b_shader.cursor = nir_instr_remove(instr);
726                nir_jump(&b_shader, nir_jump_return);
727             }
728             break;
729          }
730          default:
731             break;
732          }
733       }
734    }
735 
736    nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
737 }
738 
739 static void
insert_rt_case(nir_builder * b,nir_shader * shader,const struct rt_variables * vars,nir_ssa_def * idx,uint32_t call_idx_base,uint32_t call_idx)740 insert_rt_case(nir_builder *b, nir_shader *shader, const struct rt_variables *vars,
741                nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx)
742 {
743    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
744 
745    nir_opt_dead_cf(shader);
746 
747    struct rt_variables src_vars = create_rt_variables(shader, vars->stack_sizes);
748    map_rt_variables(var_remap, &src_vars, vars);
749 
750    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
751 
752    NIR_PASS_V(shader, nir_opt_remove_phis);
753    NIR_PASS_V(shader, nir_lower_returns);
754    NIR_PASS_V(shader, nir_opt_dce);
755 
756    if (b->shader->info.stage == MESA_SHADER_ANY_HIT ||
757        b->shader->info.stage == MESA_SHADER_INTERSECTION) {
758       src_vars.stack_sizes[src_vars.group_idx].non_recursive_size =
759          MAX2(src_vars.stack_sizes[src_vars.group_idx].non_recursive_size, shader->scratch_size);
760    } else {
761       src_vars.stack_sizes[src_vars.group_idx].recursive_size =
762          MAX2(src_vars.stack_sizes[src_vars.group_idx].recursive_size, shader->scratch_size);
763    }
764 
765    nir_push_if(b, nir_ieq(b, idx, nir_imm_int(b, call_idx)));
766    nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
767    nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
768    nir_pop_if(b, NULL);
769 
770    /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
771    ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
772 
773    ralloc_free(var_remap);
774 }
775 
776 static bool
lower_rt_derefs(nir_shader * shader)777 lower_rt_derefs(nir_shader *shader)
778 {
779    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
780 
781    bool progress = false;
782 
783    nir_builder b;
784    nir_builder_init(&b, impl);
785 
786    b.cursor = nir_before_cf_list(&impl->body);
787    nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
788 
789    nir_foreach_block (block, impl) {
790       nir_foreach_instr_safe (instr, block) {
791          switch (instr->type) {
792          case nir_instr_type_deref: {
793             if (instr->type != nir_instr_type_deref)
794                continue;
795 
796             nir_deref_instr *deref = nir_instr_as_deref(instr);
797             if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
798                deref->modes = nir_var_function_temp;
799                if (deref->deref_type == nir_deref_type_var) {
800                   b.cursor = nir_before_instr(&deref->instr);
801                   nir_deref_instr *cast = nir_build_deref_cast(
802                      &b, arg_offset, nir_var_function_temp, deref->var->type, 0);
803                   nir_ssa_def_rewrite_uses(&deref->dest.ssa, &cast->dest.ssa);
804                   nir_instr_remove(&deref->instr);
805                }
806                progress = true;
807             } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
808                deref->modes = nir_var_function_temp;
809                if (deref->deref_type == nir_deref_type_var) {
810                   b.cursor = nir_before_instr(&deref->instr);
811                   nir_deref_instr *cast =
812                      nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
813                                           nir_var_function_temp, deref->type, 0);
814                   nir_ssa_def_rewrite_uses(&deref->dest.ssa, &cast->dest.ssa);
815                   nir_instr_remove(&deref->instr);
816                }
817                progress = true;
818             }
819             break;
820          }
821          default:
822             break;
823          }
824       }
825    }
826 
827    if (progress) {
828       nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
829    } else {
830       nir_metadata_preserve(impl, nir_metadata_all);
831    }
832 
833    return progress;
834 }
835 
836 static gl_shader_stage
convert_rt_stage(VkShaderStageFlagBits vk_stage)837 convert_rt_stage(VkShaderStageFlagBits vk_stage)
838 {
839    switch (vk_stage) {
840    case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
841       return MESA_SHADER_RAYGEN;
842    case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
843       return MESA_SHADER_ANY_HIT;
844    case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
845       return MESA_SHADER_CLOSEST_HIT;
846    case VK_SHADER_STAGE_MISS_BIT_KHR:
847       return MESA_SHADER_MISS;
848    case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
849       return MESA_SHADER_INTERSECTION;
850    case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
851       return MESA_SHADER_CALLABLE;
852    default:
853       unreachable("Unhandled RT stage");
854    }
855 }
856 
857 static nir_shader *
parse_rt_stage(struct radv_device * device,struct radv_pipeline_layout * layout,const VkPipelineShaderStageCreateInfo * stage)858 parse_rt_stage(struct radv_device *device, struct radv_pipeline_layout *layout,
859                const VkPipelineShaderStageCreateInfo *stage)
860 {
861    struct radv_pipeline_key key;
862    memset(&key, 0, sizeof(key));
863 
864    nir_shader *shader = radv_shader_compile_to_nir(
865       device, vk_shader_module_from_handle(stage->module), stage->pName,
866       convert_rt_stage(stage->stage), stage->pSpecializationInfo, layout, &key);
867 
868    if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
869        shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) {
870       nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader));
871       nir_builder b_inner;
872       nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader));
873       b_inner.cursor = nir_after_block(last_block);
874       nir_rt_return_amd(&b_inner);
875    }
876 
877    NIR_PASS_V(shader, nir_lower_vars_to_explicit_types,
878               nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
879               glsl_get_natural_size_align_bytes);
880 
881    NIR_PASS_V(shader, lower_rt_derefs);
882 
883    NIR_PASS_V(shader, nir_lower_explicit_io, nir_var_function_temp,
884               nir_address_format_32bit_offset);
885 
886    return shader;
887 }
888 
889 static nir_function_impl *
lower_any_hit_for_intersection(nir_shader * any_hit)890 lower_any_hit_for_intersection(nir_shader *any_hit)
891 {
892    nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
893 
894    /* Any-hit shaders need three parameters */
895    assert(impl->function->num_params == 0);
896    nir_parameter params[] = {
897       {
898          /* A pointer to a boolean value for whether or not the hit was
899           * accepted.
900           */
901          .num_components = 1,
902          .bit_size = 32,
903       },
904       {
905          /* The hit T value */
906          .num_components = 1,
907          .bit_size = 32,
908       },
909       {
910          /* The hit kind */
911          .num_components = 1,
912          .bit_size = 32,
913       },
914    };
915    impl->function->num_params = ARRAY_SIZE(params);
916    impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
917    memcpy(impl->function->params, params, sizeof(params));
918 
919    nir_builder build;
920    nir_builder_init(&build, impl);
921    nir_builder *b = &build;
922 
923    b->cursor = nir_before_cf_list(&impl->body);
924 
925    nir_ssa_def *commit_ptr = nir_load_param(b, 0);
926    nir_ssa_def *hit_t = nir_load_param(b, 1);
927    nir_ssa_def *hit_kind = nir_load_param(b, 2);
928 
929    nir_deref_instr *commit =
930       nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
931 
932    nir_foreach_block_safe (block, impl) {
933       nir_foreach_instr_safe (instr, block) {
934          switch (instr->type) {
935          case nir_instr_type_intrinsic: {
936             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
937             switch (intrin->intrinsic) {
938             case nir_intrinsic_ignore_ray_intersection:
939                b->cursor = nir_instr_remove(&intrin->instr);
940                /* We put the newly emitted code inside a dummy if because it's
941                 * going to contain a jump instruction and we don't want to
942                 * deal with that mess here.  It'll get dealt with by our
943                 * control-flow optimization passes.
944                 */
945                nir_store_deref(b, commit, nir_imm_false(b), 0x1);
946                nir_push_if(b, nir_imm_true(b));
947                nir_jump(b, nir_jump_halt);
948                nir_pop_if(b, NULL);
949                break;
950 
951             case nir_intrinsic_terminate_ray:
952                /* The "normal" handling of terminateRay works fine in
953                 * intersection shaders.
954                 */
955                break;
956 
957             case nir_intrinsic_load_ray_t_max:
958                nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t);
959                nir_instr_remove(&intrin->instr);
960                break;
961 
962             case nir_intrinsic_load_ray_hit_kind:
963                nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind);
964                nir_instr_remove(&intrin->instr);
965                break;
966 
967             default:
968                break;
969             }
970             break;
971          }
972          case nir_instr_type_jump: {
973             nir_jump_instr *jump = nir_instr_as_jump(instr);
974             if (jump->type == nir_jump_halt) {
975                b->cursor = nir_instr_remove(instr);
976                nir_jump(b, nir_jump_return);
977             }
978             break;
979          }
980 
981          default:
982             break;
983          }
984       }
985    }
986 
987    nir_validate_shader(any_hit, "after initial any-hit lowering");
988 
989    nir_lower_returns_impl(impl);
990 
991    nir_validate_shader(any_hit, "after lowering returns");
992 
993    return impl;
994 }
995 
996 /* Inline the any_hit shader into the intersection shader so we don't have
997  * to implement yet another shader call interface here. Neither do any recursion.
998  */
999 static void
nir_lower_intersection_shader(nir_shader * intersection,nir_shader * any_hit)1000 nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
1001 {
1002    void *dead_ctx = ralloc_context(intersection);
1003 
1004    nir_function_impl *any_hit_impl = NULL;
1005    struct hash_table *any_hit_var_remap = NULL;
1006    if (any_hit) {
1007       any_hit = nir_shader_clone(dead_ctx, any_hit);
1008       NIR_PASS_V(any_hit, nir_opt_dce);
1009       any_hit_impl = lower_any_hit_for_intersection(any_hit);
1010       any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
1011    }
1012 
1013    nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
1014 
1015    nir_builder build;
1016    nir_builder_init(&build, impl);
1017    nir_builder *b = &build;
1018 
1019    b->cursor = nir_before_cf_list(&impl->body);
1020 
1021    nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
1022    nir_store_var(b, commit, nir_imm_false(b), 0x1);
1023 
1024    nir_foreach_block_safe (block, impl) {
1025       nir_foreach_instr_safe (instr, block) {
1026          if (instr->type != nir_instr_type_intrinsic)
1027             continue;
1028 
1029          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1030          if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
1031             continue;
1032 
1033          b->cursor = nir_instr_remove(&intrin->instr);
1034          nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1);
1035          nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1);
1036          nir_ssa_def *min_t = nir_load_ray_t_min(b);
1037          nir_ssa_def *max_t = nir_load_ray_t_max(b);
1038 
1039          /* bool commit_tmp = false; */
1040          nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
1041          nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
1042 
1043          nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
1044          {
1045             /* Any-hit defaults to commit */
1046             nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
1047 
1048             if (any_hit_impl != NULL) {
1049                nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
1050                {
1051                   nir_ssa_def *params[] = {
1052                      &nir_build_deref_var(b, commit_tmp)->dest.ssa,
1053                      hit_t,
1054                      hit_kind,
1055                   };
1056                   nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
1057                }
1058                nir_pop_if(b, NULL);
1059             }
1060 
1061             nir_push_if(b, nir_load_var(b, commit_tmp));
1062             {
1063                nir_report_ray_intersection(b, 1, hit_t, hit_kind);
1064             }
1065             nir_pop_if(b, NULL);
1066          }
1067          nir_pop_if(b, NULL);
1068 
1069          nir_ssa_def *accepted = nir_load_var(b, commit_tmp);
1070          nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted);
1071       }
1072    }
1073 
1074    /* We did some inlining; have to re-index SSA defs */
1075    nir_index_ssa_defs(impl);
1076 
1077    /* Eliminate the casts introduced for the commit return of the any-hit shader. */
1078    NIR_PASS_V(intersection, nir_opt_deref);
1079 
1080    ralloc_free(dead_ctx);
1081 }
1082 
1083 /* Variables only used internally to ray traversal. This is data that describes
1084  * the current state of the traversal vs. what we'd give to a shader.  e.g. what
1085  * is the instance we're currently visiting vs. what is the instance of the
1086  * closest hit. */
1087 struct rt_traversal_vars {
1088    nir_variable *origin;
1089    nir_variable *dir;
1090    nir_variable *inv_dir;
1091    nir_variable *sbt_offset_and_flags;
1092    nir_variable *instance_id;
1093    nir_variable *custom_instance_and_mask;
1094    nir_variable *instance_addr;
1095    nir_variable *should_return;
1096    nir_variable *bvh_base;
1097    nir_variable *stack;
1098    nir_variable *top_stack;
1099 };
1100 
1101 static struct rt_traversal_vars
init_traversal_vars(nir_builder * b)1102 init_traversal_vars(nir_builder *b)
1103 {
1104    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1105    struct rt_traversal_vars ret;
1106 
1107    ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
1108    ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
1109    ret.inv_dir =
1110       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
1111    ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1112                                                   "traversal_sbt_offset_and_flags");
1113    ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1114                                          "traversal_instance_id");
1115    ret.custom_instance_and_mask = nir_variable_create(
1116       b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask");
1117    ret.instance_addr =
1118       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1119    ret.should_return = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(),
1120                                            "traversal_should_return");
1121    ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(),
1122                                       "traversal_bvh_base");
1123    ret.stack =
1124       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
1125    ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1126                                        "traversal_top_stack_ptr");
1127    return ret;
1128 }
1129 
1130 static nir_ssa_def *
build_addr_to_node(nir_builder * b,nir_ssa_def * addr)1131 build_addr_to_node(nir_builder *b, nir_ssa_def *addr)
1132 {
1133    const uint64_t bvh_size = 1ull << 42;
1134    nir_ssa_def *node = nir_ushr(b, addr, nir_imm_int(b, 3));
1135    return nir_iand(b, node, nir_imm_int64(b, (bvh_size - 1) << 3));
1136 }
1137 
1138 static nir_ssa_def *
build_node_to_addr(struct radv_device * device,nir_builder * b,nir_ssa_def * node)1139 build_node_to_addr(struct radv_device *device, nir_builder *b, nir_ssa_def *node)
1140 {
1141    nir_ssa_def *addr = nir_iand(b, node, nir_imm_int64(b, ~7ull));
1142    addr = nir_ishl(b, addr, nir_imm_int(b, 3));
1143    /* Assumes everything is in the top half of address space, which is true in
1144     * GFX9+ for now. */
1145    return device->physical_device->rad_info.chip_class >= GFX9
1146       ? nir_ior(b, addr, nir_imm_int64(b, 0xffffull << 48))
1147       : addr;
1148 }
1149 
1150 /* When a hit is opaque the any_hit shader is skipped for this hit and the hit
1151  * is assumed to be an actual hit. */
1152 static nir_ssa_def *
hit_is_opaque(nir_builder * b,const struct rt_variables * vars,const struct rt_traversal_vars * trav_vars,nir_ssa_def * geometry_id_and_flags)1153 hit_is_opaque(nir_builder *b, const struct rt_variables *vars,
1154               const struct rt_traversal_vars *trav_vars, nir_ssa_def *geometry_id_and_flags)
1155 {
1156    nir_ssa_def *geom_force_opaque = nir_ine(
1157       b, nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 1u << 28 /* VK_GEOMETRY_OPAQUE_BIT */)),
1158       nir_imm_int(b, 0));
1159    nir_ssa_def *instance_force_opaque =
1160       nir_ine(b,
1161               nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1162                        nir_imm_int(b, 4 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT */)),
1163               nir_imm_int(b, 0));
1164    nir_ssa_def *instance_force_non_opaque =
1165       nir_ine(b,
1166               nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1167                        nir_imm_int(b, 8 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT */)),
1168               nir_imm_int(b, 0));
1169 
1170    nir_ssa_def *opaque = geom_force_opaque;
1171    opaque = nir_bcsel(b, instance_force_opaque, nir_imm_bool(b, true), opaque);
1172    opaque = nir_bcsel(b, instance_force_non_opaque, nir_imm_bool(b, false), opaque);
1173 
1174    nir_ssa_def *ray_force_opaque =
1175       nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 1 /* RayFlagsOpaque */)),
1176               nir_imm_int(b, 0));
1177    nir_ssa_def *ray_force_non_opaque = nir_ine(
1178       b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 2 /* RayFlagsNoOpaque */)),
1179       nir_imm_int(b, 0));
1180 
1181    opaque = nir_bcsel(b, ray_force_opaque, nir_imm_bool(b, true), opaque);
1182    opaque = nir_bcsel(b, ray_force_non_opaque, nir_imm_bool(b, false), opaque);
1183    return opaque;
1184 }
1185 
1186 static void
visit_any_hit_shaders(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,nir_builder * b,struct rt_variables * vars)1187 visit_any_hit_shaders(struct radv_device *device,
1188                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1189                       struct rt_variables *vars)
1190 {
1191    RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
1192    nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
1193 
1194    nir_push_if(b, nir_ine(b, sbt_idx, nir_imm_int(b, 0)));
1195    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1196       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1197       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1198 
1199       switch (group_info->type) {
1200       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1201          shader_id = group_info->anyHitShader;
1202          break;
1203       default:
1204          break;
1205       }
1206       if (shader_id == VK_SHADER_UNUSED_KHR)
1207          continue;
1208 
1209       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1210       nir_shader *nir_stage = parse_rt_stage(device, layout, stage);
1211 
1212       vars->group_idx = i;
1213       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
1214    }
1215    nir_pop_if(b, NULL);
1216 }
1217 
1218 static void
insert_traversal_triangle_case(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,nir_builder * b,nir_ssa_def * result,const struct rt_variables * vars,const struct rt_traversal_vars * trav_vars,nir_ssa_def * bvh_node)1219 insert_traversal_triangle_case(struct radv_device *device,
1220                                const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1221                                nir_ssa_def *result, const struct rt_variables *vars,
1222                                const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
1223 {
1224    nir_ssa_def *dist = nir_vector_extract(b, result, nir_imm_int(b, 0));
1225    nir_ssa_def *div = nir_vector_extract(b, result, nir_imm_int(b, 1));
1226    dist = nir_fdiv(b, dist, div);
1227    nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div);
1228    nir_ssa_def *switch_ccw = nir_ine(
1229       b,
1230       nir_iand(
1231          b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1232          nir_imm_int(b, 2 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FRONT_COUNTERCLOCKWISE_BIT */)),
1233       nir_imm_int(b, 0));
1234    frontface = nir_ixor(b, frontface, switch_ccw);
1235 
1236    nir_ssa_def *not_cull = nir_ieq(
1237       b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 256 /* RayFlagsSkipTriangles */)),
1238       nir_imm_int(b, 0));
1239    nir_ssa_def *not_facing_cull = nir_ieq(
1240       b,
1241       nir_iand(b, nir_load_var(b, vars->flags),
1242                nir_bcsel(b, frontface, nir_imm_int(b, 32 /* RayFlagsCullFrontFacingTriangles */),
1243                          nir_imm_int(b, 16 /* RayFlagsCullBackFacingTriangles */))),
1244       nir_imm_int(b, 0));
1245 
1246    not_cull = nir_iand(
1247       b, not_cull,
1248       nir_ior(
1249          b, not_facing_cull,
1250          nir_ine(
1251             b,
1252             nir_iand(
1253                b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1254                nir_imm_int(b, 1 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT */)),
1255             nir_imm_int(b, 0))));
1256 
1257    nir_push_if(b, nir_iand(b,
1258                            nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)),
1259                                     nir_fge(b, dist, nir_load_var(b, vars->tmin))),
1260                            not_cull));
1261    {
1262 
1263       nir_ssa_def *triangle_info = nir_build_load_global(
1264          b, 2, 32,
1265          nir_iadd(b, build_node_to_addr(device, b, bvh_node),
1266                   nir_imm_int64(b, offsetof(struct radv_bvh_triangle_node, triangle_id))),
1267          .align_mul = 4, .align_offset = 0);
1268       nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
1269       nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
1270       nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff));
1271       nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags);
1272 
1273       not_cull =
1274          nir_ieq(b,
1275                  nir_iand(b, nir_load_var(b, vars->flags),
1276                           nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))),
1277                  nir_imm_int(b, 0));
1278       nir_push_if(b, not_cull);
1279       {
1280          nir_ssa_def *sbt_idx =
1281             nir_iadd(b,
1282                      nir_iadd(b, nir_load_var(b, vars->sbt_offset),
1283                               nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1284                                        nir_imm_int(b, 0xffffff))),
1285                      nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
1286          nir_ssa_def *divs[2] = {div, div};
1287          nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2));
1288          nir_ssa_def *hit_kind =
1289             nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
1290 
1291          nir_store_scratch(
1292             b, ij,
1293             nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, RADV_HIT_ATTRIB_OFFSET)),
1294             .align_mul = 16, .write_mask = 3);
1295 
1296          nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1);
1297 
1298          nir_push_if(b, nir_ine(b, is_opaque, nir_imm_bool(b, true)));
1299          {
1300             struct rt_variables inner_vars = create_inner_vars(b, vars);
1301 
1302             nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
1303             nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
1304             nir_store_var(b, inner_vars.tmax, dist, 0x1);
1305             nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1306             nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr),
1307                           0x1);
1308             nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
1309             nir_store_var(b, inner_vars.custom_instance_and_mask,
1310                           nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1311 
1312             load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
1313 
1314             visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars);
1315 
1316             nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1)));
1317             {
1318                nir_jump(b, nir_jump_continue);
1319             }
1320             nir_pop_if(b, NULL);
1321          }
1322          nir_pop_if(b, NULL);
1323 
1324          nir_store_var(b, vars->primitive_id, primitive_id, 1);
1325          nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
1326          nir_store_var(b, vars->tmax, dist, 0x1);
1327          nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1328          nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1329          nir_store_var(b, vars->hit_kind, hit_kind, 0x1);
1330          nir_store_var(b, vars->custom_instance_and_mask,
1331                        nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1332 
1333          load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0);
1334 
1335          nir_store_var(b, trav_vars->should_return,
1336                        nir_ior(b,
1337                                nir_ine(b,
1338                                        nir_iand(b, nir_load_var(b, vars->flags),
1339                                                 nir_imm_int(b, 8 /* SkipClosestHitShader */)),
1340                                        nir_imm_int(b, 0)),
1341                                nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))),
1342                        1);
1343 
1344          nir_ssa_def *terminate_on_first_hit =
1345             nir_ine(b,
1346                     nir_iand(b, nir_load_var(b, vars->flags),
1347                              nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)),
1348                     nir_imm_int(b, 0));
1349          nir_ssa_def *ray_terminated =
1350             nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2));
1351          nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1352          {
1353             nir_jump(b, nir_jump_break);
1354          }
1355          nir_pop_if(b, NULL);
1356       }
1357       nir_pop_if(b, NULL);
1358    }
1359    nir_pop_if(b, NULL);
1360 }
1361 
1362 static void
insert_traversal_aabb_case(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,nir_builder * b,const struct rt_variables * vars,const struct rt_traversal_vars * trav_vars,nir_ssa_def * bvh_node)1363 insert_traversal_aabb_case(struct radv_device *device,
1364                            const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1365                            const struct rt_variables *vars,
1366                            const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
1367 {
1368    RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
1369 
1370    nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node);
1371    nir_ssa_def *triangle_info = nir_build_load_global(
1372       b, 2, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 24)), .align_mul = 4, .align_offset = 0);
1373    nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
1374    nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
1375    nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff));
1376    nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags);
1377 
1378    nir_ssa_def *not_cull =
1379       nir_ieq(b,
1380               nir_iand(b, nir_load_var(b, vars->flags),
1381                        nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))),
1382               nir_imm_int(b, 0));
1383    nir_push_if(b, not_cull);
1384    {
1385       nir_ssa_def *sbt_idx =
1386          nir_iadd(b,
1387                   nir_iadd(b, nir_load_var(b, vars->sbt_offset),
1388                            nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1389                                     nir_imm_int(b, 0xffffff))),
1390                   nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
1391 
1392       struct rt_variables inner_vars = create_inner_vars(b, vars);
1393 
1394       /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
1395        * next closest hit candidate. */
1396       inner_vars.hit_kind = vars->hit_kind;
1397 
1398       nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
1399       nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
1400       nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1);
1401       nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1402       nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1403       nir_store_var(b, inner_vars.custom_instance_and_mask,
1404                     nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1405       nir_store_var(b, inner_vars.opaque, is_opaque, 1);
1406 
1407       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
1408 
1409       nir_store_var(b, vars->ahit_status, nir_imm_int(b, 1), 1);
1410 
1411       nir_push_if(b, nir_ine(b, nir_load_var(b, inner_vars.idx), nir_imm_int(b, 0)));
1412       for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1413          const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1414          uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1415          uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
1416 
1417          switch (group_info->type) {
1418          case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1419             shader_id = group_info->intersectionShader;
1420             any_hit_shader_id = group_info->anyHitShader;
1421             break;
1422          default:
1423             break;
1424          }
1425          if (shader_id == VK_SHADER_UNUSED_KHR)
1426             continue;
1427 
1428          const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1429          nir_shader *nir_stage = parse_rt_stage(device, layout, stage);
1430 
1431          nir_shader *any_hit_stage = NULL;
1432          if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
1433             stage = &pCreateInfo->pStages[any_hit_shader_id];
1434             any_hit_stage = parse_rt_stage(device, layout, stage);
1435 
1436             nir_lower_intersection_shader(nir_stage, any_hit_stage);
1437             ralloc_free(any_hit_stage);
1438          }
1439 
1440          inner_vars.group_idx = i;
1441          insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
1442       }
1443       nir_push_else(b, NULL);
1444       {
1445          nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7);
1446          nir_ssa_def *vec3_inf =
1447             nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7);
1448 
1449          nir_ssa_def *bvh_lo =
1450             nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 0)),
1451                                   .align_mul = 4, .align_offset = 0);
1452          nir_ssa_def *bvh_hi =
1453             nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 12)),
1454                                   .align_mul = 4, .align_offset = 0);
1455 
1456          bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin));
1457          bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin));
1458          nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
1459                                        nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
1460          nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
1461                                         nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
1462          /* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */
1463          t2_vec =
1464             nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec);
1465 
1466          nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1));
1467          t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2));
1468 
1469          nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1));
1470          t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2));
1471 
1472          nir_push_if(b, nir_iand(b, nir_flt(b, t_min, nir_load_var(b, vars->tmax)),
1473                                  nir_fge(b, t_max, nir_load_var(b, vars->tmin))));
1474          {
1475             nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1);
1476             nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1);
1477          }
1478          nir_pop_if(b, NULL);
1479       }
1480       nir_pop_if(b, NULL);
1481 
1482       nir_push_if(b, nir_ine(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1)));
1483       {
1484          nir_store_var(b, vars->primitive_id, primitive_id, 1);
1485          nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
1486          nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
1487          nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1488          nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1489          nir_store_var(b, vars->custom_instance_and_mask,
1490                        nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1491 
1492          load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0);
1493 
1494          nir_store_var(b, trav_vars->should_return,
1495                        nir_ior(b,
1496                                nir_ine(b,
1497                                        nir_iand(b, nir_load_var(b, vars->flags),
1498                                                 nir_imm_int(b, 8 /* SkipClosestHitShader */)),
1499                                        nir_imm_int(b, 0)),
1500                                nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))),
1501                        1);
1502 
1503          nir_ssa_def *terminate_on_first_hit =
1504             nir_ine(b,
1505                     nir_iand(b, nir_load_var(b, vars->flags),
1506                              nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)),
1507                     nir_imm_int(b, 0));
1508          nir_ssa_def *ray_terminated =
1509             nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2));
1510          nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1511          {
1512             nir_jump(b, nir_jump_break);
1513          }
1514          nir_pop_if(b, NULL);
1515       }
1516       nir_pop_if(b, NULL);
1517    }
1518    nir_pop_if(b, NULL);
1519 }
1520 
1521 static void
nir_sort_hit_pair(nir_builder * b,nir_variable * var_distances,nir_variable * var_indices,uint32_t chan_1,uint32_t chan_2)1522 nir_sort_hit_pair(nir_builder *b, nir_variable *var_distances, nir_variable *var_indices, uint32_t chan_1, uint32_t chan_2)
1523 {
1524    nir_ssa_def *ssa_distances = nir_load_var(b, var_distances);
1525    nir_ssa_def *ssa_indices = nir_load_var(b, var_indices);
1526    /* if (distances[chan_2] < distances[chan_1]) { */
1527    nir_push_if(b, nir_flt(b, nir_channel(b, ssa_distances, chan_2), nir_channel(b, ssa_distances, chan_1)));
1528    {
1529       /* swap(distances[chan_2], distances[chan_1]); */
1530       nir_ssa_def *new_distances[4] = {nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32)};
1531       nir_ssa_def *new_indices[4]   = {nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32)};
1532       new_distances[chan_2] = nir_channel(b, ssa_distances, chan_1);
1533       new_distances[chan_1] = nir_channel(b, ssa_distances, chan_2);
1534       new_indices[chan_2] = nir_channel(b, ssa_indices, chan_1);
1535       new_indices[chan_1] = nir_channel(b, ssa_indices, chan_2);
1536       nir_store_var(b, var_distances, nir_vec(b, new_distances, 4), (1u << chan_1) | (1u << chan_2));
1537       nir_store_var(b, var_indices, nir_vec(b, new_indices, 4), (1u << chan_1) | (1u << chan_2));
1538    }
1539    /* } */
1540    nir_pop_if(b, NULL);
1541 }
1542 
1543 static nir_ssa_def *
intersect_ray_amd_software_box(struct radv_device * device,nir_builder * b,nir_ssa_def * bvh_node,nir_ssa_def * ray_tmax,nir_ssa_def * origin,nir_ssa_def * dir,nir_ssa_def * inv_dir)1544 intersect_ray_amd_software_box(struct radv_device *device,
1545                                nir_builder *b, nir_ssa_def *bvh_node,
1546                                nir_ssa_def *ray_tmax, nir_ssa_def *origin,
1547                                nir_ssa_def *dir, nir_ssa_def *inv_dir)
1548 {
1549    const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4);
1550    const struct glsl_type *uvec4_type = glsl_vector_type(GLSL_TYPE_UINT, 4);
1551 
1552    nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node);
1553 
1554    /* vec4 distances = vec4(INF, INF, INF, INF); */
1555    nir_variable *distances = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "distances");
1556    nir_store_var(b, distances, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, INFINITY), 0xf);
1557 
1558    /* uvec4 child_indices = uvec4(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff); */
1559    nir_variable *child_indices = nir_variable_create(b->shader, nir_var_shader_temp, uvec4_type, "child_indices");
1560    nir_store_var(b, child_indices, nir_imm_ivec4(b, 0xffffffffu, 0xffffffffu, 0xffffffffu, 0xffffffffu), 0xf);
1561 
1562    /* Need to remove infinities here because otherwise we get nasty NaN propogation
1563     * if the direction has 0s in it. */
1564    /* inv_dir = clamp(inv_dir, -FLT_MAX, FLT_MAX); */
1565    inv_dir = nir_fclamp(b, inv_dir, nir_imm_float(b, -FLT_MAX), nir_imm_float(b, FLT_MAX));
1566 
1567    for (int i = 0; i < 4; i++) {
1568       const uint32_t child_offset  = offsetof(struct radv_bvh_box32_node, children[i]);
1569       const uint32_t coord_offsets[2] = {
1570          offsetof(struct radv_bvh_box32_node, coords[i][0][0]),
1571          offsetof(struct radv_bvh_box32_node, coords[i][1][0]),
1572       };
1573 
1574       /* node->children[i] -> uint */
1575       nir_ssa_def *child_index = nir_build_load_global(b, 1, 32, nir_iadd(b, node_addr, nir_imm_int64(b, child_offset)),  .align_mul = 64, .align_offset = child_offset  % 64 );
1576       /* node->coords[i][0], node->coords[i][1] -> vec3 */
1577       nir_ssa_def *node_coords[2] = {
1578          nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[0])), .align_mul = 64, .align_offset = coord_offsets[0] % 64 ),
1579          nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[1])), .align_mul = 64, .align_offset = coord_offsets[1] % 64 ),
1580       };
1581 
1582       /* If x of the aabb min is NaN, then this is an inactive aabb.
1583        * We don't need to care about any other components being NaN as that is UB.
1584        * https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap36.html#VkAabbPositionsKHR */
1585       nir_ssa_def *min_x = nir_channel(b, node_coords[0], 0);
1586       nir_ssa_def *min_x_is_not_nan = nir_inot(b, nir_fneu(b, min_x, min_x)); /* NaN != NaN -> true */
1587 
1588       /* vec3 bound0 = (node->coords[i][0] - origin) * inv_dir; */
1589       nir_ssa_def *bound0 = nir_fmul(b, nir_fsub(b, node_coords[0], origin), inv_dir);
1590       /* vec3 bound1 = (node->coords[i][1] - origin) * inv_dir; */
1591       nir_ssa_def *bound1 = nir_fmul(b, nir_fsub(b, node_coords[1], origin), inv_dir);
1592 
1593       /* float tmin = max(max(min(bound0.x, bound1.x), min(bound0.y, bound1.y)), min(bound0.z, bound1.z)); */
1594       nir_ssa_def *tmin = nir_fmax(b, nir_fmax(b,
1595          nir_fmin(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)),
1596          nir_fmin(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))),
1597          nir_fmin(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2)));
1598 
1599       /* float tmax = min(min(max(bound0.x, bound1.x), max(bound0.y, bound1.y)), max(bound0.z, bound1.z)); */
1600       nir_ssa_def *tmax = nir_fmin(b, nir_fmin(b,
1601          nir_fmax(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)),
1602          nir_fmax(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))),
1603          nir_fmax(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2)));
1604 
1605       /* if (!isnan(node->coords[i][0].x) && tmax >= max(0.0f, tmin) && tmin < ray_tmax) { */
1606       nir_push_if(b,
1607          nir_iand(b,
1608             min_x_is_not_nan,
1609             nir_iand(b,
1610                nir_fge(b, tmax, nir_fmax(b, nir_imm_float(b, 0.0f), tmin)),
1611                nir_flt(b, tmin, ray_tmax))));
1612       {
1613          /* child_indices[i] = node->children[i]; */
1614          nir_ssa_def *new_child_indices[4] = {child_index, child_index, child_index, child_index};
1615          nir_store_var(b, child_indices, nir_vec(b, new_child_indices, 4), 1u << i);
1616 
1617          /* distances[i] = tmin; */
1618          nir_ssa_def *new_distances[4] = {tmin, tmin, tmin, tmin};
1619          nir_store_var(b, distances, nir_vec(b, new_distances, 4), 1u << i);
1620 
1621       }
1622       /* } */
1623       nir_pop_if(b, NULL);
1624    }
1625 
1626    /* Sort our distances with a sorting network. */
1627    nir_sort_hit_pair(b, distances, child_indices, 0, 1);
1628    nir_sort_hit_pair(b, distances, child_indices, 2, 3);
1629    nir_sort_hit_pair(b, distances, child_indices, 0, 2);
1630    nir_sort_hit_pair(b, distances, child_indices, 1, 3);
1631    nir_sort_hit_pair(b, distances, child_indices, 1, 2);
1632 
1633    return nir_load_var(b, child_indices);
1634 }
1635 
1636 static nir_ssa_def *
intersect_ray_amd_software_tri(struct radv_device * device,nir_builder * b,nir_ssa_def * bvh_node,nir_ssa_def * ray_tmax,nir_ssa_def * origin,nir_ssa_def * dir,nir_ssa_def * inv_dir)1637 intersect_ray_amd_software_tri(struct radv_device *device,
1638                                nir_builder *b, nir_ssa_def *bvh_node,
1639                                nir_ssa_def *ray_tmax, nir_ssa_def *origin,
1640                                nir_ssa_def *dir, nir_ssa_def *inv_dir)
1641 {
1642    const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4);
1643 
1644    nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node);
1645 
1646    const uint32_t coord_offsets[3] = {
1647       offsetof(struct radv_bvh_triangle_node, coords[0]),
1648       offsetof(struct radv_bvh_triangle_node, coords[1]),
1649       offsetof(struct radv_bvh_triangle_node, coords[2]),
1650    };
1651 
1652    /* node->coords[0], node->coords[1], node->coords[2] -> vec3 */
1653    nir_ssa_def *node_coords[3] = {
1654       nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[0])), .align_mul = 64, .align_offset = coord_offsets[0] % 64 ),
1655       nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[1])), .align_mul = 64, .align_offset = coord_offsets[1] % 64 ),
1656       nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[2])), .align_mul = 64, .align_offset = coord_offsets[2] % 64 ),
1657    };
1658 
1659    nir_variable *result = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "result");
1660    nir_store_var(b, result, nir_imm_vec4(b, INFINITY, 1.0f, 0.0f, 0.0f), 0xf);
1661 
1662    /* Based on watertight Ray/Triangle intersection from
1663     * http://jcgt.org/published/0002/01/05/paper.pdf */
1664 
1665    /* Calculate the dimension where the ray direction is largest */
1666    nir_ssa_def *abs_dir = nir_fabs(b, dir);
1667 
1668    nir_ssa_def *abs_dirs[3] = {
1669       nir_channel(b, abs_dir, 0),
1670       nir_channel(b, abs_dir, 1),
1671       nir_channel(b, abs_dir, 2),
1672    };
1673    /* Find index of greatest value of abs_dir and put that as kz. */
1674    nir_ssa_def *kz = nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[1]),
1675          nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[2]),
1676             nir_imm_int(b, 0), nir_imm_int(b, 2)),
1677          nir_bcsel(b, nir_fge(b, abs_dirs[1], abs_dirs[2]),
1678             nir_imm_int(b, 1), nir_imm_int(b, 2)));
1679    nir_ssa_def *kx = nir_imod(b, nir_iadd(b, kz, nir_imm_int(b, 1)), nir_imm_int(b, 3));
1680    nir_ssa_def *ky = nir_imod(b, nir_iadd(b, kx, nir_imm_int(b, 1)), nir_imm_int(b, 3));
1681    nir_ssa_def *k_indices[3] = { kx, ky, kz };
1682    nir_ssa_def *k = nir_vec(b, k_indices, 3);
1683 
1684    /* Swap kx and ky dimensions to preseve winding order */
1685    unsigned swap_xy_swizzle[4] = {1, 0, 2, 3};
1686    k = nir_bcsel(b,
1687       nir_flt(b, nir_vector_extract(b, dir, kz), nir_imm_float(b, 0.0f)),
1688       nir_swizzle(b, k, swap_xy_swizzle, 3),
1689       k);
1690 
1691    kx = nir_channel(b, k, 0);
1692    ky = nir_channel(b, k, 1);
1693    kz = nir_channel(b, k, 2);
1694 
1695    /* Calculate shear constants */
1696    nir_ssa_def *sz = nir_frcp(b, nir_vector_extract(b, dir, kz));
1697    nir_ssa_def *sx = nir_fmul(b, nir_vector_extract(b, dir, kx), sz);
1698    nir_ssa_def *sy = nir_fmul(b, nir_vector_extract(b, dir, ky), sz);
1699 
1700    /* Calculate vertices relative to ray origin */
1701    nir_ssa_def *v_a = nir_fsub(b, node_coords[0], origin);
1702    nir_ssa_def *v_b = nir_fsub(b, node_coords[1], origin);
1703    nir_ssa_def *v_c = nir_fsub(b, node_coords[2], origin);
1704 
1705    /* Perform shear and scale */
1706    nir_ssa_def *ax = nir_fsub(b, nir_vector_extract(b, v_a, kx), nir_fmul(b, sx, nir_vector_extract(b, v_a, kz)));
1707    nir_ssa_def *ay = nir_fsub(b, nir_vector_extract(b, v_a, ky), nir_fmul(b, sy, nir_vector_extract(b, v_a, kz)));
1708    nir_ssa_def *bx = nir_fsub(b, nir_vector_extract(b, v_b, kx), nir_fmul(b, sx, nir_vector_extract(b, v_b, kz)));
1709    nir_ssa_def *by = nir_fsub(b, nir_vector_extract(b, v_b, ky), nir_fmul(b, sy, nir_vector_extract(b, v_b, kz)));
1710    nir_ssa_def *cx = nir_fsub(b, nir_vector_extract(b, v_c, kx), nir_fmul(b, sx, nir_vector_extract(b, v_c, kz)));
1711    nir_ssa_def *cy = nir_fsub(b, nir_vector_extract(b, v_c, ky), nir_fmul(b, sy, nir_vector_extract(b, v_c, kz)));
1712 
1713    nir_ssa_def *u = nir_fsub(b, nir_fmul(b, cx, by), nir_fmul(b, cy, bx));
1714    nir_ssa_def *v = nir_fsub(b, nir_fmul(b, ax, cy), nir_fmul(b, ay, cx));
1715    nir_ssa_def *w = nir_fsub(b, nir_fmul(b, bx, ay), nir_fmul(b, by, ax));
1716 
1717    nir_variable *u_var = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "u");
1718    nir_variable *v_var = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "v");
1719    nir_variable *w_var = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "w");
1720    nir_store_var(b, u_var, u, 0x1);
1721    nir_store_var(b, v_var, v, 0x1);
1722    nir_store_var(b, w_var, w, 0x1);
1723 
1724    /* Fallback to testing edges with double precision...
1725     *
1726     * The Vulkan spec states it only needs single precision watertightness
1727     * but we fail dEQP-VK.ray_tracing_pipeline.watertightness.closedFan2.1024 with
1728     * failures = 1 without doing this. :( */
1729    nir_ssa_def *cond_retest = nir_ior(b, nir_ior(b,
1730       nir_feq(b, u, nir_imm_float(b, 0.0f)),
1731       nir_feq(b, v, nir_imm_float(b, 0.0f))),
1732       nir_feq(b, w, nir_imm_float(b, 0.0f)));
1733 
1734    nir_push_if(b, cond_retest);
1735    {
1736       ax = nir_f2f64(b, ax); ay = nir_f2f64(b, ay);
1737       bx = nir_f2f64(b, bx); by = nir_f2f64(b, by);
1738       cx = nir_f2f64(b, cx); cy = nir_f2f64(b, cy);
1739 
1740       nir_store_var(b, u_var, nir_f2f32(b, nir_fsub(b, nir_fmul(b, cx, by), nir_fmul(b, cy, bx))), 0x1);
1741       nir_store_var(b, v_var, nir_f2f32(b, nir_fsub(b, nir_fmul(b, ax, cy), nir_fmul(b, ay, cx))), 0x1);
1742       nir_store_var(b, w_var, nir_f2f32(b, nir_fsub(b, nir_fmul(b, bx, ay), nir_fmul(b, by, ax))), 0x1);
1743    }
1744    nir_pop_if(b, NULL);
1745 
1746    u = nir_load_var(b, u_var);
1747    v = nir_load_var(b, v_var);
1748    w = nir_load_var(b, w_var);
1749 
1750    /* Perform edge tests. */
1751    nir_ssa_def *cond_back = nir_ior(b, nir_ior(b,
1752       nir_flt(b, u, nir_imm_float(b, 0.0f)),
1753       nir_flt(b, v, nir_imm_float(b, 0.0f))),
1754       nir_flt(b, w, nir_imm_float(b, 0.0f)));
1755 
1756    nir_ssa_def *cond_front = nir_ior(b, nir_ior(b,
1757       nir_flt(b, nir_imm_float(b, 0.0f), u),
1758       nir_flt(b, nir_imm_float(b, 0.0f), v)),
1759       nir_flt(b, nir_imm_float(b, 0.0f), w));
1760 
1761    nir_ssa_def *cond = nir_inot(b, nir_iand(b, cond_back, cond_front));
1762 
1763    nir_push_if(b, cond);
1764    {
1765       nir_ssa_def *det = nir_fadd(b, u, nir_fadd(b, v, w));
1766 
1767       nir_ssa_def *az = nir_fmul(b, sz, nir_vector_extract(b, v_a, kz));
1768       nir_ssa_def *bz = nir_fmul(b, sz, nir_vector_extract(b, v_b, kz));
1769       nir_ssa_def *cz = nir_fmul(b, sz, nir_vector_extract(b, v_c, kz));
1770 
1771       nir_ssa_def *t = nir_fadd(b, nir_fadd(b, nir_fmul(b, u, az), nir_fmul(b, v, bz)), nir_fmul(b, w, cz));
1772 
1773       nir_ssa_def *t_signed = nir_fmul(b, nir_fsign(b, det), t);
1774 
1775       nir_ssa_def *det_cond_front = nir_inot(b, nir_flt(b, t_signed, nir_imm_float(b, 0.0f)));
1776 
1777       nir_push_if(b, det_cond_front);
1778       {
1779          nir_ssa_def *indices[4] = {
1780             t, det,
1781             v, w
1782          };
1783          nir_store_var(b, result, nir_vec(b, indices, 4), 0xf);
1784       }
1785       nir_pop_if(b, NULL);
1786    }
1787    nir_pop_if(b, NULL);
1788 
1789    return nir_load_var(b, result);
1790 }
1791 
1792 static void
insert_traversal(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,nir_builder * b,const struct rt_variables * vars)1793 insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1794                  nir_builder *b, const struct rt_variables *vars)
1795 {
1796    unsigned stack_entry_size = 4;
1797    unsigned lanes = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] *
1798                     b->shader->info.workgroup_size[2];
1799    unsigned stack_entry_stride = stack_entry_size * lanes;
1800    nir_ssa_def *stack_entry_stride_def = nir_imm_int(b, stack_entry_stride);
1801    nir_ssa_def *stack_base =
1802       nir_iadd(b, nir_imm_int(b, b->shader->info.shared_size),
1803                nir_imul(b, nir_load_subgroup_invocation(b), nir_imm_int(b, stack_entry_size)));
1804 
1805    /*
1806     * A top-level AS can contain 2^24 children and a bottom-level AS can contain 2^24 triangles. At
1807     * a branching factor of 4, that means we may need up to 24 levels of box nodes + 1 triangle node
1808     * + 1 instance node. Furthermore, when processing a box node, worst case we actually push all 4
1809     * children and remove one, so the DFS stack depth is box nodes * 3 + 2.
1810     */
1811    b->shader->info.shared_size += stack_entry_stride * 76;
1812    assert(b->shader->info.shared_size <= 32768);
1813 
1814    nir_ssa_def *accel_struct = nir_load_var(b, vars->accel_struct);
1815 
1816    struct rt_traversal_vars trav_vars = init_traversal_vars(b);
1817 
1818    /* Initialize the follow-up shader idx to 0, to be replaced by the miss shader
1819     * if we actually miss. */
1820    nir_store_var(b, vars->idx, nir_imm_int(b, 0), 1);
1821 
1822    nir_store_var(b, trav_vars.should_return, nir_imm_bool(b, false), 1);
1823 
1824    nir_push_if(b, nir_ine(b, accel_struct, nir_imm_int64(b, 0)));
1825    {
1826       nir_store_var(b, trav_vars.bvh_base, build_addr_to_node(b, accel_struct), 1);
1827 
1828       nir_ssa_def *bvh_root =
1829          nir_build_load_global(b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE,
1830                                .align_mul = 64, .align_offset = 0);
1831 
1832       /* We create a BVH descriptor that covers the entire memory range. That way we can always
1833        * use the same descriptor, which avoids divergence when different rays hit different
1834        * instances at the cost of having to use 64-bit node ids. */
1835       const uint64_t bvh_size = 1ull << 42;
1836       nir_ssa_def *desc = nir_imm_ivec4(
1837          b, 0, 1u << 31 /* Enable box sorting */, (bvh_size - 1) & 0xFFFFFFFFu,
1838          ((bvh_size - 1) >> 32) | (1u << 24 /* Return IJ for triangles */) | (1u << 31));
1839 
1840       nir_ssa_def *vec3ones = nir_channels(b, nir_imm_vec4(b, 1.0, 1.0, 1.0, 1.0), 0x7);
1841       nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
1842       nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
1843       nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
1844       nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
1845       nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
1846 
1847       nir_store_var(b, trav_vars.stack, nir_iadd(b, stack_base, stack_entry_stride_def), 1);
1848       nir_store_shared(b, bvh_root, stack_base, .base = 0, .write_mask = 0x1,
1849                        .align_mul = stack_entry_size, .align_offset = 0);
1850 
1851       nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1);
1852 
1853       nir_push_loop(b);
1854 
1855       nir_push_if(b, nir_ieq(b, nir_load_var(b, trav_vars.stack), stack_base));
1856       nir_jump(b, nir_jump_break);
1857       nir_pop_if(b, NULL);
1858 
1859       nir_push_if(
1860          b, nir_uge(b, nir_load_var(b, trav_vars.top_stack), nir_load_var(b, trav_vars.stack)));
1861       nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1);
1862       nir_store_var(b, trav_vars.bvh_base,
1863                     build_addr_to_node(b, nir_load_var(b, vars->accel_struct)), 1);
1864       nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
1865       nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
1866       nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
1867       nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
1868 
1869       nir_pop_if(b, NULL);
1870 
1871       nir_store_var(b, trav_vars.stack,
1872                     nir_isub(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1);
1873 
1874       nir_ssa_def *bvh_node = nir_load_shared(b, 1, 32, nir_load_var(b, trav_vars.stack), .base = 0,
1875                                               .align_mul = stack_entry_size, .align_offset = 0);
1876       nir_ssa_def *bvh_node_type = nir_iand(b, bvh_node, nir_imm_int(b, 7));
1877 
1878       bvh_node = nir_iadd(b, nir_load_var(b, trav_vars.bvh_base), nir_u2u(b, bvh_node, 64));
1879       nir_ssa_def *intrinsic_result = NULL;
1880       if (device->physical_device->rad_info.chip_class >= GFX10_3
1881        && !(device->instance->perftest_flags & RADV_PERFTEST_FORCE_EMULATE_RT)) {
1882          intrinsic_result = nir_bvh64_intersect_ray_amd(
1883             b, 32, desc, nir_unpack_64_2x32(b, bvh_node), nir_load_var(b, vars->tmax),
1884             nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir),
1885             nir_load_var(b, trav_vars.inv_dir));
1886       }
1887 
1888       nir_push_if(b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 4)), nir_imm_int(b, 0)));
1889       {
1890          nir_push_if(b,
1891                      nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 2)), nir_imm_int(b, 0)));
1892          {
1893             /* custom */
1894             nir_push_if(
1895                b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 1)), nir_imm_int(b, 0)));
1896             {
1897                insert_traversal_aabb_case(device, pCreateInfo, b, vars, &trav_vars, bvh_node);
1898             }
1899             nir_push_else(b, NULL);
1900             {
1901                /* instance */
1902                nir_ssa_def *instance_node_addr = build_node_to_addr(device, b, bvh_node);
1903                nir_ssa_def *instance_data = nir_build_load_global(
1904                   b, 4, 32, instance_node_addr, .align_mul = 64, .align_offset = 0);
1905                nir_ssa_def *wto_matrix[] = {
1906                   nir_build_load_global(b, 4, 32,
1907                                         nir_iadd(b, instance_node_addr, nir_imm_int64(b, 16)),
1908                                         .align_mul = 64, .align_offset = 16),
1909                   nir_build_load_global(b, 4, 32,
1910                                         nir_iadd(b, instance_node_addr, nir_imm_int64(b, 32)),
1911                                         .align_mul = 64, .align_offset = 32),
1912                   nir_build_load_global(b, 4, 32,
1913                                         nir_iadd(b, instance_node_addr, nir_imm_int64(b, 48)),
1914                                         .align_mul = 64, .align_offset = 48)};
1915                nir_ssa_def *instance_id = nir_build_load_global(
1916                   b, 1, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 88)), .align_mul = 4,
1917                   .align_offset = 0);
1918                nir_ssa_def *instance_and_mask = nir_channel(b, instance_data, 2);
1919                nir_ssa_def *instance_mask = nir_ushr(b, instance_and_mask, nir_imm_int(b, 24));
1920 
1921                nir_push_if(b,
1922                            nir_ieq(b, nir_iand(b, instance_mask, nir_load_var(b, vars->cull_mask)),
1923                                    nir_imm_int(b, 0)));
1924                nir_jump(b, nir_jump_continue);
1925                nir_pop_if(b, NULL);
1926 
1927                nir_store_var(b, trav_vars.top_stack, nir_load_var(b, trav_vars.stack), 1);
1928                nir_store_var(b, trav_vars.bvh_base,
1929                              build_addr_to_node(
1930                                 b, nir_pack_64_2x32(b, nir_channels(b, instance_data, 0x3))),
1931                              1);
1932                nir_store_shared(b,
1933                                 nir_iand(b, nir_channel(b, instance_data, 0), nir_imm_int(b, 63)),
1934                                 nir_load_var(b, trav_vars.stack), .base = 0, .write_mask = 0x1,
1935                                 .align_mul = stack_entry_size, .align_offset = 0);
1936                nir_store_var(b, trav_vars.stack,
1937                              nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def),
1938                              1);
1939 
1940                nir_store_var(
1941                   b, trav_vars.origin,
1942                   nir_build_vec3_mat_mult_pre(b, nir_load_var(b, vars->origin), wto_matrix), 7);
1943                nir_store_var(
1944                   b, trav_vars.dir,
1945                   nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false),
1946                   7);
1947                nir_store_var(b, trav_vars.inv_dir,
1948                              nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
1949                nir_store_var(b, trav_vars.custom_instance_and_mask, instance_and_mask, 1);
1950                nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3),
1951                              1);
1952                nir_store_var(b, trav_vars.instance_id, instance_id, 1);
1953                nir_store_var(b, trav_vars.instance_addr, instance_node_addr, 1);
1954             }
1955             nir_pop_if(b, NULL);
1956          }
1957          nir_push_else(b, NULL);
1958          {
1959             /* box */
1960             nir_ssa_def *result = intrinsic_result;
1961             if (!result) {
1962                /* If we didn't run the intrinsic cause the hardware didn't support it,
1963                 * emulate ray/box intersection here */
1964                result = intersect_ray_amd_software_box(device,
1965                   b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin),
1966                   nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir));
1967             }
1968 
1969             for (unsigned i = 4; i-- > 0; ) {
1970                nir_ssa_def *new_node = nir_vector_extract(b, result, nir_imm_int(b, i));
1971                nir_push_if(b, nir_ine(b, new_node, nir_imm_int(b, 0xffffffff)));
1972                {
1973                   nir_store_shared(b, new_node, nir_load_var(b, trav_vars.stack), .base = 0,
1974                                    .write_mask = 0x1, .align_mul = stack_entry_size,
1975                                    .align_offset = 0);
1976                   nir_store_var(
1977                      b, trav_vars.stack,
1978                      nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1);
1979                }
1980                nir_pop_if(b, NULL);
1981             }
1982          }
1983          nir_pop_if(b, NULL);
1984       }
1985       nir_push_else(b, NULL);
1986       {
1987          nir_ssa_def *result = intrinsic_result;
1988          if (!result) {
1989             /* If we didn't run the intrinsic cause the hardware didn't support it,
1990              * emulate ray/tri intersection here */
1991             result = intersect_ray_amd_software_tri(device,
1992                b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin),
1993                nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir));
1994          }
1995          insert_traversal_triangle_case(device, pCreateInfo, b, result, vars, &trav_vars, bvh_node);
1996       }
1997       nir_pop_if(b, NULL);
1998 
1999       nir_pop_loop(b, NULL);
2000    }
2001    nir_pop_if(b, NULL);
2002 
2003    /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence
2004     * need to return immediately to the calling shader. */
2005    nir_push_if(b, nir_load_var(b, trav_vars.should_return));
2006    {
2007       insert_rt_return(b, vars);
2008    }
2009    nir_push_else(b, NULL);
2010    {
2011       /* Only load the miss shader if we actually miss, which we determining by not having set
2012        * a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none
2013        * of the rays miss. */
2014       nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0)));
2015       {
2016          load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 0);
2017       }
2018       nir_pop_if(b, NULL);
2019    }
2020    nir_pop_if(b, NULL);
2021 }
2022 
2023 static unsigned
compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_pipeline_shader_stack_size * stack_sizes)2024 compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
2025                       const struct radv_pipeline_shader_stack_size *stack_sizes)
2026 {
2027    unsigned raygen_size = 0;
2028    unsigned callable_size = 0;
2029    unsigned chit_size = 0;
2030    unsigned miss_size = 0;
2031    unsigned non_recursive_size = 0;
2032 
2033    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
2034       non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size);
2035 
2036       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
2037       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
2038       unsigned size = stack_sizes[i].recursive_size;
2039 
2040       switch (group_info->type) {
2041       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
2042          shader_id = group_info->generalShader;
2043          break;
2044       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
2045       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
2046          shader_id = group_info->closestHitShader;
2047          break;
2048       default:
2049          break;
2050       }
2051       if (shader_id == VK_SHADER_UNUSED_KHR)
2052          continue;
2053 
2054       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
2055       switch (stage->stage) {
2056       case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
2057          raygen_size = MAX2(raygen_size, size);
2058          break;
2059       case VK_SHADER_STAGE_MISS_BIT_KHR:
2060          miss_size = MAX2(miss_size, size);
2061          break;
2062       case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
2063          chit_size = MAX2(chit_size, size);
2064          break;
2065       case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
2066          callable_size = MAX2(callable_size, size);
2067          break;
2068       default:
2069          unreachable("Invalid stage type in RT shader");
2070       }
2071    }
2072    return raygen_size +
2073           MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
2074              MAX2(MAX2(chit_size, miss_size), non_recursive_size) +
2075           MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) *
2076              MAX2(chit_size, miss_size) +
2077           2 * callable_size;
2078 }
2079 
2080 bool
radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR * pCreateInfo)2081 radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
2082 {
2083    if (!pCreateInfo->pDynamicState)
2084       return false;
2085 
2086    for (unsigned i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; ++i) {
2087       if (pCreateInfo->pDynamicState->pDynamicStates[i] ==
2088           VK_DYNAMIC_STATE_RAY_TRACING_PIPELINE_STACK_SIZE_KHR)
2089          return true;
2090    }
2091 
2092    return false;
2093 }
2094 
2095 static nir_shader *
create_rt_shader(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_pipeline_shader_stack_size * stack_sizes)2096 create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
2097                  struct radv_pipeline_shader_stack_size *stack_sizes)
2098 {
2099    RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
2100    struct radv_pipeline_key key;
2101    memset(&key, 0, sizeof(key));
2102 
2103    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "rt_combined");
2104 
2105    b.shader->info.workgroup_size[0] = 8;
2106    b.shader->info.workgroup_size[1] = 8;
2107    b.shader->info.workgroup_size[2] = 1;
2108 
2109    struct rt_variables vars = create_rt_variables(b.shader, stack_sizes);
2110    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0);
2111    nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
2112 
2113    nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1);
2114 
2115    nir_loop *loop = nir_push_loop(&b);
2116 
2117    nir_push_if(&b, nir_ior(&b, nir_ieq(&b, nir_load_var(&b, vars.idx), nir_imm_int(&b, 0)),
2118                            nir_ine(&b, nir_load_var(&b, vars.main_loop_case_visited),
2119                                    nir_imm_bool(&b, true))));
2120    nir_jump(&b, nir_jump_break);
2121    nir_pop_if(&b, NULL);
2122 
2123    nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1);
2124 
2125    nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, vars.idx), nir_imm_int(&b, 1)));
2126    nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1);
2127    insert_traversal(device, pCreateInfo, &b, &vars);
2128    nir_pop_if(&b, NULL);
2129 
2130    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
2131 
2132    /* We do a trick with the indexing of the resume shaders so that the first
2133     * shader of group x always gets id x and the resume shader ids then come after
2134     * groupCount. This makes the shadergroup handles independent of compilation. */
2135    unsigned call_idx_base = pCreateInfo->groupCount + 1;
2136    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
2137       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
2138       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
2139 
2140       switch (group_info->type) {
2141       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
2142          shader_id = group_info->generalShader;
2143          break;
2144       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
2145       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
2146          shader_id = group_info->closestHitShader;
2147          break;
2148       default:
2149          break;
2150       }
2151       if (shader_id == VK_SHADER_UNUSED_KHR)
2152          continue;
2153 
2154       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
2155       nir_shader *nir_stage = parse_rt_stage(device, layout, stage);
2156 
2157       b.shader->options = nir_stage->options;
2158 
2159       uint32_t num_resume_shaders = 0;
2160       nir_shader **resume_shaders = NULL;
2161       nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders,
2162                              &num_resume_shaders, nir_stage);
2163 
2164       vars.group_idx = i;
2165       insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
2166       for (unsigned j = 0; j < num_resume_shaders; ++j) {
2167          insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
2168       }
2169       call_idx_base += num_resume_shaders;
2170    }
2171 
2172    nir_pop_loop(&b, loop);
2173 
2174    if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) {
2175       /* Put something so scratch gets enabled in the shader. */
2176       b.shader->scratch_size = 16;
2177    } else
2178       b.shader->scratch_size = compute_rt_stack_size(pCreateInfo, stack_sizes);
2179 
2180    /* Deal with all the inline functions. */
2181    nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
2182    nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
2183 
2184    return b.shader;
2185 }
2186 
2187 static VkResult
radv_rt_pipeline_create(VkDevice _device,VkPipelineCache _cache,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipeline)2188 radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
2189                         const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
2190                         const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
2191 {
2192    RADV_FROM_HANDLE(radv_device, device, _device);
2193    VkResult result;
2194    struct radv_pipeline *pipeline = NULL;
2195    struct radv_pipeline_shader_stack_size *stack_sizes = NULL;
2196    uint8_t hash[20];
2197    nir_shader *shader = NULL;
2198    bool keep_statistic_info =
2199       (pCreateInfo->flags & VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR) ||
2200       (device->instance->debug_flags & RADV_DEBUG_DUMP_SHADER_STATS) || device->keep_shader_info;
2201 
2202    if (pCreateInfo->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR)
2203       return radv_rt_pipeline_library_create(_device, _cache, pCreateInfo, pAllocator, pPipeline);
2204 
2205    VkRayTracingPipelineCreateInfoKHR local_create_info =
2206       radv_create_merged_rt_create_info(pCreateInfo);
2207    if (!local_create_info.pStages || !local_create_info.pGroups) {
2208       result = VK_ERROR_OUT_OF_HOST_MEMORY;
2209       goto fail;
2210    }
2211 
2212    radv_hash_rt_shaders(hash, &local_create_info, radv_get_hash_flags(device, keep_statistic_info));
2213    struct vk_shader_module module = {.base.type = VK_OBJECT_TYPE_SHADER_MODULE};
2214 
2215    VkComputePipelineCreateInfo compute_info = {
2216       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
2217       .pNext = NULL,
2218       .flags = pCreateInfo->flags | VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT,
2219       .stage =
2220          {
2221             .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
2222             .stage = VK_SHADER_STAGE_COMPUTE_BIT,
2223             .module = vk_shader_module_to_handle(&module),
2224             .pName = "main",
2225          },
2226       .layout = pCreateInfo->layout,
2227    };
2228 
2229    /* First check if we can get things from the cache before we take the expensive step of
2230     * generating the nir. */
2231    result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash,
2232                                          stack_sizes, local_create_info.groupCount, pPipeline);
2233    if (result == VK_PIPELINE_COMPILE_REQUIRED_EXT) {
2234       stack_sizes = calloc(sizeof(*stack_sizes), local_create_info.groupCount);
2235       if (!stack_sizes) {
2236          result = VK_ERROR_OUT_OF_HOST_MEMORY;
2237          goto fail;
2238       }
2239 
2240       shader = create_rt_shader(device, &local_create_info, stack_sizes);
2241       module.nir = shader;
2242       compute_info.flags = pCreateInfo->flags;
2243       result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash,
2244                                             stack_sizes, local_create_info.groupCount, pPipeline);
2245       stack_sizes = NULL;
2246 
2247       if (result != VK_SUCCESS)
2248          goto shader_fail;
2249    }
2250    pipeline = radv_pipeline_from_handle(*pPipeline);
2251 
2252    pipeline->compute.rt_group_handles =
2253       calloc(sizeof(*pipeline->compute.rt_group_handles), local_create_info.groupCount);
2254    if (!pipeline->compute.rt_group_handles) {
2255       result = VK_ERROR_OUT_OF_HOST_MEMORY;
2256       goto shader_fail;
2257    }
2258 
2259    pipeline->compute.dynamic_stack_size = radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo);
2260 
2261    for (unsigned i = 0; i < local_create_info.groupCount; ++i) {
2262       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &local_create_info.pGroups[i];
2263       switch (group_info->type) {
2264       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
2265          if (group_info->generalShader != VK_SHADER_UNUSED_KHR)
2266             pipeline->compute.rt_group_handles[i].handles[0] = i + 2;
2267          break;
2268       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
2269          if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR)
2270             pipeline->compute.rt_group_handles[i].handles[1] = i + 2;
2271          FALLTHROUGH;
2272       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
2273          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
2274             pipeline->compute.rt_group_handles[i].handles[0] = i + 2;
2275          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
2276             pipeline->compute.rt_group_handles[i].handles[1] = i + 2;
2277          break;
2278       case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR:
2279          unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR");
2280       }
2281    }
2282 
2283 shader_fail:
2284    if (result != VK_SUCCESS && pipeline)
2285       radv_pipeline_destroy(device, pipeline, pAllocator);
2286    ralloc_free(shader);
2287 fail:
2288    free((void *)local_create_info.pGroups);
2289    free((void *)local_create_info.pStages);
2290    free(stack_sizes);
2291    return result;
2292 }
2293 
2294 VkResult
radv_CreateRayTracingPipelinesKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,VkPipelineCache pipelineCache,uint32_t count,const VkRayTracingPipelineCreateInfoKHR * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipelines)2295 radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
2296                                   VkPipelineCache pipelineCache, uint32_t count,
2297                                   const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
2298                                   const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines)
2299 {
2300    VkResult result = VK_SUCCESS;
2301 
2302    unsigned i = 0;
2303    for (; i < count; i++) {
2304       VkResult r;
2305       r = radv_rt_pipeline_create(_device, pipelineCache, &pCreateInfos[i], pAllocator,
2306                                   &pPipelines[i]);
2307       if (r != VK_SUCCESS) {
2308          result = r;
2309          pPipelines[i] = VK_NULL_HANDLE;
2310 
2311          if (pCreateInfos[i].flags & VK_PIPELINE_CREATE_EARLY_RETURN_ON_FAILURE_BIT_EXT)
2312             break;
2313       }
2314    }
2315 
2316    for (; i < count; ++i)
2317       pPipelines[i] = VK_NULL_HANDLE;
2318 
2319    return result;
2320 }
2321 
2322 VkResult
radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)2323 radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup,
2324                                         uint32_t groupCount, size_t dataSize, void *pData)
2325 {
2326    RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
2327    char *data = pData;
2328 
2329    STATIC_ASSERT(sizeof(*pipeline->compute.rt_group_handles) <= RADV_RT_HANDLE_SIZE);
2330 
2331    memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE);
2332 
2333    for (uint32_t i = 0; i < groupCount; ++i) {
2334       memcpy(data + i * RADV_RT_HANDLE_SIZE, &pipeline->compute.rt_group_handles[firstGroup + i],
2335              sizeof(*pipeline->compute.rt_group_handles));
2336    }
2337 
2338    return VK_SUCCESS;
2339 }
2340 
2341 VkDeviceSize
radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device,VkPipeline _pipeline,uint32_t group,VkShaderGroupShaderKHR groupShader)2342 radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group,
2343                                           VkShaderGroupShaderKHR groupShader)
2344 {
2345    RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
2346    const struct radv_pipeline_shader_stack_size *stack_size =
2347       &pipeline->compute.rt_stack_sizes[group];
2348 
2349    if (groupShader == VK_SHADER_GROUP_SHADER_ANY_HIT_KHR ||
2350        groupShader == VK_SHADER_GROUP_SHADER_INTERSECTION_KHR)
2351       return stack_size->non_recursive_size;
2352    else
2353       return stack_size->recursive_size;
2354 }
2355