1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VulkanRuntime.h"
15 
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22 
emitVulkanError(const char * api,VkResult error)23 inline void emitVulkanError(const char *api, VkResult error) {
24   std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26 
27 #define RETURN_ON_VULKAN_ERROR(result, api)                                    \
28   if ((result) != VK_SUCCESS) {                                                \
29     emitVulkanError(api, (result));                                            \
30     return failure();                                                          \
31   }
32 
33 using namespace mlir;
34 
setNumWorkGroups(const NumWorkGroups & numberWorkGroups)35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36   numWorkGroups = numberWorkGroups;
37 }
38 
setResourceStorageClassBindingMap(const ResourceStorageClassBindingMap & stClassData)39 void VulkanRuntime::setResourceStorageClassBindingMap(
40     const ResourceStorageClassBindingMap &stClassData) {
41   resourceStorageClassData = stClassData;
42 }
43 
setResourceData(const DescriptorSetIndex desIndex,const BindingIndex bindIndex,const VulkanHostMemoryBuffer & hostMemBuffer)44 void VulkanRuntime::setResourceData(
45     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46     const VulkanHostMemoryBuffer &hostMemBuffer) {
47   resourceData[desIndex][bindIndex] = hostMemBuffer;
48   resourceStorageClassData[desIndex][bindIndex] =
49       SPIRVStorageClass::StorageBuffer;
50 }
51 
setEntryPoint(const char * entryPointName)52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53   entryPoint = entryPointName;
54 }
55 
setResourceData(const ResourceData & resData)56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57   resourceData = resData;
58 }
59 
setShaderModule(uint8_t * shader,uint32_t size)60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61   binary = shader;
62   binarySize = size;
63 }
64 
mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,VkDescriptorType & descriptorType)65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66     SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67   switch (storageClass) {
68   case SPIRVStorageClass::StorageBuffer:
69     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70     break;
71   case SPIRVStorageClass::Uniform:
72     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73     break;
74   }
75   return success();
76 }
77 
mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,VkBufferUsageFlagBits & bufferUsage)78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79     SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80   switch (storageClass) {
81   case SPIRVStorageClass::StorageBuffer:
82     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83     break;
84   case SPIRVStorageClass::Uniform:
85     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86     break;
87   }
88   return success();
89 }
90 
countDeviceMemorySize()91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92   for (const auto &resourceDataMapPair : resourceData) {
93     const auto &resourceDataMap = resourceDataMapPair.second;
94     for (const auto &resourceDataBindingPair : resourceDataMap) {
95       if (resourceDataBindingPair.second.size) {
96         memorySize += resourceDataBindingPair.second.size;
97       } else {
98         std::cerr << "expected buffer size greater than zero for resource data";
99         return failure();
100       }
101     }
102   }
103   return success();
104 }
105 
initRuntime()106 LogicalResult VulkanRuntime::initRuntime() {
107   if (!resourceData.size()) {
108     std::cerr << "Vulkan runtime needs at least one resource";
109     return failure();
110   }
111   if (!binarySize || !binary) {
112     std::cerr << "binary shader size must be greater than zero";
113     return failure();
114   }
115   if (failed(countDeviceMemorySize())) {
116     return failure();
117   }
118   return success();
119 }
120 
destroy()121 LogicalResult VulkanRuntime::destroy() {
122   // According to Vulkan spec:
123   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124   // used to gate the destruction of the device. Prior to destroying a device,
125   // an application is responsible for destroying/freeing any Vulkan objects
126   // that were created using that device as the first parameter of the
127   // corresponding vkCreate* or vkAllocate* command."
128   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129 
130   // Free and destroy.
131   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132                        commandBuffers.data());
133   vkDestroyQueryPool(device, queryPool, nullptr);
134   vkDestroyCommandPool(device, commandPool, nullptr);
135   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136                        descriptorSets.data());
137   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138   vkDestroyPipeline(device, pipeline, nullptr);
139   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140   for (auto &descriptorSetLayout : descriptorSetLayouts) {
141     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142   }
143   vkDestroyShaderModule(device, shaderModule, nullptr);
144 
145   // For each descriptor set.
146   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148     // For each descriptor binding.
149     for (auto &memoryBuffer : deviceMemoryBuffers) {
150       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151       vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152       vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153       vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154     }
155   }
156 
157   vkDestroyDevice(device, nullptr);
158   vkDestroyInstance(instance, nullptr);
159   return success();
160 }
161 
run()162 LogicalResult VulkanRuntime::run() {
163   // Create logical device, shader module and memory buffers.
164   if (failed(createInstance()) || failed(createDevice()) ||
165       failed(createMemoryBuffers()) || failed(createShaderModule())) {
166     return failure();
167   }
168 
169   // Descriptor bindings divided into sets. Each descriptor binding
170   // must have a layout binding attached into a descriptor set layout.
171   // Each layout set must be binded into a pipeline layout.
172   initDescriptorSetLayoutBindingMap();
173   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174       // Each descriptor set must be allocated from a descriptor pool.
175       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177       // Create command buffer.
178       failed(createCommandPool()) || failed(createQueryPool()) ||
179       failed(createComputeCommandBuffer())) {
180     return failure();
181   }
182 
183   // Get working queue.
184   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185 
186   if (failed(copyResource(/*deviceToHost=*/false)))
187     return failure();
188 
189   auto submitStart = std::chrono::high_resolution_clock::now();
190   // Submit command buffer into the queue.
191   if (failed(submitCommandBuffersToQueue()))
192     return failure();
193   auto submitEnd = std::chrono::high_resolution_clock::now();
194 
195   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196   auto execEnd = std::chrono::high_resolution_clock::now();
197 
198   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199       submitEnd - submitStart);
200   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201       execEnd - submitEnd);
202 
203   if (queryPool != VK_NULL_HANDLE) {
204     uint64_t timestamps[2];
205     RETURN_ON_VULKAN_ERROR(
206         vkGetQueryPoolResults(
207             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208             /*dataSize=*/sizeof(timestamps),
209             /*pData=*/reinterpret_cast<void *>(timestamps),
210             /*stride=*/sizeof(uint64_t),
211             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212         "vkGetQueryPoolResults");
213     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214     std::cout << "Compute shader execution time: " << std::setprecision(3)
215               << microsec << "us\n";
216   }
217 
218   std::cout << "Command buffer submit time: " << submitDuration.count()
219             << "us\nWait idle time: " << execDuration.count() << "us\n";
220 
221   return success();
222 }
223 
createInstance()224 LogicalResult VulkanRuntime::createInstance() {
225   VkApplicationInfo applicationInfo = {};
226   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227   applicationInfo.pNext = nullptr;
228   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229   applicationInfo.applicationVersion = 0;
230   applicationInfo.pEngineName = "mlir";
231   applicationInfo.engineVersion = 0;
232   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233 
234   VkInstanceCreateInfo instanceCreateInfo = {};
235   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236   instanceCreateInfo.pNext = nullptr;
237   instanceCreateInfo.flags = 0;
238   instanceCreateInfo.pApplicationInfo = &applicationInfo;
239   instanceCreateInfo.enabledLayerCount = 0;
240   instanceCreateInfo.ppEnabledLayerNames = 0;
241   instanceCreateInfo.enabledExtensionCount = 0;
242   instanceCreateInfo.ppEnabledExtensionNames = 0;
243 
244   RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
245                          "vkCreateInstance");
246   return success();
247 }
248 
createDevice()249 LogicalResult VulkanRuntime::createDevice() {
250   uint32_t physicalDeviceCount = 0;
251   RETURN_ON_VULKAN_ERROR(
252       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
253       "vkEnumeratePhysicalDevices");
254 
255   std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
256   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
257                                                     &physicalDeviceCount,
258                                                     physicalDevices.data()),
259                          "vkEnumeratePhysicalDevices");
260 
261   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
262                          "physicalDeviceCount");
263 
264   // TODO: find the best device.
265   physicalDevice = physicalDevices.front();
266   if (failed(getBestComputeQueue()))
267     return failure();
268 
269   const float queuePriority = 1.0f;
270   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
271   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
272   deviceQueueCreateInfo.pNext = nullptr;
273   deviceQueueCreateInfo.flags = 0;
274   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
275   deviceQueueCreateInfo.queueCount = 1;
276   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
277 
278   // Structure specifying parameters of a newly created device.
279   VkDeviceCreateInfo deviceCreateInfo = {};
280   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
281   deviceCreateInfo.pNext = nullptr;
282   deviceCreateInfo.flags = 0;
283   deviceCreateInfo.queueCreateInfoCount = 1;
284   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
285   deviceCreateInfo.enabledLayerCount = 0;
286   deviceCreateInfo.ppEnabledLayerNames = nullptr;
287   deviceCreateInfo.enabledExtensionCount = 0;
288   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
289   deviceCreateInfo.pEnabledFeatures = nullptr;
290 
291   RETURN_ON_VULKAN_ERROR(
292       vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
293       "vkCreateDevice");
294 
295   VkPhysicalDeviceMemoryProperties properties = {};
296   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
297 
298   // Try to find memory type with following properties:
299   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
300   // with this type can be mapped for host access using vkMapMemory;
301   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
302   // management commands vkFlushMappedMemoryRanges and
303   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
304   // device or make device writes visible to the host, respectively.
305   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
306     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
307          properties.memoryTypes[i].propertyFlags) &&
308         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
309          properties.memoryTypes[i].propertyFlags) &&
310         (memorySize <=
311          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
312       hostMemoryTypeIndex = i;
313       break;
314     }
315   }
316 
317   // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
318   // used on the device. This will allow better performance access for GPU with
319   // on device memory.
320   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
321     if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
322          properties.memoryTypes[i].propertyFlags) &&
323         (memorySize <=
324          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
325       deviceMemoryTypeIndex = i;
326       break;
327     }
328   }
329 
330   RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
331                           deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
332                              ? VK_INCOMPLETE
333                              : VK_SUCCESS,
334                          "invalid memoryTypeIndex");
335   return success();
336 }
337 
getBestComputeQueue()338 LogicalResult VulkanRuntime::getBestComputeQueue() {
339   uint32_t queueFamilyPropertiesCount = 0;
340   vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
341                                            &queueFamilyPropertiesCount, 0);
342 
343   std::vector<VkQueueFamilyProperties> familyProperties(
344       queueFamilyPropertiesCount);
345   vkGetPhysicalDeviceQueueFamilyProperties(
346       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
347 
348   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
349   // compute operations. Try to find a compute-only queue first if possible.
350   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
351     auto flags = familyProperties[i].queueFlags;
352     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
353       queueFamilyIndex = i;
354       queueFamilyProperties = familyProperties[i];
355       return success();
356     }
357   }
358 
359   // Otherwise use a queue that can also support graphics.
360   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
361     auto flags = familyProperties[i].queueFlags;
362     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
363       queueFamilyIndex = i;
364       queueFamilyProperties = familyProperties[i];
365       return success();
366     }
367   }
368 
369   std::cerr << "cannot find valid queue";
370   return failure();
371 }
372 
createMemoryBuffers()373 LogicalResult VulkanRuntime::createMemoryBuffers() {
374   // For each descriptor set.
375   for (const auto &resourceDataMapPair : resourceData) {
376     std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
377     const auto descriptorSetIndex = resourceDataMapPair.first;
378     const auto &resourceDataMap = resourceDataMapPair.second;
379 
380     // For each descriptor binding.
381     for (const auto &resourceDataBindingPair : resourceDataMap) {
382       // Create device memory buffer.
383       VulkanDeviceMemoryBuffer memoryBuffer;
384       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
385       VkDescriptorType descriptorType = {};
386       VkBufferUsageFlagBits bufferUsage = {};
387 
388       // Check that descriptor set has storage class map.
389       const auto resourceStorageClassMapIt =
390           resourceStorageClassData.find(descriptorSetIndex);
391       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
392         std::cerr
393             << "cannot find storage class for resource in descriptor set: "
394             << descriptorSetIndex;
395         return failure();
396       }
397 
398       // Check that specific descriptor binding has storage class.
399       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
400       const auto resourceStorageClassIt =
401           resourceStorageClassMap.find(resourceDataBindingPair.first);
402       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
403         std::cerr
404             << "cannot find storage class for resource with descriptor index: "
405             << resourceDataBindingPair.first;
406         return failure();
407       }
408 
409       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
410       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
411                                                  descriptorType)) ||
412           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
413                                                   bufferUsage))) {
414         std::cerr << "storage class for resource with descriptor binding: "
415                   << resourceDataBindingPair.first
416                   << " in the descriptor set: " << descriptorSetIndex
417                   << " is not supported ";
418         return failure();
419       }
420 
421       // Set descriptor type for the specific device memory buffer.
422       memoryBuffer.descriptorType = descriptorType;
423       const auto bufferSize = resourceDataBindingPair.second.size;
424       memoryBuffer.bufferSize = bufferSize;
425       // Specify memory allocation info.
426       VkMemoryAllocateInfo memoryAllocateInfo = {};
427       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
428       memoryAllocateInfo.pNext = nullptr;
429       memoryAllocateInfo.allocationSize = bufferSize;
430       memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
431 
432       // Allocate device memory.
433       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
434                                               &memoryBuffer.hostMemory),
435                              "vkAllocateMemory");
436       memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
437       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
438                                               &memoryBuffer.deviceMemory),
439                              "vkAllocateMemory");
440       void *payload;
441       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
442                                          bufferSize, 0,
443                                          reinterpret_cast<void **>(&payload)),
444                              "vkMapMemory");
445 
446       // Copy host memory into the mapped area.
447       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
448       vkUnmapMemory(device, memoryBuffer.hostMemory);
449 
450       VkBufferCreateInfo bufferCreateInfo = {};
451       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
452       bufferCreateInfo.pNext = nullptr;
453       bufferCreateInfo.flags = 0;
454       bufferCreateInfo.size = bufferSize;
455       bufferCreateInfo.usage = bufferUsage;
456       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
457       bufferCreateInfo.queueFamilyIndexCount = 1;
458       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
459       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
460                                             &memoryBuffer.hostBuffer),
461                              "vkCreateBuffer");
462       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
463                                             &memoryBuffer.deviceBuffer),
464                              "vkCreateBuffer");
465 
466       // Bind buffer and device memory.
467       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
468                                                 memoryBuffer.hostMemory, 0),
469                              "vkBindBufferMemory");
470       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
471                                                 memoryBuffer.deviceBuffer,
472                                                 memoryBuffer.deviceMemory, 0),
473                              "vkBindBufferMemory");
474 
475       // Update buffer info.
476       memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
477       memoryBuffer.bufferInfo.offset = 0;
478       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
479       deviceMemoryBuffers.push_back(memoryBuffer);
480     }
481 
482     // Associate device memory buffers with a descriptor set.
483     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
484   }
485   return success();
486 }
487 
copyResource(bool deviceToHost)488 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
489   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
490       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
491       NULL,
492       commandPool,
493       VK_COMMAND_BUFFER_LEVEL_PRIMARY,
494       1,
495   };
496   VkCommandBuffer commandBuffer;
497   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
498                                                   &commandBufferAllocateInfo,
499                                                   &commandBuffer),
500                          "vkAllocateCommandBuffers");
501 
502   VkCommandBufferBeginInfo commandBufferBeginInfo = {
503       VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
504       NULL,
505       0,
506       NULL,
507   };
508   RETURN_ON_VULKAN_ERROR(
509       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
510       "vkBeginCommandBuffer");
511 
512   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
513     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
514     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
515     for (const auto &memBuffer : deviceMemoryBuffers) {
516       VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
517       if (deviceToHost)
518         vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
519                         memBuffer.hostBuffer, 1, &copy);
520       else
521         vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
522                         memBuffer.deviceBuffer, 1, &copy);
523     }
524   }
525 
526   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
527                          "vkEndCommandBuffer");
528   VkSubmitInfo submitInfo = {
529       VK_STRUCTURE_TYPE_SUBMIT_INFO,
530       NULL,
531       0,
532       NULL,
533       NULL,
534       1,
535       &commandBuffer,
536       0,
537       NULL,
538   };
539   submitInfo.pCommandBuffers = &commandBuffer;
540   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
541                          "vkQueueSubmit");
542   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
543 
544   vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
545   return success();
546 }
547 
createShaderModule()548 LogicalResult VulkanRuntime::createShaderModule() {
549   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
550   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
551   shaderModuleCreateInfo.pNext = nullptr;
552   shaderModuleCreateInfo.flags = 0;
553   // Set size in bytes.
554   shaderModuleCreateInfo.codeSize = binarySize;
555   // Set pointer to the binary shader.
556   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
557   RETURN_ON_VULKAN_ERROR(
558       vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
559       "vkCreateShaderModule");
560   return success();
561 }
562 
initDescriptorSetLayoutBindingMap()563 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
564   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
565     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
566     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
567     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
568 
569     // Create a layout binding for each descriptor.
570     for (const auto &memBuffer : deviceMemoryBuffers) {
571       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
572       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
573       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
574       descriptorSetLayoutBinding.descriptorCount = 1;
575       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
576       descriptorSetLayoutBinding.pImmutableSamplers = 0;
577       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
578     }
579     descriptorSetLayoutBindingMap[descriptorSetIndex] =
580         descriptorSetLayoutBindings;
581   }
582 }
583 
createDescriptorSetLayout()584 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
585   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
586     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
587     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
588     // Each descriptor in a descriptor set must be the same type.
589     VkDescriptorType descriptorType =
590         deviceMemoryBuffers.front().descriptorType;
591     const uint32_t descriptorSize = deviceMemoryBuffers.size();
592     const auto descriptorSetLayoutBindingIt =
593         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
594 
595     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
596       std::cerr << "cannot find layout bindings for the set with number: "
597                 << descriptorSetIndex;
598       return failure();
599     }
600 
601     const auto &descriptorSetLayoutBindings =
602         descriptorSetLayoutBindingIt->second;
603     // Create descriptor set layout.
604     VkDescriptorSetLayout descriptorSetLayout = {};
605     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
606 
607     descriptorSetLayoutCreateInfo.sType =
608         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
609     descriptorSetLayoutCreateInfo.pNext = nullptr;
610     descriptorSetLayoutCreateInfo.flags = 0;
611     // Amount of descriptor bindings in a layout set.
612     descriptorSetLayoutCreateInfo.bindingCount =
613         descriptorSetLayoutBindings.size();
614     descriptorSetLayoutCreateInfo.pBindings =
615         descriptorSetLayoutBindings.data();
616     RETURN_ON_VULKAN_ERROR(
617         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
618                                     &descriptorSetLayout),
619         "vkCreateDescriptorSetLayout");
620 
621     descriptorSetLayouts.push_back(descriptorSetLayout);
622     descriptorSetInfoPool.push_back(
623         {descriptorSetIndex, descriptorSize, descriptorType});
624   }
625   return success();
626 }
627 
createPipelineLayout()628 LogicalResult VulkanRuntime::createPipelineLayout() {
629   // Associate descriptor sets with a pipeline layout.
630   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
631   pipelineLayoutCreateInfo.sType =
632       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
633   pipelineLayoutCreateInfo.pNext = nullptr;
634   pipelineLayoutCreateInfo.flags = 0;
635   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
636   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
637   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
638   pipelineLayoutCreateInfo.pPushConstantRanges = 0;
639   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
640                                                 &pipelineLayoutCreateInfo, 0,
641                                                 &pipelineLayout),
642                          "vkCreatePipelineLayout");
643   return success();
644 }
645 
createComputePipeline()646 LogicalResult VulkanRuntime::createComputePipeline() {
647   VkPipelineShaderStageCreateInfo stageInfo = {};
648   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
649   stageInfo.pNext = nullptr;
650   stageInfo.flags = 0;
651   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
652   stageInfo.module = shaderModule;
653   // Set entry point.
654   stageInfo.pName = entryPoint;
655   stageInfo.pSpecializationInfo = 0;
656 
657   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
658   computePipelineCreateInfo.sType =
659       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
660   computePipelineCreateInfo.pNext = nullptr;
661   computePipelineCreateInfo.flags = 0;
662   computePipelineCreateInfo.stage = stageInfo;
663   computePipelineCreateInfo.layout = pipelineLayout;
664   computePipelineCreateInfo.basePipelineHandle = 0;
665   computePipelineCreateInfo.basePipelineIndex = 0;
666   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
667                                                   &computePipelineCreateInfo, 0,
668                                                   &pipeline),
669                          "vkCreateComputePipelines");
670   return success();
671 }
672 
createDescriptorPool()673 LogicalResult VulkanRuntime::createDescriptorPool() {
674   std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
675   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
676     // For each descriptor set populate descriptor pool size.
677     VkDescriptorPoolSize descriptorPoolSize = {};
678     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
679     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
680     descriptorPoolSizes.push_back(descriptorPoolSize);
681   }
682 
683   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
684   descriptorPoolCreateInfo.sType =
685       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
686   descriptorPoolCreateInfo.pNext = nullptr;
687   descriptorPoolCreateInfo.flags = 0;
688   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
689   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
690   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
691   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
692                                                 &descriptorPoolCreateInfo, 0,
693                                                 &descriptorPool),
694                          "vkCreateDescriptorPool");
695   return success();
696 }
697 
allocateDescriptorSets()698 LogicalResult VulkanRuntime::allocateDescriptorSets() {
699   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
700   // Size of descriptor sets and descriptor layout sets is the same.
701   descriptorSets.resize(descriptorSetLayouts.size());
702   descriptorSetAllocateInfo.sType =
703       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
704   descriptorSetAllocateInfo.pNext = nullptr;
705   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
706   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
707   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
708   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
709                                                   &descriptorSetAllocateInfo,
710                                                   descriptorSets.data()),
711                          "vkAllocateDescriptorSets");
712   return success();
713 }
714 
setWriteDescriptors()715 LogicalResult VulkanRuntime::setWriteDescriptors() {
716   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
717     std::cerr << "Each descriptor set must have descriptor set information";
718     return failure();
719   }
720   // For each descriptor set.
721   auto descriptorSetIt = descriptorSets.begin();
722   // Each descriptor set is associated with descriptor set info.
723   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
724     // For each device memory buffer in the descriptor set.
725     const auto &deviceMemoryBuffers =
726         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
727     for (const auto &memoryBuffer : deviceMemoryBuffers) {
728       // Structure describing descriptor sets to write to.
729       VkWriteDescriptorSet wSet = {};
730       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
731       wSet.pNext = nullptr;
732       // Descriptor set.
733       wSet.dstSet = *descriptorSetIt;
734       wSet.dstBinding = memoryBuffer.bindingIndex;
735       wSet.dstArrayElement = 0;
736       wSet.descriptorCount = 1;
737       wSet.descriptorType = memoryBuffer.descriptorType;
738       wSet.pImageInfo = nullptr;
739       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
740       wSet.pTexelBufferView = nullptr;
741       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
742     }
743     // Increment descriptor set iterator.
744     ++descriptorSetIt;
745   }
746   return success();
747 }
748 
createCommandPool()749 LogicalResult VulkanRuntime::createCommandPool() {
750   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
751   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
752   commandPoolCreateInfo.pNext = nullptr;
753   commandPoolCreateInfo.flags = 0;
754   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
755   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
756                                              /*pAllocator=*/nullptr,
757                                              &commandPool),
758                          "vkCreateCommandPool");
759   return success();
760 }
761 
createQueryPool()762 LogicalResult VulkanRuntime::createQueryPool() {
763   // Return directly if timestamp query is not supported.
764   if (queueFamilyProperties.timestampValidBits == 0)
765     return success();
766 
767   // Get timestamp period for this physical device.
768   VkPhysicalDeviceProperties deviceProperties = {};
769   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
770   timestampPeriod = deviceProperties.limits.timestampPeriod;
771 
772   // Create query pool.
773   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
774   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
775   queryPoolCreateInfo.pNext = nullptr;
776   queryPoolCreateInfo.flags = 0;
777   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
778   queryPoolCreateInfo.queryCount = 2;
779   queryPoolCreateInfo.pipelineStatistics = 0;
780   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
781                                            /*pAllocator=*/nullptr, &queryPool),
782                          "vkCreateQueryPool");
783 
784   return success();
785 }
786 
createComputeCommandBuffer()787 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
788   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
789   commandBufferAllocateInfo.sType =
790       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
791   commandBufferAllocateInfo.pNext = nullptr;
792   commandBufferAllocateInfo.commandPool = commandPool;
793   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
794   commandBufferAllocateInfo.commandBufferCount = 1;
795 
796   VkCommandBuffer commandBuffer;
797   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
798                                                   &commandBufferAllocateInfo,
799                                                   &commandBuffer),
800                          "vkAllocateCommandBuffers");
801 
802   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
803   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
804   commandBufferBeginInfo.pNext = nullptr;
805   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
806   commandBufferBeginInfo.pInheritanceInfo = nullptr;
807 
808   // Commands begin.
809   RETURN_ON_VULKAN_ERROR(
810       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
811       "vkBeginCommandBuffer");
812 
813   if (queryPool != VK_NULL_HANDLE)
814     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
815 
816   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
817   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
818                           pipelineLayout, 0, descriptorSets.size(),
819                           descriptorSets.data(), 0, 0);
820   // Get a timestamp before invoking the compute shader.
821   if (queryPool != VK_NULL_HANDLE)
822     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
823                         queryPool, 0);
824   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
825                 numWorkGroups.z);
826   // Get another timestamp after invoking the compute shader.
827   if (queryPool != VK_NULL_HANDLE)
828     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
829                         queryPool, 1);
830 
831   // Commands end.
832   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
833                          "vkEndCommandBuffer");
834 
835   commandBuffers.push_back(commandBuffer);
836   return success();
837 }
838 
submitCommandBuffersToQueue()839 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
840   VkSubmitInfo submitInfo = {};
841   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
842   submitInfo.pNext = nullptr;
843   submitInfo.waitSemaphoreCount = 0;
844   submitInfo.pWaitSemaphores = 0;
845   submitInfo.pWaitDstStageMask = 0;
846   submitInfo.commandBufferCount = commandBuffers.size();
847   submitInfo.pCommandBuffers = commandBuffers.data();
848   submitInfo.signalSemaphoreCount = 0;
849   submitInfo.pSignalSemaphores = nullptr;
850   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
851                          "vkQueueSubmit");
852   return success();
853 }
854 
updateHostMemoryBuffers()855 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
856   // First copy back the data to the staging buffer.
857   copyResource(/*deviceToHost=*/true);
858 
859   // For each descriptor set.
860   for (auto &resourceDataMapPair : resourceData) {
861     auto &resourceDataMap = resourceDataMapPair.second;
862     auto &deviceMemoryBuffers =
863         deviceMemoryBufferMap[resourceDataMapPair.first];
864     // For each device memory buffer in the set.
865     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
866       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
867         void *payload;
868         auto &hostMemoryBuffer =
869             resourceDataMap[deviceMemoryBuffer.bindingIndex];
870         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
871                                            deviceMemoryBuffer.hostMemory, 0,
872                                            hostMemoryBuffer.size, 0,
873                                            reinterpret_cast<void **>(&payload)),
874                                "vkMapMemory");
875         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
876         vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
877       }
878     }
879   }
880   return success();
881 }
882