1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <dmlc/memory_io.h>
21 #include <dmlc/thread_local.h>
22 #include <tvm/runtime/device_api.h>
23 #include <tvm/runtime/registry.h>
24 #include <vulkan/vulkan.h>
25 
26 #include <array>
27 #include <cstring>
28 
29 #include "../file_util.h"
30 #include "../pack_args.h"
31 #include "../thread_storage_scope.h"
32 #include "../workspace_pool.h"
33 #include "vulkan_common.h"
34 #include "vulkan_module.h"
35 #include "vulkan_shader.h"
36 #include "vulkan_stream.h"
37 
38 namespace tvm {
39 namespace runtime {
40 namespace vulkan {
41 
42 /*! \brief Maximum number of GPU supported in VulkanModule. */
43 static constexpr const int kVulkanMaxNumDevice = 8;
44 
45 /*! \brief TVM Vulkan binary pack magic number */
46 static constexpr const int kVulkanModuleMagic = 0x02700027;
47 
48 class VulkanThreadEntry {
49  public:
50   VulkanThreadEntry();
51   static VulkanThreadEntry* ThreadLocal();
52 
~VulkanThreadEntry()53   ~VulkanThreadEntry() {
54     // Because the thread entry refers to Device API
55     // The command buffer always will be destroyed before
56     // the instance and device get destroyed.
57     // The destruction need to be manually called
58     // to ensure the destruction order.
59 
60     pool.reset();
61     streams_.clear();
62     for (const auto& kv : staging_buffers_) {
63       if (!kv.second) {
64         continue;
65       }
66       auto& buf = *(kv.second);
67       if (buf.host_addr != nullptr) {
68         vkUnmapMemory(buf.device, buf.memory);
69       }
70       if (buf.memory != VK_NULL_HANDLE) {
71         vkFreeMemory(buf.device, buf.memory, nullptr);
72       }
73       if (buf.buffer != VK_NULL_HANDLE) {
74         vkDestroyBuffer(buf.device, buf.buffer, nullptr);
75       }
76     }
77   }
78 
79   TVMContext ctx;
80   std::unique_ptr<WorkspacePool> pool;
81   VulkanStream* Stream(size_t device_id);
82   VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
83 
84  private:
85   std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
86   std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
87 };
88 
89 struct VulkanBuffer {
90   VkBuffer buffer{VK_NULL_HANDLE};
91   VkDeviceMemory memory{VK_NULL_HANDLE};
92 };
93 
94 struct VulkanPipeline {
95   VulkanContext* vctx_{nullptr};
96   VkShaderModule shader{VK_NULL_HANDLE};
97   VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
98   VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
99   VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
100   VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
101   VkPipeline pipeline{VK_NULL_HANDLE};
102   VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
103 };
104 
105 typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
106 
107 class VulkanDeviceAPI final : public DeviceAPI {
108  public:
109   VulkanDeviceAPI();
~VulkanDeviceAPI()110   ~VulkanDeviceAPI() {
111     for (auto& vctx : context_) {
112       vkDestroyDevice(vctx.device, nullptr);
113     }
114     if (instance_) {
115       vkDestroyInstance(instance_, nullptr);
116     }
117   }
SetDevice(TVMContext ctx)118   void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
119   void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
AllocDataSpace(TVMContext ctx,size_t nbytes,size_t alignment,DLDataType type_hint)120   void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
121                        DLDataType type_hint) final {
122     const auto& vctx = context(ctx.device_id);
123     VkBufferCreateInfo info;
124     info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
125     info.pNext = nullptr;
126     info.flags = 0;
127     info.size = nbytes;
128     info.queueFamilyIndexCount = 1;
129     info.pQueueFamilyIndices = &(vctx.queue_family_index);
130     info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
131     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
132                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
133     // create buffer
134     VkBuffer buffer;
135     VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
136     // bind to memory
137     VkBufferMemoryRequirementsInfo2KHR req_info2;
138     req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
139     req_info2.pNext = 0;
140     req_info2.buffer = buffer;
141 
142     VkMemoryRequirements2KHR req2;
143     req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
144     req2.pNext = 0;
145 
146     VkMemoryDedicatedRequirementsKHR dedicated_req;
147     dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
148     dedicated_req.pNext = 0;
149     req2.pNext = &dedicated_req;
150 
151     bool dedicated_allocation = false;
152     if (vctx.get_buffer_memory_requirements_2_functions) {
153       vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
154           vctx.device, &req_info2, &req2);
155       dedicated_allocation =
156           dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
157     }
158 
159     VkDeviceMemory memory;
160     if (!dedicated_allocation) {
161       VkMemoryAllocateInfo minfo;
162       minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
163       minfo.pNext = nullptr;
164       minfo.allocationSize = nbytes;
165       minfo.memoryTypeIndex = vctx.compute_mtype_index;
166       VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
167     } else {
168       VkMemoryAllocateInfo minfo;
169       minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
170       minfo.pNext = nullptr;
171       minfo.allocationSize = req2.memoryRequirements.size;
172       minfo.memoryTypeIndex = vctx.compute_mtype_index;
173 
174       VkMemoryDedicatedAllocateInfoKHR mdinfo;
175       mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
176       mdinfo.pNext = 0;
177       mdinfo.image = 0;
178       mdinfo.buffer = buffer;
179       minfo.pNext = &mdinfo;
180       VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
181     }
182     VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
183     VulkanBuffer* pbuf = new VulkanBuffer();
184     pbuf->memory = memory;
185     pbuf->buffer = buffer;
186     return pbuf;
187   }
188 
FreeDataSpace(TVMContext ctx,void * ptr)189   void FreeDataSpace(TVMContext ctx, void* ptr) final {
190     // Before releasing the vkBuffer, call sync to
191     // finish all the vulkan commands that reference the buffer.
192     StreamSync(ctx, nullptr);
193 
194     const auto& vctx = context(ctx.device_id);
195     auto* pbuf = static_cast<VulkanBuffer*>(ptr);
196     vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
197     vkFreeMemory(vctx.device, pbuf->memory, nullptr);
198     delete pbuf;
199   }
200 
CopyDataFromTo(const void * from,size_t from_offset,void * to,size_t to_offset,size_t size,TVMContext ctx_from,TVMContext ctx_to,DLDataType type_hint,TVMStreamHandle stream)201   void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
202                       TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
203                       TVMStreamHandle stream) final {
204     CHECK(stream == nullptr);
205     TVMContext ctx = ctx_from;
206     if (ctx_from.device_type == kDLCPU) {
207       ctx = ctx_to;
208     }
209 
210     int from_dev_type = static_cast<int>(ctx_from.device_type);
211     int to_dev_type = static_cast<int>(ctx_to.device_type);
212     if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
213       VulkanThreadEntry::ThreadLocal()
214           ->Stream(ctx_from.device_id)
215           ->Launch([=](VulkanStreamState* state) {
216             // 1: copy
217             const auto* from_buf = static_cast<const VulkanBuffer*>(from);
218             auto* to_buf = static_cast<VulkanBuffer*>(to);
219             VkBufferCopy copy_info;
220             copy_info.srcOffset = from_offset;
221             copy_info.dstOffset = to_offset;
222             copy_info.size = size;
223             vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, &copy_info);
224             // 2: barrier(transfer-> compute|transfer)
225             CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
226             VkMemoryBarrier barrier_info;
227             barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
228             barrier_info.pNext = nullptr;
229             barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
230             barrier_info.dstAccessMask =
231                 (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
232                  VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
233             vkCmdPipelineBarrier(
234                 state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
235                 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
236                 &barrier_info, 0, nullptr, 0, nullptr);
237           });
238 
239     } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
240       const auto* from_buf = static_cast<const VulkanBuffer*>(from);
241       const auto& vctx = context(ctx_from.device_id);
242       auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
243       VulkanThreadEntry::ThreadLocal()
244           ->Stream(ctx_from.device_id)
245           ->Launch([&](VulkanStreamState* state) {
246             VkBufferCopy copy_info;
247             copy_info.srcOffset = from_offset;
248             copy_info.dstOffset = 0;
249             copy_info.size = size;
250             vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, &copy_info);
251           });
252       VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
253       if (!vctx.coherent_staging) {
254         VkMappedMemoryRange mrange;
255         mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
256         mrange.pNext = nullptr;
257         mrange.memory = temp->memory;
258         mrange.offset = 0;
259         mrange.size = VK_WHOLE_SIZE;  // size;
260         VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
261       }
262       memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
263     } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
264       const auto& vctx = context(ctx_to.device_id);
265       const auto* to_buf = static_cast<const VulkanBuffer*>(to);
266       VulkanStagingBuffer* temp =
267           VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
268       memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
269       // host side flush if access is not coherent.
270       // so writes from CPU is visible to GPU
271       if (!vctx.coherent_staging) {
272         VkMappedMemoryRange mrange;
273         mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
274         mrange.pNext = nullptr;
275         mrange.memory = temp->memory;
276         mrange.offset = 0;
277         mrange.size = VK_WHOLE_SIZE;  // size;
278         VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
279       }
280 
281       VulkanThreadEntry::ThreadLocal()
282           ->Stream(ctx_from.device_id)
283           ->Launch([&](VulkanStreamState* state) {
284             // 0: barrier(host->transfer)
285             VkMemoryBarrier barrier_info;
286             barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
287             barrier_info.pNext = nullptr;
288             barrier_info.srcAccessMask = 0;
289             barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
290             vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
291                                  VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
292                                  nullptr);
293             // 1: copy
294             VkBufferCopy copy_info;
295             copy_info.srcOffset = 0;
296             copy_info.dstOffset = to_offset;
297             copy_info.size = size;
298             vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, &copy_info);
299           });
300       // TODO(tulloch): should we instead make the staging buffer a property of the
301       // Stream? This would allow us to elide synchronizations here.
302       VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
303     } else {
304       LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
305                  << ", from=" << from_dev_type << ", to=" << to_dev_type;
306     }
307   }
308 
309   // Always use the default stream
CreateStream(TVMContext ctx)310   TVMStreamHandle CreateStream(TVMContext ctx) {
311     LOG(FATAL) << "Not implemented";
312     return nullptr;
313   }
314 
FreeStream(TVMContext ctx,TVMStreamHandle stream)315   void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
316     LOG(FATAL) << "Not implemented";
317     return;
318   }
319 
SyncStreamFromTo(TVMContext ctx,TVMStreamHandle event_src,TVMStreamHandle event_dst)320   void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
321     LOG(FATAL) << "Not implemented";
322     return;
323   }
324 
StreamSync(TVMContext ctx,TVMStreamHandle stream)325   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
326     CHECK(stream == nullptr);
327     VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
328   }
329 
SetStream(TVMContext ctx,TVMStreamHandle stream)330   void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
331     LOG(FATAL) << "Not implemented";
332     return;
333   }
334 
AllocWorkspace(TVMContext ctx,size_t size,DLDataType type_hint)335   void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
336     return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(ctx, size);
337   }
338 
FreeWorkspace(TVMContext ctx,void * data)339   void FreeWorkspace(TVMContext ctx, void* data) final {
340     VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
341   }
342 
Global()343   static VulkanDeviceAPI* Global() {
344     static VulkanDeviceAPI* inst = new VulkanDeviceAPI();
345     return inst;
346   }
347 
context(size_t device_id) const348   const VulkanContext& context(size_t device_id) const {
349     CHECK_LT(device_id, context_.size());
350     return context_[device_id];
351   }
352 
353  private:
354   VkInstance instance_{nullptr};
355   // The physical devices, have 1 to 1 mapping to devices
356   std::vector<VulkanContext> context_;
357 };
358 
GetAttr(TVMContext ctx,DeviceAttrKind kind,TVMRetValue * rv)359 void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
360   size_t index = static_cast<size_t>(ctx.device_id);
361   if (kind == kExist) {
362     *rv = static_cast<int>(index < context_.size());
363     return;
364   }
365   CHECK_LT(index, context_.size()) << "Invalid device id " << index;
366   const auto& vctx = context(index);
367   switch (kind) {
368     case kMaxThreadsPerBlock: {
369       VkPhysicalDeviceProperties phy_prop;
370       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
371       int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations;
372       *rv = value;
373       break;
374     }
375     case kMaxSharedMemoryPerBlock: {
376       VkPhysicalDeviceProperties phy_prop;
377       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
378       int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
379       *rv = value;
380       break;
381     }
382     case kWarpSize: {
383       *rv = 1;
384       break;
385     }
386     case kComputeVersion: {
387       VkPhysicalDeviceProperties phy_prop;
388       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
389       int64_t value = phy_prop.apiVersion;
390       std::ostringstream os;
391       os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
392          << VK_VERSION_PATCH(value);
393       *rv = os.str();
394       break;
395     }
396     case kDeviceName:
397       return;
398     case kMaxClockRate:
399       return;
400     case kMultiProcessorCount:
401       return;
402     case kExist:
403       break;
404     case kMaxThreadDimensions: {
405       VkPhysicalDeviceProperties phy_prop;
406       vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
407       int64_t dims[3];
408       dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0];
409       dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1];
410       dims[2] = phy_prop.limits.maxComputeWorkGroupSize[2];
411       std::stringstream ss;  // use json string to return multiple int values;
412       ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
413       *rv = ss.str();
414       break;
415     }
416     case kMaxRegistersPerBlock:
417       return;
418     case kGcnArch:
419       return;
420     case kApiVersion:
421       return;
422   }
423 }
424 
VulkanDeviceAPI()425 VulkanDeviceAPI::VulkanDeviceAPI() {
426   VkApplicationInfo app_info;
427   app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
428   app_info.pNext = nullptr;
429   app_info.pApplicationName = "TVM";
430   app_info.applicationVersion = 0;
431   app_info.pEngineName = "";
432   app_info.engineVersion = 0;
433   app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
434 
435   VkInstanceCreateInfo inst_info;
436   inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
437   inst_info.pNext = nullptr;
438   inst_info.flags = 0;
439 
440   const auto layers = []() -> std::vector<const char*> {
441     uint32_t inst_layer_prop_count;
442     VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
443     std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
444     VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
445     std::vector<const char*> l;
446     for (const auto& lp : inst_layer_prop) {
447       // TODO(tulloch): add CMAKE options.
448       (void)lp;  // suppress unused variable warning.
449 #ifdef USE_VULKAN_VALIDATION
450       if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
451         l.push_back("VK_LAYER_LUNARG_standard_validation");
452       }
453       if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
454         l.push_back("VK_LAYER_LUNARG_parameter_validation");
455       }
456       if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
457         l.push_back("VK_LAYER_KHRONOS_validation");
458       }
459 #endif
460     }
461     return l;
462   }();
463 
464   const auto instance_extensions = []() -> std::vector<const char*> {
465     uint32_t inst_extension_prop_count;
466     VULKAN_CALL(
467         vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
468     std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
469     VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
470                                                        inst_extension_prop.data()));
471     std::vector<const char*> extensions;
472     for (const auto& ip : inst_extension_prop) {
473       if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
474         extensions.push_back("VK_KHR_get_physical_device_properties2");
475       }
476     }
477     return extensions;
478   }();
479 
480   inst_info.pApplicationInfo = &app_info;
481   inst_info.enabledLayerCount = layers.size();
482   inst_info.ppEnabledLayerNames = layers.data();
483   inst_info.enabledExtensionCount = instance_extensions.size();
484   inst_info.ppEnabledExtensionNames = instance_extensions.data();
485 
486   VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
487 
488   uint32_t phy_dev_count = 0;
489   VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
490   std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
491   VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
492   for (VkPhysicalDevice phy_dev : all_phy_devs) {
493     uint32_t queue_prop_count = 0;
494     vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
495     std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
496     vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
497                                              dmlc::BeginPtr(queue_props));
498     uint32_t queue_family_index = 0;
499     std::vector<VkDeviceQueueCreateInfo> queue_create_info;
500     float priority = 1.0f;
501     for (uint32_t i = 0; i < queue_props.size(); i++) {
502       // find queues that support compute
503       if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
504         VkDeviceQueueCreateInfo info;
505         info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
506         info.pNext = nullptr;
507         info.flags = 0;
508         info.queueFamilyIndex = i;
509         info.queueCount = 1;
510         info.pQueuePriorities = &priority;
511 
512         queue_create_info.push_back(info);
513         // only use the first available queue for now
514         if (queue_create_info.size() == 0) {
515           queue_family_index = i;
516         }
517       }
518     }
519     if (queue_create_info.size() == 0) continue;
520 
521     VulkanContext ctx;
522     // setup context
523     ctx.phy_device = phy_dev;
524     vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
525 
526     const auto extensions = [&]() {
527       uint32_t device_extension_prop_count;
528       VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
529                                                        &device_extension_prop_count, nullptr));
530       std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
531       VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
532           ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
533       std::vector<const char*> extensions;
534       for (const auto& dp : device_extension_prop) {
535         if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
536           extensions.push_back("VK_KHR_push_descriptor");
537         }
538         if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
539             dp.specVersion > 0) {
540           extensions.push_back("VK_KHR_descriptor_update_template");
541         }
542         if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
543             dp.specVersion > 0) {
544           extensions.push_back("VK_KHR_get_memory_requirements2");
545         }
546         if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
547             dp.specVersion > 0) {
548           extensions.push_back("VK_KHR_dedicated_allocation");
549         }
550       }
551       return extensions;
552     }();
553     VkDeviceCreateInfo device_create_info;
554     device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
555     device_create_info.pNext = nullptr;
556     device_create_info.flags = 0;
557     device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
558     device_create_info.pQueueCreateInfos = queue_create_info.data();
559     device_create_info.enabledLayerCount = 0;
560     device_create_info.ppEnabledLayerNames = nullptr;
561     device_create_info.enabledExtensionCount = extensions.size();
562     device_create_info.ppEnabledExtensionNames = extensions.data();
563     device_create_info.pEnabledFeatures = nullptr;
564     VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
565     ctx.queue_mutex.reset(new std::mutex());
566     vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
567     ctx.queue_family_index = queue_family_index;
568     // Find suitable memory type for staging and compute
569     // Find suitable compute index.
570     VkBuffer buffer;
571     VkMemoryRequirements req_staging, req_compute;
572     VkBufferCreateInfo info;
573     info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
574     info.pNext = nullptr;
575     info.flags = 0;
576     info.size = 1024;
577     info.queueFamilyIndexCount = 1;
578     info.pQueueFamilyIndices = &(ctx.queue_family_index);
579     info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
580 
581     // get staging requirement
582     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
583     VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
584     vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
585     vkDestroyBuffer(ctx.device, buffer, nullptr);
586     // get compute requirement
587     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
588                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
589     VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
590     vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
591     vkDestroyBuffer(ctx.device, buffer, nullptr);
592 
593     // Query phyiscal device property
594     // find a memory that is host visible, no need to be consistent
595     int win_rank = -1;
596     VkPhysicalDeviceMemoryProperties prop;
597     vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
598 
599     for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
600       VkMemoryType ty = prop.memoryTypes[k];
601       size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
602       // host visible
603       if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
604       // match copy requirment
605       if (!(req_staging.memoryTypeBits & (1 << k))) continue;
606       if (heap_size < 1024) continue;
607       int rank = 0;
608       rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
609       if (rank > win_rank) {
610         win_rank = rank;
611         ctx.staging_mtype_index = k;
612         ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
613       }
614     }
615     CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
616 
617     win_rank = -1;
618     for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
619       VkMemoryType ty = prop.memoryTypes[k];
620       size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
621       // host visible
622       if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
623       // match copy requirment
624       if (!(req_staging.memoryTypeBits & (1 << k))) continue;
625       if (heap_size < 1024) continue;
626       int rank = 0;
627       // prefer not host visible
628       rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
629       if (rank > win_rank) {
630         win_rank = rank;
631         ctx.compute_mtype_index = k;
632       }
633     }
634     CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
635     auto has_extension = [&extensions](const char* query) {
636       return std::any_of(extensions.begin(), extensions.end(),
637                          [&](const char* extension) { return std::strcmp(query, extension) == 0; });
638     };
639 
640 #ifdef USE_VULKAN_IMMEDIATE_MODE
641     if (has_extension("VK_KHR_push_descriptor") &&
642         has_extension("VK_KHR_descriptor_update_template")) {
643       ctx.descriptor_template_khr_functions = std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
644           new VulkanDescriptorTemplateKHRFunctions());
645       ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
646           CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
647               ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
648       ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
649           CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
650               ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
651       ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
652           CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
653               ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
654       ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
655           CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
656               ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
657     }
658 #endif
659 
660 #ifdef USE_VULKAN_DEDICATED_ALLOCATION
661     if (has_extension("VK_KHR_get_memory_requirements2") &&
662         has_extension("VK_KHR_dedicated_allocation")) {
663       ctx.get_buffer_memory_requirements_2_functions =
664           std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
665               new VulkanGetBufferMemoryRequirements2Functions());
666       ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
667           CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
668               ctx.device, "vkGetBufferMemoryRequirements2KHR"));
669     }
670 #endif
671     context_.push_back(std::move(ctx));
672   }
673 
674   LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
675   for (size_t i = 0; i < context_.size(); ++i) {
676     LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
677               << "\' phy_dev_id=" << context_[i].phy_device
678               << " use_immediate=" << context_[i].UseImmediate();
679   }
680 }  // namespace vulkan
681 class VulkanModuleNode;
682 
683 // a wrapped function class to get packed func.
684 class VulkanWrappedFunc {
685  public:
Init(VulkanModuleNode * m,ObjectPtr<Object> sptr,const std::string & func_name,size_t num_buffer_args,size_t num_pack_args,const std::vector<std::string> & thread_axis_tags)686   void Init(VulkanModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
687             size_t num_buffer_args, size_t num_pack_args,
688             const std::vector<std::string>& thread_axis_tags) {
689     m_ = m;
690     sptr_ = sptr;
691     func_name_ = func_name;
692     num_buffer_args_ = num_buffer_args;
693     num_pack_args_ = num_pack_args;
694     thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
695   }
696 
697   void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
698 
699  private:
700   // internal module
701   VulkanModuleNode* m_;
702   // the resource holder
703   ObjectPtr<Object> sptr_;
704   // v The name of the function.
705   std::string func_name_;
706   // Number of buffer arguments
707   size_t num_buffer_args_;
708   // number of packed arguments.
709   size_t num_pack_args_;
710   // Device state cache per device.
711   // mark as mutable, to enable lazy initialization
712   // thread axis configuration
713   ThreadAxisConfig thread_axis_cfg_;
714 
715   mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
716 };
717 
718 // Multi-device enabled module.
719 class VulkanModuleNode final : public runtime::ModuleNode {
720  public:
VulkanModuleNode(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)721   explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
722                             std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
723       : smap_(smap), fmap_(fmap), source_(source) {}
724 
type_key() const725   const char* type_key() const final { return "vulkan"; }
726 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)727   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
728     CHECK_EQ(sptr_to_self.get(), this);
729     CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
730     auto it = fmap_.find(name);
731     if (it == fmap_.end()) return PackedFunc();
732     const FunctionInfo& info = it->second;
733     VulkanWrappedFunc f;
734     size_t num_buffer_args = NumBufferArgs(info.arg_types);
735     f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
736            info.thread_axis_tags);
737     return PackFuncNonBufferArg(std::move(f), info.arg_types);
738   }
739 
~VulkanModuleNode()740   ~VulkanModuleNode() {
741     // cleanup vulkan related caches.
742     for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) {
743       for (auto& kv : ecache_[device_id]) {
744         auto& pe = kv.second;
745         CHECK(pe);
746         const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
747 
748         if (pe->descriptor_update_template != VK_NULL_HANDLE) {
749           vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
750               vctx.device, pe->descriptor_update_template, nullptr);
751         }
752         vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
753         vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
754         vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
755         vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
756         vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
757       }
758     }
759   }
760 
GetPipeline(size_t device_id,const std::string & func_name,size_t num_pack_args)761   std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
762                                               size_t num_pack_args) {
763     const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
764     std::lock_guard<std::mutex> lock(mutex_);
765     const auto& cp = ecache_[device_id][func_name];
766     if (cp) {
767       return cp;
768     }
769     // Create new pipeline
770     auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
771     {
772       // create shader
773       auto sit = smap_.find(func_name);
774       CHECK(sit != smap_.end());
775       const std::vector<uint32_t>& data = sit->second.data;
776       VkShaderModuleCreateInfo shader_cinfo;
777       shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
778       shader_cinfo.pNext = nullptr;
779       shader_cinfo.flags = 0;
780       shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
781       shader_cinfo.pCode = data.data();
782       VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
783     }
784     std::vector<VkDescriptorSetLayoutBinding> arg_binding;
785     std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
786     uint32_t num_pod = 0, num_buffer = 0;
787 
788     {
789       auto fit = fmap_.find(func_name);
790       CHECK(fit != fmap_.end());
791       for (DLDataType arg_type : fit->second.arg_types) {
792         if (arg_type.code == kTVMOpaqueHandle) {
793           {
794             VkDescriptorSetLayoutBinding bd;
795             bd.binding = num_buffer;
796             bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
797             bd.descriptorCount = 1;
798             bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
799             bd.pImmutableSamplers = nullptr;
800             arg_binding.push_back(bd);
801           }
802           {
803             VkDescriptorUpdateTemplateEntryKHR tpl;
804             tpl.dstBinding = num_buffer;
805             tpl.dstArrayElement = 0;
806             tpl.descriptorCount = 1;
807             tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
808             tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
809             tpl.stride = sizeof(VkDescriptorBufferInfo);
810             arg_template.push_back(tpl);
811           }
812           ++num_buffer;
813         } else {
814           ++num_pod;
815         }
816       }
817     }
818 
819     {
820       VkDescriptorSetLayoutCreateInfo descrip_cinfo;
821       descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
822       descrip_cinfo.pNext = nullptr;
823       descrip_cinfo.flags = 0;
824       if (vctx.UseImmediate()) {
825         descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
826       }
827       descrip_cinfo.bindingCount = arg_binding.size();
828       descrip_cinfo.pBindings = arg_binding.data();
829       VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
830                                               &(pe->descriptor_set_layout)));
831     }
832 
833     {
834       VkDescriptorPoolSize pool_size;
835       pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
836       pool_size.descriptorCount = arg_binding.size();
837       VkDescriptorPoolCreateInfo descrip_pool_cinfo;
838       descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
839       descrip_pool_cinfo.pNext = nullptr;
840       descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
841       descrip_pool_cinfo.maxSets = 1;
842       descrip_pool_cinfo.poolSizeCount = 1;
843       descrip_pool_cinfo.pPoolSizes = &pool_size;
844       VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
845                                          &(pe->descriptor_pool)));
846     }
847 
848     if (!vctx.UseImmediate()) {
849       VkDescriptorSetAllocateInfo alloc_info;
850       alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
851       alloc_info.pNext = nullptr;
852       alloc_info.descriptorPool = pe->descriptor_pool;
853       alloc_info.descriptorSetCount = 1;
854       alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
855       VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
856     }
857 
858     VkPushConstantRange crange;
859     crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
860     crange.offset = 0;
861     crange.size = sizeof(ArgUnion) * num_pack_args;
862 
863     VkPipelineLayoutCreateInfo playout_cinfo;
864     playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
865     playout_cinfo.pNext = nullptr;
866     playout_cinfo.flags = 0;
867     playout_cinfo.setLayoutCount = 1;
868     playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
869 
870     if (num_pack_args != 0) {
871       playout_cinfo.pushConstantRangeCount = 1;
872       playout_cinfo.pPushConstantRanges = &crange;
873       CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
874     } else {
875       playout_cinfo.pushConstantRangeCount = 0;
876       playout_cinfo.pPushConstantRanges = nullptr;
877     }
878 
879     VULKAN_CALL(
880         vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
881 
882     VkComputePipelineCreateInfo pipeline_cinfo;
883     pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
884     pipeline_cinfo.pNext = nullptr;
885     pipeline_cinfo.flags = 0;
886     pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
887     pipeline_cinfo.stage.pNext = nullptr;
888     pipeline_cinfo.stage.flags = 0;
889     pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
890     pipeline_cinfo.stage.module = pe->shader;
891     pipeline_cinfo.stage.pName = func_name.c_str();
892     pipeline_cinfo.stage.pSpecializationInfo = nullptr;
893     pipeline_cinfo.layout = pe->pipeline_layout;
894     pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
895     pipeline_cinfo.basePipelineIndex = 0;
896     VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
897                                          &(pe->pipeline)));
898 
899     if (vctx.UseImmediate()) {
900       VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
901       descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
902       descrip_template_cinfo.pNext = 0;
903       descrip_template_cinfo.flags = 0;
904       descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
905       descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
906       descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
907       descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
908       descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
909       descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
910       descrip_template_cinfo.set = 0;
911       VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
912           vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
913     }
914     ecache_[device_id][func_name] = pe;
915     return pe;
916   }
917 
SaveToFile(const std::string & file_name,const std::string & format)918   void SaveToFile(const std::string& file_name, const std::string& format) final {
919     std::string fmt = GetFileFormat(file_name, format);
920     CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
921     std::string meta_file = GetMetaFilePath(file_name);
922     SaveMetaDataToFile(meta_file, fmap_);
923     std::string data_bin;
924     dmlc::MemoryStringStream fs(&data_bin);
925     dmlc::Stream* stream = &fs;
926     uint32_t magic = kVulkanModuleMagic;
927     stream->Write(magic);
928     stream->Write(smap_);
929     SaveBinaryToFile(file_name, data_bin);
930   }
931 
SaveToBinary(dmlc::Stream * stream)932   void SaveToBinary(dmlc::Stream* stream) final {
933     stream->Write(fmt_);
934     stream->Write(fmap_);
935     stream->Write(smap_);
936   }
GetSource(const std::string & format)937   std::string GetSource(const std::string& format) final {
938     // can only return source code.
939     return source_;
940   }
941 
942  private:
943   // function information table.
944   std::unordered_map<std::string, VulkanShader> smap_;
945   // function information table.
946   std::unordered_map<std::string, FunctionInfo> fmap_;
947   // The format
948   std::string fmt_{"vulkan"};
949   // The source
950   std::string source_;
951 
952   // Guards accesses to `ecache_`
953   std::mutex mutex_;
954   std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
955       ecache_;
956 };
957 
VulkanModuleCreate(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)958 Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
959                           std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
960   auto n = make_object<VulkanModuleNode>(smap, fmap, source);
961   return Module(n);
962 }
963 
ThreadLocal()964 VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
965 
StagingBuffer(int device_id,size_t size)966 VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
967   if (!staging_buffers_[device_id]) {
968     staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
969   }
970   auto& buf = *(staging_buffers_[device_id]);
971   if (buf.device != nullptr && buf.size < size) {
972     // free previous buffer
973     if (buf.host_addr != nullptr) {
974       vkUnmapMemory(buf.device, buf.memory);
975     }
976     if (buf.memory != VK_NULL_HANDLE) {
977       vkFreeMemory(buf.device, buf.memory, nullptr);
978     }
979     if (buf.buffer != VK_NULL_HANDLE) {
980       vkDestroyBuffer(buf.device, buf.buffer, nullptr);
981     }
982     buf.host_addr = nullptr;
983     buf.memory = VK_NULL_HANDLE;
984     buf.buffer = VK_NULL_HANDLE;
985   }
986   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
987 
988   if (buf.device == nullptr) {
989     buf.device = vctx.device;
990   }
991   if (buf.memory == VK_NULL_HANDLE) {
992     // allocate the stagging buffer memory if necessary
993     VkBufferCreateInfo info;
994     info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
995     info.pNext = nullptr;
996     info.flags = 0;
997     info.size = size;
998     info.queueFamilyIndexCount = 1;
999     info.pQueueFamilyIndices = &(vctx.queue_family_index);
1000     info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
1001     info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
1002     VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
1003     VkMemoryAllocateInfo minfo;
1004     minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
1005     minfo.pNext = nullptr;
1006     minfo.allocationSize = size;
1007     minfo.memoryTypeIndex = vctx.staging_mtype_index;
1008     VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
1009     VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
1010     VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
1011     buf.size = size;
1012   }
1013   memset(buf.host_addr, 0, size);
1014   return &buf;
1015 }
1016 
VulkanThreadEntry()1017 VulkanThreadEntry::VulkanThreadEntry()
1018     : pool(std::make_unique<WorkspacePool>(static_cast<DLDeviceType>(kDLVulkan),
1019                                            VulkanDeviceAPI::Global())) {
1020   ctx.device_id = 0;
1021   ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
1022 }
1023 
Stream(size_t device_id)1024 VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
1025   if (!streams_[device_id]) {
1026     streams_[device_id] = std::unique_ptr<VulkanStream>(
1027         new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
1028   }
1029   return streams_[device_id].get();
1030 }
1031 
operator ()(TVMArgs args,TVMRetValue * rv,const ArgUnion * pack_args) const1032 void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
1033   int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
1034   CHECK_LT(device_id, kVulkanMaxNumDevice);
1035   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
1036   if (!scache_[device_id]) {
1037     scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
1038   }
1039   const auto& pipeline = scache_[device_id];
1040   ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
1041   std::vector<VkDescriptorBufferInfo> descriptor_buffers;
1042   descriptor_buffers.resize(num_buffer_args_);
1043   for (size_t i = 0; i < num_buffer_args_; ++i) {
1044     void* buf = args[static_cast<int>(i)];
1045     VkDescriptorBufferInfo binfo;
1046     binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
1047     binfo.offset = 0;
1048     binfo.range = VK_WHOLE_SIZE;
1049     descriptor_buffers[i] = binfo;
1050   }
1051   if (vctx.UseImmediate()) {
1052     // Can safely capture by reference as this lambda is immediately executed on the calling thread.
1053     VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
1054       vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1055       CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
1056       vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
1057           state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
1058           descriptor_buffers.data());
1059       if (num_pack_args_ != 0) {
1060         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
1061                            VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
1062                            pack_args);
1063       }
1064       vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1065       VkMemoryBarrier barrier_info;
1066       barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1067       barrier_info.pNext = nullptr;
1068       barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1069       barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1070                                     VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1071       vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1072                            VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1073                            1, &barrier_info, 0, nullptr, 0, nullptr);
1074     });
1075     return;
1076   }
1077 
1078   // Otherwise, the more expensive deferred path.
1079   std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
1080   const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
1081     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
1082     write_descriptor_sets.resize(descriptor_buffers.size());
1083     for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
1084       write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
1085       write_descriptor_sets[i].pNext = 0;
1086       write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
1087       write_descriptor_sets[i].dstBinding = i;
1088       write_descriptor_sets[i].dstArrayElement = 0;
1089       write_descriptor_sets[i].descriptorCount = 1;
1090       write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
1091       write_descriptor_sets[i].pImageInfo = 0;
1092       write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
1093       write_descriptor_sets[i].pTexelBufferView = 0;
1094     }
1095     vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
1096                            0, 0);
1097   };
1098   const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
1099     vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1100     vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
1101                             pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
1102                             nullptr);
1103     if (pack_args_storage.size() != 0) {
1104       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
1105                          0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
1106     }
1107     vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1108     VkMemoryBarrier barrier_info;
1109     barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1110     barrier_info.pNext = nullptr;
1111     barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1112     barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1113                                   VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1114     vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1115                          VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1116                          1, &barrier_info, 0, nullptr, 0, nullptr);
1117   };
1118   VulkanStreamToken deferred_token;
1119   deferred_token.descriptor_set_ = pipeline->descriptor_set;
1120   deferred_token.buffers_.resize(descriptor_buffers.size());
1121   for (size_t i = 0; i < descriptor_buffers.size(); ++i) {
1122     deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
1123   }
1124   VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
1125       deferred_initializer, deferred_kernel, deferred_token);
1126 }
1127 
VulkanModuleLoadFile(const std::string & file_name,const std::string & format)1128 Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
1129   std::string data;
1130   std::unordered_map<std::string, VulkanShader> smap;
1131   std::unordered_map<std::string, FunctionInfo> fmap;
1132   std::string fmt = GetFileFormat(file_name, format);
1133   std::string meta_file = GetMetaFilePath(file_name);
1134   LoadBinaryFromFile(file_name, &data);
1135   LoadMetaDataFromFile(meta_file, &fmap);
1136   dmlc::MemoryStringStream fs(&data);
1137   dmlc::Stream* stream = &fs;
1138   uint32_t magic;
1139   stream->Read(&magic);
1140   CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
1141   stream->Read(&smap);
1142   return VulkanModuleCreate(smap, fmap, "");
1143 }
1144 
VulkanModuleLoadBinary(void * strm)1145 Module VulkanModuleLoadBinary(void* strm) {
1146   dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
1147   std::unordered_map<std::string, VulkanShader> smap;
1148   std::unordered_map<std::string, FunctionInfo> fmap;
1149 
1150   std::string fmt;
1151   stream->Read(&fmt);
1152   stream->Read(&fmap);
1153   stream->Read(&smap);
1154   return VulkanModuleCreate(smap, fmap, "");
1155 }
1156 
1157 TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
1158 
1159 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
1160 
__anonbdf0b3af0c02(TVMArgs args, TVMRetValue* rv) 1161 TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
1162   DeviceAPI* ptr = VulkanDeviceAPI::Global();
1163   *rv = static_cast<void*>(ptr);
1164 });
1165 
1166 }  // namespace vulkan
1167 }  // namespace runtime
1168 }  // namespace tvm
1169