1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VulkanRuntime.h"
15 
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22 
emitVulkanError(const char * api,VkResult error)23 inline void emitVulkanError(const char *api, VkResult error) {
24   std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26 
27 #define RETURN_ON_VULKAN_ERROR(result, api)                                    \
28   if ((result) != VK_SUCCESS) {                                                \
29     emitVulkanError(api, (result));                                            \
30     return failure();                                                          \
31   }
32 
33 using namespace mlir;
34 
setNumWorkGroups(const NumWorkGroups & numberWorkGroups)35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36   numWorkGroups = numberWorkGroups;
37 }
38 
setResourceStorageClassBindingMap(const ResourceStorageClassBindingMap & stClassData)39 void VulkanRuntime::setResourceStorageClassBindingMap(
40     const ResourceStorageClassBindingMap &stClassData) {
41   resourceStorageClassData = stClassData;
42 }
43 
setResourceData(const DescriptorSetIndex desIndex,const BindingIndex bindIndex,const VulkanHostMemoryBuffer & hostMemBuffer)44 void VulkanRuntime::setResourceData(
45     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46     const VulkanHostMemoryBuffer &hostMemBuffer) {
47   resourceData[desIndex][bindIndex] = hostMemBuffer;
48   resourceStorageClassData[desIndex][bindIndex] =
49       SPIRVStorageClass::StorageBuffer;
50 }
51 
setEntryPoint(const char * entryPointName)52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53   entryPoint = entryPointName;
54 }
55 
setResourceData(const ResourceData & resData)56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57   resourceData = resData;
58 }
59 
setShaderModule(uint8_t * shader,uint32_t size)60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61   binary = shader;
62   binarySize = size;
63 }
64 
mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,VkDescriptorType & descriptorType)65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66     SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67   switch (storageClass) {
68   case SPIRVStorageClass::StorageBuffer:
69     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70     break;
71   case SPIRVStorageClass::Uniform:
72     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73     break;
74   }
75   return success();
76 }
77 
mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,VkBufferUsageFlagBits & bufferUsage)78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79     SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80   switch (storageClass) {
81   case SPIRVStorageClass::StorageBuffer:
82     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83     break;
84   case SPIRVStorageClass::Uniform:
85     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86     break;
87   }
88   return success();
89 }
90 
countDeviceMemorySize()91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92   for (const auto &resourceDataMapPair : resourceData) {
93     const auto &resourceDataMap = resourceDataMapPair.second;
94     for (const auto &resourceDataBindingPair : resourceDataMap) {
95       if (resourceDataBindingPair.second.size) {
96         memorySize += resourceDataBindingPair.second.size;
97       } else {
98         std::cerr << "expected buffer size greater than zero for resource data";
99         return failure();
100       }
101     }
102   }
103   return success();
104 }
105 
initRuntime()106 LogicalResult VulkanRuntime::initRuntime() {
107   if (!resourceData.size()) {
108     std::cerr << "Vulkan runtime needs at least one resource";
109     return failure();
110   }
111   if (!binarySize || !binary) {
112     std::cerr << "binary shader size must be greater than zero";
113     return failure();
114   }
115   if (failed(countDeviceMemorySize())) {
116     return failure();
117   }
118   return success();
119 }
120 
destroy()121 LogicalResult VulkanRuntime::destroy() {
122   // According to Vulkan spec:
123   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124   // used to gate the destruction of the device. Prior to destroying a device,
125   // an application is responsible for destroying/freeing any Vulkan objects
126   // that were created using that device as the first parameter of the
127   // corresponding vkCreate* or vkAllocate* command."
128   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129 
130   // Free and destroy.
131   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132                        commandBuffers.data());
133   vkDestroyQueryPool(device, queryPool, nullptr);
134   vkDestroyCommandPool(device, commandPool, nullptr);
135   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136                        descriptorSets.data());
137   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138   vkDestroyPipeline(device, pipeline, nullptr);
139   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140   for (auto &descriptorSetLayout : descriptorSetLayouts) {
141     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142   }
143   vkDestroyShaderModule(device, shaderModule, nullptr);
144 
145   // For each descriptor set.
146   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148     // For each descriptor binding.
149     for (auto &memoryBuffer : deviceMemoryBuffers) {
150       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151       vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152       vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153       vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154     }
155   }
156 
157   vkDestroyDevice(device, nullptr);
158   vkDestroyInstance(instance, nullptr);
159   return success();
160 }
161 
run()162 LogicalResult VulkanRuntime::run() {
163   // Create logical device, shader module and memory buffers.
164   if (failed(createInstance()) || failed(createDevice()) ||
165       failed(createMemoryBuffers()) || failed(createShaderModule())) {
166     return failure();
167   }
168 
169   // Descriptor bindings divided into sets. Each descriptor binding
170   // must have a layout binding attached into a descriptor set layout.
171   // Each layout set must be binded into a pipeline layout.
172   initDescriptorSetLayoutBindingMap();
173   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174       // Each descriptor set must be allocated from a descriptor pool.
175       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177       // Create command buffer.
178       failed(createCommandPool()) || failed(createQueryPool()) ||
179       failed(createComputeCommandBuffer())) {
180     return failure();
181   }
182 
183   // Get working queue.
184   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185 
186   if (failed(copyResource(/*deviceToHost=*/false)))
187     return failure();
188 
189   auto submitStart = std::chrono::high_resolution_clock::now();
190   // Submit command buffer into the queue.
191   if (failed(submitCommandBuffersToQueue()))
192     return failure();
193   auto submitEnd = std::chrono::high_resolution_clock::now();
194 
195   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196   auto execEnd = std::chrono::high_resolution_clock::now();
197 
198   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199       submitEnd - submitStart);
200   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201       execEnd - submitEnd);
202 
203   if (queryPool != VK_NULL_HANDLE) {
204     uint64_t timestamps[2];
205     RETURN_ON_VULKAN_ERROR(
206         vkGetQueryPoolResults(
207             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208             /*dataSize=*/sizeof(timestamps),
209             /*pData=*/reinterpret_cast<void *>(timestamps),
210             /*stride=*/sizeof(uint64_t),
211             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212         "vkGetQueryPoolResults");
213     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214     std::cout << "Compute shader execution time: " << std::setprecision(3)
215               << microsec << "us\n";
216   }
217 
218   std::cout << "Command buffer submit time: " << submitDuration.count()
219             << "us\nWait idle time: " << execDuration.count() << "us\n";
220 
221   return success();
222 }
223 
createInstance()224 LogicalResult VulkanRuntime::createInstance() {
225   VkApplicationInfo applicationInfo = {};
226   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227   applicationInfo.pNext = nullptr;
228   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229   applicationInfo.applicationVersion = 0;
230   applicationInfo.pEngineName = "mlir";
231   applicationInfo.engineVersion = 0;
232   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233 
234   VkInstanceCreateInfo instanceCreateInfo = {};
235   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236   instanceCreateInfo.pNext = nullptr;
237   instanceCreateInfo.flags = 0;
238   instanceCreateInfo.pApplicationInfo = &applicationInfo;
239   instanceCreateInfo.enabledLayerCount = 0;
240   instanceCreateInfo.ppEnabledLayerNames = nullptr;
241   instanceCreateInfo.enabledExtensionCount = 0;
242   instanceCreateInfo.ppEnabledExtensionNames = nullptr;
243 
244   RETURN_ON_VULKAN_ERROR(
245       vkCreateInstance(&instanceCreateInfo, nullptr, &instance),
246       "vkCreateInstance");
247   return success();
248 }
249 
createDevice()250 LogicalResult VulkanRuntime::createDevice() {
251   uint32_t physicalDeviceCount = 0;
252   RETURN_ON_VULKAN_ERROR(
253       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr),
254       "vkEnumeratePhysicalDevices");
255 
256   std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
257   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
258                                                     &physicalDeviceCount,
259                                                     physicalDevices.data()),
260                          "vkEnumeratePhysicalDevices");
261 
262   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
263                          "physicalDeviceCount");
264 
265   // TODO: find the best device.
266   physicalDevice = physicalDevices.front();
267   if (failed(getBestComputeQueue()))
268     return failure();
269 
270   const float queuePriority = 1.0f;
271   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
272   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
273   deviceQueueCreateInfo.pNext = nullptr;
274   deviceQueueCreateInfo.flags = 0;
275   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
276   deviceQueueCreateInfo.queueCount = 1;
277   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
278 
279   // Structure specifying parameters of a newly created device.
280   VkDeviceCreateInfo deviceCreateInfo = {};
281   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
282   deviceCreateInfo.pNext = nullptr;
283   deviceCreateInfo.flags = 0;
284   deviceCreateInfo.queueCreateInfoCount = 1;
285   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
286   deviceCreateInfo.enabledLayerCount = 0;
287   deviceCreateInfo.ppEnabledLayerNames = nullptr;
288   deviceCreateInfo.enabledExtensionCount = 0;
289   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
290   deviceCreateInfo.pEnabledFeatures = nullptr;
291 
292   RETURN_ON_VULKAN_ERROR(
293       vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device),
294       "vkCreateDevice");
295 
296   VkPhysicalDeviceMemoryProperties properties = {};
297   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
298 
299   // Try to find memory type with following properties:
300   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
301   // with this type can be mapped for host access using vkMapMemory;
302   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
303   // management commands vkFlushMappedMemoryRanges and
304   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
305   // device or make device writes visible to the host, respectively.
306   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
307     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
308          properties.memoryTypes[i].propertyFlags) &&
309         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
310          properties.memoryTypes[i].propertyFlags) &&
311         (memorySize <=
312          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
313       hostMemoryTypeIndex = i;
314       break;
315     }
316   }
317 
318   // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
319   // used on the device. This will allow better performance access for GPU with
320   // on device memory.
321   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
322     if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
323          properties.memoryTypes[i].propertyFlags) &&
324         (memorySize <=
325          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
326       deviceMemoryTypeIndex = i;
327       break;
328     }
329   }
330 
331   RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
332                           deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
333                              ? VK_INCOMPLETE
334                              : VK_SUCCESS,
335                          "invalid memoryTypeIndex");
336   return success();
337 }
338 
getBestComputeQueue()339 LogicalResult VulkanRuntime::getBestComputeQueue() {
340   uint32_t queueFamilyPropertiesCount = 0;
341   vkGetPhysicalDeviceQueueFamilyProperties(
342       physicalDevice, &queueFamilyPropertiesCount, nullptr);
343 
344   std::vector<VkQueueFamilyProperties> familyProperties(
345       queueFamilyPropertiesCount);
346   vkGetPhysicalDeviceQueueFamilyProperties(
347       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
348 
349   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
350   // compute operations. Try to find a compute-only queue first if possible.
351   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
352     auto flags = familyProperties[i].queueFlags;
353     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
354       queueFamilyIndex = i;
355       queueFamilyProperties = familyProperties[i];
356       return success();
357     }
358   }
359 
360   // Otherwise use a queue that can also support graphics.
361   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
362     auto flags = familyProperties[i].queueFlags;
363     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
364       queueFamilyIndex = i;
365       queueFamilyProperties = familyProperties[i];
366       return success();
367     }
368   }
369 
370   std::cerr << "cannot find valid queue";
371   return failure();
372 }
373 
createMemoryBuffers()374 LogicalResult VulkanRuntime::createMemoryBuffers() {
375   // For each descriptor set.
376   for (const auto &resourceDataMapPair : resourceData) {
377     std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
378     const auto descriptorSetIndex = resourceDataMapPair.first;
379     const auto &resourceDataMap = resourceDataMapPair.second;
380 
381     // For each descriptor binding.
382     for (const auto &resourceDataBindingPair : resourceDataMap) {
383       // Create device memory buffer.
384       VulkanDeviceMemoryBuffer memoryBuffer;
385       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
386       VkDescriptorType descriptorType = {};
387       VkBufferUsageFlagBits bufferUsage = {};
388 
389       // Check that descriptor set has storage class map.
390       const auto resourceStorageClassMapIt =
391           resourceStorageClassData.find(descriptorSetIndex);
392       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
393         std::cerr
394             << "cannot find storage class for resource in descriptor set: "
395             << descriptorSetIndex;
396         return failure();
397       }
398 
399       // Check that specific descriptor binding has storage class.
400       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
401       const auto resourceStorageClassIt =
402           resourceStorageClassMap.find(resourceDataBindingPair.first);
403       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
404         std::cerr
405             << "cannot find storage class for resource with descriptor index: "
406             << resourceDataBindingPair.first;
407         return failure();
408       }
409 
410       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
411       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
412                                                  descriptorType)) ||
413           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
414                                                   bufferUsage))) {
415         std::cerr << "storage class for resource with descriptor binding: "
416                   << resourceDataBindingPair.first
417                   << " in the descriptor set: " << descriptorSetIndex
418                   << " is not supported ";
419         return failure();
420       }
421 
422       // Set descriptor type for the specific device memory buffer.
423       memoryBuffer.descriptorType = descriptorType;
424       const auto bufferSize = resourceDataBindingPair.second.size;
425       memoryBuffer.bufferSize = bufferSize;
426       // Specify memory allocation info.
427       VkMemoryAllocateInfo memoryAllocateInfo = {};
428       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
429       memoryAllocateInfo.pNext = nullptr;
430       memoryAllocateInfo.allocationSize = bufferSize;
431       memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
432 
433       // Allocate device memory.
434       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
435                                               nullptr,
436                                               &memoryBuffer.hostMemory),
437                              "vkAllocateMemory");
438       memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
439       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
440                                               nullptr,
441                                               &memoryBuffer.deviceMemory),
442                              "vkAllocateMemory");
443       void *payload;
444       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
445                                          bufferSize, 0,
446                                          reinterpret_cast<void **>(&payload)),
447                              "vkMapMemory");
448 
449       // Copy host memory into the mapped area.
450       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
451       vkUnmapMemory(device, memoryBuffer.hostMemory);
452 
453       VkBufferCreateInfo bufferCreateInfo = {};
454       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
455       bufferCreateInfo.pNext = nullptr;
456       bufferCreateInfo.flags = 0;
457       bufferCreateInfo.size = bufferSize;
458       bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
459                                VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
460       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
461       bufferCreateInfo.queueFamilyIndexCount = 1;
462       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
463       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
464                                             &memoryBuffer.hostBuffer),
465                              "vkCreateBuffer");
466       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
467                                             &memoryBuffer.deviceBuffer),
468                              "vkCreateBuffer");
469 
470       // Bind buffer and device memory.
471       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
472                                                 memoryBuffer.hostMemory, 0),
473                              "vkBindBufferMemory");
474       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
475                                                 memoryBuffer.deviceBuffer,
476                                                 memoryBuffer.deviceMemory, 0),
477                              "vkBindBufferMemory");
478 
479       // Update buffer info.
480       memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
481       memoryBuffer.bufferInfo.offset = 0;
482       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
483       deviceMemoryBuffers.push_back(memoryBuffer);
484     }
485 
486     // Associate device memory buffers with a descriptor set.
487     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
488   }
489   return success();
490 }
491 
copyResource(bool deviceToHost)492 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
493   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
494       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
495       nullptr,
496       commandPool,
497       VK_COMMAND_BUFFER_LEVEL_PRIMARY,
498       1,
499   };
500   VkCommandBuffer commandBuffer;
501   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
502                                                   &commandBufferAllocateInfo,
503                                                   &commandBuffer),
504                          "vkAllocateCommandBuffers");
505 
506   VkCommandBufferBeginInfo commandBufferBeginInfo = {
507       VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
508       nullptr,
509       0,
510       nullptr,
511   };
512   RETURN_ON_VULKAN_ERROR(
513       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
514       "vkBeginCommandBuffer");
515 
516   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
517     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
518     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
519     for (const auto &memBuffer : deviceMemoryBuffers) {
520       VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
521       if (deviceToHost)
522         vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
523                         memBuffer.hostBuffer, 1, &copy);
524       else
525         vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
526                         memBuffer.deviceBuffer, 1, &copy);
527     }
528   }
529 
530   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
531                          "vkEndCommandBuffer");
532   VkSubmitInfo submitInfo = {
533       VK_STRUCTURE_TYPE_SUBMIT_INFO,
534       nullptr,
535       0,
536       nullptr,
537       nullptr,
538       1,
539       &commandBuffer,
540       0,
541       nullptr,
542   };
543   submitInfo.pCommandBuffers = &commandBuffer;
544   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
545                          "vkQueueSubmit");
546   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
547 
548   vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
549   return success();
550 }
551 
createShaderModule()552 LogicalResult VulkanRuntime::createShaderModule() {
553   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
554   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
555   shaderModuleCreateInfo.pNext = nullptr;
556   shaderModuleCreateInfo.flags = 0;
557   // Set size in bytes.
558   shaderModuleCreateInfo.codeSize = binarySize;
559   // Set pointer to the binary shader.
560   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
561   RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo,
562                                               nullptr, &shaderModule),
563                          "vkCreateShaderModule");
564   return success();
565 }
566 
initDescriptorSetLayoutBindingMap()567 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
568   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
569     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
570     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
571     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
572 
573     // Create a layout binding for each descriptor.
574     for (const auto &memBuffer : deviceMemoryBuffers) {
575       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
576       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
577       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
578       descriptorSetLayoutBinding.descriptorCount = 1;
579       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
580       descriptorSetLayoutBinding.pImmutableSamplers = nullptr;
581       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
582     }
583     descriptorSetLayoutBindingMap[descriptorSetIndex] =
584         descriptorSetLayoutBindings;
585   }
586 }
587 
createDescriptorSetLayout()588 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
589   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
590     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
591     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
592     // Each descriptor in a descriptor set must be the same type.
593     VkDescriptorType descriptorType =
594         deviceMemoryBuffers.front().descriptorType;
595     const uint32_t descriptorSize = deviceMemoryBuffers.size();
596     const auto descriptorSetLayoutBindingIt =
597         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
598 
599     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
600       std::cerr << "cannot find layout bindings for the set with number: "
601                 << descriptorSetIndex;
602       return failure();
603     }
604 
605     const auto &descriptorSetLayoutBindings =
606         descriptorSetLayoutBindingIt->second;
607     // Create descriptor set layout.
608     VkDescriptorSetLayout descriptorSetLayout = {};
609     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
610 
611     descriptorSetLayoutCreateInfo.sType =
612         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
613     descriptorSetLayoutCreateInfo.pNext = nullptr;
614     descriptorSetLayoutCreateInfo.flags = 0;
615     // Amount of descriptor bindings in a layout set.
616     descriptorSetLayoutCreateInfo.bindingCount =
617         descriptorSetLayoutBindings.size();
618     descriptorSetLayoutCreateInfo.pBindings =
619         descriptorSetLayoutBindings.data();
620     RETURN_ON_VULKAN_ERROR(
621         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo,
622                                     nullptr, &descriptorSetLayout),
623         "vkCreateDescriptorSetLayout");
624 
625     descriptorSetLayouts.push_back(descriptorSetLayout);
626     descriptorSetInfoPool.push_back(
627         {descriptorSetIndex, descriptorSize, descriptorType});
628   }
629   return success();
630 }
631 
createPipelineLayout()632 LogicalResult VulkanRuntime::createPipelineLayout() {
633   // Associate descriptor sets with a pipeline layout.
634   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
635   pipelineLayoutCreateInfo.sType =
636       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
637   pipelineLayoutCreateInfo.pNext = nullptr;
638   pipelineLayoutCreateInfo.flags = 0;
639   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
640   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
641   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
642   pipelineLayoutCreateInfo.pPushConstantRanges = nullptr;
643   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
644                                                 &pipelineLayoutCreateInfo,
645                                                 nullptr, &pipelineLayout),
646                          "vkCreatePipelineLayout");
647   return success();
648 }
649 
createComputePipeline()650 LogicalResult VulkanRuntime::createComputePipeline() {
651   VkPipelineShaderStageCreateInfo stageInfo = {};
652   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
653   stageInfo.pNext = nullptr;
654   stageInfo.flags = 0;
655   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
656   stageInfo.module = shaderModule;
657   // Set entry point.
658   stageInfo.pName = entryPoint;
659   stageInfo.pSpecializationInfo = nullptr;
660 
661   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
662   computePipelineCreateInfo.sType =
663       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
664   computePipelineCreateInfo.pNext = nullptr;
665   computePipelineCreateInfo.flags = 0;
666   computePipelineCreateInfo.stage = stageInfo;
667   computePipelineCreateInfo.layout = pipelineLayout;
668   computePipelineCreateInfo.basePipelineHandle = nullptr;
669   computePipelineCreateInfo.basePipelineIndex = 0;
670   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1,
671                                                   &computePipelineCreateInfo,
672                                                   nullptr, &pipeline),
673                          "vkCreateComputePipelines");
674   return success();
675 }
676 
createDescriptorPool()677 LogicalResult VulkanRuntime::createDescriptorPool() {
678   std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
679   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
680     // For each descriptor set populate descriptor pool size.
681     VkDescriptorPoolSize descriptorPoolSize = {};
682     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
683     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
684     descriptorPoolSizes.push_back(descriptorPoolSize);
685   }
686 
687   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
688   descriptorPoolCreateInfo.sType =
689       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
690   descriptorPoolCreateInfo.pNext = nullptr;
691   descriptorPoolCreateInfo.flags = 0;
692   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
693   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
694   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
695   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
696                                                 &descriptorPoolCreateInfo,
697                                                 nullptr, &descriptorPool),
698                          "vkCreateDescriptorPool");
699   return success();
700 }
701 
allocateDescriptorSets()702 LogicalResult VulkanRuntime::allocateDescriptorSets() {
703   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
704   // Size of descriptor sets and descriptor layout sets is the same.
705   descriptorSets.resize(descriptorSetLayouts.size());
706   descriptorSetAllocateInfo.sType =
707       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
708   descriptorSetAllocateInfo.pNext = nullptr;
709   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
710   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
711   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
712   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
713                                                   &descriptorSetAllocateInfo,
714                                                   descriptorSets.data()),
715                          "vkAllocateDescriptorSets");
716   return success();
717 }
718 
setWriteDescriptors()719 LogicalResult VulkanRuntime::setWriteDescriptors() {
720   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
721     std::cerr << "Each descriptor set must have descriptor set information";
722     return failure();
723   }
724   // For each descriptor set.
725   auto descriptorSetIt = descriptorSets.begin();
726   // Each descriptor set is associated with descriptor set info.
727   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
728     // For each device memory buffer in the descriptor set.
729     const auto &deviceMemoryBuffers =
730         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
731     for (const auto &memoryBuffer : deviceMemoryBuffers) {
732       // Structure describing descriptor sets to write to.
733       VkWriteDescriptorSet wSet = {};
734       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
735       wSet.pNext = nullptr;
736       // Descriptor set.
737       wSet.dstSet = *descriptorSetIt;
738       wSet.dstBinding = memoryBuffer.bindingIndex;
739       wSet.dstArrayElement = 0;
740       wSet.descriptorCount = 1;
741       wSet.descriptorType = memoryBuffer.descriptorType;
742       wSet.pImageInfo = nullptr;
743       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
744       wSet.pTexelBufferView = nullptr;
745       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
746     }
747     // Increment descriptor set iterator.
748     ++descriptorSetIt;
749   }
750   return success();
751 }
752 
createCommandPool()753 LogicalResult VulkanRuntime::createCommandPool() {
754   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
755   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
756   commandPoolCreateInfo.pNext = nullptr;
757   commandPoolCreateInfo.flags = 0;
758   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
759   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
760                                              /*pAllocator=*/nullptr,
761                                              &commandPool),
762                          "vkCreateCommandPool");
763   return success();
764 }
765 
createQueryPool()766 LogicalResult VulkanRuntime::createQueryPool() {
767   // Return directly if timestamp query is not supported.
768   if (queueFamilyProperties.timestampValidBits == 0)
769     return success();
770 
771   // Get timestamp period for this physical device.
772   VkPhysicalDeviceProperties deviceProperties = {};
773   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
774   timestampPeriod = deviceProperties.limits.timestampPeriod;
775 
776   // Create query pool.
777   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
778   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
779   queryPoolCreateInfo.pNext = nullptr;
780   queryPoolCreateInfo.flags = 0;
781   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
782   queryPoolCreateInfo.queryCount = 2;
783   queryPoolCreateInfo.pipelineStatistics = 0;
784   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
785                                            /*pAllocator=*/nullptr, &queryPool),
786                          "vkCreateQueryPool");
787 
788   return success();
789 }
790 
createComputeCommandBuffer()791 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
792   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
793   commandBufferAllocateInfo.sType =
794       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
795   commandBufferAllocateInfo.pNext = nullptr;
796   commandBufferAllocateInfo.commandPool = commandPool;
797   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
798   commandBufferAllocateInfo.commandBufferCount = 1;
799 
800   VkCommandBuffer commandBuffer;
801   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
802                                                   &commandBufferAllocateInfo,
803                                                   &commandBuffer),
804                          "vkAllocateCommandBuffers");
805 
806   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
807   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
808   commandBufferBeginInfo.pNext = nullptr;
809   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
810   commandBufferBeginInfo.pInheritanceInfo = nullptr;
811 
812   // Commands begin.
813   RETURN_ON_VULKAN_ERROR(
814       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
815       "vkBeginCommandBuffer");
816 
817   if (queryPool != VK_NULL_HANDLE)
818     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
819 
820   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
821   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
822                           pipelineLayout, 0, descriptorSets.size(),
823                           descriptorSets.data(), 0, nullptr);
824   // Get a timestamp before invoking the compute shader.
825   if (queryPool != VK_NULL_HANDLE)
826     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
827                         queryPool, 0);
828   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
829                 numWorkGroups.z);
830   // Get another timestamp after invoking the compute shader.
831   if (queryPool != VK_NULL_HANDLE)
832     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
833                         queryPool, 1);
834 
835   // Commands end.
836   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
837                          "vkEndCommandBuffer");
838 
839   commandBuffers.push_back(commandBuffer);
840   return success();
841 }
842 
submitCommandBuffersToQueue()843 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
844   VkSubmitInfo submitInfo = {};
845   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
846   submitInfo.pNext = nullptr;
847   submitInfo.waitSemaphoreCount = 0;
848   submitInfo.pWaitSemaphores = nullptr;
849   submitInfo.pWaitDstStageMask = nullptr;
850   submitInfo.commandBufferCount = commandBuffers.size();
851   submitInfo.pCommandBuffers = commandBuffers.data();
852   submitInfo.signalSemaphoreCount = 0;
853   submitInfo.pSignalSemaphores = nullptr;
854   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr),
855                          "vkQueueSubmit");
856   return success();
857 }
858 
updateHostMemoryBuffers()859 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
860   // First copy back the data to the staging buffer.
861   (void)copyResource(/*deviceToHost=*/true);
862 
863   // For each descriptor set.
864   for (auto &resourceDataMapPair : resourceData) {
865     auto &resourceDataMap = resourceDataMapPair.second;
866     auto &deviceMemoryBuffers =
867         deviceMemoryBufferMap[resourceDataMapPair.first];
868     // For each device memory buffer in the set.
869     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
870       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
871         void *payload;
872         auto &hostMemoryBuffer =
873             resourceDataMap[deviceMemoryBuffer.bindingIndex];
874         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
875                                            deviceMemoryBuffer.hostMemory, 0,
876                                            hostMemoryBuffer.size, 0,
877                                            reinterpret_cast<void **>(&payload)),
878                                "vkMapMemory");
879         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
880         vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
881       }
882     }
883   }
884   return success();
885 }
886