1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VulkanRuntime.h"
15 
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22 
emitVulkanError(const char * api,VkResult error)23 inline void emitVulkanError(const char *api, VkResult error) {
24   std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26 
27 #define RETURN_ON_VULKAN_ERROR(result, api)                                    \
28   if ((result) != VK_SUCCESS) {                                                \
29     emitVulkanError(api, (result));                                            \
30     return failure();                                                          \
31   }
32 
33 using namespace mlir;
34 
setNumWorkGroups(const NumWorkGroups & numberWorkGroups)35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36   numWorkGroups = numberWorkGroups;
37 }
38 
setResourceStorageClassBindingMap(const ResourceStorageClassBindingMap & stClassData)39 void VulkanRuntime::setResourceStorageClassBindingMap(
40     const ResourceStorageClassBindingMap &stClassData) {
41   resourceStorageClassData = stClassData;
42 }
43 
setResourceData(const DescriptorSetIndex desIndex,const BindingIndex bindIndex,const VulkanHostMemoryBuffer & hostMemBuffer)44 void VulkanRuntime::setResourceData(
45     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46     const VulkanHostMemoryBuffer &hostMemBuffer) {
47   resourceData[desIndex][bindIndex] = hostMemBuffer;
48   resourceStorageClassData[desIndex][bindIndex] =
49       SPIRVStorageClass::StorageBuffer;
50 }
51 
setEntryPoint(const char * entryPointName)52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53   entryPoint = entryPointName;
54 }
55 
setResourceData(const ResourceData & resData)56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57   resourceData = resData;
58 }
59 
setShaderModule(uint8_t * shader,uint32_t size)60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61   binary = shader;
62   binarySize = size;
63 }
64 
mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,VkDescriptorType & descriptorType)65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66     SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67   switch (storageClass) {
68   case SPIRVStorageClass::StorageBuffer:
69     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70     break;
71   case SPIRVStorageClass::Uniform:
72     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73     break;
74   }
75   return success();
76 }
77 
mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,VkBufferUsageFlagBits & bufferUsage)78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79     SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80   switch (storageClass) {
81   case SPIRVStorageClass::StorageBuffer:
82     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83     break;
84   case SPIRVStorageClass::Uniform:
85     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86     break;
87   }
88   return success();
89 }
90 
countDeviceMemorySize()91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92   for (const auto &resourceDataMapPair : resourceData) {
93     const auto &resourceDataMap = resourceDataMapPair.second;
94     for (const auto &resourceDataBindingPair : resourceDataMap) {
95       if (resourceDataBindingPair.second.size) {
96         memorySize += resourceDataBindingPair.second.size;
97       } else {
98         std::cerr << "expected buffer size greater than zero for resource data";
99         return failure();
100       }
101     }
102   }
103   return success();
104 }
105 
initRuntime()106 LogicalResult VulkanRuntime::initRuntime() {
107   if (!resourceData.size()) {
108     std::cerr << "Vulkan runtime needs at least one resource";
109     return failure();
110   }
111   if (!binarySize || !binary) {
112     std::cerr << "binary shader size must be greater than zero";
113     return failure();
114   }
115   if (failed(countDeviceMemorySize())) {
116     return failure();
117   }
118   return success();
119 }
120 
destroy()121 LogicalResult VulkanRuntime::destroy() {
122   // According to Vulkan spec:
123   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124   // used to gate the destruction of the device. Prior to destroying a device,
125   // an application is responsible for destroying/freeing any Vulkan objects
126   // that were created using that device as the first parameter of the
127   // corresponding vkCreate* or vkAllocate* command."
128   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129 
130   // Free and destroy.
131   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132                        commandBuffers.data());
133   vkDestroyQueryPool(device, queryPool, nullptr);
134   vkDestroyCommandPool(device, commandPool, nullptr);
135   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136                        descriptorSets.data());
137   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138   vkDestroyPipeline(device, pipeline, nullptr);
139   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140   for (auto &descriptorSetLayout : descriptorSetLayouts) {
141     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142   }
143   vkDestroyShaderModule(device, shaderModule, nullptr);
144 
145   // For each descriptor set.
146   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148     // For each descriptor binding.
149     for (auto &memoryBuffer : deviceMemoryBuffers) {
150       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151       vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152       vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153       vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154     }
155   }
156 
157   vkDestroyDevice(device, nullptr);
158   vkDestroyInstance(instance, nullptr);
159   return success();
160 }
161 
run()162 LogicalResult VulkanRuntime::run() {
163   // Create logical device, shader module and memory buffers.
164   if (failed(createInstance()) || failed(createDevice()) ||
165       failed(createMemoryBuffers()) || failed(createShaderModule())) {
166     return failure();
167   }
168 
169   // Descriptor bindings divided into sets. Each descriptor binding
170   // must have a layout binding attached into a descriptor set layout.
171   // Each layout set must be binded into a pipeline layout.
172   initDescriptorSetLayoutBindingMap();
173   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174       // Each descriptor set must be allocated from a descriptor pool.
175       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177       // Create command buffer.
178       failed(createCommandPool()) || failed(createQueryPool()) ||
179       failed(createComputeCommandBuffer())) {
180     return failure();
181   }
182 
183   // Get working queue.
184   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185 
186   if (failed(copyResource(/*deviceToHost=*/false)))
187     return failure();
188 
189   auto submitStart = std::chrono::high_resolution_clock::now();
190   // Submit command buffer into the queue.
191   if (failed(submitCommandBuffersToQueue()))
192     return failure();
193   auto submitEnd = std::chrono::high_resolution_clock::now();
194 
195   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196   auto execEnd = std::chrono::high_resolution_clock::now();
197 
198   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199       submitEnd - submitStart);
200   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201       execEnd - submitEnd);
202 
203   if (queryPool != VK_NULL_HANDLE) {
204     uint64_t timestamps[2];
205     RETURN_ON_VULKAN_ERROR(
206         vkGetQueryPoolResults(
207             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208             /*dataSize=*/sizeof(timestamps),
209             /*pData=*/reinterpret_cast<void *>(timestamps),
210             /*stride=*/sizeof(uint64_t),
211             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212         "vkGetQueryPoolResults");
213     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214     std::cout << "Compute shader execution time: " << std::setprecision(3)
215               << microsec << "us\n";
216   }
217 
218   std::cout << "Command buffer submit time: " << submitDuration.count()
219             << "us\nWait idle time: " << execDuration.count() << "us\n";
220 
221   return success();
222 }
223 
createInstance()224 LogicalResult VulkanRuntime::createInstance() {
225   VkApplicationInfo applicationInfo = {};
226   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227   applicationInfo.pNext = nullptr;
228   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229   applicationInfo.applicationVersion = 0;
230   applicationInfo.pEngineName = "mlir";
231   applicationInfo.engineVersion = 0;
232   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233 
234   VkInstanceCreateInfo instanceCreateInfo = {};
235   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236   instanceCreateInfo.pNext = nullptr;
237   instanceCreateInfo.flags = 0;
238   instanceCreateInfo.pApplicationInfo = &applicationInfo;
239   instanceCreateInfo.enabledLayerCount = 0;
240   instanceCreateInfo.ppEnabledLayerNames = 0;
241   instanceCreateInfo.enabledExtensionCount = 0;
242   instanceCreateInfo.ppEnabledExtensionNames = 0;
243 
244   RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
245                          "vkCreateInstance");
246   return success();
247 }
248 
createDevice()249 LogicalResult VulkanRuntime::createDevice() {
250   uint32_t physicalDeviceCount = 0;
251   RETURN_ON_VULKAN_ERROR(
252       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
253       "vkEnumeratePhysicalDevices");
254 
255   std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
256   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
257                                                     &physicalDeviceCount,
258                                                     physicalDevices.data()),
259                          "vkEnumeratePhysicalDevices");
260 
261   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
262                          "physicalDeviceCount");
263 
264   // TODO: find the best device.
265   physicalDevice = physicalDevices.front();
266   if (failed(getBestComputeQueue()))
267     return failure();
268 
269   const float queuePriority = 1.0f;
270   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
271   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
272   deviceQueueCreateInfo.pNext = nullptr;
273   deviceQueueCreateInfo.flags = 0;
274   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
275   deviceQueueCreateInfo.queueCount = 1;
276   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
277 
278   // Structure specifying parameters of a newly created device.
279   VkDeviceCreateInfo deviceCreateInfo = {};
280   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
281   deviceCreateInfo.pNext = nullptr;
282   deviceCreateInfo.flags = 0;
283   deviceCreateInfo.queueCreateInfoCount = 1;
284   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
285   deviceCreateInfo.enabledLayerCount = 0;
286   deviceCreateInfo.ppEnabledLayerNames = nullptr;
287   deviceCreateInfo.enabledExtensionCount = 0;
288   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
289   deviceCreateInfo.pEnabledFeatures = nullptr;
290 
291   RETURN_ON_VULKAN_ERROR(
292       vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
293       "vkCreateDevice");
294 
295   VkPhysicalDeviceMemoryProperties properties = {};
296   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
297 
298   // Try to find memory type with following properties:
299   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
300   // with this type can be mapped for host access using vkMapMemory;
301   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
302   // management commands vkFlushMappedMemoryRanges and
303   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
304   // device or make device writes visible to the host, respectively.
305   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
306     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
307          properties.memoryTypes[i].propertyFlags) &&
308         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
309          properties.memoryTypes[i].propertyFlags) &&
310         (memorySize <=
311          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
312       hostMemoryTypeIndex = i;
313       break;
314     }
315   }
316 
317   // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
318   // used on the device. This will allow better performance access for GPU with
319   // on device memory.
320   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
321     if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
322          properties.memoryTypes[i].propertyFlags) &&
323         (memorySize <=
324          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
325       deviceMemoryTypeIndex = i;
326       break;
327     }
328   }
329 
330   RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
331                           deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
332                              ? VK_INCOMPLETE
333                              : VK_SUCCESS,
334                          "invalid memoryTypeIndex");
335   return success();
336 }
337 
getBestComputeQueue()338 LogicalResult VulkanRuntime::getBestComputeQueue() {
339   uint32_t queueFamilyPropertiesCount = 0;
340   vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
341                                            &queueFamilyPropertiesCount, 0);
342 
343   std::vector<VkQueueFamilyProperties> familyProperties(
344       queueFamilyPropertiesCount);
345   vkGetPhysicalDeviceQueueFamilyProperties(
346       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
347 
348   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
349   // compute operations. Try to find a compute-only queue first if possible.
350   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
351     auto flags = familyProperties[i].queueFlags;
352     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
353       queueFamilyIndex = i;
354       queueFamilyProperties = familyProperties[i];
355       return success();
356     }
357   }
358 
359   // Otherwise use a queue that can also support graphics.
360   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
361     auto flags = familyProperties[i].queueFlags;
362     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
363       queueFamilyIndex = i;
364       queueFamilyProperties = familyProperties[i];
365       return success();
366     }
367   }
368 
369   std::cerr << "cannot find valid queue";
370   return failure();
371 }
372 
createMemoryBuffers()373 LogicalResult VulkanRuntime::createMemoryBuffers() {
374   // For each descriptor set.
375   for (const auto &resourceDataMapPair : resourceData) {
376     std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
377     const auto descriptorSetIndex = resourceDataMapPair.first;
378     const auto &resourceDataMap = resourceDataMapPair.second;
379 
380     // For each descriptor binding.
381     for (const auto &resourceDataBindingPair : resourceDataMap) {
382       // Create device memory buffer.
383       VulkanDeviceMemoryBuffer memoryBuffer;
384       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
385       VkDescriptorType descriptorType = {};
386       VkBufferUsageFlagBits bufferUsage = {};
387 
388       // Check that descriptor set has storage class map.
389       const auto resourceStorageClassMapIt =
390           resourceStorageClassData.find(descriptorSetIndex);
391       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
392         std::cerr
393             << "cannot find storage class for resource in descriptor set: "
394             << descriptorSetIndex;
395         return failure();
396       }
397 
398       // Check that specific descriptor binding has storage class.
399       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
400       const auto resourceStorageClassIt =
401           resourceStorageClassMap.find(resourceDataBindingPair.first);
402       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
403         std::cerr
404             << "cannot find storage class for resource with descriptor index: "
405             << resourceDataBindingPair.first;
406         return failure();
407       }
408 
409       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
410       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
411                                                  descriptorType)) ||
412           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
413                                                   bufferUsage))) {
414         std::cerr << "storage class for resource with descriptor binding: "
415                   << resourceDataBindingPair.first
416                   << " in the descriptor set: " << descriptorSetIndex
417                   << " is not supported ";
418         return failure();
419       }
420 
421       // Set descriptor type for the specific device memory buffer.
422       memoryBuffer.descriptorType = descriptorType;
423       const auto bufferSize = resourceDataBindingPair.second.size;
424       memoryBuffer.bufferSize = bufferSize;
425       // Specify memory allocation info.
426       VkMemoryAllocateInfo memoryAllocateInfo = {};
427       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
428       memoryAllocateInfo.pNext = nullptr;
429       memoryAllocateInfo.allocationSize = bufferSize;
430       memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
431 
432       // Allocate device memory.
433       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
434                                               &memoryBuffer.hostMemory),
435                              "vkAllocateMemory");
436       memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
437       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
438                                               &memoryBuffer.deviceMemory),
439                              "vkAllocateMemory");
440       void *payload;
441       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
442                                          bufferSize, 0,
443                                          reinterpret_cast<void **>(&payload)),
444                              "vkMapMemory");
445 
446       // Copy host memory into the mapped area.
447       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
448       vkUnmapMemory(device, memoryBuffer.hostMemory);
449 
450       VkBufferCreateInfo bufferCreateInfo = {};
451       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
452       bufferCreateInfo.pNext = nullptr;
453       bufferCreateInfo.flags = 0;
454       bufferCreateInfo.size = bufferSize;
455       bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
456                                VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
457       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
458       bufferCreateInfo.queueFamilyIndexCount = 1;
459       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
460       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
461                                             &memoryBuffer.hostBuffer),
462                              "vkCreateBuffer");
463       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
464                                             &memoryBuffer.deviceBuffer),
465                              "vkCreateBuffer");
466 
467       // Bind buffer and device memory.
468       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
469                                                 memoryBuffer.hostMemory, 0),
470                              "vkBindBufferMemory");
471       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
472                                                 memoryBuffer.deviceBuffer,
473                                                 memoryBuffer.deviceMemory, 0),
474                              "vkBindBufferMemory");
475 
476       // Update buffer info.
477       memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
478       memoryBuffer.bufferInfo.offset = 0;
479       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
480       deviceMemoryBuffers.push_back(memoryBuffer);
481     }
482 
483     // Associate device memory buffers with a descriptor set.
484     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
485   }
486   return success();
487 }
488 
copyResource(bool deviceToHost)489 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
490   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
491       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
492       NULL,
493       commandPool,
494       VK_COMMAND_BUFFER_LEVEL_PRIMARY,
495       1,
496   };
497   VkCommandBuffer commandBuffer;
498   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
499                                                   &commandBufferAllocateInfo,
500                                                   &commandBuffer),
501                          "vkAllocateCommandBuffers");
502 
503   VkCommandBufferBeginInfo commandBufferBeginInfo = {
504       VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
505       NULL,
506       0,
507       NULL,
508   };
509   RETURN_ON_VULKAN_ERROR(
510       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
511       "vkBeginCommandBuffer");
512 
513   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
514     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
515     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
516     for (const auto &memBuffer : deviceMemoryBuffers) {
517       VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
518       if (deviceToHost)
519         vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
520                         memBuffer.hostBuffer, 1, &copy);
521       else
522         vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
523                         memBuffer.deviceBuffer, 1, &copy);
524     }
525   }
526 
527   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
528                          "vkEndCommandBuffer");
529   VkSubmitInfo submitInfo = {
530       VK_STRUCTURE_TYPE_SUBMIT_INFO,
531       NULL,
532       0,
533       NULL,
534       NULL,
535       1,
536       &commandBuffer,
537       0,
538       NULL,
539   };
540   submitInfo.pCommandBuffers = &commandBuffer;
541   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
542                          "vkQueueSubmit");
543   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
544 
545   vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
546   return success();
547 }
548 
createShaderModule()549 LogicalResult VulkanRuntime::createShaderModule() {
550   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
551   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
552   shaderModuleCreateInfo.pNext = nullptr;
553   shaderModuleCreateInfo.flags = 0;
554   // Set size in bytes.
555   shaderModuleCreateInfo.codeSize = binarySize;
556   // Set pointer to the binary shader.
557   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
558   RETURN_ON_VULKAN_ERROR(
559       vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
560       "vkCreateShaderModule");
561   return success();
562 }
563 
initDescriptorSetLayoutBindingMap()564 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
565   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
566     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
567     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
568     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
569 
570     // Create a layout binding for each descriptor.
571     for (const auto &memBuffer : deviceMemoryBuffers) {
572       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
573       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
574       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
575       descriptorSetLayoutBinding.descriptorCount = 1;
576       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
577       descriptorSetLayoutBinding.pImmutableSamplers = 0;
578       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
579     }
580     descriptorSetLayoutBindingMap[descriptorSetIndex] =
581         descriptorSetLayoutBindings;
582   }
583 }
584 
createDescriptorSetLayout()585 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
586   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
587     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
588     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
589     // Each descriptor in a descriptor set must be the same type.
590     VkDescriptorType descriptorType =
591         deviceMemoryBuffers.front().descriptorType;
592     const uint32_t descriptorSize = deviceMemoryBuffers.size();
593     const auto descriptorSetLayoutBindingIt =
594         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
595 
596     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
597       std::cerr << "cannot find layout bindings for the set with number: "
598                 << descriptorSetIndex;
599       return failure();
600     }
601 
602     const auto &descriptorSetLayoutBindings =
603         descriptorSetLayoutBindingIt->second;
604     // Create descriptor set layout.
605     VkDescriptorSetLayout descriptorSetLayout = {};
606     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
607 
608     descriptorSetLayoutCreateInfo.sType =
609         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
610     descriptorSetLayoutCreateInfo.pNext = nullptr;
611     descriptorSetLayoutCreateInfo.flags = 0;
612     // Amount of descriptor bindings in a layout set.
613     descriptorSetLayoutCreateInfo.bindingCount =
614         descriptorSetLayoutBindings.size();
615     descriptorSetLayoutCreateInfo.pBindings =
616         descriptorSetLayoutBindings.data();
617     RETURN_ON_VULKAN_ERROR(
618         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
619                                     &descriptorSetLayout),
620         "vkCreateDescriptorSetLayout");
621 
622     descriptorSetLayouts.push_back(descriptorSetLayout);
623     descriptorSetInfoPool.push_back(
624         {descriptorSetIndex, descriptorSize, descriptorType});
625   }
626   return success();
627 }
628 
createPipelineLayout()629 LogicalResult VulkanRuntime::createPipelineLayout() {
630   // Associate descriptor sets with a pipeline layout.
631   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
632   pipelineLayoutCreateInfo.sType =
633       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
634   pipelineLayoutCreateInfo.pNext = nullptr;
635   pipelineLayoutCreateInfo.flags = 0;
636   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
637   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
638   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
639   pipelineLayoutCreateInfo.pPushConstantRanges = 0;
640   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
641                                                 &pipelineLayoutCreateInfo, 0,
642                                                 &pipelineLayout),
643                          "vkCreatePipelineLayout");
644   return success();
645 }
646 
createComputePipeline()647 LogicalResult VulkanRuntime::createComputePipeline() {
648   VkPipelineShaderStageCreateInfo stageInfo = {};
649   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
650   stageInfo.pNext = nullptr;
651   stageInfo.flags = 0;
652   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
653   stageInfo.module = shaderModule;
654   // Set entry point.
655   stageInfo.pName = entryPoint;
656   stageInfo.pSpecializationInfo = 0;
657 
658   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
659   computePipelineCreateInfo.sType =
660       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
661   computePipelineCreateInfo.pNext = nullptr;
662   computePipelineCreateInfo.flags = 0;
663   computePipelineCreateInfo.stage = stageInfo;
664   computePipelineCreateInfo.layout = pipelineLayout;
665   computePipelineCreateInfo.basePipelineHandle = 0;
666   computePipelineCreateInfo.basePipelineIndex = 0;
667   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
668                                                   &computePipelineCreateInfo, 0,
669                                                   &pipeline),
670                          "vkCreateComputePipelines");
671   return success();
672 }
673 
createDescriptorPool()674 LogicalResult VulkanRuntime::createDescriptorPool() {
675   std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
676   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
677     // For each descriptor set populate descriptor pool size.
678     VkDescriptorPoolSize descriptorPoolSize = {};
679     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
680     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
681     descriptorPoolSizes.push_back(descriptorPoolSize);
682   }
683 
684   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
685   descriptorPoolCreateInfo.sType =
686       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
687   descriptorPoolCreateInfo.pNext = nullptr;
688   descriptorPoolCreateInfo.flags = 0;
689   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
690   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
691   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
692   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
693                                                 &descriptorPoolCreateInfo, 0,
694                                                 &descriptorPool),
695                          "vkCreateDescriptorPool");
696   return success();
697 }
698 
allocateDescriptorSets()699 LogicalResult VulkanRuntime::allocateDescriptorSets() {
700   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
701   // Size of descriptor sets and descriptor layout sets is the same.
702   descriptorSets.resize(descriptorSetLayouts.size());
703   descriptorSetAllocateInfo.sType =
704       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
705   descriptorSetAllocateInfo.pNext = nullptr;
706   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
707   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
708   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
709   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
710                                                   &descriptorSetAllocateInfo,
711                                                   descriptorSets.data()),
712                          "vkAllocateDescriptorSets");
713   return success();
714 }
715 
setWriteDescriptors()716 LogicalResult VulkanRuntime::setWriteDescriptors() {
717   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
718     std::cerr << "Each descriptor set must have descriptor set information";
719     return failure();
720   }
721   // For each descriptor set.
722   auto descriptorSetIt = descriptorSets.begin();
723   // Each descriptor set is associated with descriptor set info.
724   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
725     // For each device memory buffer in the descriptor set.
726     const auto &deviceMemoryBuffers =
727         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
728     for (const auto &memoryBuffer : deviceMemoryBuffers) {
729       // Structure describing descriptor sets to write to.
730       VkWriteDescriptorSet wSet = {};
731       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
732       wSet.pNext = nullptr;
733       // Descriptor set.
734       wSet.dstSet = *descriptorSetIt;
735       wSet.dstBinding = memoryBuffer.bindingIndex;
736       wSet.dstArrayElement = 0;
737       wSet.descriptorCount = 1;
738       wSet.descriptorType = memoryBuffer.descriptorType;
739       wSet.pImageInfo = nullptr;
740       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
741       wSet.pTexelBufferView = nullptr;
742       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
743     }
744     // Increment descriptor set iterator.
745     ++descriptorSetIt;
746   }
747   return success();
748 }
749 
createCommandPool()750 LogicalResult VulkanRuntime::createCommandPool() {
751   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
752   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
753   commandPoolCreateInfo.pNext = nullptr;
754   commandPoolCreateInfo.flags = 0;
755   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
756   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
757                                              /*pAllocator=*/nullptr,
758                                              &commandPool),
759                          "vkCreateCommandPool");
760   return success();
761 }
762 
createQueryPool()763 LogicalResult VulkanRuntime::createQueryPool() {
764   // Return directly if timestamp query is not supported.
765   if (queueFamilyProperties.timestampValidBits == 0)
766     return success();
767 
768   // Get timestamp period for this physical device.
769   VkPhysicalDeviceProperties deviceProperties = {};
770   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
771   timestampPeriod = deviceProperties.limits.timestampPeriod;
772 
773   // Create query pool.
774   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
775   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
776   queryPoolCreateInfo.pNext = nullptr;
777   queryPoolCreateInfo.flags = 0;
778   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
779   queryPoolCreateInfo.queryCount = 2;
780   queryPoolCreateInfo.pipelineStatistics = 0;
781   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
782                                            /*pAllocator=*/nullptr, &queryPool),
783                          "vkCreateQueryPool");
784 
785   return success();
786 }
787 
createComputeCommandBuffer()788 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
789   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
790   commandBufferAllocateInfo.sType =
791       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
792   commandBufferAllocateInfo.pNext = nullptr;
793   commandBufferAllocateInfo.commandPool = commandPool;
794   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
795   commandBufferAllocateInfo.commandBufferCount = 1;
796 
797   VkCommandBuffer commandBuffer;
798   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
799                                                   &commandBufferAllocateInfo,
800                                                   &commandBuffer),
801                          "vkAllocateCommandBuffers");
802 
803   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
804   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
805   commandBufferBeginInfo.pNext = nullptr;
806   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
807   commandBufferBeginInfo.pInheritanceInfo = nullptr;
808 
809   // Commands begin.
810   RETURN_ON_VULKAN_ERROR(
811       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
812       "vkBeginCommandBuffer");
813 
814   if (queryPool != VK_NULL_HANDLE)
815     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
816 
817   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
818   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
819                           pipelineLayout, 0, descriptorSets.size(),
820                           descriptorSets.data(), 0, 0);
821   // Get a timestamp before invoking the compute shader.
822   if (queryPool != VK_NULL_HANDLE)
823     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
824                         queryPool, 0);
825   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
826                 numWorkGroups.z);
827   // Get another timestamp after invoking the compute shader.
828   if (queryPool != VK_NULL_HANDLE)
829     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
830                         queryPool, 1);
831 
832   // Commands end.
833   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
834                          "vkEndCommandBuffer");
835 
836   commandBuffers.push_back(commandBuffer);
837   return success();
838 }
839 
submitCommandBuffersToQueue()840 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
841   VkSubmitInfo submitInfo = {};
842   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
843   submitInfo.pNext = nullptr;
844   submitInfo.waitSemaphoreCount = 0;
845   submitInfo.pWaitSemaphores = 0;
846   submitInfo.pWaitDstStageMask = 0;
847   submitInfo.commandBufferCount = commandBuffers.size();
848   submitInfo.pCommandBuffers = commandBuffers.data();
849   submitInfo.signalSemaphoreCount = 0;
850   submitInfo.pSignalSemaphores = nullptr;
851   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
852                          "vkQueueSubmit");
853   return success();
854 }
855 
updateHostMemoryBuffers()856 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
857   // First copy back the data to the staging buffer.
858   copyResource(/*deviceToHost=*/true);
859 
860   // For each descriptor set.
861   for (auto &resourceDataMapPair : resourceData) {
862     auto &resourceDataMap = resourceDataMapPair.second;
863     auto &deviceMemoryBuffers =
864         deviceMemoryBufferMap[resourceDataMapPair.first];
865     // For each device memory buffer in the set.
866     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
867       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
868         void *payload;
869         auto &hostMemoryBuffer =
870             resourceDataMap[deviceMemoryBuffer.bindingIndex];
871         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
872                                            deviceMemoryBuffer.hostMemory, 0,
873                                            hostMemoryBuffer.size, 0,
874                                            reinterpret_cast<void **>(&payload)),
875                                "vkMapMemory");
876         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
877         vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
878       }
879     }
880   }
881   return success();
882 }
883