1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "VulkanRuntime.h"
15
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22
emitVulkanError(const char * api,VkResult error)23 inline void emitVulkanError(const char *api, VkResult error) {
24 std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26
27 #define RETURN_ON_VULKAN_ERROR(result, api) \
28 if ((result) != VK_SUCCESS) { \
29 emitVulkanError(api, (result)); \
30 return failure(); \
31 }
32
33 using namespace mlir;
34
setNumWorkGroups(const NumWorkGroups & numberWorkGroups)35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36 numWorkGroups = numberWorkGroups;
37 }
38
setResourceStorageClassBindingMap(const ResourceStorageClassBindingMap & stClassData)39 void VulkanRuntime::setResourceStorageClassBindingMap(
40 const ResourceStorageClassBindingMap &stClassData) {
41 resourceStorageClassData = stClassData;
42 }
43
setResourceData(const DescriptorSetIndex desIndex,const BindingIndex bindIndex,const VulkanHostMemoryBuffer & hostMemBuffer)44 void VulkanRuntime::setResourceData(
45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46 const VulkanHostMemoryBuffer &hostMemBuffer) {
47 resourceData[desIndex][bindIndex] = hostMemBuffer;
48 resourceStorageClassData[desIndex][bindIndex] =
49 SPIRVStorageClass::StorageBuffer;
50 }
51
setEntryPoint(const char * entryPointName)52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53 entryPoint = entryPointName;
54 }
55
setResourceData(const ResourceData & resData)56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57 resourceData = resData;
58 }
59
setShaderModule(uint8_t * shader,uint32_t size)60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61 binary = shader;
62 binarySize = size;
63 }
64
mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,VkDescriptorType & descriptorType)65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67 switch (storageClass) {
68 case SPIRVStorageClass::StorageBuffer:
69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70 break;
71 case SPIRVStorageClass::Uniform:
72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73 break;
74 }
75 return success();
76 }
77
mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,VkBufferUsageFlagBits & bufferUsage)78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80 switch (storageClass) {
81 case SPIRVStorageClass::StorageBuffer:
82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83 break;
84 case SPIRVStorageClass::Uniform:
85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86 break;
87 }
88 return success();
89 }
90
countDeviceMemorySize()91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92 for (const auto &resourceDataMapPair : resourceData) {
93 const auto &resourceDataMap = resourceDataMapPair.second;
94 for (const auto &resourceDataBindingPair : resourceDataMap) {
95 if (resourceDataBindingPair.second.size) {
96 memorySize += resourceDataBindingPair.second.size;
97 } else {
98 std::cerr << "expected buffer size greater than zero for resource data";
99 return failure();
100 }
101 }
102 }
103 return success();
104 }
105
initRuntime()106 LogicalResult VulkanRuntime::initRuntime() {
107 if (!resourceData.size()) {
108 std::cerr << "Vulkan runtime needs at least one resource";
109 return failure();
110 }
111 if (!binarySize || !binary) {
112 std::cerr << "binary shader size must be greater than zero";
113 return failure();
114 }
115 if (failed(countDeviceMemorySize())) {
116 return failure();
117 }
118 return success();
119 }
120
destroy()121 LogicalResult VulkanRuntime::destroy() {
122 // According to Vulkan spec:
123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124 // used to gate the destruction of the device. Prior to destroying a device,
125 // an application is responsible for destroying/freeing any Vulkan objects
126 // that were created using that device as the first parameter of the
127 // corresponding vkCreate* or vkAllocate* command."
128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129
130 // Free and destroy.
131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132 commandBuffers.data());
133 vkDestroyQueryPool(device, queryPool, nullptr);
134 vkDestroyCommandPool(device, commandPool, nullptr);
135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136 descriptorSets.data());
137 vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138 vkDestroyPipeline(device, pipeline, nullptr);
139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140 for (auto &descriptorSetLayout : descriptorSetLayouts) {
141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142 }
143 vkDestroyShaderModule(device, shaderModule, nullptr);
144
145 // For each descriptor set.
146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148 // For each descriptor binding.
149 for (auto &memoryBuffer : deviceMemoryBuffers) {
150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151 vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152 vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153 vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154 }
155 }
156
157 vkDestroyDevice(device, nullptr);
158 vkDestroyInstance(instance, nullptr);
159 return success();
160 }
161
run()162 LogicalResult VulkanRuntime::run() {
163 // Create logical device, shader module and memory buffers.
164 if (failed(createInstance()) || failed(createDevice()) ||
165 failed(createMemoryBuffers()) || failed(createShaderModule())) {
166 return failure();
167 }
168
169 // Descriptor bindings divided into sets. Each descriptor binding
170 // must have a layout binding attached into a descriptor set layout.
171 // Each layout set must be binded into a pipeline layout.
172 initDescriptorSetLayoutBindingMap();
173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174 // Each descriptor set must be allocated from a descriptor pool.
175 failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177 // Create command buffer.
178 failed(createCommandPool()) || failed(createQueryPool()) ||
179 failed(createComputeCommandBuffer())) {
180 return failure();
181 }
182
183 // Get working queue.
184 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185
186 if (failed(copyResource(/*deviceToHost=*/false)))
187 return failure();
188
189 auto submitStart = std::chrono::high_resolution_clock::now();
190 // Submit command buffer into the queue.
191 if (failed(submitCommandBuffersToQueue()))
192 return failure();
193 auto submitEnd = std::chrono::high_resolution_clock::now();
194
195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196 auto execEnd = std::chrono::high_resolution_clock::now();
197
198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199 submitEnd - submitStart);
200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201 execEnd - submitEnd);
202
203 if (queryPool != VK_NULL_HANDLE) {
204 uint64_t timestamps[2];
205 RETURN_ON_VULKAN_ERROR(
206 vkGetQueryPoolResults(
207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208 /*dataSize=*/sizeof(timestamps),
209 /*pData=*/reinterpret_cast<void *>(timestamps),
210 /*stride=*/sizeof(uint64_t),
211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212 "vkGetQueryPoolResults");
213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214 std::cout << "Compute shader execution time: " << std::setprecision(3)
215 << microsec << "us\n";
216 }
217
218 std::cout << "Command buffer submit time: " << submitDuration.count()
219 << "us\nWait idle time: " << execDuration.count() << "us\n";
220
221 return success();
222 }
223
createInstance()224 LogicalResult VulkanRuntime::createInstance() {
225 VkApplicationInfo applicationInfo = {};
226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227 applicationInfo.pNext = nullptr;
228 applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229 applicationInfo.applicationVersion = 0;
230 applicationInfo.pEngineName = "mlir";
231 applicationInfo.engineVersion = 0;
232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233
234 VkInstanceCreateInfo instanceCreateInfo = {};
235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236 instanceCreateInfo.pNext = nullptr;
237 instanceCreateInfo.flags = 0;
238 instanceCreateInfo.pApplicationInfo = &applicationInfo;
239 instanceCreateInfo.enabledLayerCount = 0;
240 instanceCreateInfo.ppEnabledLayerNames = 0;
241 instanceCreateInfo.enabledExtensionCount = 0;
242 instanceCreateInfo.ppEnabledExtensionNames = 0;
243
244 RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
245 "vkCreateInstance");
246 return success();
247 }
248
createDevice()249 LogicalResult VulkanRuntime::createDevice() {
250 uint32_t physicalDeviceCount = 0;
251 RETURN_ON_VULKAN_ERROR(
252 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
253 "vkEnumeratePhysicalDevices");
254
255 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
256 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
257 &physicalDeviceCount,
258 physicalDevices.data()),
259 "vkEnumeratePhysicalDevices");
260
261 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
262 "physicalDeviceCount");
263
264 // TODO: find the best device.
265 physicalDevice = physicalDevices.front();
266 if (failed(getBestComputeQueue()))
267 return failure();
268
269 const float queuePriority = 1.0f;
270 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
271 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
272 deviceQueueCreateInfo.pNext = nullptr;
273 deviceQueueCreateInfo.flags = 0;
274 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
275 deviceQueueCreateInfo.queueCount = 1;
276 deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
277
278 // Structure specifying parameters of a newly created device.
279 VkDeviceCreateInfo deviceCreateInfo = {};
280 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
281 deviceCreateInfo.pNext = nullptr;
282 deviceCreateInfo.flags = 0;
283 deviceCreateInfo.queueCreateInfoCount = 1;
284 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
285 deviceCreateInfo.enabledLayerCount = 0;
286 deviceCreateInfo.ppEnabledLayerNames = nullptr;
287 deviceCreateInfo.enabledExtensionCount = 0;
288 deviceCreateInfo.ppEnabledExtensionNames = nullptr;
289 deviceCreateInfo.pEnabledFeatures = nullptr;
290
291 RETURN_ON_VULKAN_ERROR(
292 vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
293 "vkCreateDevice");
294
295 VkPhysicalDeviceMemoryProperties properties = {};
296 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
297
298 // Try to find memory type with following properties:
299 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
300 // with this type can be mapped for host access using vkMapMemory;
301 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
302 // management commands vkFlushMappedMemoryRanges and
303 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
304 // device or make device writes visible to the host, respectively.
305 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
306 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
307 properties.memoryTypes[i].propertyFlags) &&
308 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
309 properties.memoryTypes[i].propertyFlags) &&
310 (memorySize <=
311 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
312 hostMemoryTypeIndex = i;
313 break;
314 }
315 }
316
317 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
318 // used on the device. This will allow better performance access for GPU with
319 // on device memory.
320 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
321 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
322 properties.memoryTypes[i].propertyFlags) &&
323 (memorySize <=
324 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
325 deviceMemoryTypeIndex = i;
326 break;
327 }
328 }
329
330 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
331 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
332 ? VK_INCOMPLETE
333 : VK_SUCCESS,
334 "invalid memoryTypeIndex");
335 return success();
336 }
337
getBestComputeQueue()338 LogicalResult VulkanRuntime::getBestComputeQueue() {
339 uint32_t queueFamilyPropertiesCount = 0;
340 vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
341 &queueFamilyPropertiesCount, 0);
342
343 std::vector<VkQueueFamilyProperties> familyProperties(
344 queueFamilyPropertiesCount);
345 vkGetPhysicalDeviceQueueFamilyProperties(
346 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
347
348 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
349 // compute operations. Try to find a compute-only queue first if possible.
350 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
351 auto flags = familyProperties[i].queueFlags;
352 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
353 queueFamilyIndex = i;
354 queueFamilyProperties = familyProperties[i];
355 return success();
356 }
357 }
358
359 // Otherwise use a queue that can also support graphics.
360 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
361 auto flags = familyProperties[i].queueFlags;
362 if ((flags & VK_QUEUE_COMPUTE_BIT)) {
363 queueFamilyIndex = i;
364 queueFamilyProperties = familyProperties[i];
365 return success();
366 }
367 }
368
369 std::cerr << "cannot find valid queue";
370 return failure();
371 }
372
createMemoryBuffers()373 LogicalResult VulkanRuntime::createMemoryBuffers() {
374 // For each descriptor set.
375 for (const auto &resourceDataMapPair : resourceData) {
376 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
377 const auto descriptorSetIndex = resourceDataMapPair.first;
378 const auto &resourceDataMap = resourceDataMapPair.second;
379
380 // For each descriptor binding.
381 for (const auto &resourceDataBindingPair : resourceDataMap) {
382 // Create device memory buffer.
383 VulkanDeviceMemoryBuffer memoryBuffer;
384 memoryBuffer.bindingIndex = resourceDataBindingPair.first;
385 VkDescriptorType descriptorType = {};
386 VkBufferUsageFlagBits bufferUsage = {};
387
388 // Check that descriptor set has storage class map.
389 const auto resourceStorageClassMapIt =
390 resourceStorageClassData.find(descriptorSetIndex);
391 if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
392 std::cerr
393 << "cannot find storage class for resource in descriptor set: "
394 << descriptorSetIndex;
395 return failure();
396 }
397
398 // Check that specific descriptor binding has storage class.
399 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
400 const auto resourceStorageClassIt =
401 resourceStorageClassMap.find(resourceDataBindingPair.first);
402 if (resourceStorageClassIt == resourceStorageClassMap.end()) {
403 std::cerr
404 << "cannot find storage class for resource with descriptor index: "
405 << resourceDataBindingPair.first;
406 return failure();
407 }
408
409 const auto resourceStorageClassBinding = resourceStorageClassIt->second;
410 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
411 descriptorType)) ||
412 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
413 bufferUsage))) {
414 std::cerr << "storage class for resource with descriptor binding: "
415 << resourceDataBindingPair.first
416 << " in the descriptor set: " << descriptorSetIndex
417 << " is not supported ";
418 return failure();
419 }
420
421 // Set descriptor type for the specific device memory buffer.
422 memoryBuffer.descriptorType = descriptorType;
423 const auto bufferSize = resourceDataBindingPair.second.size;
424 memoryBuffer.bufferSize = bufferSize;
425 // Specify memory allocation info.
426 VkMemoryAllocateInfo memoryAllocateInfo = {};
427 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
428 memoryAllocateInfo.pNext = nullptr;
429 memoryAllocateInfo.allocationSize = bufferSize;
430 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
431
432 // Allocate device memory.
433 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
434 &memoryBuffer.hostMemory),
435 "vkAllocateMemory");
436 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
437 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
438 &memoryBuffer.deviceMemory),
439 "vkAllocateMemory");
440 void *payload;
441 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
442 bufferSize, 0,
443 reinterpret_cast<void **>(&payload)),
444 "vkMapMemory");
445
446 // Copy host memory into the mapped area.
447 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
448 vkUnmapMemory(device, memoryBuffer.hostMemory);
449
450 VkBufferCreateInfo bufferCreateInfo = {};
451 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
452 bufferCreateInfo.pNext = nullptr;
453 bufferCreateInfo.flags = 0;
454 bufferCreateInfo.size = bufferSize;
455 bufferCreateInfo.usage = bufferUsage;
456 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
457 bufferCreateInfo.queueFamilyIndexCount = 1;
458 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
459 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
460 &memoryBuffer.hostBuffer),
461 "vkCreateBuffer");
462 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
463 &memoryBuffer.deviceBuffer),
464 "vkCreateBuffer");
465
466 // Bind buffer and device memory.
467 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
468 memoryBuffer.hostMemory, 0),
469 "vkBindBufferMemory");
470 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
471 memoryBuffer.deviceBuffer,
472 memoryBuffer.deviceMemory, 0),
473 "vkBindBufferMemory");
474
475 // Update buffer info.
476 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
477 memoryBuffer.bufferInfo.offset = 0;
478 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
479 deviceMemoryBuffers.push_back(memoryBuffer);
480 }
481
482 // Associate device memory buffers with a descriptor set.
483 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
484 }
485 return success();
486 }
487
copyResource(bool deviceToHost)488 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
489 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
490 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
491 NULL,
492 commandPool,
493 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
494 1,
495 };
496 VkCommandBuffer commandBuffer;
497 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
498 &commandBufferAllocateInfo,
499 &commandBuffer),
500 "vkAllocateCommandBuffers");
501
502 VkCommandBufferBeginInfo commandBufferBeginInfo = {
503 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
504 NULL,
505 0,
506 NULL,
507 };
508 RETURN_ON_VULKAN_ERROR(
509 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
510 "vkBeginCommandBuffer");
511
512 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
513 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
514 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
515 for (const auto &memBuffer : deviceMemoryBuffers) {
516 VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
517 if (deviceToHost)
518 vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
519 memBuffer.hostBuffer, 1, ©);
520 else
521 vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
522 memBuffer.deviceBuffer, 1, ©);
523 }
524 }
525
526 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
527 "vkEndCommandBuffer");
528 VkSubmitInfo submitInfo = {
529 VK_STRUCTURE_TYPE_SUBMIT_INFO,
530 NULL,
531 0,
532 NULL,
533 NULL,
534 1,
535 &commandBuffer,
536 0,
537 NULL,
538 };
539 submitInfo.pCommandBuffers = &commandBuffer;
540 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
541 "vkQueueSubmit");
542 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
543
544 vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
545 return success();
546 }
547
createShaderModule()548 LogicalResult VulkanRuntime::createShaderModule() {
549 VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
550 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
551 shaderModuleCreateInfo.pNext = nullptr;
552 shaderModuleCreateInfo.flags = 0;
553 // Set size in bytes.
554 shaderModuleCreateInfo.codeSize = binarySize;
555 // Set pointer to the binary shader.
556 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
557 RETURN_ON_VULKAN_ERROR(
558 vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
559 "vkCreateShaderModule");
560 return success();
561 }
562
initDescriptorSetLayoutBindingMap()563 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
564 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
565 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
566 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
567 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
568
569 // Create a layout binding for each descriptor.
570 for (const auto &memBuffer : deviceMemoryBuffers) {
571 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
572 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
573 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
574 descriptorSetLayoutBinding.descriptorCount = 1;
575 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
576 descriptorSetLayoutBinding.pImmutableSamplers = 0;
577 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
578 }
579 descriptorSetLayoutBindingMap[descriptorSetIndex] =
580 descriptorSetLayoutBindings;
581 }
582 }
583
createDescriptorSetLayout()584 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
585 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
586 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
587 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
588 // Each descriptor in a descriptor set must be the same type.
589 VkDescriptorType descriptorType =
590 deviceMemoryBuffers.front().descriptorType;
591 const uint32_t descriptorSize = deviceMemoryBuffers.size();
592 const auto descriptorSetLayoutBindingIt =
593 descriptorSetLayoutBindingMap.find(descriptorSetIndex);
594
595 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
596 std::cerr << "cannot find layout bindings for the set with number: "
597 << descriptorSetIndex;
598 return failure();
599 }
600
601 const auto &descriptorSetLayoutBindings =
602 descriptorSetLayoutBindingIt->second;
603 // Create descriptor set layout.
604 VkDescriptorSetLayout descriptorSetLayout = {};
605 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
606
607 descriptorSetLayoutCreateInfo.sType =
608 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
609 descriptorSetLayoutCreateInfo.pNext = nullptr;
610 descriptorSetLayoutCreateInfo.flags = 0;
611 // Amount of descriptor bindings in a layout set.
612 descriptorSetLayoutCreateInfo.bindingCount =
613 descriptorSetLayoutBindings.size();
614 descriptorSetLayoutCreateInfo.pBindings =
615 descriptorSetLayoutBindings.data();
616 RETURN_ON_VULKAN_ERROR(
617 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
618 &descriptorSetLayout),
619 "vkCreateDescriptorSetLayout");
620
621 descriptorSetLayouts.push_back(descriptorSetLayout);
622 descriptorSetInfoPool.push_back(
623 {descriptorSetIndex, descriptorSize, descriptorType});
624 }
625 return success();
626 }
627
createPipelineLayout()628 LogicalResult VulkanRuntime::createPipelineLayout() {
629 // Associate descriptor sets with a pipeline layout.
630 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
631 pipelineLayoutCreateInfo.sType =
632 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
633 pipelineLayoutCreateInfo.pNext = nullptr;
634 pipelineLayoutCreateInfo.flags = 0;
635 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
636 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
637 pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
638 pipelineLayoutCreateInfo.pPushConstantRanges = 0;
639 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
640 &pipelineLayoutCreateInfo, 0,
641 &pipelineLayout),
642 "vkCreatePipelineLayout");
643 return success();
644 }
645
createComputePipeline()646 LogicalResult VulkanRuntime::createComputePipeline() {
647 VkPipelineShaderStageCreateInfo stageInfo = {};
648 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
649 stageInfo.pNext = nullptr;
650 stageInfo.flags = 0;
651 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
652 stageInfo.module = shaderModule;
653 // Set entry point.
654 stageInfo.pName = entryPoint;
655 stageInfo.pSpecializationInfo = 0;
656
657 VkComputePipelineCreateInfo computePipelineCreateInfo = {};
658 computePipelineCreateInfo.sType =
659 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
660 computePipelineCreateInfo.pNext = nullptr;
661 computePipelineCreateInfo.flags = 0;
662 computePipelineCreateInfo.stage = stageInfo;
663 computePipelineCreateInfo.layout = pipelineLayout;
664 computePipelineCreateInfo.basePipelineHandle = 0;
665 computePipelineCreateInfo.basePipelineIndex = 0;
666 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
667 &computePipelineCreateInfo, 0,
668 &pipeline),
669 "vkCreateComputePipelines");
670 return success();
671 }
672
createDescriptorPool()673 LogicalResult VulkanRuntime::createDescriptorPool() {
674 std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
675 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
676 // For each descriptor set populate descriptor pool size.
677 VkDescriptorPoolSize descriptorPoolSize = {};
678 descriptorPoolSize.type = descriptorSetInfo.descriptorType;
679 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
680 descriptorPoolSizes.push_back(descriptorPoolSize);
681 }
682
683 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
684 descriptorPoolCreateInfo.sType =
685 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
686 descriptorPoolCreateInfo.pNext = nullptr;
687 descriptorPoolCreateInfo.flags = 0;
688 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
689 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
690 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
691 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
692 &descriptorPoolCreateInfo, 0,
693 &descriptorPool),
694 "vkCreateDescriptorPool");
695 return success();
696 }
697
allocateDescriptorSets()698 LogicalResult VulkanRuntime::allocateDescriptorSets() {
699 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
700 // Size of descriptor sets and descriptor layout sets is the same.
701 descriptorSets.resize(descriptorSetLayouts.size());
702 descriptorSetAllocateInfo.sType =
703 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
704 descriptorSetAllocateInfo.pNext = nullptr;
705 descriptorSetAllocateInfo.descriptorPool = descriptorPool;
706 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
707 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
708 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
709 &descriptorSetAllocateInfo,
710 descriptorSets.data()),
711 "vkAllocateDescriptorSets");
712 return success();
713 }
714
setWriteDescriptors()715 LogicalResult VulkanRuntime::setWriteDescriptors() {
716 if (descriptorSets.size() != descriptorSetInfoPool.size()) {
717 std::cerr << "Each descriptor set must have descriptor set information";
718 return failure();
719 }
720 // For each descriptor set.
721 auto descriptorSetIt = descriptorSets.begin();
722 // Each descriptor set is associated with descriptor set info.
723 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
724 // For each device memory buffer in the descriptor set.
725 const auto &deviceMemoryBuffers =
726 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
727 for (const auto &memoryBuffer : deviceMemoryBuffers) {
728 // Structure describing descriptor sets to write to.
729 VkWriteDescriptorSet wSet = {};
730 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
731 wSet.pNext = nullptr;
732 // Descriptor set.
733 wSet.dstSet = *descriptorSetIt;
734 wSet.dstBinding = memoryBuffer.bindingIndex;
735 wSet.dstArrayElement = 0;
736 wSet.descriptorCount = 1;
737 wSet.descriptorType = memoryBuffer.descriptorType;
738 wSet.pImageInfo = nullptr;
739 wSet.pBufferInfo = &memoryBuffer.bufferInfo;
740 wSet.pTexelBufferView = nullptr;
741 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
742 }
743 // Increment descriptor set iterator.
744 ++descriptorSetIt;
745 }
746 return success();
747 }
748
createCommandPool()749 LogicalResult VulkanRuntime::createCommandPool() {
750 VkCommandPoolCreateInfo commandPoolCreateInfo = {};
751 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
752 commandPoolCreateInfo.pNext = nullptr;
753 commandPoolCreateInfo.flags = 0;
754 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
755 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
756 /*pAllocator=*/nullptr,
757 &commandPool),
758 "vkCreateCommandPool");
759 return success();
760 }
761
createQueryPool()762 LogicalResult VulkanRuntime::createQueryPool() {
763 // Return directly if timestamp query is not supported.
764 if (queueFamilyProperties.timestampValidBits == 0)
765 return success();
766
767 // Get timestamp period for this physical device.
768 VkPhysicalDeviceProperties deviceProperties = {};
769 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
770 timestampPeriod = deviceProperties.limits.timestampPeriod;
771
772 // Create query pool.
773 VkQueryPoolCreateInfo queryPoolCreateInfo = {};
774 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
775 queryPoolCreateInfo.pNext = nullptr;
776 queryPoolCreateInfo.flags = 0;
777 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
778 queryPoolCreateInfo.queryCount = 2;
779 queryPoolCreateInfo.pipelineStatistics = 0;
780 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
781 /*pAllocator=*/nullptr, &queryPool),
782 "vkCreateQueryPool");
783
784 return success();
785 }
786
createComputeCommandBuffer()787 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
788 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
789 commandBufferAllocateInfo.sType =
790 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
791 commandBufferAllocateInfo.pNext = nullptr;
792 commandBufferAllocateInfo.commandPool = commandPool;
793 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
794 commandBufferAllocateInfo.commandBufferCount = 1;
795
796 VkCommandBuffer commandBuffer;
797 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
798 &commandBufferAllocateInfo,
799 &commandBuffer),
800 "vkAllocateCommandBuffers");
801
802 VkCommandBufferBeginInfo commandBufferBeginInfo = {};
803 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
804 commandBufferBeginInfo.pNext = nullptr;
805 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
806 commandBufferBeginInfo.pInheritanceInfo = nullptr;
807
808 // Commands begin.
809 RETURN_ON_VULKAN_ERROR(
810 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
811 "vkBeginCommandBuffer");
812
813 if (queryPool != VK_NULL_HANDLE)
814 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
815
816 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
817 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
818 pipelineLayout, 0, descriptorSets.size(),
819 descriptorSets.data(), 0, 0);
820 // Get a timestamp before invoking the compute shader.
821 if (queryPool != VK_NULL_HANDLE)
822 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
823 queryPool, 0);
824 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
825 numWorkGroups.z);
826 // Get another timestamp after invoking the compute shader.
827 if (queryPool != VK_NULL_HANDLE)
828 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
829 queryPool, 1);
830
831 // Commands end.
832 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
833 "vkEndCommandBuffer");
834
835 commandBuffers.push_back(commandBuffer);
836 return success();
837 }
838
submitCommandBuffersToQueue()839 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
840 VkSubmitInfo submitInfo = {};
841 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
842 submitInfo.pNext = nullptr;
843 submitInfo.waitSemaphoreCount = 0;
844 submitInfo.pWaitSemaphores = 0;
845 submitInfo.pWaitDstStageMask = 0;
846 submitInfo.commandBufferCount = commandBuffers.size();
847 submitInfo.pCommandBuffers = commandBuffers.data();
848 submitInfo.signalSemaphoreCount = 0;
849 submitInfo.pSignalSemaphores = nullptr;
850 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
851 "vkQueueSubmit");
852 return success();
853 }
854
updateHostMemoryBuffers()855 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
856 // First copy back the data to the staging buffer.
857 copyResource(/*deviceToHost=*/true);
858
859 // For each descriptor set.
860 for (auto &resourceDataMapPair : resourceData) {
861 auto &resourceDataMap = resourceDataMapPair.second;
862 auto &deviceMemoryBuffers =
863 deviceMemoryBufferMap[resourceDataMapPair.first];
864 // For each device memory buffer in the set.
865 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
866 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
867 void *payload;
868 auto &hostMemoryBuffer =
869 resourceDataMap[deviceMemoryBuffer.bindingIndex];
870 RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
871 deviceMemoryBuffer.hostMemory, 0,
872 hostMemoryBuffer.size, 0,
873 reinterpret_cast<void **>(&payload)),
874 "vkMapMemory");
875 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
876 vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
877 }
878 }
879 }
880 return success();
881 }
882