1 /*
2  * Copyright © 2021 Bas Nieuwenhuizen
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 #include "radv_acceleration_structure.h"
24 #include "radv_private.h"
25 
26 #include "util/format/format_utils.h"
27 #include "util/half_float.h"
28 #include "nir_builder.h"
29 #include "radv_cs.h"
30 #include "radv_meta.h"
31 
32 void
radv_GetAccelerationStructureBuildSizesKHR(VkDevice _device,VkAccelerationStructureBuildTypeKHR buildType,const VkAccelerationStructureBuildGeometryInfoKHR * pBuildInfo,const uint32_t * pMaxPrimitiveCounts,VkAccelerationStructureBuildSizesInfoKHR * pSizeInfo)33 radv_GetAccelerationStructureBuildSizesKHR(
34    VkDevice _device, VkAccelerationStructureBuildTypeKHR buildType,
35    const VkAccelerationStructureBuildGeometryInfoKHR *pBuildInfo,
36    const uint32_t *pMaxPrimitiveCounts, VkAccelerationStructureBuildSizesInfoKHR *pSizeInfo)
37 {
38    uint64_t triangles = 0, boxes = 0, instances = 0;
39 
40    STATIC_ASSERT(sizeof(struct radv_bvh_triangle_node) == 64);
41    STATIC_ASSERT(sizeof(struct radv_bvh_aabb_node) == 64);
42    STATIC_ASSERT(sizeof(struct radv_bvh_instance_node) == 128);
43    STATIC_ASSERT(sizeof(struct radv_bvh_box16_node) == 64);
44    STATIC_ASSERT(sizeof(struct radv_bvh_box32_node) == 128);
45 
46    for (uint32_t i = 0; i < pBuildInfo->geometryCount; ++i) {
47       const VkAccelerationStructureGeometryKHR *geometry;
48       if (pBuildInfo->pGeometries)
49          geometry = &pBuildInfo->pGeometries[i];
50       else
51          geometry = pBuildInfo->ppGeometries[i];
52 
53       switch (geometry->geometryType) {
54       case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
55          triangles += pMaxPrimitiveCounts[i];
56          break;
57       case VK_GEOMETRY_TYPE_AABBS_KHR:
58          boxes += pMaxPrimitiveCounts[i];
59          break;
60       case VK_GEOMETRY_TYPE_INSTANCES_KHR:
61          instances += pMaxPrimitiveCounts[i];
62          break;
63       case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
64          unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
65       }
66    }
67 
68    uint64_t children = boxes + instances + triangles;
69    uint64_t internal_nodes = 0;
70    while (children > 1) {
71       children = DIV_ROUND_UP(children, 4);
72       internal_nodes += children;
73    }
74 
75    /* The stray 128 is to ensure we have space for a header
76     * which we'd want to use for some metadata (like the
77     * total AABB of the BVH) */
78    uint64_t size = boxes * 128 + instances * 128 + triangles * 64 + internal_nodes * 128 + 192;
79 
80    pSizeInfo->accelerationStructureSize = size;
81 
82    /* 2x the max number of nodes in a BVH layer (one uint32_t each) */
83    pSizeInfo->updateScratchSize = pSizeInfo->buildScratchSize =
84       MAX2(4096, 2 * (boxes + instances + triangles) * sizeof(uint32_t));
85 }
86 
87 VkResult
radv_CreateAccelerationStructureKHR(VkDevice _device,const VkAccelerationStructureCreateInfoKHR * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkAccelerationStructureKHR * pAccelerationStructure)88 radv_CreateAccelerationStructureKHR(VkDevice _device,
89                                     const VkAccelerationStructureCreateInfoKHR *pCreateInfo,
90                                     const VkAllocationCallbacks *pAllocator,
91                                     VkAccelerationStructureKHR *pAccelerationStructure)
92 {
93    RADV_FROM_HANDLE(radv_device, device, _device);
94    RADV_FROM_HANDLE(radv_buffer, buffer, pCreateInfo->buffer);
95    struct radv_acceleration_structure *accel;
96 
97    accel = vk_alloc2(&device->vk.alloc, pAllocator, sizeof(*accel), 8,
98                      VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
99    if (accel == NULL)
100       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
101 
102    vk_object_base_init(&device->vk, &accel->base, VK_OBJECT_TYPE_ACCELERATION_STRUCTURE_KHR);
103 
104    accel->mem_offset = buffer->offset + pCreateInfo->offset;
105    accel->size = pCreateInfo->size;
106    accel->bo = buffer->bo;
107 
108    *pAccelerationStructure = radv_acceleration_structure_to_handle(accel);
109    return VK_SUCCESS;
110 }
111 
112 void
radv_DestroyAccelerationStructureKHR(VkDevice _device,VkAccelerationStructureKHR accelerationStructure,const VkAllocationCallbacks * pAllocator)113 radv_DestroyAccelerationStructureKHR(VkDevice _device,
114                                      VkAccelerationStructureKHR accelerationStructure,
115                                      const VkAllocationCallbacks *pAllocator)
116 {
117    RADV_FROM_HANDLE(radv_device, device, _device);
118    RADV_FROM_HANDLE(radv_acceleration_structure, accel, accelerationStructure);
119 
120    if (!accel)
121       return;
122 
123    vk_object_base_finish(&accel->base);
124    vk_free2(&device->vk.alloc, pAllocator, accel);
125 }
126 
127 VkDeviceAddress
radv_GetAccelerationStructureDeviceAddressKHR(VkDevice _device,const VkAccelerationStructureDeviceAddressInfoKHR * pInfo)128 radv_GetAccelerationStructureDeviceAddressKHR(
129    VkDevice _device, const VkAccelerationStructureDeviceAddressInfoKHR *pInfo)
130 {
131    RADV_FROM_HANDLE(radv_acceleration_structure, accel, pInfo->accelerationStructure);
132    return radv_accel_struct_get_va(accel);
133 }
134 
135 VkResult
radv_WriteAccelerationStructuresPropertiesKHR(VkDevice _device,uint32_t accelerationStructureCount,const VkAccelerationStructureKHR * pAccelerationStructures,VkQueryType queryType,size_t dataSize,void * pData,size_t stride)136 radv_WriteAccelerationStructuresPropertiesKHR(
137    VkDevice _device, uint32_t accelerationStructureCount,
138    const VkAccelerationStructureKHR *pAccelerationStructures, VkQueryType queryType,
139    size_t dataSize, void *pData, size_t stride)
140 {
141    RADV_FROM_HANDLE(radv_device, device, _device);
142    char *data_out = (char*)pData;
143 
144    for (uint32_t i = 0; i < accelerationStructureCount; ++i) {
145       RADV_FROM_HANDLE(radv_acceleration_structure, accel, pAccelerationStructures[i]);
146       const char *base_ptr = (const char *)device->ws->buffer_map(accel->bo);
147       if (!base_ptr)
148          return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
149 
150       const struct radv_accel_struct_header *header = (const void*)(base_ptr + accel->mem_offset);
151       if (stride * i + sizeof(VkDeviceSize) <= dataSize) {
152          uint64_t value;
153          switch (queryType) {
154          case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR:
155             value = header->compacted_size;
156             break;
157          case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR:
158             value = header->serialization_size;
159             break;
160          default:
161             unreachable("Unhandled acceleration structure query");
162          }
163          *(VkDeviceSize *)(data_out + stride * i) = value;
164       }
165       device->ws->buffer_unmap(accel->bo);
166    }
167    return VK_SUCCESS;
168 }
169 
170 struct radv_bvh_build_ctx {
171    uint32_t *write_scratch;
172    char *base;
173    char *curr_ptr;
174 };
175 
176 static void
build_triangles(struct radv_bvh_build_ctx * ctx,const VkAccelerationStructureGeometryKHR * geom,const VkAccelerationStructureBuildRangeInfoKHR * range,unsigned geometry_id)177 build_triangles(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
178                 const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
179 {
180    const VkAccelerationStructureGeometryTrianglesDataKHR *tri_data = &geom->geometry.triangles;
181    VkTransformMatrixKHR matrix;
182    const char *index_data = (const char *)tri_data->indexData.hostAddress + range->primitiveOffset;
183 
184    if (tri_data->transformData.hostAddress) {
185       matrix = *(const VkTransformMatrixKHR *)((const char *)tri_data->transformData.hostAddress +
186                                                range->transformOffset);
187    } else {
188       matrix = (VkTransformMatrixKHR){
189          .matrix = {{1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0}}};
190    }
191 
192    for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
193       struct radv_bvh_triangle_node *node = (void*)ctx->curr_ptr;
194       uint32_t node_offset = ctx->curr_ptr - ctx->base;
195       uint32_t node_id = node_offset >> 3;
196       *ctx->write_scratch++ = node_id;
197 
198       for (unsigned v = 0; v < 3; ++v) {
199          uint32_t v_index = range->firstVertex;
200          switch (tri_data->indexType) {
201          case VK_INDEX_TYPE_NONE_KHR:
202             v_index += p * 3 + v;
203             break;
204          case VK_INDEX_TYPE_UINT8_EXT:
205             v_index += *(const uint8_t *)index_data;
206             index_data += 1;
207             break;
208          case VK_INDEX_TYPE_UINT16:
209             v_index += *(const uint16_t *)index_data;
210             index_data += 2;
211             break;
212          case VK_INDEX_TYPE_UINT32:
213             v_index += *(const uint32_t *)index_data;
214             index_data += 4;
215             break;
216          case VK_INDEX_TYPE_MAX_ENUM:
217             unreachable("Unhandled VK_INDEX_TYPE_MAX_ENUM");
218             break;
219          }
220 
221          const char *v_data = (const char *)tri_data->vertexData.hostAddress + v_index * tri_data->vertexStride;
222          float coords[4];
223          switch (tri_data->vertexFormat) {
224          case VK_FORMAT_R32G32_SFLOAT:
225             coords[0] = *(const float *)(v_data + 0);
226             coords[1] = *(const float *)(v_data + 4);
227             coords[2] = 0.0f;
228             coords[3] = 1.0f;
229             break;
230          case VK_FORMAT_R32G32B32_SFLOAT:
231             coords[0] = *(const float *)(v_data + 0);
232             coords[1] = *(const float *)(v_data + 4);
233             coords[2] = *(const float *)(v_data + 8);
234             coords[3] = 1.0f;
235             break;
236          case VK_FORMAT_R32G32B32A32_SFLOAT:
237             coords[0] = *(const float *)(v_data + 0);
238             coords[1] = *(const float *)(v_data + 4);
239             coords[2] = *(const float *)(v_data + 8);
240             coords[3] = *(const float *)(v_data + 12);
241             break;
242          case VK_FORMAT_R16G16_SFLOAT:
243             coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
244             coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
245             coords[2] = 0.0f;
246             coords[3] = 1.0f;
247             break;
248          case VK_FORMAT_R16G16B16_SFLOAT:
249             coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
250             coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
251             coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
252             coords[3] = 1.0f;
253             break;
254          case VK_FORMAT_R16G16B16A16_SFLOAT:
255             coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
256             coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
257             coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
258             coords[3] = _mesa_half_to_float(*(const uint16_t *)(v_data + 6));
259             break;
260          case VK_FORMAT_R16G16_SNORM:
261             coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
262             coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
263             coords[2] = 0.0f;
264             coords[3] = 1.0f;
265             break;
266          case VK_FORMAT_R16G16B16A16_SNORM:
267             coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
268             coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
269             coords[2] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 4), 16);
270             coords[3] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 6), 16);
271             break;
272          case VK_FORMAT_R16G16B16A16_UNORM:
273             coords[0] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 0), 16);
274             coords[1] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 2), 16);
275             coords[2] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 4), 16);
276             coords[3] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 6), 16);
277             break;
278          default:
279             unreachable("Unhandled vertex format in BVH build");
280          }
281 
282          for (unsigned j = 0; j < 3; ++j) {
283             float r = 0;
284             for (unsigned k = 0; k < 4; ++k)
285                r += matrix.matrix[j][k] * coords[k];
286             node->coords[v][j] = r;
287          }
288 
289          node->triangle_id = p;
290          node->geometry_id_and_flags = geometry_id | (geom->flags << 28);
291 
292          /* Seems to be needed for IJ, otherwise I = J = ? */
293          node->id = 9;
294       }
295    }
296 }
297 
298 static VkResult
build_instances(struct radv_device * device,struct radv_bvh_build_ctx * ctx,const VkAccelerationStructureGeometryKHR * geom,const VkAccelerationStructureBuildRangeInfoKHR * range)299 build_instances(struct radv_device *device, struct radv_bvh_build_ctx *ctx,
300                 const VkAccelerationStructureGeometryKHR *geom,
301                 const VkAccelerationStructureBuildRangeInfoKHR *range)
302 {
303    const VkAccelerationStructureGeometryInstancesDataKHR *inst_data = &geom->geometry.instances;
304 
305    for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 128) {
306       const VkAccelerationStructureInstanceKHR *instance =
307          inst_data->arrayOfPointers
308             ? (((const VkAccelerationStructureInstanceKHR *const *)inst_data->data.hostAddress)[p])
309             : &((const VkAccelerationStructureInstanceKHR *)inst_data->data.hostAddress)[p];
310       if (!instance->accelerationStructureReference) {
311          continue;
312       }
313 
314       struct radv_bvh_instance_node *node = (void*)ctx->curr_ptr;
315       uint32_t node_offset = ctx->curr_ptr - ctx->base;
316       uint32_t node_id = (node_offset >> 3) | 6;
317       *ctx->write_scratch++ = node_id;
318 
319       float transform[16], inv_transform[16];
320       memcpy(transform, &instance->transform.matrix, sizeof(instance->transform.matrix));
321       transform[12] = transform[13] = transform[14] = 0.0f;
322       transform[15] = 1.0f;
323 
324       util_invert_mat4x4(inv_transform, transform);
325       memcpy(node->wto_matrix, inv_transform, sizeof(node->wto_matrix));
326       node->wto_matrix[3] = transform[3];
327       node->wto_matrix[7] = transform[7];
328       node->wto_matrix[11] = transform[11];
329       node->custom_instance_and_mask = instance->instanceCustomIndex | (instance->mask << 24);
330       node->sbt_offset_and_flags =
331          instance->instanceShaderBindingTableRecordOffset | (instance->flags << 24);
332       node->instance_id = p;
333 
334       for (unsigned i = 0; i < 3; ++i)
335          for (unsigned j = 0; j < 3; ++j)
336             node->otw_matrix[i * 3 + j] = instance->transform.matrix[j][i];
337 
338       RADV_FROM_HANDLE(radv_acceleration_structure, src_accel_struct,
339                        (VkAccelerationStructureKHR)instance->accelerationStructureReference);
340       const void *src_base = device->ws->buffer_map(src_accel_struct->bo);
341       if (!src_base)
342          return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
343 
344       src_base = (const char *)src_base + src_accel_struct->mem_offset;
345       const struct radv_accel_struct_header *src_header = src_base;
346       node->base_ptr = radv_accel_struct_get_va(src_accel_struct) | src_header->root_node_offset;
347 
348       for (unsigned j = 0; j < 3; ++j) {
349          node->aabb[0][j] = instance->transform.matrix[j][3];
350          node->aabb[1][j] = instance->transform.matrix[j][3];
351          for (unsigned k = 0; k < 3; ++k) {
352             node->aabb[0][j] += MIN2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
353                                      instance->transform.matrix[j][k] * src_header->aabb[1][k]);
354             node->aabb[1][j] += MAX2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
355                                      instance->transform.matrix[j][k] * src_header->aabb[1][k]);
356          }
357       }
358       device->ws->buffer_unmap(src_accel_struct->bo);
359    }
360    return VK_SUCCESS;
361 }
362 
363 static void
build_aabbs(struct radv_bvh_build_ctx * ctx,const VkAccelerationStructureGeometryKHR * geom,const VkAccelerationStructureBuildRangeInfoKHR * range,unsigned geometry_id)364 build_aabbs(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
365             const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
366 {
367    const VkAccelerationStructureGeometryAabbsDataKHR *aabb_data = &geom->geometry.aabbs;
368 
369    for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
370       struct radv_bvh_aabb_node *node = (void*)ctx->curr_ptr;
371       uint32_t node_offset = ctx->curr_ptr - ctx->base;
372       uint32_t node_id = (node_offset >> 3) | 7;
373       *ctx->write_scratch++ = node_id;
374 
375       const VkAabbPositionsKHR *aabb =
376          (const VkAabbPositionsKHR *)((const char *)aabb_data->data.hostAddress +
377                                       p * aabb_data->stride);
378 
379       node->aabb[0][0] = aabb->minX;
380       node->aabb[0][1] = aabb->minY;
381       node->aabb[0][2] = aabb->minZ;
382       node->aabb[1][0] = aabb->maxX;
383       node->aabb[1][1] = aabb->maxY;
384       node->aabb[1][2] = aabb->maxZ;
385       node->primitive_id = p;
386       node->geometry_id_and_flags = geometry_id;
387    }
388 }
389 
390 static uint32_t
leaf_node_count(const VkAccelerationStructureBuildGeometryInfoKHR * info,const VkAccelerationStructureBuildRangeInfoKHR * ranges)391 leaf_node_count(const VkAccelerationStructureBuildGeometryInfoKHR *info,
392                 const VkAccelerationStructureBuildRangeInfoKHR *ranges)
393 {
394    uint32_t count = 0;
395    for (uint32_t i = 0; i < info->geometryCount; ++i) {
396       count += ranges[i].primitiveCount;
397    }
398    return count;
399 }
400 
401 static void
compute_bounds(const char * base_ptr,uint32_t node_id,float * bounds)402 compute_bounds(const char *base_ptr, uint32_t node_id, float *bounds)
403 {
404    for (unsigned i = 0; i < 3; ++i)
405       bounds[i] = INFINITY;
406    for (unsigned i = 0; i < 3; ++i)
407       bounds[3 + i] = -INFINITY;
408 
409    switch (node_id & 7) {
410    case 0: {
411       const struct radv_bvh_triangle_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
412       for (unsigned v = 0; v < 3; ++v) {
413          for (unsigned j = 0; j < 3; ++j) {
414             bounds[j] = MIN2(bounds[j], node->coords[v][j]);
415             bounds[3 + j] = MAX2(bounds[3 + j], node->coords[v][j]);
416          }
417       }
418       break;
419    }
420    case 5: {
421       const struct radv_bvh_box32_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
422       for (unsigned c2 = 0; c2 < 4; ++c2) {
423          if (isnan(node->coords[c2][0][0]))
424             continue;
425          for (unsigned j = 0; j < 3; ++j) {
426             bounds[j] = MIN2(bounds[j], node->coords[c2][0][j]);
427             bounds[3 + j] = MAX2(bounds[3 + j], node->coords[c2][1][j]);
428          }
429       }
430       break;
431    }
432    case 6: {
433       const struct radv_bvh_instance_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
434       for (unsigned j = 0; j < 3; ++j) {
435          bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
436          bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
437       }
438       break;
439    }
440    case 7: {
441       const struct radv_bvh_aabb_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
442       for (unsigned j = 0; j < 3; ++j) {
443          bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
444          bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
445       }
446       break;
447    }
448    }
449 }
450 
451 struct bvh_opt_entry {
452    uint64_t key;
453    uint32_t node_id;
454 };
455 
456 static int
bvh_opt_compare(const void * _a,const void * _b)457 bvh_opt_compare(const void *_a, const void *_b)
458 {
459    const struct bvh_opt_entry *a = _a;
460    const struct bvh_opt_entry *b = _b;
461 
462    if (a->key < b->key)
463       return -1;
464    if (a->key > b->key)
465       return 1;
466    if (a->node_id < b->node_id)
467       return -1;
468    if (a->node_id > b->node_id)
469       return 1;
470    return 0;
471 }
472 
473 static void
optimize_bvh(const char * base_ptr,uint32_t * node_ids,uint32_t node_count)474 optimize_bvh(const char *base_ptr, uint32_t *node_ids, uint32_t node_count)
475 {
476    float bounds[6];
477    for (unsigned i = 0; i < 3; ++i)
478       bounds[i] = INFINITY;
479    for (unsigned i = 0; i < 3; ++i)
480       bounds[3 + i] = -INFINITY;
481 
482    for (uint32_t i = 0; i < node_count; ++i) {
483       float node_bounds[6];
484       compute_bounds(base_ptr, node_ids[i], node_bounds);
485       for (unsigned j = 0; j < 3; ++j)
486          bounds[j] = MIN2(bounds[j], node_bounds[j]);
487       for (unsigned j = 0; j < 3; ++j)
488          bounds[3 + j] = MAX2(bounds[3 + j], node_bounds[3 + j]);
489    }
490 
491    struct bvh_opt_entry *entries = calloc(node_count, sizeof(struct bvh_opt_entry));
492    if (!entries)
493       return;
494 
495    for (uint32_t i = 0; i < node_count; ++i) {
496       float node_bounds[6];
497       compute_bounds(base_ptr, node_ids[i], node_bounds);
498       float node_coords[3];
499       for (unsigned j = 0; j < 3; ++j)
500          node_coords[j] = (node_bounds[j] + node_bounds[3 + j]) * 0.5;
501       int32_t coords[3];
502       for (unsigned j = 0; j < 3; ++j)
503          coords[j] = MAX2(
504             MIN2((int32_t)((node_coords[j] - bounds[j]) / (bounds[3 + j] - bounds[j]) * (1 << 21)),
505                  (1 << 21) - 1),
506             0);
507       uint64_t key = 0;
508       for (unsigned j = 0; j < 21; ++j)
509          for (unsigned k = 0; k < 3; ++k)
510             key |= (uint64_t)((coords[k] >> j) & 1) << (j * 3 + k);
511       entries[i].key = key;
512       entries[i].node_id = node_ids[i];
513    }
514 
515    qsort(entries, node_count, sizeof(entries[0]), bvh_opt_compare);
516    for (unsigned i = 0; i < node_count; ++i)
517       node_ids[i] = entries[i].node_id;
518 
519    free(entries);
520 }
521 
522 static VkResult
build_bvh(struct radv_device * device,const VkAccelerationStructureBuildGeometryInfoKHR * info,const VkAccelerationStructureBuildRangeInfoKHR * ranges)523 build_bvh(struct radv_device *device, const VkAccelerationStructureBuildGeometryInfoKHR *info,
524           const VkAccelerationStructureBuildRangeInfoKHR *ranges)
525 {
526    RADV_FROM_HANDLE(radv_acceleration_structure, accel, info->dstAccelerationStructure);
527    VkResult result = VK_SUCCESS;
528 
529    uint32_t *scratch[2];
530    scratch[0] = info->scratchData.hostAddress;
531    scratch[1] = scratch[0] + leaf_node_count(info, ranges);
532 
533    char *base_ptr = (char*)device->ws->buffer_map(accel->bo);
534    if (!base_ptr)
535       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
536 
537    base_ptr = base_ptr + accel->mem_offset;
538    struct radv_accel_struct_header *header = (void*)base_ptr;
539    void *first_node_ptr = (char *)base_ptr + ALIGN(sizeof(*header), 64);
540 
541    struct radv_bvh_build_ctx ctx = {.write_scratch = scratch[0],
542                                     .base = base_ptr,
543                                     .curr_ptr = (char *)first_node_ptr + 128};
544 
545    uint64_t instance_offset = (const char *)ctx.curr_ptr - (const char *)base_ptr;
546    uint64_t instance_count = 0;
547 
548    /* This initializes the leaf nodes of the BVH all at the same level. */
549    for (int inst = 1; inst >= 0; --inst) {
550       for (uint32_t i = 0; i < info->geometryCount; ++i) {
551          const VkAccelerationStructureGeometryKHR *geom =
552             info->pGeometries ? &info->pGeometries[i] : info->ppGeometries[i];
553 
554          if ((inst && geom->geometryType != VK_GEOMETRY_TYPE_INSTANCES_KHR) ||
555              (!inst && geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
556             continue;
557 
558          switch (geom->geometryType) {
559          case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
560             build_triangles(&ctx, geom, ranges + i, i);
561             break;
562          case VK_GEOMETRY_TYPE_AABBS_KHR:
563             build_aabbs(&ctx, geom, ranges + i, i);
564             break;
565          case VK_GEOMETRY_TYPE_INSTANCES_KHR: {
566             result = build_instances(device, &ctx, geom, ranges + i);
567             if (result != VK_SUCCESS)
568                goto fail;
569 
570             instance_count += ranges[i].primitiveCount;
571             break;
572          }
573          case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
574             unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
575          }
576       }
577    }
578 
579    uint32_t node_counts[2] = {ctx.write_scratch - scratch[0], 0};
580    optimize_bvh(base_ptr, scratch[0], node_counts[0]);
581    unsigned d;
582 
583    /*
584     * This is the most naive BVH building algorithm I could think of:
585     * just iteratively builds each level from bottom to top with
586     * the children of each node being in-order and tightly packed.
587     *
588     * Is probably terrible for traversal but should be easy to build an
589     * equivalent GPU version.
590     */
591    for (d = 0; node_counts[d & 1] > 1 || d == 0; ++d) {
592       uint32_t child_count = node_counts[d & 1];
593       const uint32_t *children = scratch[d & 1];
594       uint32_t *dst_ids = scratch[(d & 1) ^ 1];
595       unsigned dst_count;
596       unsigned child_idx = 0;
597       for (dst_count = 0; child_idx < MAX2(1, child_count); ++dst_count, child_idx += 4) {
598          unsigned local_child_count = MIN2(4, child_count - child_idx);
599          uint32_t child_ids[4];
600          float bounds[4][6];
601 
602          for (unsigned c = 0; c < local_child_count; ++c) {
603             uint32_t id = children[child_idx + c];
604             child_ids[c] = id;
605 
606             compute_bounds(base_ptr, id, bounds[c]);
607          }
608 
609          struct radv_bvh_box32_node *node;
610 
611          /* Put the root node at base_ptr so the id = 0, which allows some
612           * traversal optimizations. */
613          if (child_idx == 0 && local_child_count == child_count) {
614             node = first_node_ptr;
615             header->root_node_offset = ((char *)first_node_ptr - (char *)base_ptr) / 64 * 8 + 5;
616          } else {
617             uint32_t dst_id = (ctx.curr_ptr - base_ptr) / 64;
618             dst_ids[dst_count] = dst_id * 8 + 5;
619 
620             node = (void*)ctx.curr_ptr;
621             ctx.curr_ptr += 128;
622          }
623 
624          for (unsigned c = 0; c < local_child_count; ++c) {
625             node->children[c] = child_ids[c];
626             for (unsigned i = 0; i < 2; ++i)
627                for (unsigned j = 0; j < 3; ++j)
628                   node->coords[c][i][j] = bounds[c][i * 3 + j];
629          }
630          for (unsigned c = local_child_count; c < 4; ++c) {
631             for (unsigned i = 0; i < 2; ++i)
632                for (unsigned j = 0; j < 3; ++j)
633                   node->coords[c][i][j] = NAN;
634          }
635       }
636 
637       node_counts[(d & 1) ^ 1] = dst_count;
638    }
639 
640    compute_bounds(base_ptr, header->root_node_offset, &header->aabb[0][0]);
641 
642    header->instance_offset = instance_offset;
643    header->instance_count = instance_count;
644    header->compacted_size = (char *)ctx.curr_ptr - base_ptr;
645 
646    /* 16 bytes per invocation, 64 invocations per workgroup */
647    header->copy_dispatch_size[0] = DIV_ROUND_UP(header->compacted_size, 16 * 64);
648    header->copy_dispatch_size[1] = 1;
649    header->copy_dispatch_size[2] = 1;
650 
651    header->serialization_size =
652       header->compacted_size + align(sizeof(struct radv_accel_struct_serialization_header) +
653                                         sizeof(uint64_t) * header->instance_count,
654                                      128);
655 
656 fail:
657    device->ws->buffer_unmap(accel->bo);
658    return result;
659 }
660 
661 VkResult
radv_BuildAccelerationStructuresKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkAccelerationStructureBuildRangeInfoKHR * const * ppBuildRangeInfos)662 radv_BuildAccelerationStructuresKHR(
663    VkDevice _device, VkDeferredOperationKHR deferredOperation, uint32_t infoCount,
664    const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
665    const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
666 {
667    RADV_FROM_HANDLE(radv_device, device, _device);
668    VkResult result = VK_SUCCESS;
669 
670    for (uint32_t i = 0; i < infoCount; ++i) {
671       result = build_bvh(device, pInfos + i, ppBuildRangeInfos[i]);
672       if (result != VK_SUCCESS)
673          break;
674    }
675    return result;
676 }
677 
678 VkResult
radv_CopyAccelerationStructureKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,const VkCopyAccelerationStructureInfoKHR * pInfo)679 radv_CopyAccelerationStructureKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
680                                   const VkCopyAccelerationStructureInfoKHR *pInfo)
681 {
682    RADV_FROM_HANDLE(radv_device, device, _device);
683    RADV_FROM_HANDLE(radv_acceleration_structure, src_struct, pInfo->src);
684    RADV_FROM_HANDLE(radv_acceleration_structure, dst_struct, pInfo->dst);
685 
686    char *src_ptr = (char *)device->ws->buffer_map(src_struct->bo);
687    if (!src_ptr)
688       return VK_ERROR_OUT_OF_HOST_MEMORY;
689 
690    char *dst_ptr = (char *)device->ws->buffer_map(dst_struct->bo);
691    if (!dst_ptr) {
692       device->ws->buffer_unmap(src_struct->bo);
693       return VK_ERROR_OUT_OF_HOST_MEMORY;
694    }
695 
696    src_ptr += src_struct->mem_offset;
697    dst_ptr += dst_struct->mem_offset;
698 
699    const struct radv_accel_struct_header *header = (const void *)src_ptr;
700    memcpy(dst_ptr, src_ptr, header->compacted_size);
701 
702    device->ws->buffer_unmap(src_struct->bo);
703    device->ws->buffer_unmap(dst_struct->bo);
704    return VK_SUCCESS;
705 }
706 
707 static nir_ssa_def *
get_indices(nir_builder * b,nir_ssa_def * addr,nir_ssa_def * type,nir_ssa_def * id)708 get_indices(nir_builder *b, nir_ssa_def *addr, nir_ssa_def *type, nir_ssa_def *id)
709 {
710    const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3);
711    nir_variable *result =
712       nir_variable_create(b->shader, nir_var_shader_temp, uvec3_type, "indices");
713 
714    nir_push_if(b, nir_ult(b, type, nir_imm_int(b, 2)));
715    nir_push_if(b, nir_ieq(b, type, nir_imm_int(b, VK_INDEX_TYPE_UINT16)));
716    {
717       nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 6));
718       nir_ssa_def *indices[3];
719       for (unsigned i = 0; i < 3; ++i) {
720          indices[i] = nir_build_load_global(
721             b, 1, 16, nir_iadd(b, addr, nir_u2u64(b, nir_iadd(b, index_id, nir_imm_int(b, 2 * i)))),
722             .align_mul = 2, .align_offset = 0);
723       }
724       nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
725    }
726    nir_push_else(b, NULL);
727    {
728       nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 12));
729       nir_ssa_def *indices = nir_build_load_global(
730          b, 3, 32, nir_iadd(b, addr, nir_u2u64(b, index_id)), .align_mul = 4, .align_offset = 0);
731       nir_store_var(b, result, indices, 7);
732    }
733    nir_pop_if(b, NULL);
734    nir_push_else(b, NULL);
735    {
736       nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 3));
737       nir_ssa_def *indices[] = {
738          index_id,
739          nir_iadd(b, index_id, nir_imm_int(b, 1)),
740          nir_iadd(b, index_id, nir_imm_int(b, 2)),
741       };
742 
743       nir_push_if(b, nir_ieq(b, type, nir_imm_int(b, VK_INDEX_TYPE_NONE_KHR)));
744       {
745          nir_store_var(b, result, nir_vec(b, indices, 3), 7);
746       }
747       nir_push_else(b, NULL);
748       {
749          for (unsigned i = 0; i < 3; ++i) {
750             indices[i] = nir_build_load_global(b, 1, 8, nir_iadd(b, addr, nir_u2u64(b, indices[i])),
751                                                .align_mul = 1, .align_offset = 0);
752          }
753          nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
754       }
755       nir_pop_if(b, NULL);
756    }
757    nir_pop_if(b, NULL);
758    return nir_load_var(b, result);
759 }
760 
761 static void
get_vertices(nir_builder * b,nir_ssa_def * addresses,nir_ssa_def * format,nir_ssa_def * positions[3])762 get_vertices(nir_builder *b, nir_ssa_def *addresses, nir_ssa_def *format, nir_ssa_def *positions[3])
763 {
764    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
765    nir_variable *results[3] = {
766       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex0"),
767       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex1"),
768       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex2")};
769 
770    VkFormat formats[] = {
771       VK_FORMAT_R32G32B32_SFLOAT,    VK_FORMAT_R32G32B32A32_SFLOAT, VK_FORMAT_R16G16B16_SFLOAT,
772       VK_FORMAT_R16G16B16A16_SFLOAT, VK_FORMAT_R16G16_SFLOAT,       VK_FORMAT_R32G32_SFLOAT,
773       VK_FORMAT_R16G16_SNORM,        VK_FORMAT_R16G16B16A16_SNORM,  VK_FORMAT_R16G16B16A16_UNORM,
774    };
775 
776    for (unsigned f = 0; f < ARRAY_SIZE(formats); ++f) {
777       if (f + 1 < ARRAY_SIZE(formats))
778          nir_push_if(b, nir_ieq(b, format, nir_imm_int(b, formats[f])));
779 
780       for (unsigned i = 0; i < 3; ++i) {
781          switch (formats[f]) {
782          case VK_FORMAT_R32G32B32_SFLOAT:
783          case VK_FORMAT_R32G32B32A32_SFLOAT:
784             nir_store_var(b, results[i],
785                           nir_build_load_global(b, 3, 32, nir_channel(b, addresses, i),
786                                                 .align_mul = 4, .align_offset = 0),
787                           7);
788             break;
789          case VK_FORMAT_R32G32_SFLOAT:
790          case VK_FORMAT_R16G16_SFLOAT:
791          case VK_FORMAT_R16G16B16_SFLOAT:
792          case VK_FORMAT_R16G16B16A16_SFLOAT:
793          case VK_FORMAT_R16G16_SNORM:
794          case VK_FORMAT_R16G16B16A16_SNORM:
795          case VK_FORMAT_R16G16B16A16_UNORM: {
796             unsigned components = MIN2(3, vk_format_get_nr_components(formats[f]));
797             unsigned comp_bits =
798                vk_format_get_blocksizebits(formats[f]) / vk_format_get_nr_components(formats[f]);
799             unsigned comp_bytes = comp_bits / 8;
800             nir_ssa_def *values[3];
801             nir_ssa_def *addr = nir_channel(b, addresses, i);
802             for (unsigned j = 0; j < components; ++j)
803                values[j] = nir_build_load_global(
804                   b, 1, comp_bits, nir_iadd(b, addr, nir_imm_int64(b, j * comp_bytes)),
805                   .align_mul = comp_bytes, .align_offset = 0);
806 
807             for (unsigned j = components; j < 3; ++j)
808                values[j] = nir_imm_intN_t(b, 0, comp_bits);
809 
810             nir_ssa_def *vec;
811             if (util_format_is_snorm(vk_format_to_pipe_format(formats[f]))) {
812                for (unsigned j = 0; j < 3; ++j) {
813                   values[j] = nir_fdiv(b, nir_i2f32(b, values[j]),
814                                        nir_imm_float(b, (1u << (comp_bits - 1)) - 1));
815                   values[j] = nir_fmax(b, values[j], nir_imm_float(b, -1.0));
816                }
817                vec = nir_vec(b, values, 3);
818             } else if (util_format_is_unorm(vk_format_to_pipe_format(formats[f]))) {
819                for (unsigned j = 0; j < 3; ++j) {
820                   values[j] =
821                      nir_fdiv(b, nir_u2f32(b, values[j]), nir_imm_float(b, (1u << comp_bits) - 1));
822                   values[j] = nir_fmin(b, values[j], nir_imm_float(b, 1.0));
823                }
824                vec = nir_vec(b, values, 3);
825             } else if (comp_bits == 16)
826                vec = nir_f2f32(b, nir_vec(b, values, 3));
827             else
828                vec = nir_vec(b, values, 3);
829             nir_store_var(b, results[i], vec, 7);
830             break;
831          }
832          default:
833             unreachable("Unhandled format");
834          }
835       }
836       if (f + 1 < ARRAY_SIZE(formats))
837          nir_push_else(b, NULL);
838    }
839    for (unsigned f = 1; f < ARRAY_SIZE(formats); ++f) {
840       nir_pop_if(b, NULL);
841    }
842 
843    for (unsigned i = 0; i < 3; ++i)
844       positions[i] = nir_load_var(b, results[i]);
845 }
846 
847 struct build_primitive_constants {
848    uint64_t node_dst_addr;
849    uint64_t scratch_addr;
850    uint32_t dst_offset;
851    uint32_t dst_scratch_offset;
852    uint32_t geometry_type;
853    uint32_t geometry_id;
854 
855    union {
856       struct {
857          uint64_t vertex_addr;
858          uint64_t index_addr;
859          uint64_t transform_addr;
860          uint32_t vertex_stride;
861          uint32_t vertex_format;
862          uint32_t index_format;
863       };
864       struct {
865          uint64_t instance_data;
866          uint32_t array_of_pointers;
867       };
868       struct {
869          uint64_t aabb_addr;
870          uint32_t aabb_stride;
871       };
872    };
873 };
874 
875 struct build_internal_constants {
876    uint64_t node_dst_addr;
877    uint64_t scratch_addr;
878    uint32_t dst_offset;
879    uint32_t dst_scratch_offset;
880    uint32_t src_scratch_offset;
881    uint32_t fill_header;
882 };
883 
884 /* This inverts a 3x3 matrix using cofactors, as in e.g.
885  * https://www.mathsisfun.com/algebra/matrix-inverse-minors-cofactors-adjugate.html */
886 static void
nir_invert_3x3(nir_builder * b,nir_ssa_def * in[3][3],nir_ssa_def * out[3][3])887 nir_invert_3x3(nir_builder *b, nir_ssa_def *in[3][3], nir_ssa_def *out[3][3])
888 {
889    nir_ssa_def *cofactors[3][3];
890    for (unsigned i = 0; i < 3; ++i) {
891       for (unsigned j = 0; j < 3; ++j) {
892          cofactors[i][j] =
893             nir_fsub(b, nir_fmul(b, in[(i + 1) % 3][(j + 1) % 3], in[(i + 2) % 3][(j + 2) % 3]),
894                      nir_fmul(b, in[(i + 1) % 3][(j + 2) % 3], in[(i + 2) % 3][(j + 1) % 3]));
895       }
896    }
897 
898    nir_ssa_def *det = NULL;
899    for (unsigned i = 0; i < 3; ++i) {
900       nir_ssa_def *det_part = nir_fmul(b, in[0][i], cofactors[0][i]);
901       det = det ? nir_fadd(b, det, det_part) : det_part;
902    }
903 
904    nir_ssa_def *det_inv = nir_frcp(b, det);
905    for (unsigned i = 0; i < 3; ++i) {
906       for (unsigned j = 0; j < 3; ++j) {
907          out[i][j] = nir_fmul(b, cofactors[j][i], det_inv);
908       }
909    }
910 }
911 
912 static nir_shader *
build_leaf_shader(struct radv_device * dev)913 build_leaf_shader(struct radv_device *dev)
914 {
915    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
916    nir_builder b =
917       nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "accel_build_leaf_shader");
918 
919    b.shader->info.workgroup_size[0] = 64;
920    b.shader->info.workgroup_size[1] = 1;
921    b.shader->info.workgroup_size[2] = 1;
922 
923    nir_ssa_def *pconst0 =
924       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
925    nir_ssa_def *pconst1 =
926       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
927    nir_ssa_def *pconst2 =
928       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 32, .range = 16);
929    nir_ssa_def *pconst3 =
930       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 48, .range = 16);
931    nir_ssa_def *pconst4 =
932       nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 64, .range = 4);
933 
934    nir_ssa_def *geom_type = nir_channel(&b, pconst1, 2);
935    nir_ssa_def *node_dst_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
936    nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 12));
937    nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
938    nir_ssa_def *scratch_offset = nir_channel(&b, pconst1, 1);
939    nir_ssa_def *geometry_id = nir_channel(&b, pconst1, 3);
940 
941    nir_ssa_def *global_id =
942       nir_iadd(&b,
943                nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
944                           nir_imm_int(&b, b.shader->info.workgroup_size[0])),
945                nir_channels(&b, nir_load_local_invocation_id(&b), 1));
946    scratch_addr = nir_iadd(
947       &b, scratch_addr,
948       nir_u2u64(&b, nir_iadd(&b, scratch_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 4)))));
949 
950    nir_push_if(&b, nir_ieq(&b, geom_type, nir_imm_int(&b, VK_GEOMETRY_TYPE_TRIANGLES_KHR)));
951    { /* Triangles */
952       nir_ssa_def *vertex_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3));
953       nir_ssa_def *index_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 12));
954       nir_ssa_def *transform_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst3, 3));
955       nir_ssa_def *vertex_stride = nir_channel(&b, pconst3, 2);
956       nir_ssa_def *vertex_format = nir_channel(&b, pconst3, 3);
957       nir_ssa_def *index_format = nir_channel(&b, pconst4, 0);
958       unsigned repl_swizzle[4] = {0, 0, 0, 0};
959 
960       nir_ssa_def *node_offset =
961          nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
962       nir_ssa_def *triangle_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
963 
964       nir_ssa_def *indices = get_indices(&b, index_addr, index_format, global_id);
965       nir_ssa_def *vertex_addresses = nir_iadd(
966          &b, nir_u2u64(&b, nir_imul(&b, indices, nir_swizzle(&b, vertex_stride, repl_swizzle, 3))),
967          nir_swizzle(&b, vertex_addr, repl_swizzle, 3));
968       nir_ssa_def *positions[3];
969       get_vertices(&b, vertex_addresses, vertex_format, positions);
970 
971       nir_ssa_def *node_data[16];
972       memset(node_data, 0, sizeof(node_data));
973 
974       nir_variable *transform[] = {
975          nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform0"),
976          nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform1"),
977          nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform2"),
978       };
979       nir_store_var(&b, transform[0], nir_imm_vec4(&b, 1.0, 0.0, 0.0, 0.0), 0xf);
980       nir_store_var(&b, transform[1], nir_imm_vec4(&b, 0.0, 1.0, 0.0, 0.0), 0xf);
981       nir_store_var(&b, transform[2], nir_imm_vec4(&b, 0.0, 0.0, 1.0, 0.0), 0xf);
982 
983       nir_push_if(&b, nir_ine(&b, transform_addr, nir_imm_int64(&b, 0)));
984       nir_store_var(
985          &b, transform[0],
986          nir_build_load_global(&b, 4, 32, nir_iadd(&b, transform_addr, nir_imm_int64(&b, 0)),
987                                .align_mul = 4, .align_offset = 0),
988          0xf);
989       nir_store_var(
990          &b, transform[1],
991          nir_build_load_global(&b, 4, 32, nir_iadd(&b, transform_addr, nir_imm_int64(&b, 16)),
992                                .align_mul = 4, .align_offset = 0),
993          0xf);
994       nir_store_var(
995          &b, transform[2],
996          nir_build_load_global(&b, 4, 32, nir_iadd(&b, transform_addr, nir_imm_int64(&b, 32)),
997                                .align_mul = 4, .align_offset = 0),
998          0xf);
999       nir_pop_if(&b, NULL);
1000 
1001       for (unsigned i = 0; i < 3; ++i)
1002          for (unsigned j = 0; j < 3; ++j)
1003             node_data[i * 3 + j] = nir_fdph(&b, positions[i], nir_load_var(&b, transform[j]));
1004 
1005       node_data[12] = global_id;
1006       node_data[13] = geometry_id;
1007       node_data[15] = nir_imm_int(&b, 9);
1008       for (unsigned i = 0; i < ARRAY_SIZE(node_data); ++i)
1009          if (!node_data[i])
1010             node_data[i] = nir_imm_int(&b, 0);
1011 
1012       for (unsigned i = 0; i < 4; ++i) {
1013          nir_build_store_global(&b, nir_vec(&b, node_data + i * 4, 4),
1014                                 nir_iadd(&b, triangle_node_dst_addr, nir_imm_int64(&b, i * 16)),
1015                                 .write_mask = 15, .align_mul = 16, .align_offset = 0);
1016       }
1017 
1018       nir_ssa_def *node_id = nir_ushr(&b, node_offset, nir_imm_int(&b, 3));
1019       nir_build_store_global(&b, node_id, scratch_addr, .write_mask = 1, .align_mul = 4,
1020                              .align_offset = 0);
1021    }
1022    nir_push_else(&b, NULL);
1023    nir_push_if(&b, nir_ieq(&b, geom_type, nir_imm_int(&b, VK_GEOMETRY_TYPE_AABBS_KHR)));
1024    { /* AABBs */
1025       nir_ssa_def *aabb_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3));
1026       nir_ssa_def *aabb_stride = nir_channel(&b, pconst2, 2);
1027 
1028       nir_ssa_def *node_offset =
1029          nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
1030       nir_ssa_def *aabb_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1031       nir_ssa_def *node_id =
1032          nir_iadd(&b, nir_ushr(&b, node_offset, nir_imm_int(&b, 3)), nir_imm_int(&b, 7));
1033       nir_build_store_global(&b, node_id, scratch_addr, .write_mask = 1, .align_mul = 4,
1034                              .align_offset = 0);
1035 
1036       aabb_addr = nir_iadd(&b, aabb_addr, nir_u2u64(&b, nir_imul(&b, aabb_stride, global_id)));
1037 
1038       nir_ssa_def *min_bound =
1039          nir_build_load_global(&b, 3, 32, nir_iadd(&b, aabb_addr, nir_imm_int64(&b, 0)),
1040                                .align_mul = 4, .align_offset = 0);
1041       nir_ssa_def *max_bound =
1042          nir_build_load_global(&b, 3, 32, nir_iadd(&b, aabb_addr, nir_imm_int64(&b, 12)),
1043                                .align_mul = 4, .align_offset = 0);
1044 
1045       nir_ssa_def *values[] = {nir_channel(&b, min_bound, 0),
1046                                nir_channel(&b, min_bound, 1),
1047                                nir_channel(&b, min_bound, 2),
1048                                nir_channel(&b, max_bound, 0),
1049                                nir_channel(&b, max_bound, 1),
1050                                nir_channel(&b, max_bound, 2),
1051                                global_id,
1052                                geometry_id};
1053 
1054       nir_build_store_global(&b, nir_vec(&b, values + 0, 4),
1055                              nir_iadd(&b, aabb_node_dst_addr, nir_imm_int64(&b, 0)),
1056                              .write_mask = 15, .align_mul = 16, .align_offset = 0);
1057       nir_build_store_global(&b, nir_vec(&b, values + 4, 4),
1058                              nir_iadd(&b, aabb_node_dst_addr, nir_imm_int64(&b, 16)),
1059                              .write_mask = 15, .align_mul = 16, .align_offset = 0);
1060    }
1061    nir_push_else(&b, NULL);
1062    { /* Instances */
1063 
1064       nir_variable *instance_addr_var =
1065          nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1066       nir_push_if(&b, nir_ine(&b, nir_channel(&b, pconst2, 2), nir_imm_int(&b, 0)));
1067       {
1068          nir_ssa_def *ptr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3)),
1069                                      nir_u2u64(&b, nir_imul(&b, global_id, nir_imm_int(&b, 8))));
1070          nir_ssa_def *addr = nir_pack_64_2x32(
1071             &b, nir_build_load_global(&b, 2, 32, ptr, .align_mul = 8, .align_offset = 0));
1072          nir_store_var(&b, instance_addr_var, addr, 1);
1073       }
1074       nir_push_else(&b, NULL);
1075       {
1076          nir_ssa_def *addr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3)),
1077                                       nir_u2u64(&b, nir_imul(&b, global_id, nir_imm_int(&b, 64))));
1078          nir_store_var(&b, instance_addr_var, addr, 1);
1079       }
1080       nir_pop_if(&b, NULL);
1081       nir_ssa_def *instance_addr = nir_load_var(&b, instance_addr_var);
1082 
1083       nir_ssa_def *inst_transform[] = {
1084          nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 0)),
1085                                .align_mul = 4, .align_offset = 0),
1086          nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 16)),
1087                                .align_mul = 4, .align_offset = 0),
1088          nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 32)),
1089                                .align_mul = 4, .align_offset = 0)};
1090       nir_ssa_def *inst3 =
1091          nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 48)),
1092                                .align_mul = 4, .align_offset = 0);
1093 
1094       nir_ssa_def *node_offset =
1095          nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 128)));
1096       node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1097       nir_ssa_def *node_id =
1098          nir_iadd(&b, nir_ushr(&b, node_offset, nir_imm_int(&b, 3)), nir_imm_int(&b, 6));
1099       nir_build_store_global(&b, node_id, scratch_addr, .write_mask = 1, .align_mul = 4,
1100                              .align_offset = 0);
1101 
1102       nir_variable *bounds[2] = {
1103          nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1104          nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1105       };
1106 
1107       nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1108       nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1109 
1110       nir_ssa_def *header_addr = nir_pack_64_2x32(&b, nir_channels(&b, inst3, 12));
1111       nir_push_if(&b, nir_ine(&b, header_addr, nir_imm_int64(&b, 0)));
1112       nir_ssa_def *header_root_offset =
1113          nir_build_load_global(&b, 1, 32, nir_iadd(&b, header_addr, nir_imm_int64(&b, 0)),
1114                                .align_mul = 4, .align_offset = 0);
1115       nir_ssa_def *header_min =
1116          nir_build_load_global(&b, 3, 32, nir_iadd(&b, header_addr, nir_imm_int64(&b, 8)),
1117                                .align_mul = 4, .align_offset = 0);
1118       nir_ssa_def *header_max =
1119          nir_build_load_global(&b, 3, 32, nir_iadd(&b, header_addr, nir_imm_int64(&b, 20)),
1120                                .align_mul = 4, .align_offset = 0);
1121 
1122       nir_ssa_def *bound_defs[2][3];
1123       for (unsigned i = 0; i < 3; ++i) {
1124          bound_defs[0][i] = bound_defs[1][i] = nir_channel(&b, inst_transform[i], 3);
1125 
1126          nir_ssa_def *mul_a = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_min);
1127          nir_ssa_def *mul_b = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_max);
1128          nir_ssa_def *mi = nir_fmin(&b, mul_a, mul_b);
1129          nir_ssa_def *ma = nir_fmax(&b, mul_a, mul_b);
1130          for (unsigned j = 0; j < 3; ++j) {
1131             bound_defs[0][i] = nir_fadd(&b, bound_defs[0][i], nir_channel(&b, mi, j));
1132             bound_defs[1][i] = nir_fadd(&b, bound_defs[1][i], nir_channel(&b, ma, j));
1133          }
1134       }
1135 
1136       nir_store_var(&b, bounds[0], nir_vec(&b, bound_defs[0], 3), 7);
1137       nir_store_var(&b, bounds[1], nir_vec(&b, bound_defs[1], 3), 7);
1138 
1139       /* Store object to world matrix */
1140       for (unsigned i = 0; i < 3; ++i) {
1141          nir_ssa_def *vals[3];
1142          for (unsigned j = 0; j < 3; ++j)
1143             vals[j] = nir_channel(&b, inst_transform[j], i);
1144 
1145          nir_build_store_global(&b, nir_vec(&b, vals, 3),
1146                                 nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 92 + 12 * i)),
1147                                 .write_mask = 0x7, .align_mul = 4, .align_offset = 0);
1148       }
1149 
1150       nir_ssa_def *m_in[3][3], *m_out[3][3], *m_vec[3][4];
1151       for (unsigned i = 0; i < 3; ++i)
1152          for (unsigned j = 0; j < 3; ++j)
1153             m_in[i][j] = nir_channel(&b, inst_transform[i], j);
1154       nir_invert_3x3(&b, m_in, m_out);
1155       for (unsigned i = 0; i < 3; ++i) {
1156          for (unsigned j = 0; j < 3; ++j)
1157             m_vec[i][j] = m_out[i][j];
1158          m_vec[i][3] = nir_channel(&b, inst_transform[i], 3);
1159       }
1160 
1161       for (unsigned i = 0; i < 3; ++i) {
1162          nir_build_store_global(&b, nir_vec(&b, m_vec[i], 4),
1163                                 nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 16 + 16 * i)),
1164                                 .write_mask = 0xf, .align_mul = 4, .align_offset = 0);
1165       }
1166 
1167       nir_ssa_def *out0[4] = {
1168          nir_ior(&b, nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 0), header_root_offset),
1169          nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 1), nir_channel(&b, inst3, 0),
1170          nir_channel(&b, inst3, 1)};
1171       nir_build_store_global(&b, nir_vec(&b, out0, 4),
1172                              nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 0)), .write_mask = 0xf,
1173                              .align_mul = 4, .align_offset = 0);
1174       nir_build_store_global(&b, global_id, nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 88)),
1175                              .write_mask = 0x1, .align_mul = 4, .align_offset = 0);
1176       nir_pop_if(&b, NULL);
1177       nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
1178                              nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 64)), .write_mask = 0x7,
1179                              .align_mul = 4, .align_offset = 0);
1180       nir_build_store_global(&b, nir_load_var(&b, bounds[1]),
1181                              nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 76)), .write_mask = 0x7,
1182                              .align_mul = 4, .align_offset = 0);
1183    }
1184    nir_pop_if(&b, NULL);
1185    nir_pop_if(&b, NULL);
1186 
1187    return b.shader;
1188 }
1189 
1190 static void
determine_bounds(nir_builder * b,nir_ssa_def * node_addr,nir_ssa_def * node_id,nir_variable * bounds_vars[2])1191 determine_bounds(nir_builder *b, nir_ssa_def *node_addr, nir_ssa_def *node_id,
1192                  nir_variable *bounds_vars[2])
1193 {
1194    nir_ssa_def *node_type = nir_iand(b, node_id, nir_imm_int(b, 7));
1195    node_addr = nir_iadd(
1196       b, node_addr,
1197       nir_u2u64(b, nir_ishl(b, nir_iand(b, node_id, nir_imm_int(b, ~7u)), nir_imm_int(b, 3))));
1198 
1199    nir_push_if(b, nir_ieq(b, node_type, nir_imm_int(b, 0)));
1200    {
1201       nir_ssa_def *positions[3];
1202       for (unsigned i = 0; i < 3; ++i)
1203          positions[i] =
1204             nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, i * 12)),
1205                                   .align_mul = 4, .align_offset = 0);
1206       nir_ssa_def *bounds[] = {positions[0], positions[0]};
1207       for (unsigned i = 1; i < 3; ++i) {
1208          bounds[0] = nir_fmin(b, bounds[0], positions[i]);
1209          bounds[1] = nir_fmax(b, bounds[1], positions[i]);
1210       }
1211       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1212       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1213    }
1214    nir_push_else(b, NULL);
1215    nir_push_if(b, nir_ieq(b, node_type, nir_imm_int(b, 5)));
1216    {
1217       nir_ssa_def *input_bounds[4][2];
1218       for (unsigned i = 0; i < 4; ++i)
1219          for (unsigned j = 0; j < 2; ++j)
1220             input_bounds[i][j] = nir_build_load_global(
1221                b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 16 + i * 24 + j * 12)),
1222                .align_mul = 4, .align_offset = 0);
1223       nir_ssa_def *bounds[] = {input_bounds[0][0], input_bounds[0][1]};
1224       for (unsigned i = 1; i < 4; ++i) {
1225          bounds[0] = nir_fmin(b, bounds[0], input_bounds[i][0]);
1226          bounds[1] = nir_fmax(b, bounds[1], input_bounds[i][1]);
1227       }
1228 
1229       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1230       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1231    }
1232    nir_push_else(b, NULL);
1233    nir_push_if(b, nir_ieq(b, node_type, nir_imm_int(b, 6)));
1234    { /* Instances */
1235       nir_ssa_def *bounds[2];
1236       for (unsigned i = 0; i < 2; ++i)
1237          bounds[i] =
1238             nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 64 + i * 12)),
1239                                   .align_mul = 4, .align_offset = 0);
1240       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1241       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1242    }
1243    nir_push_else(b, NULL);
1244    { /* AABBs */
1245       nir_ssa_def *bounds[2];
1246       for (unsigned i = 0; i < 2; ++i)
1247          bounds[i] =
1248             nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, i * 12)),
1249                                   .align_mul = 4, .align_offset = 0);
1250       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1251       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1252    }
1253    nir_pop_if(b, NULL);
1254    nir_pop_if(b, NULL);
1255    nir_pop_if(b, NULL);
1256 }
1257 
1258 static nir_shader *
build_internal_shader(struct radv_device * dev)1259 build_internal_shader(struct radv_device *dev)
1260 {
1261    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1262    nir_builder b =
1263       nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "accel_build_internal_shader");
1264 
1265    b.shader->info.workgroup_size[0] = 64;
1266    b.shader->info.workgroup_size[1] = 1;
1267    b.shader->info.workgroup_size[2] = 1;
1268 
1269    /*
1270     * push constants:
1271     *   i32 x 2: node dst address
1272     *   i32 x 2: scratch address
1273     *   i32: dst offset
1274     *   i32: dst scratch offset
1275     *   i32: src scratch offset
1276     *   i32: src_node_count | (fill_header << 31)
1277     */
1278    nir_ssa_def *pconst0 =
1279       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1280    nir_ssa_def *pconst1 =
1281       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
1282 
1283    nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
1284    nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 12));
1285    nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
1286    nir_ssa_def *dst_scratch_offset = nir_channel(&b, pconst1, 1);
1287    nir_ssa_def *src_scratch_offset = nir_channel(&b, pconst1, 2);
1288    nir_ssa_def *src_node_count =
1289       nir_iand(&b, nir_channel(&b, pconst1, 3), nir_imm_int(&b, 0x7FFFFFFFU));
1290    nir_ssa_def *fill_header =
1291       nir_ine(&b, nir_iand(&b, nir_channel(&b, pconst1, 3), nir_imm_int(&b, 0x80000000U)),
1292               nir_imm_int(&b, 0));
1293 
1294    nir_ssa_def *global_id =
1295       nir_iadd(&b,
1296                nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
1297                           nir_imm_int(&b, b.shader->info.workgroup_size[0])),
1298                nir_channels(&b, nir_load_local_invocation_id(&b), 1));
1299    nir_ssa_def *src_idx = nir_imul(&b, global_id, nir_imm_int(&b, 4));
1300    nir_ssa_def *src_count = nir_umin(&b, nir_imm_int(&b, 4), nir_isub(&b, src_node_count, src_idx));
1301 
1302    nir_ssa_def *node_offset =
1303       nir_iadd(&b, node_dst_offset, nir_ishl(&b, global_id, nir_imm_int(&b, 7)));
1304    nir_ssa_def *node_dst_addr = nir_iadd(&b, node_addr, nir_u2u64(&b, node_offset));
1305    nir_ssa_def *src_nodes = nir_build_load_global(
1306       &b, 4, 32,
1307       nir_iadd(&b, scratch_addr,
1308                nir_u2u64(&b, nir_iadd(&b, src_scratch_offset,
1309                                       nir_ishl(&b, global_id, nir_imm_int(&b, 4))))),
1310       .align_mul = 4, .align_offset = 0);
1311 
1312    nir_build_store_global(&b, src_nodes, nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 0)),
1313                           .write_mask = 0xf, .align_mul = 4, .align_offset = 0);
1314 
1315    nir_ssa_def *total_bounds[2] = {
1316       nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1317       nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1318    };
1319 
1320    for (unsigned i = 0; i < 4; ++i) {
1321       nir_variable *bounds[2] = {
1322          nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1323          nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1324       };
1325       nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1326       nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1327 
1328       nir_push_if(&b, nir_ilt(&b, nir_imm_int(&b, i), src_count));
1329       determine_bounds(&b, node_addr, nir_channel(&b, src_nodes, i), bounds);
1330       nir_pop_if(&b, NULL);
1331       nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
1332                              nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 16 + 24 * i)),
1333                              .write_mask = 0x7, .align_mul = 4, .align_offset = 0);
1334       nir_build_store_global(&b, nir_load_var(&b, bounds[1]),
1335                              nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 28 + 24 * i)),
1336                              .write_mask = 0x7, .align_mul = 4, .align_offset = 0);
1337       total_bounds[0] = nir_fmin(&b, total_bounds[0], nir_load_var(&b, bounds[0]));
1338       total_bounds[1] = nir_fmax(&b, total_bounds[1], nir_load_var(&b, bounds[1]));
1339    }
1340 
1341    nir_ssa_def *node_id =
1342       nir_iadd(&b, nir_ushr(&b, node_offset, nir_imm_int(&b, 3)), nir_imm_int(&b, 5));
1343    nir_ssa_def *dst_scratch_addr = nir_iadd(
1344       &b, scratch_addr,
1345       nir_u2u64(&b, nir_iadd(&b, dst_scratch_offset, nir_ishl(&b, global_id, nir_imm_int(&b, 2)))));
1346    nir_build_store_global(&b, node_id, dst_scratch_addr, .write_mask = 1, .align_mul = 4,
1347                           .align_offset = 0);
1348 
1349    nir_push_if(&b, fill_header);
1350    nir_build_store_global(&b, node_id, node_addr, .write_mask = 1, .align_mul = 4,
1351                           .align_offset = 0);
1352    nir_build_store_global(&b, total_bounds[0], nir_iadd(&b, node_addr, nir_imm_int64(&b, 8)),
1353                           .write_mask = 7, .align_mul = 4, .align_offset = 0);
1354    nir_build_store_global(&b, total_bounds[1], nir_iadd(&b, node_addr, nir_imm_int64(&b, 20)),
1355                           .write_mask = 7, .align_mul = 4, .align_offset = 0);
1356    nir_pop_if(&b, NULL);
1357    return b.shader;
1358 }
1359 
1360 enum copy_mode {
1361    COPY_MODE_COPY,
1362    COPY_MODE_SERIALIZE,
1363    COPY_MODE_DESERIALIZE,
1364 };
1365 
1366 struct copy_constants {
1367    uint64_t src_addr;
1368    uint64_t dst_addr;
1369    uint32_t mode;
1370 };
1371 
1372 static nir_shader *
build_copy_shader(struct radv_device * dev)1373 build_copy_shader(struct radv_device *dev)
1374 {
1375    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "accel_copy");
1376    b.shader->info.workgroup_size[0] = 64;
1377    b.shader->info.workgroup_size[1] = 1;
1378    b.shader->info.workgroup_size[2] = 1;
1379 
1380    nir_ssa_def *invoc_id = nir_load_local_invocation_id(&b);
1381    nir_ssa_def *wg_id = nir_load_workgroup_id(&b, 32);
1382    nir_ssa_def *block_size =
1383       nir_imm_ivec4(&b, b.shader->info.workgroup_size[0], b.shader->info.workgroup_size[1],
1384                     b.shader->info.workgroup_size[2], 0);
1385 
1386    nir_ssa_def *global_id =
1387       nir_channel(&b, nir_iadd(&b, nir_imul(&b, wg_id, block_size), invoc_id), 0);
1388 
1389    nir_variable *offset_var =
1390       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "offset");
1391    nir_ssa_def *offset = nir_imul(&b, global_id, nir_imm_int(&b, 16));
1392    nir_store_var(&b, offset_var, offset, 1);
1393 
1394    nir_ssa_def *increment = nir_imul(&b, nir_channel(&b, nir_load_num_workgroups(&b, 32), 0),
1395                                      nir_imm_int(&b, b.shader->info.workgroup_size[0] * 16));
1396 
1397    nir_ssa_def *pconst0 =
1398       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1399    nir_ssa_def *pconst1 =
1400       nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 16, .range = 4);
1401    nir_ssa_def *src_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
1402    nir_ssa_def *dst_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0xc));
1403    nir_ssa_def *mode = nir_channel(&b, pconst1, 0);
1404 
1405    nir_variable *compacted_size_var =
1406       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "compacted_size");
1407    nir_variable *src_offset_var =
1408       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "src_offset");
1409    nir_variable *dst_offset_var =
1410       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "dst_offset");
1411    nir_variable *instance_offset_var =
1412       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_offset");
1413    nir_variable *instance_count_var =
1414       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_count");
1415    nir_variable *value_var =
1416       nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "value");
1417 
1418    nir_push_if(&b, nir_ieq(&b, mode, nir_imm_int(&b, COPY_MODE_SERIALIZE)));
1419    {
1420       nir_ssa_def *instance_count = nir_build_load_global(
1421          &b, 1, 32,
1422          nir_iadd(&b, src_base_addr,
1423                   nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, instance_count))),
1424          .align_mul = 4, .align_offset = 0);
1425       nir_ssa_def *compacted_size = nir_build_load_global(
1426          &b, 1, 64,
1427          nir_iadd(&b, src_base_addr,
1428                   nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1429          .align_mul = 8, .align_offset = 0);
1430       nir_ssa_def *serialization_size = nir_build_load_global(
1431          &b, 1, 64,
1432          nir_iadd(&b, src_base_addr,
1433                   nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, serialization_size))),
1434          .align_mul = 8, .align_offset = 0);
1435 
1436       nir_store_var(&b, compacted_size_var, compacted_size, 1);
1437       nir_store_var(
1438          &b, instance_offset_var,
1439          nir_build_load_global(
1440             &b, 1, 32,
1441             nir_iadd(&b, src_base_addr,
1442                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, instance_offset))),
1443             .align_mul = 4, .align_offset = 0),
1444          1);
1445       nir_store_var(&b, instance_count_var, instance_count, 1);
1446 
1447       nir_ssa_def *dst_offset =
1448          nir_iadd(&b, nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)),
1449                   nir_imul(&b, instance_count, nir_imm_int(&b, sizeof(uint64_t))));
1450       nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1451       nir_store_var(&b, dst_offset_var, dst_offset, 1);
1452 
1453       nir_push_if(&b, nir_ieq(&b, global_id, nir_imm_int(&b, 0)));
1454       {
1455          nir_build_store_global(
1456             &b, serialization_size,
1457             nir_iadd(&b, dst_base_addr,
1458                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_serialization_header,
1459                                                 serialization_size))),
1460             .write_mask = 0x1, .align_mul = 8, .align_offset = 0);
1461          nir_build_store_global(
1462             &b, compacted_size,
1463             nir_iadd(&b, dst_base_addr,
1464                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_serialization_header,
1465                                                 compacted_size))),
1466             .write_mask = 0x1, .align_mul = 8, .align_offset = 0);
1467          nir_build_store_global(
1468             &b, nir_u2u64(&b, instance_count),
1469             nir_iadd(&b, dst_base_addr,
1470                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_serialization_header,
1471                                                 instance_count))),
1472             .write_mask = 0x1, .align_mul = 8, .align_offset = 0);
1473       }
1474       nir_pop_if(&b, NULL);
1475    }
1476    nir_push_else(&b, NULL);
1477    nir_push_if(&b, nir_ieq(&b, mode, nir_imm_int(&b, COPY_MODE_DESERIALIZE)));
1478    {
1479       nir_ssa_def *instance_count = nir_build_load_global(
1480          &b, 1, 32,
1481          nir_iadd(&b, src_base_addr,
1482                   nir_imm_int64(
1483                      &b, offsetof(struct radv_accel_struct_serialization_header, instance_count))),
1484          .align_mul = 4, .align_offset = 0);
1485       nir_ssa_def *src_offset =
1486          nir_iadd(&b, nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)),
1487                   nir_imul(&b, instance_count, nir_imm_int(&b, sizeof(uint64_t))));
1488 
1489       nir_ssa_def *header_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1490       nir_store_var(
1491          &b, compacted_size_var,
1492          nir_build_load_global(
1493             &b, 1, 64,
1494             nir_iadd(&b, header_addr,
1495                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1496             .align_mul = 8, .align_offset = 0),
1497          1);
1498       nir_store_var(
1499          &b, instance_offset_var,
1500          nir_build_load_global(
1501             &b, 1, 32,
1502             nir_iadd(&b, header_addr,
1503                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, instance_offset))),
1504             .align_mul = 4, .align_offset = 0),
1505          1);
1506       nir_store_var(&b, instance_count_var, instance_count, 1);
1507       nir_store_var(&b, src_offset_var, src_offset, 1);
1508       nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1509    }
1510    nir_push_else(&b, NULL); /* COPY_MODE_COPY */
1511    {
1512       nir_store_var(
1513          &b, compacted_size_var,
1514          nir_build_load_global(
1515             &b, 1, 64,
1516             nir_iadd(&b, src_base_addr,
1517                      nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1518             .align_mul = 8, .align_offset = 0),
1519          1);
1520 
1521       nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1522       nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1523       nir_store_var(&b, instance_offset_var, nir_imm_int(&b, 0), 1);
1524       nir_store_var(&b, instance_count_var, nir_imm_int(&b, 0), 1);
1525    }
1526    nir_pop_if(&b, NULL);
1527    nir_pop_if(&b, NULL);
1528 
1529    nir_ssa_def *instance_bound =
1530       nir_imul(&b, nir_imm_int(&b, sizeof(struct radv_bvh_instance_node)),
1531                nir_load_var(&b, instance_count_var));
1532    nir_ssa_def *compacted_size = nir_build_load_global(
1533       &b, 1, 32,
1534       nir_iadd(&b, src_base_addr,
1535                nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1536       .align_mul = 4, .align_offset = 0);
1537 
1538    nir_push_loop(&b);
1539    {
1540       offset = nir_load_var(&b, offset_var);
1541       nir_push_if(&b, nir_ilt(&b, offset, compacted_size));
1542       {
1543          nir_ssa_def *src_offset = nir_iadd(&b, offset, nir_load_var(&b, src_offset_var));
1544          nir_ssa_def *dst_offset = nir_iadd(&b, offset, nir_load_var(&b, dst_offset_var));
1545          nir_ssa_def *src_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1546          nir_ssa_def *dst_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, dst_offset));
1547 
1548          nir_ssa_def *value =
1549             nir_build_load_global(&b, 4, 32, src_addr, .align_mul = 16, .align_offset = 0);
1550          nir_store_var(&b, value_var, value, 0xf);
1551 
1552          nir_ssa_def *instance_offset = nir_isub(&b, offset, nir_load_var(&b, instance_offset_var));
1553          nir_ssa_def *in_instance_bound =
1554             nir_iand(&b, nir_uge(&b, offset, nir_load_var(&b, instance_offset_var)),
1555                      nir_ult(&b, instance_offset, instance_bound));
1556          nir_ssa_def *instance_start =
1557             nir_ieq(&b,
1558                     nir_iand(&b, instance_offset,
1559                              nir_imm_int(&b, sizeof(struct radv_bvh_instance_node) - 1)),
1560                     nir_imm_int(&b, 0));
1561 
1562          nir_push_if(&b, nir_iand(&b, in_instance_bound, instance_start));
1563          {
1564             nir_ssa_def *instance_id = nir_ushr(&b, instance_offset, nir_imm_int(&b, 7));
1565 
1566             nir_push_if(&b, nir_ieq(&b, mode, nir_imm_int(&b, COPY_MODE_SERIALIZE)));
1567             {
1568                nir_ssa_def *instance_addr =
1569                   nir_imul(&b, instance_id, nir_imm_int(&b, sizeof(uint64_t)));
1570                instance_addr =
1571                   nir_iadd(&b, instance_addr,
1572                            nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)));
1573                instance_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, instance_addr));
1574 
1575                nir_build_store_global(&b, nir_channels(&b, value, 3), instance_addr,
1576                                       .write_mask = 3, .align_mul = 8, .align_offset = 0);
1577             }
1578             nir_push_else(&b, NULL);
1579             {
1580                nir_ssa_def *instance_addr =
1581                   nir_imul(&b, instance_id, nir_imm_int(&b, sizeof(uint64_t)));
1582                instance_addr =
1583                   nir_iadd(&b, instance_addr,
1584                            nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)));
1585                instance_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, instance_addr));
1586 
1587                nir_ssa_def *instance_value = nir_build_load_global(
1588                   &b, 2, 32, instance_addr, .align_mul = 8, .align_offset = 0);
1589 
1590                nir_ssa_def *values[] = {
1591                   nir_channel(&b, instance_value, 0),
1592                   nir_channel(&b, instance_value, 1),
1593                   nir_channel(&b, value, 2),
1594                   nir_channel(&b, value, 3),
1595                };
1596 
1597                nir_store_var(&b, value_var, nir_vec(&b, values, 4), 0xf);
1598             }
1599             nir_pop_if(&b, NULL);
1600          }
1601          nir_pop_if(&b, NULL);
1602 
1603          nir_store_var(&b, offset_var, nir_iadd(&b, offset, increment), 1);
1604 
1605          nir_build_store_global(&b, nir_load_var(&b, value_var), dst_addr, .write_mask = 0xf,
1606                                 .align_mul = 16, .align_offset = 0);
1607       }
1608       nir_push_else(&b, NULL);
1609       {
1610          nir_jump(&b, nir_jump_break);
1611       }
1612       nir_pop_if(&b, NULL);
1613    }
1614    nir_pop_loop(&b, NULL);
1615    return b.shader;
1616 }
1617 
1618 void
radv_device_finish_accel_struct_build_state(struct radv_device * device)1619 radv_device_finish_accel_struct_build_state(struct radv_device *device)
1620 {
1621    struct radv_meta_state *state = &device->meta_state;
1622    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.copy_pipeline,
1623                         &state->alloc);
1624    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.internal_pipeline,
1625                         &state->alloc);
1626    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.leaf_pipeline,
1627                         &state->alloc);
1628    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1629                               state->accel_struct_build.copy_p_layout, &state->alloc);
1630    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1631                               state->accel_struct_build.internal_p_layout, &state->alloc);
1632    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1633                               state->accel_struct_build.leaf_p_layout, &state->alloc);
1634 }
1635 
1636 VkResult
radv_device_init_accel_struct_build_state(struct radv_device * device)1637 radv_device_init_accel_struct_build_state(struct radv_device *device)
1638 {
1639    VkResult result;
1640    nir_shader *leaf_cs = build_leaf_shader(device);
1641    nir_shader *internal_cs = build_internal_shader(device);
1642    nir_shader *copy_cs = build_copy_shader(device);
1643 
1644    const VkPipelineLayoutCreateInfo leaf_pl_create_info = {
1645       .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1646       .setLayoutCount = 0,
1647       .pushConstantRangeCount = 1,
1648       .pPushConstantRanges = &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0,
1649                                                     sizeof(struct build_primitive_constants)},
1650    };
1651 
1652    result = radv_CreatePipelineLayout(radv_device_to_handle(device), &leaf_pl_create_info,
1653                                       &device->meta_state.alloc,
1654                                       &device->meta_state.accel_struct_build.leaf_p_layout);
1655    if (result != VK_SUCCESS)
1656       goto fail;
1657 
1658    VkPipelineShaderStageCreateInfo leaf_shader_stage = {
1659       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1660       .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1661       .module = vk_shader_module_handle_from_nir(leaf_cs),
1662       .pName = "main",
1663       .pSpecializationInfo = NULL,
1664    };
1665 
1666    VkComputePipelineCreateInfo leaf_pipeline_info = {
1667       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1668       .stage = leaf_shader_stage,
1669       .flags = 0,
1670       .layout = device->meta_state.accel_struct_build.leaf_p_layout,
1671    };
1672 
1673    result = radv_CreateComputePipelines(
1674       radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1675       &leaf_pipeline_info, NULL, &device->meta_state.accel_struct_build.leaf_pipeline);
1676    if (result != VK_SUCCESS)
1677       goto fail;
1678 
1679    const VkPipelineLayoutCreateInfo internal_pl_create_info = {
1680       .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1681       .setLayoutCount = 0,
1682       .pushConstantRangeCount = 1,
1683       .pPushConstantRanges = &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0,
1684                                                     sizeof(struct build_internal_constants)},
1685    };
1686 
1687    result = radv_CreatePipelineLayout(radv_device_to_handle(device), &internal_pl_create_info,
1688                                       &device->meta_state.alloc,
1689                                       &device->meta_state.accel_struct_build.internal_p_layout);
1690    if (result != VK_SUCCESS)
1691       goto fail;
1692 
1693    VkPipelineShaderStageCreateInfo internal_shader_stage = {
1694       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1695       .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1696       .module = vk_shader_module_handle_from_nir(internal_cs),
1697       .pName = "main",
1698       .pSpecializationInfo = NULL,
1699    };
1700 
1701    VkComputePipelineCreateInfo internal_pipeline_info = {
1702       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1703       .stage = internal_shader_stage,
1704       .flags = 0,
1705       .layout = device->meta_state.accel_struct_build.internal_p_layout,
1706    };
1707 
1708    result = radv_CreateComputePipelines(
1709       radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1710       &internal_pipeline_info, NULL, &device->meta_state.accel_struct_build.internal_pipeline);
1711    if (result != VK_SUCCESS)
1712       goto fail;
1713 
1714    const VkPipelineLayoutCreateInfo copy_pl_create_info = {
1715       .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1716       .setLayoutCount = 0,
1717       .pushConstantRangeCount = 1,
1718       .pPushConstantRanges =
1719          &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct copy_constants)},
1720    };
1721 
1722    result = radv_CreatePipelineLayout(radv_device_to_handle(device), &copy_pl_create_info,
1723                                       &device->meta_state.alloc,
1724                                       &device->meta_state.accel_struct_build.copy_p_layout);
1725    if (result != VK_SUCCESS)
1726       goto fail;
1727 
1728    VkPipelineShaderStageCreateInfo copy_shader_stage = {
1729       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1730       .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1731       .module = vk_shader_module_handle_from_nir(copy_cs),
1732       .pName = "main",
1733       .pSpecializationInfo = NULL,
1734    };
1735 
1736    VkComputePipelineCreateInfo copy_pipeline_info = {
1737       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1738       .stage = copy_shader_stage,
1739       .flags = 0,
1740       .layout = device->meta_state.accel_struct_build.copy_p_layout,
1741    };
1742 
1743    result = radv_CreateComputePipelines(
1744       radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1745       &copy_pipeline_info, NULL, &device->meta_state.accel_struct_build.copy_pipeline);
1746    if (result != VK_SUCCESS)
1747       goto fail;
1748 
1749    ralloc_free(copy_cs);
1750    ralloc_free(internal_cs);
1751    ralloc_free(leaf_cs);
1752 
1753    return VK_SUCCESS;
1754 
1755 fail:
1756    radv_device_finish_accel_struct_build_state(device);
1757    ralloc_free(copy_cs);
1758    ralloc_free(internal_cs);
1759    ralloc_free(leaf_cs);
1760    return result;
1761 }
1762 
1763 struct bvh_state {
1764    uint32_t node_offset;
1765    uint32_t node_count;
1766    uint32_t scratch_offset;
1767 
1768    uint32_t instance_offset;
1769    uint32_t instance_count;
1770 };
1771 
1772 void
radv_CmdBuildAccelerationStructuresKHR(VkCommandBuffer commandBuffer,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkAccelerationStructureBuildRangeInfoKHR * const * ppBuildRangeInfos)1773 radv_CmdBuildAccelerationStructuresKHR(
1774    VkCommandBuffer commandBuffer, uint32_t infoCount,
1775    const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
1776    const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
1777 {
1778    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
1779    struct radv_meta_saved_state saved_state;
1780 
1781    radv_meta_save(
1782       &saved_state, cmd_buffer,
1783       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
1784    struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
1785 
1786    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
1787                         cmd_buffer->device->meta_state.accel_struct_build.leaf_pipeline);
1788 
1789    for (uint32_t i = 0; i < infoCount; ++i) {
1790       RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
1791                        pInfos[i].dstAccelerationStructure);
1792 
1793       struct build_primitive_constants prim_consts = {
1794          .node_dst_addr = radv_accel_struct_get_va(accel_struct),
1795          .scratch_addr = pInfos[i].scratchData.deviceAddress,
1796          .dst_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64) + 128,
1797          .dst_scratch_offset = 0,
1798       };
1799       bvh_states[i].node_offset = prim_consts.dst_offset;
1800       bvh_states[i].instance_offset = prim_consts.dst_offset;
1801 
1802       for (int inst = 1; inst >= 0; --inst) {
1803          for (unsigned j = 0; j < pInfos[i].geometryCount; ++j) {
1804             const VkAccelerationStructureGeometryKHR *geom =
1805                pInfos[i].pGeometries ? &pInfos[i].pGeometries[j] : pInfos[i].ppGeometries[j];
1806 
1807             if ((inst && geom->geometryType != VK_GEOMETRY_TYPE_INSTANCES_KHR) ||
1808                 (!inst && geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
1809                continue;
1810 
1811             prim_consts.geometry_type = geom->geometryType;
1812             prim_consts.geometry_id = j | (geom->flags << 28);
1813             unsigned prim_size;
1814             switch (geom->geometryType) {
1815             case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
1816                prim_consts.vertex_addr =
1817                   geom->geometry.triangles.vertexData.deviceAddress +
1818                   ppBuildRangeInfos[i][j].firstVertex * geom->geometry.triangles.vertexStride +
1819                   (geom->geometry.triangles.indexType != VK_INDEX_TYPE_NONE_KHR
1820                       ? ppBuildRangeInfos[i][j].primitiveOffset
1821                       : 0);
1822                prim_consts.index_addr = geom->geometry.triangles.indexData.deviceAddress +
1823                                         ppBuildRangeInfos[i][j].primitiveOffset;
1824                prim_consts.transform_addr = geom->geometry.triangles.transformData.deviceAddress +
1825                                             ppBuildRangeInfos[i][j].transformOffset;
1826                prim_consts.vertex_stride = geom->geometry.triangles.vertexStride;
1827                prim_consts.vertex_format = geom->geometry.triangles.vertexFormat;
1828                prim_consts.index_format = geom->geometry.triangles.indexType;
1829                prim_size = 64;
1830                break;
1831             case VK_GEOMETRY_TYPE_AABBS_KHR:
1832                prim_consts.aabb_addr =
1833                   geom->geometry.aabbs.data.deviceAddress + ppBuildRangeInfos[i][j].primitiveOffset;
1834                prim_consts.aabb_stride = geom->geometry.aabbs.stride;
1835                prim_size = 64;
1836                break;
1837             case VK_GEOMETRY_TYPE_INSTANCES_KHR:
1838                prim_consts.instance_data = geom->geometry.instances.data.deviceAddress;
1839                prim_consts.array_of_pointers = geom->geometry.instances.arrayOfPointers ? 1 : 0;
1840                prim_size = 128;
1841                bvh_states[i].instance_count += ppBuildRangeInfos[i][j].primitiveCount;
1842                break;
1843             default:
1844                unreachable("Unknown geometryType");
1845             }
1846 
1847             radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
1848                                   cmd_buffer->device->meta_state.accel_struct_build.leaf_p_layout,
1849                                   VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(prim_consts),
1850                                   &prim_consts);
1851             radv_unaligned_dispatch(cmd_buffer, ppBuildRangeInfos[i][j].primitiveCount, 1, 1);
1852             prim_consts.dst_offset += prim_size * ppBuildRangeInfos[i][j].primitiveCount;
1853             prim_consts.dst_scratch_offset += 4 * ppBuildRangeInfos[i][j].primitiveCount;
1854          }
1855       }
1856       bvh_states[i].node_offset = prim_consts.dst_offset;
1857       bvh_states[i].node_count = prim_consts.dst_scratch_offset / 4;
1858    }
1859 
1860    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
1861                         cmd_buffer->device->meta_state.accel_struct_build.internal_pipeline);
1862    bool progress = true;
1863    for (unsigned iter = 0; progress; ++iter) {
1864       progress = false;
1865       for (uint32_t i = 0; i < infoCount; ++i) {
1866          RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
1867                           pInfos[i].dstAccelerationStructure);
1868 
1869          if (iter && bvh_states[i].node_count == 1)
1870             continue;
1871 
1872          if (!progress) {
1873             cmd_buffer->state.flush_bits |=
1874                RADV_CMD_FLAG_CS_PARTIAL_FLUSH |
1875                radv_src_access_flush(cmd_buffer, VK_ACCESS_SHADER_WRITE_BIT, NULL) |
1876                radv_dst_access_flush(cmd_buffer,
1877                                      VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, NULL);
1878          }
1879          progress = true;
1880          uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 4));
1881          bool final_iter = dst_node_count == 1;
1882          uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
1883          uint32_t dst_scratch_offset = src_scratch_offset ? 0 : bvh_states[i].node_count * 4;
1884          uint32_t dst_node_offset = bvh_states[i].node_offset;
1885          if (final_iter)
1886             dst_node_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64);
1887 
1888          const struct build_internal_constants consts = {
1889             .node_dst_addr = radv_accel_struct_get_va(accel_struct),
1890             .scratch_addr = pInfos[i].scratchData.deviceAddress,
1891             .dst_offset = dst_node_offset,
1892             .dst_scratch_offset = dst_scratch_offset,
1893             .src_scratch_offset = src_scratch_offset,
1894             .fill_header = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
1895          };
1896 
1897          radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
1898                                cmd_buffer->device->meta_state.accel_struct_build.internal_p_layout,
1899                                VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
1900          radv_unaligned_dispatch(cmd_buffer, dst_node_count, 1, 1);
1901          if (!final_iter)
1902             bvh_states[i].node_offset += dst_node_count * 128;
1903          bvh_states[i].node_count = dst_node_count;
1904          bvh_states[i].scratch_offset = dst_scratch_offset;
1905       }
1906    }
1907    for (uint32_t i = 0; i < infoCount; ++i) {
1908       RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
1909                        pInfos[i].dstAccelerationStructure);
1910       const size_t base = offsetof(struct radv_accel_struct_header, compacted_size);
1911       struct radv_accel_struct_header header;
1912 
1913       header.instance_offset = bvh_states[i].instance_offset;
1914       header.instance_count = bvh_states[i].instance_count;
1915       header.compacted_size = bvh_states[i].node_offset;
1916 
1917       /* 16 bytes per invocation, 64 invocations per workgroup */
1918       header.copy_dispatch_size[0] = DIV_ROUND_UP(header.compacted_size, 16 * 64);
1919       header.copy_dispatch_size[1] = 1;
1920       header.copy_dispatch_size[2] = 1;
1921 
1922       header.serialization_size =
1923          header.compacted_size + align(sizeof(struct radv_accel_struct_serialization_header) +
1924                                           sizeof(uint64_t) * header.instance_count,
1925                                        128);
1926 
1927       radv_update_buffer_cp(cmd_buffer,
1928                             radv_buffer_get_va(accel_struct->bo) + accel_struct->mem_offset + base,
1929                             (const char *)&header + base, sizeof(header) - base);
1930    }
1931    free(bvh_states);
1932    radv_meta_restore(&saved_state, cmd_buffer);
1933 }
1934 
1935 void
radv_CmdCopyAccelerationStructureKHR(VkCommandBuffer commandBuffer,const VkCopyAccelerationStructureInfoKHR * pInfo)1936 radv_CmdCopyAccelerationStructureKHR(VkCommandBuffer commandBuffer,
1937                                      const VkCopyAccelerationStructureInfoKHR *pInfo)
1938 {
1939    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
1940    RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
1941    RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
1942    struct radv_meta_saved_state saved_state;
1943 
1944    radv_meta_save(
1945       &saved_state, cmd_buffer,
1946       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
1947 
1948    uint64_t src_addr = radv_accel_struct_get_va(src);
1949    uint64_t dst_addr = radv_accel_struct_get_va(dst);
1950 
1951    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
1952                         cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
1953 
1954    const struct copy_constants consts = {
1955       .src_addr = src_addr,
1956       .dst_addr = dst_addr,
1957       .mode = COPY_MODE_COPY,
1958    };
1959 
1960    radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
1961                          cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
1962                          VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
1963 
1964    radv_indirect_dispatch(cmd_buffer, src->bo,
1965                           src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
1966    radv_meta_restore(&saved_state, cmd_buffer);
1967 }
1968 
1969 void
radv_GetDeviceAccelerationStructureCompatibilityKHR(VkDevice _device,const VkAccelerationStructureVersionInfoKHR * pVersionInfo,VkAccelerationStructureCompatibilityKHR * pCompatibility)1970 radv_GetDeviceAccelerationStructureCompatibilityKHR(
1971    VkDevice _device, const VkAccelerationStructureVersionInfoKHR *pVersionInfo,
1972    VkAccelerationStructureCompatibilityKHR *pCompatibility)
1973 {
1974    RADV_FROM_HANDLE(radv_device, device, _device);
1975    uint8_t zero[VK_UUID_SIZE] = {
1976       0,
1977    };
1978    bool compat =
1979       memcmp(pVersionInfo->pVersionData, device->physical_device->driver_uuid, VK_UUID_SIZE) == 0 &&
1980       memcmp(pVersionInfo->pVersionData + VK_UUID_SIZE, zero, VK_UUID_SIZE) == 0;
1981    *pCompatibility = compat ? VK_ACCELERATION_STRUCTURE_COMPATIBILITY_COMPATIBLE_KHR
1982                             : VK_ACCELERATION_STRUCTURE_COMPATIBILITY_INCOMPATIBLE_KHR;
1983 }
1984 
1985 VkResult
radv_CopyMemoryToAccelerationStructureKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,const VkCopyMemoryToAccelerationStructureInfoKHR * pInfo)1986 radv_CopyMemoryToAccelerationStructureKHR(VkDevice _device,
1987                                           VkDeferredOperationKHR deferredOperation,
1988                                           const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
1989 {
1990    RADV_FROM_HANDLE(radv_device, device, _device);
1991    RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->dst);
1992 
1993    char *base = device->ws->buffer_map(accel_struct->bo);
1994    if (!base)
1995       return VK_ERROR_OUT_OF_HOST_MEMORY;
1996 
1997    base += accel_struct->mem_offset;
1998    const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
1999 
2000    const char *src = pInfo->src.hostAddress;
2001    struct radv_accel_struct_serialization_header *src_header = (void *)src;
2002    src += sizeof(*src_header) + sizeof(uint64_t) * src_header->instance_count;
2003 
2004    memcpy(base, src, src_header->compacted_size);
2005 
2006    for (unsigned i = 0; i < src_header->instance_count; ++i) {
2007       uint64_t *p = (uint64_t *)(base + i * 128 + header->instance_offset);
2008       *p = (*p & 63) | src_header->instances[i];
2009    }
2010 
2011    device->ws->buffer_unmap(accel_struct->bo);
2012    return VK_SUCCESS;
2013 }
2014 
2015 VkResult
radv_CopyAccelerationStructureToMemoryKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,const VkCopyAccelerationStructureToMemoryInfoKHR * pInfo)2016 radv_CopyAccelerationStructureToMemoryKHR(VkDevice _device,
2017                                           VkDeferredOperationKHR deferredOperation,
2018                                           const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2019 {
2020    RADV_FROM_HANDLE(radv_device, device, _device);
2021    RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->src);
2022 
2023    const char *base = device->ws->buffer_map(accel_struct->bo);
2024    if (!base)
2025       return VK_ERROR_OUT_OF_HOST_MEMORY;
2026 
2027    base += accel_struct->mem_offset;
2028    const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
2029 
2030    char *dst = pInfo->dst.hostAddress;
2031    struct radv_accel_struct_serialization_header *dst_header = (void *)dst;
2032    dst += sizeof(*dst_header) + sizeof(uint64_t) * header->instance_count;
2033 
2034    memcpy(dst_header->driver_uuid, device->physical_device->driver_uuid, VK_UUID_SIZE);
2035    memset(dst_header->accel_struct_compat, 0, VK_UUID_SIZE);
2036 
2037    dst_header->serialization_size = header->serialization_size;
2038    dst_header->compacted_size = header->compacted_size;
2039    dst_header->instance_count = header->instance_count;
2040 
2041    memcpy(dst, base, header->compacted_size);
2042 
2043    for (unsigned i = 0; i < header->instance_count; ++i) {
2044       dst_header->instances[i] =
2045          *(const uint64_t *)(base + i * 128 + header->instance_offset) & ~63ull;
2046    }
2047 
2048    device->ws->buffer_unmap(accel_struct->bo);
2049    return VK_SUCCESS;
2050 }
2051 
2052 void
radv_CmdCopyMemoryToAccelerationStructureKHR(VkCommandBuffer commandBuffer,const VkCopyMemoryToAccelerationStructureInfoKHR * pInfo)2053 radv_CmdCopyMemoryToAccelerationStructureKHR(
2054    VkCommandBuffer commandBuffer, const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
2055 {
2056    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2057    RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
2058    struct radv_meta_saved_state saved_state;
2059 
2060    radv_meta_save(
2061       &saved_state, cmd_buffer,
2062       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2063 
2064    uint64_t dst_addr = radv_accel_struct_get_va(dst);
2065 
2066    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2067                         cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2068 
2069    const struct copy_constants consts = {
2070       .src_addr = pInfo->src.deviceAddress,
2071       .dst_addr = dst_addr,
2072       .mode = COPY_MODE_DESERIALIZE,
2073    };
2074 
2075    radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2076                          cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2077                          VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2078 
2079    radv_CmdDispatch(commandBuffer, 512, 1, 1);
2080    radv_meta_restore(&saved_state, cmd_buffer);
2081 }
2082 
2083 void
radv_CmdCopyAccelerationStructureToMemoryKHR(VkCommandBuffer commandBuffer,const VkCopyAccelerationStructureToMemoryInfoKHR * pInfo)2084 radv_CmdCopyAccelerationStructureToMemoryKHR(
2085    VkCommandBuffer commandBuffer, const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2086 {
2087    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2088    RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
2089    struct radv_meta_saved_state saved_state;
2090 
2091    radv_meta_save(
2092       &saved_state, cmd_buffer,
2093       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2094 
2095    uint64_t src_addr = radv_accel_struct_get_va(src);
2096 
2097    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2098                         cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2099 
2100    const struct copy_constants consts = {
2101       .src_addr = src_addr,
2102       .dst_addr = pInfo->dst.deviceAddress,
2103       .mode = COPY_MODE_SERIALIZE,
2104    };
2105 
2106    radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2107                          cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2108                          VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2109 
2110    radv_indirect_dispatch(cmd_buffer, src->bo,
2111                           src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
2112    radv_meta_restore(&saved_state, cmd_buffer);
2113 
2114    /* Set the header of the serialized data. */
2115    uint8_t header_data[2 * VK_UUID_SIZE] = {0};
2116    memcpy(header_data, cmd_buffer->device->physical_device->driver_uuid, VK_UUID_SIZE);
2117 
2118    radv_update_buffer_cp(cmd_buffer, pInfo->dst.deviceAddress, header_data, sizeof(header_data));
2119 }
2120