1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "VulkanRuntime.h"
15
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22
emitVulkanError(const char * api,VkResult error)23 inline void emitVulkanError(const char *api, VkResult error) {
24 std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26
27 #define RETURN_ON_VULKAN_ERROR(result, api) \
28 if ((result) != VK_SUCCESS) { \
29 emitVulkanError(api, (result)); \
30 return failure(); \
31 }
32
33 using namespace mlir;
34
setNumWorkGroups(const NumWorkGroups & numberWorkGroups)35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36 numWorkGroups = numberWorkGroups;
37 }
38
setResourceStorageClassBindingMap(const ResourceStorageClassBindingMap & stClassData)39 void VulkanRuntime::setResourceStorageClassBindingMap(
40 const ResourceStorageClassBindingMap &stClassData) {
41 resourceStorageClassData = stClassData;
42 }
43
setResourceData(const DescriptorSetIndex desIndex,const BindingIndex bindIndex,const VulkanHostMemoryBuffer & hostMemBuffer)44 void VulkanRuntime::setResourceData(
45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46 const VulkanHostMemoryBuffer &hostMemBuffer) {
47 resourceData[desIndex][bindIndex] = hostMemBuffer;
48 resourceStorageClassData[desIndex][bindIndex] =
49 SPIRVStorageClass::StorageBuffer;
50 }
51
setEntryPoint(const char * entryPointName)52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53 entryPoint = entryPointName;
54 }
55
setResourceData(const ResourceData & resData)56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57 resourceData = resData;
58 }
59
setShaderModule(uint8_t * shader,uint32_t size)60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61 binary = shader;
62 binarySize = size;
63 }
64
mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,VkDescriptorType & descriptorType)65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67 switch (storageClass) {
68 case SPIRVStorageClass::StorageBuffer:
69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70 break;
71 case SPIRVStorageClass::Uniform:
72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73 break;
74 }
75 return success();
76 }
77
mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,VkBufferUsageFlagBits & bufferUsage)78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80 switch (storageClass) {
81 case SPIRVStorageClass::StorageBuffer:
82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83 break;
84 case SPIRVStorageClass::Uniform:
85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86 break;
87 }
88 return success();
89 }
90
countDeviceMemorySize()91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92 for (const auto &resourceDataMapPair : resourceData) {
93 const auto &resourceDataMap = resourceDataMapPair.second;
94 for (const auto &resourceDataBindingPair : resourceDataMap) {
95 if (resourceDataBindingPair.second.size) {
96 memorySize += resourceDataBindingPair.second.size;
97 } else {
98 std::cerr << "expected buffer size greater than zero for resource data";
99 return failure();
100 }
101 }
102 }
103 return success();
104 }
105
initRuntime()106 LogicalResult VulkanRuntime::initRuntime() {
107 if (!resourceData.size()) {
108 std::cerr << "Vulkan runtime needs at least one resource";
109 return failure();
110 }
111 if (!binarySize || !binary) {
112 std::cerr << "binary shader size must be greater than zero";
113 return failure();
114 }
115 if (failed(countDeviceMemorySize())) {
116 return failure();
117 }
118 return success();
119 }
120
destroy()121 LogicalResult VulkanRuntime::destroy() {
122 // According to Vulkan spec:
123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124 // used to gate the destruction of the device. Prior to destroying a device,
125 // an application is responsible for destroying/freeing any Vulkan objects
126 // that were created using that device as the first parameter of the
127 // corresponding vkCreate* or vkAllocate* command."
128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129
130 // Free and destroy.
131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132 commandBuffers.data());
133 vkDestroyQueryPool(device, queryPool, nullptr);
134 vkDestroyCommandPool(device, commandPool, nullptr);
135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136 descriptorSets.data());
137 vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138 vkDestroyPipeline(device, pipeline, nullptr);
139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140 for (auto &descriptorSetLayout : descriptorSetLayouts) {
141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142 }
143 vkDestroyShaderModule(device, shaderModule, nullptr);
144
145 // For each descriptor set.
146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148 // For each descriptor binding.
149 for (auto &memoryBuffer : deviceMemoryBuffers) {
150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151 vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152 vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153 vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154 }
155 }
156
157 vkDestroyDevice(device, nullptr);
158 vkDestroyInstance(instance, nullptr);
159 return success();
160 }
161
run()162 LogicalResult VulkanRuntime::run() {
163 // Create logical device, shader module and memory buffers.
164 if (failed(createInstance()) || failed(createDevice()) ||
165 failed(createMemoryBuffers()) || failed(createShaderModule())) {
166 return failure();
167 }
168
169 // Descriptor bindings divided into sets. Each descriptor binding
170 // must have a layout binding attached into a descriptor set layout.
171 // Each layout set must be binded into a pipeline layout.
172 initDescriptorSetLayoutBindingMap();
173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174 // Each descriptor set must be allocated from a descriptor pool.
175 failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177 // Create command buffer.
178 failed(createCommandPool()) || failed(createQueryPool()) ||
179 failed(createComputeCommandBuffer())) {
180 return failure();
181 }
182
183 // Get working queue.
184 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185
186 if (failed(copyResource(/*deviceToHost=*/false)))
187 return failure();
188
189 auto submitStart = std::chrono::high_resolution_clock::now();
190 // Submit command buffer into the queue.
191 if (failed(submitCommandBuffersToQueue()))
192 return failure();
193 auto submitEnd = std::chrono::high_resolution_clock::now();
194
195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196 auto execEnd = std::chrono::high_resolution_clock::now();
197
198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199 submitEnd - submitStart);
200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201 execEnd - submitEnd);
202
203 if (queryPool != VK_NULL_HANDLE) {
204 uint64_t timestamps[2];
205 RETURN_ON_VULKAN_ERROR(
206 vkGetQueryPoolResults(
207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208 /*dataSize=*/sizeof(timestamps),
209 /*pData=*/reinterpret_cast<void *>(timestamps),
210 /*stride=*/sizeof(uint64_t),
211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212 "vkGetQueryPoolResults");
213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214 std::cout << "Compute shader execution time: " << std::setprecision(3)
215 << microsec << "us\n";
216 }
217
218 std::cout << "Command buffer submit time: " << submitDuration.count()
219 << "us\nWait idle time: " << execDuration.count() << "us\n";
220
221 return success();
222 }
223
createInstance()224 LogicalResult VulkanRuntime::createInstance() {
225 VkApplicationInfo applicationInfo = {};
226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227 applicationInfo.pNext = nullptr;
228 applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229 applicationInfo.applicationVersion = 0;
230 applicationInfo.pEngineName = "mlir";
231 applicationInfo.engineVersion = 0;
232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233
234 VkInstanceCreateInfo instanceCreateInfo = {};
235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236 instanceCreateInfo.pNext = nullptr;
237 instanceCreateInfo.flags = 0;
238 instanceCreateInfo.pApplicationInfo = &applicationInfo;
239 instanceCreateInfo.enabledLayerCount = 0;
240 instanceCreateInfo.ppEnabledLayerNames = nullptr;
241 instanceCreateInfo.enabledExtensionCount = 0;
242 instanceCreateInfo.ppEnabledExtensionNames = nullptr;
243
244 RETURN_ON_VULKAN_ERROR(
245 vkCreateInstance(&instanceCreateInfo, nullptr, &instance),
246 "vkCreateInstance");
247 return success();
248 }
249
createDevice()250 LogicalResult VulkanRuntime::createDevice() {
251 uint32_t physicalDeviceCount = 0;
252 RETURN_ON_VULKAN_ERROR(
253 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr),
254 "vkEnumeratePhysicalDevices");
255
256 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
257 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
258 &physicalDeviceCount,
259 physicalDevices.data()),
260 "vkEnumeratePhysicalDevices");
261
262 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
263 "physicalDeviceCount");
264
265 // TODO: find the best device.
266 physicalDevice = physicalDevices.front();
267 if (failed(getBestComputeQueue()))
268 return failure();
269
270 const float queuePriority = 1.0f;
271 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
272 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
273 deviceQueueCreateInfo.pNext = nullptr;
274 deviceQueueCreateInfo.flags = 0;
275 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
276 deviceQueueCreateInfo.queueCount = 1;
277 deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
278
279 // Structure specifying parameters of a newly created device.
280 VkDeviceCreateInfo deviceCreateInfo = {};
281 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
282 deviceCreateInfo.pNext = nullptr;
283 deviceCreateInfo.flags = 0;
284 deviceCreateInfo.queueCreateInfoCount = 1;
285 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
286 deviceCreateInfo.enabledLayerCount = 0;
287 deviceCreateInfo.ppEnabledLayerNames = nullptr;
288 deviceCreateInfo.enabledExtensionCount = 0;
289 deviceCreateInfo.ppEnabledExtensionNames = nullptr;
290 deviceCreateInfo.pEnabledFeatures = nullptr;
291
292 RETURN_ON_VULKAN_ERROR(
293 vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device),
294 "vkCreateDevice");
295
296 VkPhysicalDeviceMemoryProperties properties = {};
297 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
298
299 // Try to find memory type with following properties:
300 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
301 // with this type can be mapped for host access using vkMapMemory;
302 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
303 // management commands vkFlushMappedMemoryRanges and
304 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
305 // device or make device writes visible to the host, respectively.
306 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
307 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
308 properties.memoryTypes[i].propertyFlags) &&
309 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
310 properties.memoryTypes[i].propertyFlags) &&
311 (memorySize <=
312 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
313 hostMemoryTypeIndex = i;
314 break;
315 }
316 }
317
318 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
319 // used on the device. This will allow better performance access for GPU with
320 // on device memory.
321 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
322 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
323 properties.memoryTypes[i].propertyFlags) &&
324 (memorySize <=
325 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
326 deviceMemoryTypeIndex = i;
327 break;
328 }
329 }
330
331 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
332 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
333 ? VK_INCOMPLETE
334 : VK_SUCCESS,
335 "invalid memoryTypeIndex");
336 return success();
337 }
338
getBestComputeQueue()339 LogicalResult VulkanRuntime::getBestComputeQueue() {
340 uint32_t queueFamilyPropertiesCount = 0;
341 vkGetPhysicalDeviceQueueFamilyProperties(
342 physicalDevice, &queueFamilyPropertiesCount, nullptr);
343
344 std::vector<VkQueueFamilyProperties> familyProperties(
345 queueFamilyPropertiesCount);
346 vkGetPhysicalDeviceQueueFamilyProperties(
347 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
348
349 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
350 // compute operations. Try to find a compute-only queue first if possible.
351 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
352 auto flags = familyProperties[i].queueFlags;
353 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
354 queueFamilyIndex = i;
355 queueFamilyProperties = familyProperties[i];
356 return success();
357 }
358 }
359
360 // Otherwise use a queue that can also support graphics.
361 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
362 auto flags = familyProperties[i].queueFlags;
363 if ((flags & VK_QUEUE_COMPUTE_BIT)) {
364 queueFamilyIndex = i;
365 queueFamilyProperties = familyProperties[i];
366 return success();
367 }
368 }
369
370 std::cerr << "cannot find valid queue";
371 return failure();
372 }
373
createMemoryBuffers()374 LogicalResult VulkanRuntime::createMemoryBuffers() {
375 // For each descriptor set.
376 for (const auto &resourceDataMapPair : resourceData) {
377 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
378 const auto descriptorSetIndex = resourceDataMapPair.first;
379 const auto &resourceDataMap = resourceDataMapPair.second;
380
381 // For each descriptor binding.
382 for (const auto &resourceDataBindingPair : resourceDataMap) {
383 // Create device memory buffer.
384 VulkanDeviceMemoryBuffer memoryBuffer;
385 memoryBuffer.bindingIndex = resourceDataBindingPair.first;
386 VkDescriptorType descriptorType = {};
387 VkBufferUsageFlagBits bufferUsage = {};
388
389 // Check that descriptor set has storage class map.
390 const auto resourceStorageClassMapIt =
391 resourceStorageClassData.find(descriptorSetIndex);
392 if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
393 std::cerr
394 << "cannot find storage class for resource in descriptor set: "
395 << descriptorSetIndex;
396 return failure();
397 }
398
399 // Check that specific descriptor binding has storage class.
400 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
401 const auto resourceStorageClassIt =
402 resourceStorageClassMap.find(resourceDataBindingPair.first);
403 if (resourceStorageClassIt == resourceStorageClassMap.end()) {
404 std::cerr
405 << "cannot find storage class for resource with descriptor index: "
406 << resourceDataBindingPair.first;
407 return failure();
408 }
409
410 const auto resourceStorageClassBinding = resourceStorageClassIt->second;
411 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
412 descriptorType)) ||
413 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
414 bufferUsage))) {
415 std::cerr << "storage class for resource with descriptor binding: "
416 << resourceDataBindingPair.first
417 << " in the descriptor set: " << descriptorSetIndex
418 << " is not supported ";
419 return failure();
420 }
421
422 // Set descriptor type for the specific device memory buffer.
423 memoryBuffer.descriptorType = descriptorType;
424 const auto bufferSize = resourceDataBindingPair.second.size;
425 memoryBuffer.bufferSize = bufferSize;
426 // Specify memory allocation info.
427 VkMemoryAllocateInfo memoryAllocateInfo = {};
428 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
429 memoryAllocateInfo.pNext = nullptr;
430 memoryAllocateInfo.allocationSize = bufferSize;
431 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
432
433 // Allocate device memory.
434 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
435 nullptr,
436 &memoryBuffer.hostMemory),
437 "vkAllocateMemory");
438 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
439 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
440 nullptr,
441 &memoryBuffer.deviceMemory),
442 "vkAllocateMemory");
443 void *payload;
444 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
445 bufferSize, 0,
446 reinterpret_cast<void **>(&payload)),
447 "vkMapMemory");
448
449 // Copy host memory into the mapped area.
450 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
451 vkUnmapMemory(device, memoryBuffer.hostMemory);
452
453 VkBufferCreateInfo bufferCreateInfo = {};
454 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
455 bufferCreateInfo.pNext = nullptr;
456 bufferCreateInfo.flags = 0;
457 bufferCreateInfo.size = bufferSize;
458 bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
459 VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
460 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
461 bufferCreateInfo.queueFamilyIndexCount = 1;
462 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
463 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
464 &memoryBuffer.hostBuffer),
465 "vkCreateBuffer");
466 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
467 &memoryBuffer.deviceBuffer),
468 "vkCreateBuffer");
469
470 // Bind buffer and device memory.
471 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
472 memoryBuffer.hostMemory, 0),
473 "vkBindBufferMemory");
474 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
475 memoryBuffer.deviceBuffer,
476 memoryBuffer.deviceMemory, 0),
477 "vkBindBufferMemory");
478
479 // Update buffer info.
480 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
481 memoryBuffer.bufferInfo.offset = 0;
482 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
483 deviceMemoryBuffers.push_back(memoryBuffer);
484 }
485
486 // Associate device memory buffers with a descriptor set.
487 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
488 }
489 return success();
490 }
491
copyResource(bool deviceToHost)492 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
493 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
494 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
495 nullptr,
496 commandPool,
497 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
498 1,
499 };
500 VkCommandBuffer commandBuffer;
501 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
502 &commandBufferAllocateInfo,
503 &commandBuffer),
504 "vkAllocateCommandBuffers");
505
506 VkCommandBufferBeginInfo commandBufferBeginInfo = {
507 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
508 nullptr,
509 0,
510 nullptr,
511 };
512 RETURN_ON_VULKAN_ERROR(
513 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
514 "vkBeginCommandBuffer");
515
516 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
517 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
518 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
519 for (const auto &memBuffer : deviceMemoryBuffers) {
520 VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
521 if (deviceToHost)
522 vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
523 memBuffer.hostBuffer, 1, ©);
524 else
525 vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
526 memBuffer.deviceBuffer, 1, ©);
527 }
528 }
529
530 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
531 "vkEndCommandBuffer");
532 VkSubmitInfo submitInfo = {
533 VK_STRUCTURE_TYPE_SUBMIT_INFO,
534 nullptr,
535 0,
536 nullptr,
537 nullptr,
538 1,
539 &commandBuffer,
540 0,
541 nullptr,
542 };
543 submitInfo.pCommandBuffers = &commandBuffer;
544 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
545 "vkQueueSubmit");
546 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
547
548 vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
549 return success();
550 }
551
createShaderModule()552 LogicalResult VulkanRuntime::createShaderModule() {
553 VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
554 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
555 shaderModuleCreateInfo.pNext = nullptr;
556 shaderModuleCreateInfo.flags = 0;
557 // Set size in bytes.
558 shaderModuleCreateInfo.codeSize = binarySize;
559 // Set pointer to the binary shader.
560 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
561 RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo,
562 nullptr, &shaderModule),
563 "vkCreateShaderModule");
564 return success();
565 }
566
initDescriptorSetLayoutBindingMap()567 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
568 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
569 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
570 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
571 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
572
573 // Create a layout binding for each descriptor.
574 for (const auto &memBuffer : deviceMemoryBuffers) {
575 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
576 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
577 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
578 descriptorSetLayoutBinding.descriptorCount = 1;
579 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
580 descriptorSetLayoutBinding.pImmutableSamplers = nullptr;
581 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
582 }
583 descriptorSetLayoutBindingMap[descriptorSetIndex] =
584 descriptorSetLayoutBindings;
585 }
586 }
587
createDescriptorSetLayout()588 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
589 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
590 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
591 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
592 // Each descriptor in a descriptor set must be the same type.
593 VkDescriptorType descriptorType =
594 deviceMemoryBuffers.front().descriptorType;
595 const uint32_t descriptorSize = deviceMemoryBuffers.size();
596 const auto descriptorSetLayoutBindingIt =
597 descriptorSetLayoutBindingMap.find(descriptorSetIndex);
598
599 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
600 std::cerr << "cannot find layout bindings for the set with number: "
601 << descriptorSetIndex;
602 return failure();
603 }
604
605 const auto &descriptorSetLayoutBindings =
606 descriptorSetLayoutBindingIt->second;
607 // Create descriptor set layout.
608 VkDescriptorSetLayout descriptorSetLayout = {};
609 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
610
611 descriptorSetLayoutCreateInfo.sType =
612 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
613 descriptorSetLayoutCreateInfo.pNext = nullptr;
614 descriptorSetLayoutCreateInfo.flags = 0;
615 // Amount of descriptor bindings in a layout set.
616 descriptorSetLayoutCreateInfo.bindingCount =
617 descriptorSetLayoutBindings.size();
618 descriptorSetLayoutCreateInfo.pBindings =
619 descriptorSetLayoutBindings.data();
620 RETURN_ON_VULKAN_ERROR(
621 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo,
622 nullptr, &descriptorSetLayout),
623 "vkCreateDescriptorSetLayout");
624
625 descriptorSetLayouts.push_back(descriptorSetLayout);
626 descriptorSetInfoPool.push_back(
627 {descriptorSetIndex, descriptorSize, descriptorType});
628 }
629 return success();
630 }
631
createPipelineLayout()632 LogicalResult VulkanRuntime::createPipelineLayout() {
633 // Associate descriptor sets with a pipeline layout.
634 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
635 pipelineLayoutCreateInfo.sType =
636 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
637 pipelineLayoutCreateInfo.pNext = nullptr;
638 pipelineLayoutCreateInfo.flags = 0;
639 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
640 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
641 pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
642 pipelineLayoutCreateInfo.pPushConstantRanges = nullptr;
643 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
644 &pipelineLayoutCreateInfo,
645 nullptr, &pipelineLayout),
646 "vkCreatePipelineLayout");
647 return success();
648 }
649
createComputePipeline()650 LogicalResult VulkanRuntime::createComputePipeline() {
651 VkPipelineShaderStageCreateInfo stageInfo = {};
652 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
653 stageInfo.pNext = nullptr;
654 stageInfo.flags = 0;
655 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
656 stageInfo.module = shaderModule;
657 // Set entry point.
658 stageInfo.pName = entryPoint;
659 stageInfo.pSpecializationInfo = nullptr;
660
661 VkComputePipelineCreateInfo computePipelineCreateInfo = {};
662 computePipelineCreateInfo.sType =
663 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
664 computePipelineCreateInfo.pNext = nullptr;
665 computePipelineCreateInfo.flags = 0;
666 computePipelineCreateInfo.stage = stageInfo;
667 computePipelineCreateInfo.layout = pipelineLayout;
668 computePipelineCreateInfo.basePipelineHandle = nullptr;
669 computePipelineCreateInfo.basePipelineIndex = 0;
670 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1,
671 &computePipelineCreateInfo,
672 nullptr, &pipeline),
673 "vkCreateComputePipelines");
674 return success();
675 }
676
createDescriptorPool()677 LogicalResult VulkanRuntime::createDescriptorPool() {
678 std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
679 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
680 // For each descriptor set populate descriptor pool size.
681 VkDescriptorPoolSize descriptorPoolSize = {};
682 descriptorPoolSize.type = descriptorSetInfo.descriptorType;
683 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
684 descriptorPoolSizes.push_back(descriptorPoolSize);
685 }
686
687 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
688 descriptorPoolCreateInfo.sType =
689 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
690 descriptorPoolCreateInfo.pNext = nullptr;
691 descriptorPoolCreateInfo.flags = 0;
692 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
693 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
694 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
695 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
696 &descriptorPoolCreateInfo,
697 nullptr, &descriptorPool),
698 "vkCreateDescriptorPool");
699 return success();
700 }
701
allocateDescriptorSets()702 LogicalResult VulkanRuntime::allocateDescriptorSets() {
703 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
704 // Size of descriptor sets and descriptor layout sets is the same.
705 descriptorSets.resize(descriptorSetLayouts.size());
706 descriptorSetAllocateInfo.sType =
707 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
708 descriptorSetAllocateInfo.pNext = nullptr;
709 descriptorSetAllocateInfo.descriptorPool = descriptorPool;
710 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
711 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
712 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
713 &descriptorSetAllocateInfo,
714 descriptorSets.data()),
715 "vkAllocateDescriptorSets");
716 return success();
717 }
718
setWriteDescriptors()719 LogicalResult VulkanRuntime::setWriteDescriptors() {
720 if (descriptorSets.size() != descriptorSetInfoPool.size()) {
721 std::cerr << "Each descriptor set must have descriptor set information";
722 return failure();
723 }
724 // For each descriptor set.
725 auto descriptorSetIt = descriptorSets.begin();
726 // Each descriptor set is associated with descriptor set info.
727 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
728 // For each device memory buffer in the descriptor set.
729 const auto &deviceMemoryBuffers =
730 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
731 for (const auto &memoryBuffer : deviceMemoryBuffers) {
732 // Structure describing descriptor sets to write to.
733 VkWriteDescriptorSet wSet = {};
734 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
735 wSet.pNext = nullptr;
736 // Descriptor set.
737 wSet.dstSet = *descriptorSetIt;
738 wSet.dstBinding = memoryBuffer.bindingIndex;
739 wSet.dstArrayElement = 0;
740 wSet.descriptorCount = 1;
741 wSet.descriptorType = memoryBuffer.descriptorType;
742 wSet.pImageInfo = nullptr;
743 wSet.pBufferInfo = &memoryBuffer.bufferInfo;
744 wSet.pTexelBufferView = nullptr;
745 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
746 }
747 // Increment descriptor set iterator.
748 ++descriptorSetIt;
749 }
750 return success();
751 }
752
createCommandPool()753 LogicalResult VulkanRuntime::createCommandPool() {
754 VkCommandPoolCreateInfo commandPoolCreateInfo = {};
755 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
756 commandPoolCreateInfo.pNext = nullptr;
757 commandPoolCreateInfo.flags = 0;
758 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
759 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
760 /*pAllocator=*/nullptr,
761 &commandPool),
762 "vkCreateCommandPool");
763 return success();
764 }
765
createQueryPool()766 LogicalResult VulkanRuntime::createQueryPool() {
767 // Return directly if timestamp query is not supported.
768 if (queueFamilyProperties.timestampValidBits == 0)
769 return success();
770
771 // Get timestamp period for this physical device.
772 VkPhysicalDeviceProperties deviceProperties = {};
773 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
774 timestampPeriod = deviceProperties.limits.timestampPeriod;
775
776 // Create query pool.
777 VkQueryPoolCreateInfo queryPoolCreateInfo = {};
778 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
779 queryPoolCreateInfo.pNext = nullptr;
780 queryPoolCreateInfo.flags = 0;
781 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
782 queryPoolCreateInfo.queryCount = 2;
783 queryPoolCreateInfo.pipelineStatistics = 0;
784 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
785 /*pAllocator=*/nullptr, &queryPool),
786 "vkCreateQueryPool");
787
788 return success();
789 }
790
createComputeCommandBuffer()791 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
792 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
793 commandBufferAllocateInfo.sType =
794 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
795 commandBufferAllocateInfo.pNext = nullptr;
796 commandBufferAllocateInfo.commandPool = commandPool;
797 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
798 commandBufferAllocateInfo.commandBufferCount = 1;
799
800 VkCommandBuffer commandBuffer;
801 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
802 &commandBufferAllocateInfo,
803 &commandBuffer),
804 "vkAllocateCommandBuffers");
805
806 VkCommandBufferBeginInfo commandBufferBeginInfo = {};
807 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
808 commandBufferBeginInfo.pNext = nullptr;
809 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
810 commandBufferBeginInfo.pInheritanceInfo = nullptr;
811
812 // Commands begin.
813 RETURN_ON_VULKAN_ERROR(
814 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
815 "vkBeginCommandBuffer");
816
817 if (queryPool != VK_NULL_HANDLE)
818 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
819
820 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
821 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
822 pipelineLayout, 0, descriptorSets.size(),
823 descriptorSets.data(), 0, nullptr);
824 // Get a timestamp before invoking the compute shader.
825 if (queryPool != VK_NULL_HANDLE)
826 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
827 queryPool, 0);
828 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
829 numWorkGroups.z);
830 // Get another timestamp after invoking the compute shader.
831 if (queryPool != VK_NULL_HANDLE)
832 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
833 queryPool, 1);
834
835 // Commands end.
836 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
837 "vkEndCommandBuffer");
838
839 commandBuffers.push_back(commandBuffer);
840 return success();
841 }
842
submitCommandBuffersToQueue()843 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
844 VkSubmitInfo submitInfo = {};
845 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
846 submitInfo.pNext = nullptr;
847 submitInfo.waitSemaphoreCount = 0;
848 submitInfo.pWaitSemaphores = nullptr;
849 submitInfo.pWaitDstStageMask = nullptr;
850 submitInfo.commandBufferCount = commandBuffers.size();
851 submitInfo.pCommandBuffers = commandBuffers.data();
852 submitInfo.signalSemaphoreCount = 0;
853 submitInfo.pSignalSemaphores = nullptr;
854 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr),
855 "vkQueueSubmit");
856 return success();
857 }
858
updateHostMemoryBuffers()859 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
860 // First copy back the data to the staging buffer.
861 (void)copyResource(/*deviceToHost=*/true);
862
863 // For each descriptor set.
864 for (auto &resourceDataMapPair : resourceData) {
865 auto &resourceDataMap = resourceDataMapPair.second;
866 auto &deviceMemoryBuffers =
867 deviceMemoryBufferMap[resourceDataMapPair.first];
868 // For each device memory buffer in the set.
869 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
870 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
871 void *payload;
872 auto &hostMemoryBuffer =
873 resourceDataMap[deviceMemoryBuffer.bindingIndex];
874 RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
875 deviceMemoryBuffer.hostMemory, 0,
876 hostMemoryBuffer.size, 0,
877 reinterpret_cast<void **>(&payload)),
878 "vkMapMemory");
879 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
880 vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
881 }
882 }
883 }
884 return success();
885 }
886