1 /* Copyright (c) 2018-2021 The Khronos Group Inc.
2  * Copyright (c) 2018-2021 Valve Corporation
3  * Copyright (c) 2018-2021 LunarG, Inc.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  * Author: Karl Schultz <karl@lunarg.com>
18  * Author: Tony Barbour <tony@lunarg.com>
19  */
20 
21 #pragma once
22 
23 #include "chassis.h"
24 #include "state_tracker.h"
25 #include "vk_mem_alloc.h"
26 #include "gpu_utils.h"
27 class GpuAssisted;
28 class CMD_BUFFER_STATE_GPUAV;
29 
30 struct GpuAssistedDeviceMemoryBlock {
31     VkBuffer buffer;
32     VmaAllocation allocation;
33     layer_data::unordered_map<uint32_t, const cvdescriptorset::Descriptor*> update_at_submit;
34 };
35 
36 struct GpuAssistedPreDrawResources {
37     VkDescriptorPool desc_pool;
38     VkDescriptorSet desc_set;
39     VkBuffer buffer;
40     VkDeviceSize offset;
41     uint32_t stride;
42     VkDeviceSize buf_size;
43 };
44 
45 struct GpuAssistedBufferInfo {
46     GpuAssistedDeviceMemoryBlock output_mem_block;
47     GpuAssistedDeviceMemoryBlock di_input_mem_block;   // Descriptor Indexing input
48     GpuAssistedDeviceMemoryBlock bda_input_mem_block;  // Buffer Device Address input
49     GpuAssistedPreDrawResources pre_draw_resources;
50     VkDescriptorSet desc_set;
51     VkDescriptorPool desc_pool;
52     VkPipelineBindPoint pipeline_bind_point;
53     CMD_TYPE cmd_type;
GpuAssistedBufferInfoGpuAssistedBufferInfo54     GpuAssistedBufferInfo(GpuAssistedDeviceMemoryBlock output_mem_block, GpuAssistedDeviceMemoryBlock di_input_mem_block,
55                           GpuAssistedDeviceMemoryBlock bda_input_mem_block, GpuAssistedPreDrawResources pre_draw_resources,
56                           VkDescriptorSet desc_set, VkDescriptorPool desc_pool, VkPipelineBindPoint pipeline_bind_point,
57                           CMD_TYPE cmd_type)
58         : output_mem_block(output_mem_block),
59           di_input_mem_block(di_input_mem_block),
60           bda_input_mem_block(bda_input_mem_block),
61           pre_draw_resources(pre_draw_resources),
62           desc_set(desc_set),
63           desc_pool(desc_pool),
64           pipeline_bind_point(pipeline_bind_point),
65           cmd_type(cmd_type){};
66 };
67 
68 struct GpuAssistedShaderTracker {
69     VkPipeline pipeline;
70     VkShaderModule shader_module;
71     std::vector<unsigned int> pgm;
72 };
73 
74 struct GpuVuid {
75     const char* uniform_access_oob = kVUIDUndefined;
76     const char* storage_access_oob = kVUIDUndefined;
77     const char* count_exceeds_bufsize_1 = kVUIDUndefined;
78     const char* count_exceeds_bufsize = kVUIDUndefined;
79     const char* count_exceeds_device_limit = kVUIDUndefined;
80     const char* first_instance_not_zero = kVUIDUndefined;
81 };
82 
83 struct GpuAssistedAccelerationStructureBuildValidationBufferInfo {
84     // The acceleration structure that is being built.
85     VkAccelerationStructureNV acceleration_structure = VK_NULL_HANDLE;
86 
87     // The descriptor pool and descriptor set being used to validate a given build.
88     VkDescriptorPool descriptor_pool = VK_NULL_HANDLE;
89     VkDescriptorSet descriptor_set = VK_NULL_HANDLE;
90 
91     // The storage buffer used by the validating compute shader whichcontains info about
92     // the valid handles and which is written to communicate found invalid handles.
93     VkBuffer validation_buffer = VK_NULL_HANDLE;
94     VmaAllocation validation_buffer_allocation = VK_NULL_HANDLE;
95 };
96 
97 struct GpuAssistedAccelerationStructureBuildValidationState {
98     bool initialized = false;
99 
100     VkPipeline pipeline = VK_NULL_HANDLE;
101     VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
102 
103     VkAccelerationStructureNV replacement_as = VK_NULL_HANDLE;
104     VmaAllocation replacement_as_allocation = VK_NULL_HANDLE;
105     uint64_t replacement_as_handle = 0;
106 
107 };
108 
109 struct GpuAssistedPreDrawValidationState {
110     bool globals_created = false;
111     VkShaderModule validation_shader_module = VK_NULL_HANDLE;
112     VkDescriptorSetLayout validation_ds_layout = VK_NULL_HANDLE;
113     VkPipelineLayout validation_pipeline_layout = VK_NULL_HANDLE;
114     layer_data::unordered_map <VkRenderPass, VkPipeline> renderpass_to_pipeline;
115 };
116 
117 struct GpuAssistedCmdDrawIndirectState {
118     VkBuffer buffer;
119     VkDeviceSize offset;
120     uint32_t drawCount;
121     uint32_t stride;
122     VkBuffer count_buffer;
123     VkDeviceSize count_buffer_offset;
124 };
125 
126 class CMD_BUFFER_STATE_GPUAV : public CMD_BUFFER_STATE {
127   public:
128     std::vector<GpuAssistedBufferInfo> gpuav_buffer_list;
129     std::vector<GpuAssistedAccelerationStructureBuildValidationBufferInfo> as_validation_buffers;
130 
131     CMD_BUFFER_STATE_GPUAV(GpuAssisted* ga, VkCommandBuffer cb, const VkCommandBufferAllocateInfo* pCreateInfo,
132                            const COMMAND_POOL_STATE* pool);
133 
134     void Reset() final;
135 };
136 
137 class GpuAssisted : public ValidationStateTracker {
138     VkPhysicalDeviceFeatures supported_features;
139     VkBool32 shaderInt64;
140     uint32_t unique_shader_module_id = 0;
141     uint32_t output_buffer_size;
142     bool buffer_oob_enabled;
143     bool validate_draw_indirect;
144     std::map<VkDeviceAddress, VkDeviceSize> buffer_map;
145     GpuAssistedAccelerationStructureBuildValidationState acceleration_structure_validation_state;
146     GpuAssistedPreDrawValidationState pre_draw_validation_state;
147 
148     void PreRecordCommandBuffer(VkCommandBuffer command_buffer);
149     bool CommandBufferNeedsProcessing(VkCommandBuffer command_buffer);
150     void ProcessCommandBuffer(VkQueue queue, VkCommandBuffer command_buffer);
151 
152   public:
GpuAssisted()153     GpuAssisted() { container_type = LayerObjectTypeGpuAssisted; }
154 
155     bool aborted = false;
156     bool descriptor_indexing = false;
157     VkDevice device;
158     VkPhysicalDevice physicalDevice;
159     uint32_t adjusted_max_desc_sets;
160     uint32_t desc_set_bind_index;
161     VkDescriptorSetLayout debug_desc_layout = VK_NULL_HANDLE;
162     VkDescriptorSetLayout dummy_desc_layout = VK_NULL_HANDLE;
163     std::unique_ptr<UtilDescriptorSetManager> desc_set_manager;
164     layer_data::unordered_map<uint32_t, GpuAssistedShaderTracker> shader_map;
165     PFN_vkSetDeviceLoaderData vkSetDeviceLoaderData;
166     VmaAllocator vmaAllocator = {};
167     std::map<VkQueue, UtilQueueBarrierCommandInfo> queue_barrier_command_infos;
168   public:
169     template <typename T>
170     void ReportSetupProblem(T object, const char* const specific_message) const;
171     bool CheckForDescriptorIndexing(DeviceFeatures enabled_features) const;
172     void PreCallRecordCreateDevice(VkPhysicalDevice gpu, const VkDeviceCreateInfo* pCreateInfo,
173                                    const VkAllocationCallbacks* pAllocator, VkDevice* pDevice, void* modified_create_info) override;
174     void PostCallRecordCreateDevice(VkPhysicalDevice gpu, const VkDeviceCreateInfo* pCreateInfo,
175                                     const VkAllocationCallbacks* pAllocator, VkDevice* pDevice, VkResult result) override;
176     void PostCallRecordGetBufferDeviceAddress(VkDevice device, const VkBufferDeviceAddressInfo* pInfo,
177                                               VkDeviceAddress address) override;
178     void PostCallRecordGetBufferDeviceAddressKHR(VkDevice device, const VkBufferDeviceAddressInfo* pInfo,
179                                                  VkDeviceAddress address) override;
180     void PostCallRecordGetBufferDeviceAddressEXT(VkDevice device, const VkBufferDeviceAddressInfo* pInfo,
181                                                  VkDeviceAddress address) override;
182     void PreCallRecordDestroyBuffer(VkDevice device, VkBuffer buffer, const VkAllocationCallbacks* pAllocator) override;
183     void PreCallRecordDestroyDevice(VkDevice device, const VkAllocationCallbacks* pAllocator) override;
184     void PostCallRecordBindAccelerationStructureMemoryNV(VkDevice device, uint32_t bindInfoCount,
185                                                          const VkBindAccelerationStructureMemoryInfoNV* pBindInfos,
186                                                          VkResult result) override;
187     void PreCallRecordCreatePipelineLayout(VkDevice device, const VkPipelineLayoutCreateInfo* pCreateInfo,
188                                            const VkAllocationCallbacks* pAllocator, VkPipelineLayout* pPipelineLayout,
189                                            void* cpl_state_data) override;
190     void PostCallRecordCreatePipelineLayout(VkDevice device, const VkPipelineLayoutCreateInfo* pCreateInfo,
191                                             const VkAllocationCallbacks* pAllocator, VkPipelineLayout* pPipelineLayout,
192                                             VkResult result) override;
193     bool PreCallValidateCmdWaitEvents(VkCommandBuffer commandBuffer, uint32_t eventCount, const VkEvent* pEvents,
194                                       VkPipelineStageFlags srcStageMask, VkPipelineStageFlags dstStageMask,
195                                       uint32_t memoryBarrierCount, const VkMemoryBarrier* pMemoryBarriers,
196                                       uint32_t bufferMemoryBarrierCount, const VkBufferMemoryBarrier* pBufferMemoryBarriers,
197                                       uint32_t imageMemoryBarrierCount, const VkImageMemoryBarrier* pImageMemoryBarriers) const override;
198     bool PreCallValidateCmdWaitEvents2KHR(VkCommandBuffer commandBuffer, uint32_t eventCount, const VkEvent* pEvents,
199                                           const VkDependencyInfoKHR* pDependencyInfos) const override;
200     void PreCallRecordCreateBuffer(VkDevice device, const VkBufferCreateInfo* pCreateInfo, const VkAllocationCallbacks* pAllocator,
201                                    VkBuffer* pBuffer, void* cb_state_data) override;
202     void PostCallRecordCreateBuffer(VkDevice device, const VkBufferCreateInfo* pCreateInfo, const VkAllocationCallbacks* pAllocator,
203                                     VkBuffer* pBuffer, VkResult result) override;
204     void CreateAccelerationStructureBuildValidationState(GpuAssisted* device_GpuAssisted);
205     void DestroyAccelerationStructureBuildValidationState();
206     void PreCallRecordCmdBuildAccelerationStructureNV(VkCommandBuffer commandBuffer, const VkAccelerationStructureInfoNV* pInfo,
207                                                       VkBuffer instanceData, VkDeviceSize instanceOffset, VkBool32 update,
208                                                       VkAccelerationStructureNV dst, VkAccelerationStructureNV src,
209                                                       VkBuffer scratch, VkDeviceSize scratchOffset) override;
210     void ProcessAccelerationStructureBuildValidationBuffer(VkQueue queue, CMD_BUFFER_STATE_GPUAV* cb_node);
211     void PreCallRecordCreateGraphicsPipelines(VkDevice device, VkPipelineCache pipelineCache, uint32_t count,
212                                               const VkGraphicsPipelineCreateInfo* pCreateInfos,
213                                               const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines,
214                                               void* cgpl_state_data) override;
215     void PreCallRecordCreateComputePipelines(VkDevice device, VkPipelineCache pipelineCache, uint32_t count,
216                                              const VkComputePipelineCreateInfo* pCreateInfos,
217                                              const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines,
218                                              void* ccpl_state_data) override;
219     void PreCallRecordCreateRayTracingPipelinesNV(VkDevice device, VkPipelineCache pipelineCache, uint32_t count,
220                                                   const VkRayTracingPipelineCreateInfoNV* pCreateInfos,
221                                                   const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines,
222                                                   void* crtpl_state_data) override;
223     void PreCallRecordCreateRayTracingPipelinesKHR(VkDevice device, VkDeferredOperationKHR deferredOperation,
224                                                    VkPipelineCache pipelineCache, uint32_t count,
225                                                    const VkRayTracingPipelineCreateInfoKHR* pCreateInfos,
226                                                    const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines,
227                                                    void* crtpl_state_data) override;
228     void PostCallRecordCreateGraphicsPipelines(VkDevice device, VkPipelineCache pipelineCache, uint32_t count,
229                                                const VkGraphicsPipelineCreateInfo* pCreateInfos,
230                                                const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines, VkResult result,
231                                                void* cgpl_state_data) override;
232     void PostCallRecordCreateComputePipelines(VkDevice device, VkPipelineCache pipelineCache, uint32_t count,
233                                               const VkComputePipelineCreateInfo* pCreateInfos,
234                                               const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines, VkResult result,
235                                               void* ccpl_state_data) override;
236     void PostCallRecordCreateRayTracingPipelinesNV(VkDevice device, VkPipelineCache pipelineCache, uint32_t count,
237                                                    const VkRayTracingPipelineCreateInfoNV* pCreateInfos,
238                                                    const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines, VkResult result,
239                                                    void* crtpl_state_data) override;
240     void PostCallRecordCreateRayTracingPipelinesKHR(VkDevice device, VkDeferredOperationKHR deferredOperation,
241                                                     VkPipelineCache pipelineCache, uint32_t count,
242                                                     const VkRayTracingPipelineCreateInfoKHR* pCreateInfos,
243                                                     const VkAllocationCallbacks* pAllocator, VkPipeline* pPipelines,
244                                                     VkResult result, void* crtpl_state_data) override;
245     void PreCallRecordDestroyPipeline(VkDevice device, VkPipeline pipeline, const VkAllocationCallbacks* pAllocator) override;
246     void PreCallRecordDestroyRenderPass(VkDevice device, VkRenderPass renderPass, const VkAllocationCallbacks *pAllocator) override;
247     bool InstrumentShader(const VkShaderModuleCreateInfo* pCreateInfo, std::vector<unsigned int>& new_pgm,
248                           uint32_t* unique_shader_id);
249     void PreCallRecordCreateShaderModule(VkDevice device, const VkShaderModuleCreateInfo* pCreateInfo,
250                                          const VkAllocationCallbacks* pAllocator, VkShaderModule* pShaderModule,
251                                          void* csm_state_data) override;
252     void AnalyzeAndGenerateMessages(VkCommandBuffer command_buffer, VkQueue queue, GpuAssistedBufferInfo &buffer_info,
253         uint32_t operation_index, uint32_t* const debug_output_buffer);
254 
255     void SetDescriptorInitialized(uint32_t* pData, uint32_t index, const cvdescriptorset::Descriptor* descriptor);
256     void UpdateInstrumentationBuffer(CMD_BUFFER_STATE_GPUAV* cb_node);
257     const GpuVuid& GetGpuVuid(CMD_TYPE cmd_type) const;
258     void PreCallRecordQueueSubmit(VkQueue queue, uint32_t submitCount, const VkSubmitInfo* pSubmits, VkFence fence) override;
259     void PostCallRecordQueueSubmit(VkQueue queue, uint32_t submitCount, const VkSubmitInfo* pSubmits, VkFence fence,
260                                    VkResult result) override;
261     void PreCallRecordQueueSubmit2KHR(VkQueue queue, uint32_t submitCount, const VkSubmitInfo2KHR* pSubmits,
262                                       VkFence fence) override;
263     void PostCallRecordQueueSubmit2KHR(VkQueue queue, uint32_t submitCount, const VkSubmitInfo2KHR* pSubmits, VkFence fence,
264                                        VkResult result) override;
265     void PreCallRecordCmdDraw(VkCommandBuffer commandBuffer, uint32_t vertexCount, uint32_t instanceCount, uint32_t firstVertex,
266                               uint32_t firstInstance) override;
267     void PreCallRecordCmdDrawMultiEXT(VkCommandBuffer commandBuffer, uint32_t drawCount, const VkMultiDrawInfoEXT* pVertexInfo,
268                                       uint32_t instanceCount, uint32_t firstInstance, uint32_t stride) override;
269     void PreCallRecordCmdDrawIndexed(VkCommandBuffer commandBuffer, uint32_t indexCount, uint32_t instanceCount,
270                                      uint32_t firstIndex, int32_t vertexOffset, uint32_t firstInstance) override;
271     void PreCallRecordCmdDrawMultiIndexedEXT(VkCommandBuffer commandBuffer, uint32_t drawCount,
272                                              const VkMultiDrawIndexedInfoEXT* pIndexInfo, uint32_t instanceCount,
273                                              uint32_t firstInstance, uint32_t stride, const int32_t* pVertexOffset) override;
274     void PreCallRecordCmdDrawIndirect(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, uint32_t count,
275                                       uint32_t stride) override;
276     void PreCallRecordCmdDrawIndexedIndirect(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, uint32_t count,
277                                              uint32_t stride) override;
278     void PreCallRecordCmdDrawIndirectCountKHR(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset,
279                                               VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount,
280                                               uint32_t stride) override;
281     void PreCallRecordCmdDrawIndirectCount(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset,
282                                            VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount,
283                                            uint32_t stride) override;
284     void PreCallRecordCmdDrawIndirectByteCountEXT(VkCommandBuffer commandBuffer, uint32_t instanceCount, uint32_t firstInstance,
285                                                   VkBuffer counterBuffer, VkDeviceSize counterBufferOffset, uint32_t counterOffset,
286                                                   uint32_t vertexStride) override;
287     void PreCallRecordCmdDrawIndexedIndirectCountKHR(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset,
288                                                      VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount,
289                                                      uint32_t stride) override;
290     void PreCallRecordCmdDrawIndexedIndirectCount(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset,
291                                                   VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount,
292                                                   uint32_t stride) override;
293     void PreCallRecordCmdDrawMeshTasksNV(VkCommandBuffer commandBuffer, uint32_t taskCount, uint32_t firstTask) override;
294     void PreCallRecordCmdDrawMeshTasksIndirectNV(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset,
295                                                  uint32_t drawCount, uint32_t stride) override;
296     void PreCallRecordCmdDrawMeshTasksIndirectCountNV(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset,
297                                                       VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount,
298                                                       uint32_t stride) override;
299     void PreCallRecordCmdDispatch(VkCommandBuffer commandBuffer, uint32_t x, uint32_t y, uint32_t z) override;
300     void PreCallRecordCmdDispatchIndirect(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset) override;
301     void PreCallRecordCmdDispatchBase(VkCommandBuffer commandBuffer, uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
302                                       uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ) override;
303     void PreCallRecordCmdDispatchBaseKHR(VkCommandBuffer commandBuffer, uint32_t baseGroupX, uint32_t baseGroupY,
304                                          uint32_t baseGroupZ, uint32_t groupCountX, uint32_t groupCountY,
305                                          uint32_t groupCountZ) override;
306     void PreCallRecordCmdTraceRaysNV(VkCommandBuffer commandBuffer, VkBuffer raygenShaderBindingTableBuffer,
307                                      VkDeviceSize raygenShaderBindingOffset, VkBuffer missShaderBindingTableBuffer,
308                                      VkDeviceSize missShaderBindingOffset, VkDeviceSize missShaderBindingStride,
309                                      VkBuffer hitShaderBindingTableBuffer, VkDeviceSize hitShaderBindingOffset,
310                                      VkDeviceSize hitShaderBindingStride, VkBuffer callableShaderBindingTableBuffer,
311                                      VkDeviceSize callableShaderBindingOffset, VkDeviceSize callableShaderBindingStride,
312                                      uint32_t width, uint32_t height, uint32_t depth) override;
313     void PostCallRecordCmdTraceRaysNV(VkCommandBuffer commandBuffer, VkBuffer raygenShaderBindingTableBuffer,
314                                       VkDeviceSize raygenShaderBindingOffset, VkBuffer missShaderBindingTableBuffer,
315                                       VkDeviceSize missShaderBindingOffset, VkDeviceSize missShaderBindingStride,
316                                       VkBuffer hitShaderBindingTableBuffer, VkDeviceSize hitShaderBindingOffset,
317                                       VkDeviceSize hitShaderBindingStride, VkBuffer callableShaderBindingTableBuffer,
318                                       VkDeviceSize callableShaderBindingOffset, VkDeviceSize callableShaderBindingStride,
319                                       uint32_t width, uint32_t height, uint32_t depth) override;
320     void PreCallRecordCmdTraceRaysKHR(VkCommandBuffer commandBuffer,
321                                       const VkStridedDeviceAddressRegionKHR* pRaygenShaderBindingTable,
322                                       const VkStridedDeviceAddressRegionKHR* pMissShaderBindingTable,
323                                       const VkStridedDeviceAddressRegionKHR* pHitShaderBindingTable,
324                                       const VkStridedDeviceAddressRegionKHR* pCallableShaderBindingTable, uint32_t width,
325                                       uint32_t height, uint32_t depth) override;
326     void PostCallRecordCmdTraceRaysKHR(VkCommandBuffer commandBuffer,
327                                        const VkStridedDeviceAddressRegionKHR* pRaygenShaderBindingTable,
328                                        const VkStridedDeviceAddressRegionKHR* pMissShaderBindingTable,
329                                        const VkStridedDeviceAddressRegionKHR* pHitShaderBindingTable,
330                                        const VkStridedDeviceAddressRegionKHR* pCallableShaderBindingTable, uint32_t width,
331                                        uint32_t height, uint32_t depth) override;
332     void PreCallRecordCmdTraceRaysIndirectKHR(VkCommandBuffer commandBuffer,
333                                               const VkStridedDeviceAddressRegionKHR* pRaygenShaderBindingTable,
334                                               const VkStridedDeviceAddressRegionKHR* pMissShaderBindingTable,
335                                               const VkStridedDeviceAddressRegionKHR* pHitShaderBindingTable,
336                                               const VkStridedDeviceAddressRegionKHR* pCallableShaderBindingTable,
337                                               VkDeviceAddress indirectDeviceAddress) override;
338     void PostCallRecordCmdTraceRaysIndirectKHR(VkCommandBuffer commandBuffer,
339                                                const VkStridedDeviceAddressRegionKHR* pRaygenShaderBindingTable,
340                                                const VkStridedDeviceAddressRegionKHR* pMissShaderBindingTable,
341                                                const VkStridedDeviceAddressRegionKHR* pHitShaderBindingTable,
342                                                const VkStridedDeviceAddressRegionKHR* pCallableShaderBindingTable,
343                                                VkDeviceAddress indirectDeviceAddress) override;
344     void AllocateValidationResources(const VkCommandBuffer cmd_buffer, const VkPipelineBindPoint bind_point, CMD_TYPE cmd, const GpuAssistedCmdDrawIndirectState *cdic_state = nullptr);
345     void AllocatePreDrawValidationResources(GpuAssistedDeviceMemoryBlock output_block, GpuAssistedPreDrawResources& resources,
346                                             const LAST_BOUND_STATE& state, VkPipeline *pPipeline, const GpuAssistedCmdDrawIndirectState *cdic_state);
347     void PostCallRecordGetPhysicalDeviceProperties(VkPhysicalDevice physicalDevice,
348                                                    VkPhysicalDeviceProperties* pPhysicalDeviceProperties) override;
349     void PostCallRecordGetPhysicalDeviceProperties2(VkPhysicalDevice physicalDevice,
350                                                     VkPhysicalDeviceProperties2* pPhysicalDeviceProperties2) override;
351 
GetCBState(VkCommandBuffer commandBuffer)352     std::shared_ptr<CMD_BUFFER_STATE_GPUAV> GetCBState(VkCommandBuffer commandBuffer) {
353         return std::static_pointer_cast<CMD_BUFFER_STATE_GPUAV>(Get<CMD_BUFFER_STATE>(commandBuffer));
354     }
GetCBState(VkCommandBuffer commandBuffer)355     const std::shared_ptr<const CMD_BUFFER_STATE_GPUAV> GetCBState(VkCommandBuffer commandBuffer) const {
356         return std::static_pointer_cast<const CMD_BUFFER_STATE_GPUAV>(Get<CMD_BUFFER_STATE>(commandBuffer));
357     }
GetShaderModuleState(VkShaderModule shader_module)358     std::shared_ptr<SHADER_MODULE_STATE> GetShaderModuleState(VkShaderModule shader_module) {
359         return Get<SHADER_MODULE_STATE>(shader_module);
360     }
GetShaderModuleState(VkShaderModule shader_module)361     std::shared_ptr<const SHADER_MODULE_STATE> GetShaderModuleState(VkShaderModule shader_module) const {
362         return Get<SHADER_MODULE_STATE>(shader_module);
363     }
GetPipelineState(VkPipeline pipeline)364     std::shared_ptr<const PIPELINE_STATE> GetPipelineState(VkPipeline pipeline) const { return Get<PIPELINE_STATE>(pipeline); }
GetPipelineState(VkPipeline pipeline)365     std::shared_ptr<PIPELINE_STATE> GetPipelineState(VkPipeline pipeline) { return Get<PIPELINE_STATE>(pipeline); }
366 
GetBufferInfo(const CMD_BUFFER_STATE * cb_node)367     const std::vector<GpuAssistedBufferInfo>& GetBufferInfo(const CMD_BUFFER_STATE* cb_node) const {
368         assert(cb_node);
369         return static_cast<const CMD_BUFFER_STATE_GPUAV*>(cb_node)->gpuav_buffer_list;
370     }
371 
GetBufferInfo(CMD_BUFFER_STATE * cb_node)372     std::vector<GpuAssistedBufferInfo>& GetBufferInfo(CMD_BUFFER_STATE* cb_node) {
373         assert(cb_node);
374         return static_cast<CMD_BUFFER_STATE_GPUAV*>(cb_node)->gpuav_buffer_list;
375     }
376 
377     std::shared_ptr<CMD_BUFFER_STATE> CreateCmdBufferState(VkCommandBuffer cb, const VkCommandBufferAllocateInfo* create_info,
378                                                            const COMMAND_POOL_STATE* pool) final;
379 
380     void DestroyBuffer(GpuAssistedBufferInfo& buffer_info);
381     void DestroyBuffer(GpuAssistedAccelerationStructureBuildValidationBufferInfo& buffer_info);
382 };
383