1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <vulkan/vulkan.h>
21 #include <dmlc/memory_io.h>
22 #include <dmlc/thread_local.h>
23 #include <tvm/runtime/device_api.h>
24 #include <tvm/runtime/registry.h>
25 
26 #include <array>
27 #include <cstring>
28 
29 
30 #include "../file_util.h"
31 #include "../pack_args.h"
32 #include "../thread_storage_scope.h"
33 #include "../workspace_pool.h"
34 
35 #include "vulkan_common.h"
36 #include "vulkan_module.h"
37 #include "vulkan_shader.h"
38 #include "vulkan_stream.h"
39 
40 namespace tvm {
41 namespace runtime {
42 namespace vulkan {
43 
44 /*! \brief Maximum number of GPU supported in VulkanModule. */
45 static constexpr const int kVulkanMaxNumDevice = 8;
46 
47 /*! \brief TVM Vulkan binary pack magic number */
48 static constexpr const int kVulkanModuleMagic = 0x02700027;
49 
50 class VulkanThreadEntry {
51  public:
52   VulkanThreadEntry();
53   static VulkanThreadEntry* ThreadLocal();
54 
~VulkanThreadEntry()55   ~VulkanThreadEntry() {
56     // Because the thread entry refers to Device API
57     // The command buffer always will be destroyed before
58     // the instance and device get destroyed.
59     // The destruction need to be manually called
60     // to ensure the destruction order.
61     streams_.clear();
62     for (const auto& kv : staging_buffers_) {
63       if (!kv.second) {
64         continue;
65       }
66       auto& buf = *(kv.second);
67       if (buf.host_addr != nullptr) {
68         vkUnmapMemory(buf.device, buf.memory);
69       }
70       if (buf.memory != VK_NULL_HANDLE) {
71         vkFreeMemory(buf.device, buf.memory, nullptr);
72       }
73       if (buf.buffer != VK_NULL_HANDLE) {
74         vkDestroyBuffer(buf.device, buf.buffer, nullptr);
75       }
76     }
77   }
78 
79   TVMContext ctx;
80   WorkspacePool pool;
81   VulkanStream* Stream(size_t device_id);
82   VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
83 
84  private:
85   std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
86   std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
87 };
88 
89 struct VulkanBuffer {
90   VkBuffer buffer{VK_NULL_HANDLE};
91   VkDeviceMemory memory{VK_NULL_HANDLE};
92 };
93 
94 struct VulkanPipeline {
95   VulkanContext* vctx_{nullptr};
96   VkShaderModule shader{VK_NULL_HANDLE};
97   VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
98   VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
99   VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
100   VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
101   VkPipeline pipeline{VK_NULL_HANDLE};
102   VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
103 };
104 
105 typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
106 
107 class VulkanDeviceAPI final : public DeviceAPI {
108  public:
109   VulkanDeviceAPI();
~VulkanDeviceAPI()110   ~VulkanDeviceAPI() {
111     for (auto& vctx : context_) {
112       vkDestroyDevice(vctx.device, nullptr);
113     }
114     if (instance_) {
115       vkDestroyInstance(instance_, nullptr);
116     }
117   }
SetDevice(TVMContext ctx)118   void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
119   void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
AllocDataSpace(TVMContext ctx,size_t nbytes,size_t alignment,TVMType type_hint)120   void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
121     const auto& vctx = context(ctx.device_id);
122     VkBufferCreateInfo info;
123     info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
124     info.pNext = nullptr;
125     info.flags = 0;
126     info.size = nbytes;
127     info.queueFamilyIndexCount = 1;
128     info.pQueueFamilyIndices = &(vctx.queue_family_index);
129     info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
130     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
131                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
132     // create buffer
133     VkBuffer buffer;
134     VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
135     // bind to memory
136     VkBufferMemoryRequirementsInfo2KHR req_info2;
137     req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
138     req_info2.pNext = 0;
139     req_info2.buffer = buffer;
140 
141     VkMemoryRequirements2KHR req2;
142     req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
143     req2.pNext = 0;
144 
145     VkMemoryDedicatedRequirementsKHR dedicated_req;
146     dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
147     dedicated_req.pNext = 0;
148     req2.pNext = &dedicated_req;
149 
150     bool dedicated_allocation = false;
151     if (vctx.get_buffer_memory_requirements_2_functions) {
152       vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
153           vctx.device, &req_info2, &req2);
154       dedicated_allocation =
155           dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
156     }
157 
158     VkDeviceMemory memory;
159     if (!dedicated_allocation) {
160       VkMemoryAllocateInfo minfo;
161       minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
162       minfo.pNext = nullptr;
163       minfo.allocationSize = nbytes;
164       minfo.memoryTypeIndex = vctx.compute_mtype_index;
165       VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
166     } else {
167       VkMemoryAllocateInfo minfo;
168       minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
169       minfo.pNext = nullptr;
170       minfo.allocationSize = req2.memoryRequirements.size;
171       minfo.memoryTypeIndex = vctx.compute_mtype_index;
172 
173       VkMemoryDedicatedAllocateInfoKHR mdinfo;
174       mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
175       mdinfo.pNext = 0;
176       mdinfo.image = 0;
177       mdinfo.buffer = buffer;
178       minfo.pNext = &mdinfo;
179       VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
180     }
181     VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
182     VulkanBuffer* pbuf = new VulkanBuffer();
183     pbuf->memory = memory;
184     pbuf->buffer = buffer;
185     return pbuf;
186   }
187 
FreeDataSpace(TVMContext ctx,void * ptr)188   void FreeDataSpace(TVMContext ctx, void* ptr) final {
189     const auto& vctx = context(ctx.device_id);
190     auto* pbuf = static_cast<VulkanBuffer*>(ptr);
191     vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
192     vkFreeMemory(vctx.device, pbuf->memory, nullptr);
193     delete pbuf;
194   }
195 
CopyDataFromTo(const void * from,size_t from_offset,void * to,size_t to_offset,size_t size,TVMContext ctx_from,TVMContext ctx_to,TVMType type_hint,TVMStreamHandle stream)196   void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
197                       TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
198                       TVMStreamHandle stream) final {
199     CHECK(stream == nullptr);
200     TVMContext ctx = ctx_from;
201     if (ctx_from.device_type == kDLCPU) {
202       ctx = ctx_to;
203     }
204 
205     int from_dev_type = static_cast<int>(ctx_from.device_type);
206     int to_dev_type = static_cast<int>(ctx_to.device_type);
207     if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
208       VulkanThreadEntry::ThreadLocal()
209           ->Stream(ctx_from.device_id)
210           ->Launch([=](VulkanStreamState* state) {
211             // 1: copy
212             const auto* from_buf = static_cast<const VulkanBuffer*>(from);
213             auto* to_buf = static_cast<VulkanBuffer*>(to);
214             VkBufferCopy copy_info;
215             copy_info.srcOffset = from_offset;
216             copy_info.dstOffset = to_offset;
217             copy_info.size = size;
218             vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, &copy_info);
219             // 2: barrier(transfer-> compute|transfer)
220             CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
221             VkMemoryBarrier barrier_info;
222             barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
223             barrier_info.pNext = nullptr;
224             barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
225             barrier_info.dstAccessMask =
226                 (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
227                  VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
228             vkCmdPipelineBarrier(
229                 state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
230                 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
231                 &barrier_info, 0, nullptr, 0, nullptr);
232           });
233 
234     } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
235       const auto* from_buf = static_cast<const VulkanBuffer*>(from);
236       const auto& vctx = context(ctx_from.device_id);
237       auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
238       VulkanThreadEntry::ThreadLocal()
239           ->Stream(ctx_from.device_id)
240           ->Launch([&](VulkanStreamState* state) {
241             VkBufferCopy copy_info;
242             copy_info.srcOffset = from_offset;
243             copy_info.dstOffset = 0;
244             copy_info.size = size;
245             vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, &copy_info);
246           });
247       VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
248       if (!vctx.coherent_staging) {
249         VkMappedMemoryRange mrange;
250         mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
251         mrange.pNext = nullptr;
252         mrange.memory = temp->memory;
253         mrange.offset = 0;
254         mrange.size = VK_WHOLE_SIZE;  // size;
255         VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
256       }
257       memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
258     } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
259       const auto& vctx = context(ctx_to.device_id);
260       const auto* to_buf = static_cast<const VulkanBuffer*>(to);
261       VulkanStagingBuffer* temp =
262           VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
263       memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
264       // host side flush if access is not coherent.
265       // so writes from CPU is visible to GPU
266       if (!vctx.coherent_staging) {
267         VkMappedMemoryRange mrange;
268         mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
269         mrange.pNext = nullptr;
270         mrange.memory = temp->memory;
271         mrange.offset = 0;
272         mrange.size = VK_WHOLE_SIZE;  // size;
273         VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
274       }
275 
276       VulkanThreadEntry::ThreadLocal()
277           ->Stream(ctx_from.device_id)
278           ->Launch([&](VulkanStreamState* state) {
279             // 0: barrier(host->transfer)
280             VkMemoryBarrier barrier_info;
281             barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
282             barrier_info.pNext = nullptr;
283             barrier_info.srcAccessMask = 0;
284             barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
285             vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
286                                  VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
287                                  nullptr);
288             // 1: copy
289             VkBufferCopy copy_info;
290             copy_info.srcOffset = 0;
291             copy_info.dstOffset = to_offset;
292             copy_info.size = size;
293             vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, &copy_info);
294           });
295       // TODO(tulloch): should we instead make the staging buffer a property of the
296       // Stream? This would allow us to elide synchronizations here.
297       VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
298     } else {
299       LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
300                  << ", from=" << from_dev_type << ", to=" << to_dev_type;
301     }
302   }
303 
304   // Always use the default stream
CreateStream(TVMContext ctx)305   TVMStreamHandle CreateStream(TVMContext ctx) {
306     LOG(FATAL) << "Not implemented";
307     return nullptr;
308   }
309 
FreeStream(TVMContext ctx,TVMStreamHandle stream)310   void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
311     LOG(FATAL) << "Not implemented";
312     return;
313   }
314 
SyncStreamFromTo(TVMContext ctx,TVMStreamHandle event_src,TVMStreamHandle event_dst)315   void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
316     LOG(FATAL) << "Not implemented";
317     return;
318   }
319 
StreamSync(TVMContext ctx,TVMStreamHandle stream)320   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
321     CHECK(stream == nullptr);
322     VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
323   }
324 
SetStream(TVMContext ctx,TVMStreamHandle stream)325   void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
326     LOG(FATAL) << "Not implemented";
327     return;
328   }
329 
AllocWorkspace(TVMContext ctx,size_t size,TVMType type_hint)330   void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
331     return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
332   }
333 
FreeWorkspace(TVMContext ctx,void * data)334   void FreeWorkspace(TVMContext ctx, void* data) final {
335     VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
336   }
337 
Global()338   static const std::shared_ptr<VulkanDeviceAPI>& Global() {
339     static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
340     return inst;
341   }
342 
context(size_t device_id) const343   const VulkanContext& context(size_t device_id) const {
344     CHECK_LT(device_id, context_.size());
345     return context_[device_id];
346   }
347 
348  private:
349   VkInstance instance_{nullptr};
350   // The physical devices, have 1 to 1 mapping to devices
351   std::vector<VulkanContext> context_;
352 };
353 
GetAttr(TVMContext ctx,DeviceAttrKind kind,TVMRetValue * rv)354 void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
355   size_t index = static_cast<size_t>(ctx.device_id);
356   if (kind == kExist) {
357     *rv = static_cast<int>(index < context_.size());
358     return;
359   }
360   CHECK_LT(index, context_.size()) << "Invalid device id " << index;
361   const auto& vctx = context(index);
362   switch (kind) {
363     case kMaxThreadsPerBlock: {
364       VkPhysicalDeviceProperties phy_prop;
365       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
366       int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
367       *rv = value;
368       break;
369     }
370     case kMaxSharedMemoryPerBlock: {
371       VkPhysicalDeviceProperties phy_prop;
372       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
373       int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
374       *rv = value;
375       break;
376     }
377     case kWarpSize: {
378       *rv = 1;
379       break;
380     }
381     case kComputeVersion: {
382       VkPhysicalDeviceProperties phy_prop;
383       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
384       int64_t value = phy_prop.apiVersion;
385       std::ostringstream os;
386       os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
387          << VK_VERSION_PATCH(value);
388       *rv = os.str();
389       break;
390     }
391     case kDeviceName:
392       return;
393     case kMaxClockRate:
394       return;
395     case kMultiProcessorCount:
396       return;
397     case kExist:
398       break;
399     case kMaxThreadDimensions:
400       break;
401     case kGcnArch:
402       return;
403   }
404 }
405 
VulkanDeviceAPI()406 VulkanDeviceAPI::VulkanDeviceAPI() {
407   VkApplicationInfo app_info;
408   app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
409   app_info.pNext = nullptr;
410   app_info.pApplicationName = "TVM";
411   app_info.applicationVersion = 0;
412   app_info.pEngineName = "";
413   app_info.engineVersion = 0;
414   app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
415 
416   VkInstanceCreateInfo inst_info;
417   inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
418   inst_info.pNext = nullptr;
419   inst_info.flags = 0;
420 
421   const auto layers = []() -> std::vector<const char*> {
422     uint32_t inst_layer_prop_count;
423     VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
424     std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
425     VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
426     std::vector<const char*> l;
427     for (const auto& lp : inst_layer_prop) {
428       // TODO(tulloch): add CMAKE options.
429       (void)lp;  // suppress unused variable warning.
430 #ifdef USE_VULKAN_VALIDATION
431       if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
432         l.push_back("VK_LAYER_LUNARG_standard_validation");
433       }
434       if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
435         l.push_back("VK_LAYER_LUNARG_parameter_validation");
436       }
437       if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
438         l.push_back("VK_LAYER_KHRONOS_validation");
439       }
440 #endif
441     }
442     return l;
443   }();
444 
445   const auto instance_extensions = []() -> std::vector<const char*> {
446     uint32_t inst_extension_prop_count;
447     VULKAN_CALL(
448         vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
449     std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
450     VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
451                                                        inst_extension_prop.data()));
452     std::vector<const char*> extensions;
453     for (const auto& ip : inst_extension_prop) {
454       if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
455         extensions.push_back("VK_KHR_get_physical_device_properties2");
456       }
457     }
458     return extensions;
459   }();
460 
461   inst_info.pApplicationInfo = &app_info;
462   inst_info.enabledLayerCount = layers.size();
463   inst_info.ppEnabledLayerNames = layers.data();
464   inst_info.enabledExtensionCount = instance_extensions.size();
465   inst_info.ppEnabledExtensionNames = instance_extensions.data();
466 
467   VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
468 
469   uint32_t phy_dev_count = 0;
470   VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
471   std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
472   VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
473   for (VkPhysicalDevice phy_dev : all_phy_devs) {
474     uint32_t queue_prop_count = 0;
475     vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
476     std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
477     vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
478                                              dmlc::BeginPtr(queue_props));
479     uint32_t queue_family_index = 0;
480     std::vector<VkDeviceQueueCreateInfo> queue_create_info;
481     float priority = 1.0f;
482     for (uint32_t i = 0; i < queue_props.size(); i++) {
483       // find queues that support compute
484       if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
485         VkDeviceQueueCreateInfo info;
486         info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
487         info.pNext = nullptr;
488         info.flags = 0;
489         info.queueFamilyIndex = i;
490         info.queueCount = 1;
491         info.pQueuePriorities = &priority;
492 
493         queue_create_info.push_back(info);
494         // only use the first available queue for now
495         if (queue_create_info.size() == 0) {
496           queue_family_index = i;
497         }
498       }
499     }
500     if (queue_create_info.size() == 0) continue;
501 
502     VulkanContext ctx;
503     // setup context
504     ctx.phy_device = phy_dev;
505     vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
506 
507     const auto extensions = [&]() {
508       uint32_t device_extension_prop_count;
509       VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
510                                                        &device_extension_prop_count, nullptr));
511       std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
512       VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
513           ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
514       std::vector<const char*> extensions;
515       for (const auto& dp : device_extension_prop) {
516         if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
517           extensions.push_back("VK_KHR_push_descriptor");
518         }
519         if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
520             dp.specVersion > 0) {
521           extensions.push_back("VK_KHR_descriptor_update_template");
522         }
523         if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
524             dp.specVersion > 0) {
525           extensions.push_back("VK_KHR_get_memory_requirements2");
526         }
527         if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
528             dp.specVersion > 0) {
529           extensions.push_back("VK_KHR_dedicated_allocation");
530         }
531       }
532       return extensions;
533     }();
534     VkDeviceCreateInfo device_create_info;
535     device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
536     device_create_info.pNext = nullptr;
537     device_create_info.flags = 0;
538     device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
539     device_create_info.pQueueCreateInfos = queue_create_info.data();
540     device_create_info.enabledLayerCount = 0;
541     device_create_info.ppEnabledLayerNames = nullptr;
542     device_create_info.enabledExtensionCount = extensions.size();
543     device_create_info.ppEnabledExtensionNames = extensions.data();
544     device_create_info.pEnabledFeatures = nullptr;
545     VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
546     ctx.queue_mutex.reset(new std::mutex());
547     vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
548     ctx.queue_family_index = queue_family_index;
549     // Find suitable memory type for staging and compute
550     // Find suitable compute index.
551     VkBuffer buffer;
552     VkMemoryRequirements req_staging, req_compute;
553     VkBufferCreateInfo info;
554     info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
555     info.pNext = nullptr;
556     info.flags = 0;
557     info.size = 1024;
558     info.queueFamilyIndexCount = 1;
559     info.pQueueFamilyIndices = &(ctx.queue_family_index);
560     info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
561 
562     // get staging requirement
563     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
564     VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
565     vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
566     vkDestroyBuffer(ctx.device, buffer, nullptr);
567     // get compute requirement
568     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
569                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
570     VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
571     vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
572     vkDestroyBuffer(ctx.device, buffer, nullptr);
573 
574     // Query phyiscal device property
575     // find a memory that is host visible, no need to be consistent
576     int win_rank = -1;
577     VkPhysicalDeviceMemoryProperties prop;
578     vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
579 
580     for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
581       VkMemoryType ty = prop.memoryTypes[k];
582       size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
583       // host visible
584       if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
585       // match copy requirment
586       if (!(req_staging.memoryTypeBits & (1 << k))) continue;
587       if (heap_size < 1024) continue;
588       int rank = 0;
589       rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
590       if (rank > win_rank) {
591         win_rank = rank;
592         ctx.staging_mtype_index = k;
593         ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
594       }
595     }
596     CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
597 
598     win_rank = -1;
599     for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
600       VkMemoryType ty = prop.memoryTypes[k];
601       size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
602       // host visible
603       if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
604       // match copy requirment
605       if (!(req_staging.memoryTypeBits & (1 << k))) continue;
606       if (heap_size < 1024) continue;
607       int rank = 0;
608       // prefer not host visible
609       rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
610       if (rank > win_rank) {
611         win_rank = rank;
612         ctx.compute_mtype_index = k;
613       }
614     }
615     CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
616     auto has_extension = [&extensions](const char* query) {
617       return std::any_of(extensions.begin(), extensions.end(),
618                          [&](const char* extension) { return std::strcmp(query, extension) == 0; });
619     };
620 
621 #ifdef USE_VULKAN_IMMEDIATE_MODE
622     if (has_extension("VK_KHR_push_descriptor") &&
623         has_extension("VK_KHR_descriptor_update_template")) {
624       ctx.descriptor_template_khr_functions =
625           std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
626               new VulkanDescriptorTemplateKHRFunctions());
627       ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
628           CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
629               ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
630       ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
631           CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
632               ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
633       ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
634           CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
635               ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
636       ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
637           CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
638               ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
639     }
640 #endif
641 
642 #ifdef USE_VULKAN_DEDICATED_ALLOCATION
643     if (has_extension("VK_KHR_get_memory_requirements2") &&
644         has_extension("VK_KHR_dedicated_allocation")) {
645       ctx.get_buffer_memory_requirements_2_functions =
646           std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
647               new VulkanGetBufferMemoryRequirements2Functions());
648       ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
649           CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
650               ctx.device, "vkGetBufferMemoryRequirements2KHR"));
651     }
652 #endif
653     context_.push_back(std::move(ctx));
654   }
655 
656   LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
657   for (size_t i = 0; i < context_.size(); ++i) {
658     LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
659               << "\' phy_dev_id=" << context_[i].phy_device
660               << " use_immediate=" << context_[i].UseImmediate();
661   }
662 }  // namespace vulkan
663 class VulkanModuleNode;
664 
665 // a wrapped function class to get packed func.
666 class VulkanWrappedFunc {
667  public:
Init(VulkanModuleNode * m,ObjectPtr<Object> sptr,const std::string & func_name,size_t num_buffer_args,size_t num_pack_args,const std::vector<std::string> & thread_axis_tags)668   void Init(VulkanModuleNode* m,
669             ObjectPtr<Object> sptr,
670             const std::string& func_name,
671             size_t num_buffer_args, size_t num_pack_args,
672             const std::vector<std::string>& thread_axis_tags) {
673     m_ = m;
674     sptr_ = sptr;
675     func_name_ = func_name;
676     num_buffer_args_ = num_buffer_args;
677     num_pack_args_ = num_pack_args;
678     thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
679   }
680 
681   void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
682 
683  private:
684   // internal module
685   VulkanModuleNode* m_;
686   // the resource holder
687   ObjectPtr<Object> sptr_;
688   // v The name of the function.
689   std::string func_name_;
690   // Number of buffer arguments
691   size_t num_buffer_args_;
692   // number of packed arguments.
693   size_t num_pack_args_;
694   // Device state cache per device.
695   // mark as mutable, to enable lazy initialization
696   // thread axis configuration
697   ThreadAxisConfig thread_axis_cfg_;
698 
699   mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
700 };
701 
702 // Multi-device enabled module.
703 class VulkanModuleNode final : public runtime::ModuleNode {
704  public:
VulkanModuleNode(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)705   explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
706                              std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
707       : smap_(smap), fmap_(fmap), source_(source) {}
708 
type_key() const709   const char* type_key() const final { return "vulkan"; }
710 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)711   PackedFunc GetFunction(const std::string& name,
712                          const ObjectPtr<Object>& sptr_to_self) final {
713     CHECK_EQ(sptr_to_self.get(), this);
714     CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
715     auto it = fmap_.find(name);
716     if (it == fmap_.end()) return PackedFunc();
717     const FunctionInfo& info = it->second;
718     VulkanWrappedFunc f;
719     size_t num_buffer_args = NumBufferArgs(info.arg_types);
720     f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
721            info.thread_axis_tags);
722     return PackFuncNonBufferArg(std::move(f), info.arg_types);
723   }
724 
~VulkanModuleNode()725   ~VulkanModuleNode() {
726     // cleanup vulkan related caches.
727     for (int device_id = 0; device_id < ecache_.size(); ++device_id) {
728       for (auto& kv : ecache_[device_id]) {
729         auto& pe = kv.second;
730         CHECK(pe);
731         const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
732 
733         if (pe->descriptor_update_template != VK_NULL_HANDLE) {
734           vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
735               vctx.device, pe->descriptor_update_template, nullptr);
736         }
737         vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
738         vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
739         vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
740         vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
741         vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
742       }
743     }
744   }
745 
GetPipeline(size_t device_id,const std::string & func_name,size_t num_pack_args)746   std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
747                                                size_t num_pack_args) {
748     const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
749     std::lock_guard<std::mutex> lock(mutex_);
750     const auto& cp = ecache_[device_id][func_name];
751     if (cp) {
752       return cp;
753     }
754     // Create new pipeline
755     auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
756     {
757       // create shader
758       auto sit = smap_.find(func_name);
759       CHECK(sit != smap_.end());
760       const std::vector<uint32_t>& data = sit->second.data;
761       VkShaderModuleCreateInfo shader_cinfo;
762       shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
763       shader_cinfo.pNext = nullptr;
764       shader_cinfo.flags = 0;
765       shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
766       shader_cinfo.pCode = data.data();
767       VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
768     }
769     std::vector<VkDescriptorSetLayoutBinding> arg_binding;
770     std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
771     uint32_t num_pod = 0, num_buffer = 0;
772     {
773       auto fit = fmap_.find(func_name);
774       CHECK(fit != fmap_.end());
775       for (TVMType arg_type : fit->second.arg_types) {
776         if (arg_type.code == kHandle) {
777           {
778             VkDescriptorSetLayoutBinding bd;
779             bd.binding = num_buffer;
780             bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
781             bd.descriptorCount = 1;
782             bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
783             bd.pImmutableSamplers = nullptr;
784             arg_binding.push_back(bd);
785           }
786           {
787             VkDescriptorUpdateTemplateEntryKHR tpl;
788             tpl.dstBinding = num_buffer;
789             tpl.dstArrayElement = 0;
790             tpl.descriptorCount = 1;
791             tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
792             tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
793             tpl.stride = sizeof(VkDescriptorBufferInfo);
794             arg_template.push_back(tpl);
795           }
796           ++num_buffer;
797         } else {
798           ++num_pod;
799         }
800       }
801     }
802 
803     {
804       VkDescriptorSetLayoutCreateInfo descrip_cinfo;
805       descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
806       descrip_cinfo.pNext = nullptr;
807       descrip_cinfo.flags = 0;
808       if (vctx.UseImmediate()) {
809         descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
810       }
811       descrip_cinfo.bindingCount = arg_binding.size();
812       descrip_cinfo.pBindings = arg_binding.data();
813       VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
814                                               &(pe->descriptor_set_layout)));
815     }
816 
817     {
818       VkDescriptorPoolSize pool_size;
819       pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
820       pool_size.descriptorCount = arg_binding.size();
821       VkDescriptorPoolCreateInfo descrip_pool_cinfo;
822       descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
823       descrip_pool_cinfo.pNext = nullptr;
824       descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
825       descrip_pool_cinfo.maxSets = 1;
826       descrip_pool_cinfo.poolSizeCount = 1;
827       descrip_pool_cinfo.pPoolSizes = &pool_size;
828       VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
829                                          &(pe->descriptor_pool)));
830     }
831 
832     if (!vctx.UseImmediate()) {
833       VkDescriptorSetAllocateInfo alloc_info;
834       alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
835       alloc_info.pNext = nullptr;
836       alloc_info.descriptorPool = pe->descriptor_pool;
837       alloc_info.descriptorSetCount = 1;
838       alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
839       VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
840     }
841 
842     VkPushConstantRange crange;
843     crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
844     crange.offset = 0;
845     crange.size = sizeof(ArgUnion) * num_pack_args;
846 
847     VkPipelineLayoutCreateInfo playout_cinfo;
848     playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
849     playout_cinfo.pNext = nullptr;
850     playout_cinfo.flags = 0;
851     playout_cinfo.setLayoutCount = 1;
852     playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
853 
854     if (num_pack_args != 0) {
855       playout_cinfo.pushConstantRangeCount = 1;
856       playout_cinfo.pPushConstantRanges = &crange;
857       CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
858     } else {
859       playout_cinfo.pushConstantRangeCount = 0;
860       playout_cinfo.pPushConstantRanges = nullptr;
861     }
862 
863     VULKAN_CALL(
864         vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
865 
866     VkComputePipelineCreateInfo pipeline_cinfo;
867     pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
868     pipeline_cinfo.pNext = nullptr;
869     pipeline_cinfo.flags = 0;
870     pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
871     pipeline_cinfo.stage.pNext = nullptr;
872     pipeline_cinfo.stage.flags = 0;
873     pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
874     pipeline_cinfo.stage.module = pe->shader;
875     pipeline_cinfo.stage.pName = func_name.c_str();
876     pipeline_cinfo.stage.pSpecializationInfo = nullptr;
877     pipeline_cinfo.layout = pe->pipeline_layout;
878     pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
879     pipeline_cinfo.basePipelineIndex = 0;
880     VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
881                                          &(pe->pipeline)));
882 
883     if (vctx.UseImmediate()) {
884       VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
885       descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
886       descrip_template_cinfo.pNext = 0;
887       descrip_template_cinfo.flags = 0;
888       descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
889       descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
890       descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
891       descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
892       descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
893       descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
894       descrip_template_cinfo.set = 0;
895       VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
896           vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
897     }
898     ecache_[device_id][func_name] = pe;
899     return pe;
900   }
901 
SaveToFile(const std::string & file_name,const std::string & format)902   void SaveToFile(const std::string& file_name, const std::string& format) final {
903     std::string fmt = GetFileFormat(file_name, format);
904     CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
905     std::string meta_file = GetMetaFilePath(file_name);
906     SaveMetaDataToFile(meta_file, fmap_);
907     std::string data_bin;
908     dmlc::MemoryStringStream fs(&data_bin);
909     dmlc::Stream* stream = &fs;
910     uint32_t magic = kVulkanModuleMagic;
911     stream->Write(magic);
912     stream->Write(smap_);
913     SaveBinaryToFile(file_name, data_bin);
914   }
915 
SaveToBinary(dmlc::Stream * stream)916   void SaveToBinary(dmlc::Stream* stream) final {
917     stream->Write(fmt_);
918     stream->Write(fmap_);
919     stream->Write(smap_);
920   }
GetSource(const std::string & format)921   std::string GetSource(const std::string& format) final {
922     // can only return source code.
923     return source_;
924   }
925 
926  private:
927   // the binary data
928   std::vector<uint32_t> data_;
929   // function information table.
930   std::unordered_map<std::string, VulkanShader> smap_;
931   // function information table.
932   std::unordered_map<std::string, FunctionInfo> fmap_;
933   // The format
934   std::string fmt_{"vulkan"};
935   // The source
936   std::string source_;
937 
938   // Guards accesses to `ecache_`
939   std::mutex mutex_;
940   std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
941       ecache_;
942 };
943 
VulkanModuleCreate(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)944 Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
945                           std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
946   auto n = make_object<VulkanModuleNode>(smap, fmap, source);
947   return Module(n);
948 }
949 
ThreadLocal()950 VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
951 
StagingBuffer(int device_id,size_t size)952 VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
953   if (!staging_buffers_[device_id]) {
954     staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
955   }
956   auto& buf = *(staging_buffers_[device_id]);
957   if (buf.device != nullptr && buf.size < size) {
958     // free previous buffer
959     if (buf.host_addr != nullptr) {
960       vkUnmapMemory(buf.device, buf.memory);
961     }
962     if (buf.memory != VK_NULL_HANDLE) {
963       vkFreeMemory(buf.device, buf.memory, nullptr);
964     }
965     if (buf.buffer != VK_NULL_HANDLE) {
966       vkDestroyBuffer(buf.device, buf.buffer, nullptr);
967     }
968     buf.host_addr = nullptr;
969     buf.memory = VK_NULL_HANDLE;
970     buf.buffer = VK_NULL_HANDLE;
971   }
972   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
973 
974   if (buf.device == nullptr) {
975     buf.device = vctx.device;
976   }
977   if (buf.memory == VK_NULL_HANDLE) {
978     // allocate the stagging buffer memory if necessary
979     VkBufferCreateInfo info;
980     info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
981     info.pNext = nullptr;
982     info.flags = 0;
983     info.size = size;
984     info.queueFamilyIndexCount = 1;
985     info.pQueueFamilyIndices = &(vctx.queue_family_index);
986     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
987     info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
988     VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
989     VkMemoryAllocateInfo minfo;
990     minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
991     minfo.pNext = nullptr;
992     minfo.allocationSize = size;
993     minfo.memoryTypeIndex = vctx.staging_mtype_index;
994     VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
995     VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
996     VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
997     buf.size = size;
998   }
999   memset(buf.host_addr, 0, size);
1000   return &buf;
1001 }
1002 
VulkanThreadEntry()1003 VulkanThreadEntry::VulkanThreadEntry()
1004     : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
1005   ctx.device_id = 0;
1006   ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
1007 }
1008 
Stream(size_t device_id)1009 VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
1010   if (!streams_[device_id]) {
1011     streams_[device_id] = std::unique_ptr<VulkanStream>(
1012         new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
1013   }
1014   return streams_[device_id].get();
1015 }
1016 
operator ()(TVMArgs args,TVMRetValue * rv,const ArgUnion * pack_args) const1017 void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
1018                                     const ArgUnion* pack_args) const {
1019   int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
1020   CHECK_LT(device_id, kVulkanMaxNumDevice);
1021   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
1022   if (!scache_[device_id]) {
1023     scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
1024   }
1025   const auto& pipeline = scache_[device_id];
1026   ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
1027   std::vector<VkDescriptorBufferInfo> descriptor_buffers;
1028   descriptor_buffers.resize(num_buffer_args_);
1029   for (int i = 0; i < num_buffer_args_; ++i) {
1030     void* buf = args[static_cast<int>(i)];
1031     VkDescriptorBufferInfo binfo;
1032     binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
1033     binfo.offset = 0;
1034     binfo.range = VK_WHOLE_SIZE;
1035     descriptor_buffers[i] = binfo;
1036   }
1037   if (vctx.UseImmediate()) {
1038     // Can safely capture by reference as this lambda is immediately executed on the calling thread.
1039     VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
1040       vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1041       CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
1042       vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
1043           state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
1044           descriptor_buffers.data());
1045       if (num_pack_args_ != 0) {
1046         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
1047                            VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
1048                            pack_args);
1049       }
1050       vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1051       VkMemoryBarrier barrier_info;
1052       barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1053       barrier_info.pNext = nullptr;
1054       barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1055       barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1056                                     VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1057       vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1058                            VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1059                            1, &barrier_info, 0, nullptr, 0, nullptr);
1060     });
1061     return;
1062   }
1063 
1064   // Otherwise, the more expensive deferred path.
1065   std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
1066   const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
1067     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
1068     write_descriptor_sets.resize(descriptor_buffers.size());
1069     for (int i = 0; i < write_descriptor_sets.size(); i++) {
1070       write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
1071       write_descriptor_sets[i].pNext = 0;
1072       write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
1073       write_descriptor_sets[i].dstBinding = i;
1074       write_descriptor_sets[i].dstArrayElement = 0;
1075       write_descriptor_sets[i].descriptorCount = 1;
1076       write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
1077       write_descriptor_sets[i].pImageInfo = 0;
1078       write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
1079       write_descriptor_sets[i].pTexelBufferView = 0;
1080     }
1081     vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
1082                            0, 0);
1083   };
1084   const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
1085     vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1086     vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
1087                             pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
1088                             nullptr);
1089     if (pack_args_storage.size() != 0) {
1090       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
1091                          0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
1092     }
1093     vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1094     VkMemoryBarrier barrier_info;
1095     barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1096     barrier_info.pNext = nullptr;
1097     barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1098     barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1099                                   VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1100     vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1101                          VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1102                          1, &barrier_info, 0, nullptr, 0, nullptr);
1103   };
1104   VulkanStreamToken deferred_token;
1105   deferred_token.descriptor_set_ = pipeline->descriptor_set;
1106   deferred_token.buffers_.resize(descriptor_buffers.size());
1107   for (int i = 0; i < descriptor_buffers.size(); ++i) {
1108     deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
1109   }
1110   VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
1111       deferred_initializer, deferred_kernel, deferred_token);
1112 }
1113 
VulkanModuleLoadFile(const std::string & file_name,const std::string & format)1114 Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
1115   std::string data;
1116   std::unordered_map<std::string, VulkanShader> smap;
1117   std::unordered_map<std::string, FunctionInfo> fmap;
1118   std::string fmt = GetFileFormat(file_name, format);
1119   std::string meta_file = GetMetaFilePath(file_name);
1120   LoadBinaryFromFile(file_name, &data);
1121   LoadMetaDataFromFile(meta_file, &fmap);
1122   dmlc::MemoryStringStream fs(&data);
1123   dmlc::Stream* stream = &fs;
1124   uint32_t magic;
1125   stream->Read(&magic);
1126   CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
1127   stream->Read(&smap);
1128   return VulkanModuleCreate(smap, fmap, "");
1129 }
1130 
VulkanModuleLoadBinary(void * strm)1131 Module VulkanModuleLoadBinary(void* strm) {
1132   dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
1133   std::unordered_map<std::string, VulkanShader> smap;
1134   std::unordered_map<std::string, FunctionInfo> fmap;
1135 
1136   std::string fmt;
1137   stream->Read(&fmt);
1138   stream->Read(&fmap);
1139   stream->Read(&smap);
1140   return VulkanModuleCreate(smap, fmap, "");
1141 }
1142 
1143 TVM_REGISTER_GLOBAL("module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
1144 
1145 TVM_REGISTER_GLOBAL("module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
1146 
__anoneb4276740c02(TVMArgs args, TVMRetValue* rv) 1147 TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
1148   DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
1149   *rv = static_cast<void*>(ptr);
1150 });
1151 
1152 }  // namespace vulkan
1153 }  // namespace runtime
1154 }  // namespace tvm
1155