1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <vulkan/vulkan.h>
21 #include <dmlc/memory_io.h>
22 #include <dmlc/thread_local.h>
23 #include <tvm/runtime/device_api.h>
24 #include <tvm/runtime/registry.h>
25
26 #include <array>
27 #include <cstring>
28
29
30 #include "../file_util.h"
31 #include "../pack_args.h"
32 #include "../thread_storage_scope.h"
33 #include "../workspace_pool.h"
34
35 #include "vulkan_common.h"
36 #include "vulkan_module.h"
37 #include "vulkan_shader.h"
38 #include "vulkan_stream.h"
39
40 namespace tvm {
41 namespace runtime {
42 namespace vulkan {
43
44 /*! \brief Maximum number of GPU supported in VulkanModule. */
45 static constexpr const int kVulkanMaxNumDevice = 8;
46
47 /*! \brief TVM Vulkan binary pack magic number */
48 static constexpr const int kVulkanModuleMagic = 0x02700027;
49
50 class VulkanThreadEntry {
51 public:
52 VulkanThreadEntry();
53 static VulkanThreadEntry* ThreadLocal();
54
~VulkanThreadEntry()55 ~VulkanThreadEntry() {
56 // Because the thread entry refers to Device API
57 // The command buffer always will be destroyed before
58 // the instance and device get destroyed.
59 // The destruction need to be manually called
60 // to ensure the destruction order.
61 streams_.clear();
62 for (const auto& kv : staging_buffers_) {
63 if (!kv.second) {
64 continue;
65 }
66 auto& buf = *(kv.second);
67 if (buf.host_addr != nullptr) {
68 vkUnmapMemory(buf.device, buf.memory);
69 }
70 if (buf.memory != VK_NULL_HANDLE) {
71 vkFreeMemory(buf.device, buf.memory, nullptr);
72 }
73 if (buf.buffer != VK_NULL_HANDLE) {
74 vkDestroyBuffer(buf.device, buf.buffer, nullptr);
75 }
76 }
77 }
78
79 TVMContext ctx;
80 WorkspacePool pool;
81 VulkanStream* Stream(size_t device_id);
82 VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
83
84 private:
85 std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
86 std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
87 };
88
89 struct VulkanBuffer {
90 VkBuffer buffer{VK_NULL_HANDLE};
91 VkDeviceMemory memory{VK_NULL_HANDLE};
92 };
93
94 struct VulkanPipeline {
95 VulkanContext* vctx_{nullptr};
96 VkShaderModule shader{VK_NULL_HANDLE};
97 VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
98 VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
99 VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
100 VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
101 VkPipeline pipeline{VK_NULL_HANDLE};
102 VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
103 };
104
105 typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
106
107 class VulkanDeviceAPI final : public DeviceAPI {
108 public:
109 VulkanDeviceAPI();
~VulkanDeviceAPI()110 ~VulkanDeviceAPI() {
111 for (auto& vctx : context_) {
112 vkDestroyDevice(vctx.device, nullptr);
113 }
114 if (instance_) {
115 vkDestroyInstance(instance_, nullptr);
116 }
117 }
SetDevice(TVMContext ctx)118 void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
119 void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
AllocDataSpace(TVMContext ctx,size_t nbytes,size_t alignment,TVMType type_hint)120 void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
121 const auto& vctx = context(ctx.device_id);
122 VkBufferCreateInfo info;
123 info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
124 info.pNext = nullptr;
125 info.flags = 0;
126 info.size = nbytes;
127 info.queueFamilyIndexCount = 1;
128 info.pQueueFamilyIndices = &(vctx.queue_family_index);
129 info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
130 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
131 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
132 // create buffer
133 VkBuffer buffer;
134 VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
135 // bind to memory
136 VkBufferMemoryRequirementsInfo2KHR req_info2;
137 req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
138 req_info2.pNext = 0;
139 req_info2.buffer = buffer;
140
141 VkMemoryRequirements2KHR req2;
142 req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
143 req2.pNext = 0;
144
145 VkMemoryDedicatedRequirementsKHR dedicated_req;
146 dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
147 dedicated_req.pNext = 0;
148 req2.pNext = &dedicated_req;
149
150 bool dedicated_allocation = false;
151 if (vctx.get_buffer_memory_requirements_2_functions) {
152 vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
153 vctx.device, &req_info2, &req2);
154 dedicated_allocation =
155 dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
156 }
157
158 VkDeviceMemory memory;
159 if (!dedicated_allocation) {
160 VkMemoryAllocateInfo minfo;
161 minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
162 minfo.pNext = nullptr;
163 minfo.allocationSize = nbytes;
164 minfo.memoryTypeIndex = vctx.compute_mtype_index;
165 VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
166 } else {
167 VkMemoryAllocateInfo minfo;
168 minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
169 minfo.pNext = nullptr;
170 minfo.allocationSize = req2.memoryRequirements.size;
171 minfo.memoryTypeIndex = vctx.compute_mtype_index;
172
173 VkMemoryDedicatedAllocateInfoKHR mdinfo;
174 mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
175 mdinfo.pNext = 0;
176 mdinfo.image = 0;
177 mdinfo.buffer = buffer;
178 minfo.pNext = &mdinfo;
179 VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
180 }
181 VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
182 VulkanBuffer* pbuf = new VulkanBuffer();
183 pbuf->memory = memory;
184 pbuf->buffer = buffer;
185 return pbuf;
186 }
187
FreeDataSpace(TVMContext ctx,void * ptr)188 void FreeDataSpace(TVMContext ctx, void* ptr) final {
189 const auto& vctx = context(ctx.device_id);
190 auto* pbuf = static_cast<VulkanBuffer*>(ptr);
191 vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
192 vkFreeMemory(vctx.device, pbuf->memory, nullptr);
193 delete pbuf;
194 }
195
CopyDataFromTo(const void * from,size_t from_offset,void * to,size_t to_offset,size_t size,TVMContext ctx_from,TVMContext ctx_to,TVMType type_hint,TVMStreamHandle stream)196 void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
197 TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
198 TVMStreamHandle stream) final {
199 CHECK(stream == nullptr);
200 TVMContext ctx = ctx_from;
201 if (ctx_from.device_type == kDLCPU) {
202 ctx = ctx_to;
203 }
204
205 int from_dev_type = static_cast<int>(ctx_from.device_type);
206 int to_dev_type = static_cast<int>(ctx_to.device_type);
207 if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
208 VulkanThreadEntry::ThreadLocal()
209 ->Stream(ctx_from.device_id)
210 ->Launch([=](VulkanStreamState* state) {
211 // 1: copy
212 const auto* from_buf = static_cast<const VulkanBuffer*>(from);
213 auto* to_buf = static_cast<VulkanBuffer*>(to);
214 VkBufferCopy copy_info;
215 copy_info.srcOffset = from_offset;
216 copy_info.dstOffset = to_offset;
217 copy_info.size = size;
218 vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info);
219 // 2: barrier(transfer-> compute|transfer)
220 CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
221 VkMemoryBarrier barrier_info;
222 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
223 barrier_info.pNext = nullptr;
224 barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
225 barrier_info.dstAccessMask =
226 (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
227 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
228 vkCmdPipelineBarrier(
229 state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
230 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
231 &barrier_info, 0, nullptr, 0, nullptr);
232 });
233
234 } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
235 const auto* from_buf = static_cast<const VulkanBuffer*>(from);
236 const auto& vctx = context(ctx_from.device_id);
237 auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
238 VulkanThreadEntry::ThreadLocal()
239 ->Stream(ctx_from.device_id)
240 ->Launch([&](VulkanStreamState* state) {
241 VkBufferCopy copy_info;
242 copy_info.srcOffset = from_offset;
243 copy_info.dstOffset = 0;
244 copy_info.size = size;
245 vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, ©_info);
246 });
247 VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
248 if (!vctx.coherent_staging) {
249 VkMappedMemoryRange mrange;
250 mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
251 mrange.pNext = nullptr;
252 mrange.memory = temp->memory;
253 mrange.offset = 0;
254 mrange.size = VK_WHOLE_SIZE; // size;
255 VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
256 }
257 memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
258 } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
259 const auto& vctx = context(ctx_to.device_id);
260 const auto* to_buf = static_cast<const VulkanBuffer*>(to);
261 VulkanStagingBuffer* temp =
262 VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
263 memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
264 // host side flush if access is not coherent.
265 // so writes from CPU is visible to GPU
266 if (!vctx.coherent_staging) {
267 VkMappedMemoryRange mrange;
268 mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
269 mrange.pNext = nullptr;
270 mrange.memory = temp->memory;
271 mrange.offset = 0;
272 mrange.size = VK_WHOLE_SIZE; // size;
273 VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
274 }
275
276 VulkanThreadEntry::ThreadLocal()
277 ->Stream(ctx_from.device_id)
278 ->Launch([&](VulkanStreamState* state) {
279 // 0: barrier(host->transfer)
280 VkMemoryBarrier barrier_info;
281 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
282 barrier_info.pNext = nullptr;
283 barrier_info.srcAccessMask = 0;
284 barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
285 vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
286 VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
287 nullptr);
288 // 1: copy
289 VkBufferCopy copy_info;
290 copy_info.srcOffset = 0;
291 copy_info.dstOffset = to_offset;
292 copy_info.size = size;
293 vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, ©_info);
294 });
295 // TODO(tulloch): should we instead make the staging buffer a property of the
296 // Stream? This would allow us to elide synchronizations here.
297 VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
298 } else {
299 LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
300 << ", from=" << from_dev_type << ", to=" << to_dev_type;
301 }
302 }
303
304 // Always use the default stream
CreateStream(TVMContext ctx)305 TVMStreamHandle CreateStream(TVMContext ctx) {
306 LOG(FATAL) << "Not implemented";
307 return nullptr;
308 }
309
FreeStream(TVMContext ctx,TVMStreamHandle stream)310 void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
311 LOG(FATAL) << "Not implemented";
312 return;
313 }
314
SyncStreamFromTo(TVMContext ctx,TVMStreamHandle event_src,TVMStreamHandle event_dst)315 void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
316 LOG(FATAL) << "Not implemented";
317 return;
318 }
319
StreamSync(TVMContext ctx,TVMStreamHandle stream)320 void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
321 CHECK(stream == nullptr);
322 VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
323 }
324
SetStream(TVMContext ctx,TVMStreamHandle stream)325 void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
326 LOG(FATAL) << "Not implemented";
327 return;
328 }
329
AllocWorkspace(TVMContext ctx,size_t size,TVMType type_hint)330 void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
331 return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
332 }
333
FreeWorkspace(TVMContext ctx,void * data)334 void FreeWorkspace(TVMContext ctx, void* data) final {
335 VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
336 }
337
Global()338 static const std::shared_ptr<VulkanDeviceAPI>& Global() {
339 static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
340 return inst;
341 }
342
context(size_t device_id) const343 const VulkanContext& context(size_t device_id) const {
344 CHECK_LT(device_id, context_.size());
345 return context_[device_id];
346 }
347
348 private:
349 VkInstance instance_{nullptr};
350 // The physical devices, have 1 to 1 mapping to devices
351 std::vector<VulkanContext> context_;
352 };
353
GetAttr(TVMContext ctx,DeviceAttrKind kind,TVMRetValue * rv)354 void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
355 size_t index = static_cast<size_t>(ctx.device_id);
356 if (kind == kExist) {
357 *rv = static_cast<int>(index < context_.size());
358 return;
359 }
360 CHECK_LT(index, context_.size()) << "Invalid device id " << index;
361 const auto& vctx = context(index);
362 switch (kind) {
363 case kMaxThreadsPerBlock: {
364 VkPhysicalDeviceProperties phy_prop;
365 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
366 int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
367 *rv = value;
368 break;
369 }
370 case kMaxSharedMemoryPerBlock: {
371 VkPhysicalDeviceProperties phy_prop;
372 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
373 int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
374 *rv = value;
375 break;
376 }
377 case kWarpSize: {
378 *rv = 1;
379 break;
380 }
381 case kComputeVersion: {
382 VkPhysicalDeviceProperties phy_prop;
383 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
384 int64_t value = phy_prop.apiVersion;
385 std::ostringstream os;
386 os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
387 << VK_VERSION_PATCH(value);
388 *rv = os.str();
389 break;
390 }
391 case kDeviceName:
392 return;
393 case kMaxClockRate:
394 return;
395 case kMultiProcessorCount:
396 return;
397 case kExist:
398 break;
399 case kMaxThreadDimensions:
400 break;
401 case kGcnArch:
402 return;
403 }
404 }
405
VulkanDeviceAPI()406 VulkanDeviceAPI::VulkanDeviceAPI() {
407 VkApplicationInfo app_info;
408 app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
409 app_info.pNext = nullptr;
410 app_info.pApplicationName = "TVM";
411 app_info.applicationVersion = 0;
412 app_info.pEngineName = "";
413 app_info.engineVersion = 0;
414 app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
415
416 VkInstanceCreateInfo inst_info;
417 inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
418 inst_info.pNext = nullptr;
419 inst_info.flags = 0;
420
421 const auto layers = []() -> std::vector<const char*> {
422 uint32_t inst_layer_prop_count;
423 VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
424 std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
425 VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
426 std::vector<const char*> l;
427 for (const auto& lp : inst_layer_prop) {
428 // TODO(tulloch): add CMAKE options.
429 (void)lp; // suppress unused variable warning.
430 #ifdef USE_VULKAN_VALIDATION
431 if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
432 l.push_back("VK_LAYER_LUNARG_standard_validation");
433 }
434 if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
435 l.push_back("VK_LAYER_LUNARG_parameter_validation");
436 }
437 if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
438 l.push_back("VK_LAYER_KHRONOS_validation");
439 }
440 #endif
441 }
442 return l;
443 }();
444
445 const auto instance_extensions = []() -> std::vector<const char*> {
446 uint32_t inst_extension_prop_count;
447 VULKAN_CALL(
448 vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
449 std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
450 VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
451 inst_extension_prop.data()));
452 std::vector<const char*> extensions;
453 for (const auto& ip : inst_extension_prop) {
454 if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
455 extensions.push_back("VK_KHR_get_physical_device_properties2");
456 }
457 }
458 return extensions;
459 }();
460
461 inst_info.pApplicationInfo = &app_info;
462 inst_info.enabledLayerCount = layers.size();
463 inst_info.ppEnabledLayerNames = layers.data();
464 inst_info.enabledExtensionCount = instance_extensions.size();
465 inst_info.ppEnabledExtensionNames = instance_extensions.data();
466
467 VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
468
469 uint32_t phy_dev_count = 0;
470 VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
471 std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
472 VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
473 for (VkPhysicalDevice phy_dev : all_phy_devs) {
474 uint32_t queue_prop_count = 0;
475 vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
476 std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
477 vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
478 dmlc::BeginPtr(queue_props));
479 uint32_t queue_family_index = 0;
480 std::vector<VkDeviceQueueCreateInfo> queue_create_info;
481 float priority = 1.0f;
482 for (uint32_t i = 0; i < queue_props.size(); i++) {
483 // find queues that support compute
484 if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
485 VkDeviceQueueCreateInfo info;
486 info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
487 info.pNext = nullptr;
488 info.flags = 0;
489 info.queueFamilyIndex = i;
490 info.queueCount = 1;
491 info.pQueuePriorities = &priority;
492
493 queue_create_info.push_back(info);
494 // only use the first available queue for now
495 if (queue_create_info.size() == 0) {
496 queue_family_index = i;
497 }
498 }
499 }
500 if (queue_create_info.size() == 0) continue;
501
502 VulkanContext ctx;
503 // setup context
504 ctx.phy_device = phy_dev;
505 vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
506
507 const auto extensions = [&]() {
508 uint32_t device_extension_prop_count;
509 VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
510 &device_extension_prop_count, nullptr));
511 std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
512 VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
513 ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
514 std::vector<const char*> extensions;
515 for (const auto& dp : device_extension_prop) {
516 if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
517 extensions.push_back("VK_KHR_push_descriptor");
518 }
519 if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
520 dp.specVersion > 0) {
521 extensions.push_back("VK_KHR_descriptor_update_template");
522 }
523 if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
524 dp.specVersion > 0) {
525 extensions.push_back("VK_KHR_get_memory_requirements2");
526 }
527 if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
528 dp.specVersion > 0) {
529 extensions.push_back("VK_KHR_dedicated_allocation");
530 }
531 }
532 return extensions;
533 }();
534 VkDeviceCreateInfo device_create_info;
535 device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
536 device_create_info.pNext = nullptr;
537 device_create_info.flags = 0;
538 device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
539 device_create_info.pQueueCreateInfos = queue_create_info.data();
540 device_create_info.enabledLayerCount = 0;
541 device_create_info.ppEnabledLayerNames = nullptr;
542 device_create_info.enabledExtensionCount = extensions.size();
543 device_create_info.ppEnabledExtensionNames = extensions.data();
544 device_create_info.pEnabledFeatures = nullptr;
545 VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
546 ctx.queue_mutex.reset(new std::mutex());
547 vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
548 ctx.queue_family_index = queue_family_index;
549 // Find suitable memory type for staging and compute
550 // Find suitable compute index.
551 VkBuffer buffer;
552 VkMemoryRequirements req_staging, req_compute;
553 VkBufferCreateInfo info;
554 info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
555 info.pNext = nullptr;
556 info.flags = 0;
557 info.size = 1024;
558 info.queueFamilyIndexCount = 1;
559 info.pQueueFamilyIndices = &(ctx.queue_family_index);
560 info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
561
562 // get staging requirement
563 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
564 VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
565 vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
566 vkDestroyBuffer(ctx.device, buffer, nullptr);
567 // get compute requirement
568 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
569 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
570 VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
571 vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
572 vkDestroyBuffer(ctx.device, buffer, nullptr);
573
574 // Query phyiscal device property
575 // find a memory that is host visible, no need to be consistent
576 int win_rank = -1;
577 VkPhysicalDeviceMemoryProperties prop;
578 vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
579
580 for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
581 VkMemoryType ty = prop.memoryTypes[k];
582 size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
583 // host visible
584 if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
585 // match copy requirment
586 if (!(req_staging.memoryTypeBits & (1 << k))) continue;
587 if (heap_size < 1024) continue;
588 int rank = 0;
589 rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
590 if (rank > win_rank) {
591 win_rank = rank;
592 ctx.staging_mtype_index = k;
593 ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
594 }
595 }
596 CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
597
598 win_rank = -1;
599 for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
600 VkMemoryType ty = prop.memoryTypes[k];
601 size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
602 // host visible
603 if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
604 // match copy requirment
605 if (!(req_staging.memoryTypeBits & (1 << k))) continue;
606 if (heap_size < 1024) continue;
607 int rank = 0;
608 // prefer not host visible
609 rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
610 if (rank > win_rank) {
611 win_rank = rank;
612 ctx.compute_mtype_index = k;
613 }
614 }
615 CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
616 auto has_extension = [&extensions](const char* query) {
617 return std::any_of(extensions.begin(), extensions.end(),
618 [&](const char* extension) { return std::strcmp(query, extension) == 0; });
619 };
620
621 #ifdef USE_VULKAN_IMMEDIATE_MODE
622 if (has_extension("VK_KHR_push_descriptor") &&
623 has_extension("VK_KHR_descriptor_update_template")) {
624 ctx.descriptor_template_khr_functions =
625 std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
626 new VulkanDescriptorTemplateKHRFunctions());
627 ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
628 CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
629 ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
630 ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
631 CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
632 ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
633 ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
634 CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
635 ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
636 ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
637 CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
638 ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
639 }
640 #endif
641
642 #ifdef USE_VULKAN_DEDICATED_ALLOCATION
643 if (has_extension("VK_KHR_get_memory_requirements2") &&
644 has_extension("VK_KHR_dedicated_allocation")) {
645 ctx.get_buffer_memory_requirements_2_functions =
646 std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
647 new VulkanGetBufferMemoryRequirements2Functions());
648 ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
649 CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
650 ctx.device, "vkGetBufferMemoryRequirements2KHR"));
651 }
652 #endif
653 context_.push_back(std::move(ctx));
654 }
655
656 LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
657 for (size_t i = 0; i < context_.size(); ++i) {
658 LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
659 << "\' phy_dev_id=" << context_[i].phy_device
660 << " use_immediate=" << context_[i].UseImmediate();
661 }
662 } // namespace vulkan
663 class VulkanModuleNode;
664
665 // a wrapped function class to get packed func.
666 class VulkanWrappedFunc {
667 public:
Init(VulkanModuleNode * m,ObjectPtr<Object> sptr,const std::string & func_name,size_t num_buffer_args,size_t num_pack_args,const std::vector<std::string> & thread_axis_tags)668 void Init(VulkanModuleNode* m,
669 ObjectPtr<Object> sptr,
670 const std::string& func_name,
671 size_t num_buffer_args, size_t num_pack_args,
672 const std::vector<std::string>& thread_axis_tags) {
673 m_ = m;
674 sptr_ = sptr;
675 func_name_ = func_name;
676 num_buffer_args_ = num_buffer_args;
677 num_pack_args_ = num_pack_args;
678 thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
679 }
680
681 void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
682
683 private:
684 // internal module
685 VulkanModuleNode* m_;
686 // the resource holder
687 ObjectPtr<Object> sptr_;
688 // v The name of the function.
689 std::string func_name_;
690 // Number of buffer arguments
691 size_t num_buffer_args_;
692 // number of packed arguments.
693 size_t num_pack_args_;
694 // Device state cache per device.
695 // mark as mutable, to enable lazy initialization
696 // thread axis configuration
697 ThreadAxisConfig thread_axis_cfg_;
698
699 mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
700 };
701
702 // Multi-device enabled module.
703 class VulkanModuleNode final : public runtime::ModuleNode {
704 public:
VulkanModuleNode(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)705 explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
706 std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
707 : smap_(smap), fmap_(fmap), source_(source) {}
708
type_key() const709 const char* type_key() const final { return "vulkan"; }
710
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)711 PackedFunc GetFunction(const std::string& name,
712 const ObjectPtr<Object>& sptr_to_self) final {
713 CHECK_EQ(sptr_to_self.get(), this);
714 CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
715 auto it = fmap_.find(name);
716 if (it == fmap_.end()) return PackedFunc();
717 const FunctionInfo& info = it->second;
718 VulkanWrappedFunc f;
719 size_t num_buffer_args = NumBufferArgs(info.arg_types);
720 f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
721 info.thread_axis_tags);
722 return PackFuncNonBufferArg(std::move(f), info.arg_types);
723 }
724
~VulkanModuleNode()725 ~VulkanModuleNode() {
726 // cleanup vulkan related caches.
727 for (int device_id = 0; device_id < ecache_.size(); ++device_id) {
728 for (auto& kv : ecache_[device_id]) {
729 auto& pe = kv.second;
730 CHECK(pe);
731 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
732
733 if (pe->descriptor_update_template != VK_NULL_HANDLE) {
734 vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
735 vctx.device, pe->descriptor_update_template, nullptr);
736 }
737 vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
738 vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
739 vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
740 vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
741 vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
742 }
743 }
744 }
745
GetPipeline(size_t device_id,const std::string & func_name,size_t num_pack_args)746 std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
747 size_t num_pack_args) {
748 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
749 std::lock_guard<std::mutex> lock(mutex_);
750 const auto& cp = ecache_[device_id][func_name];
751 if (cp) {
752 return cp;
753 }
754 // Create new pipeline
755 auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
756 {
757 // create shader
758 auto sit = smap_.find(func_name);
759 CHECK(sit != smap_.end());
760 const std::vector<uint32_t>& data = sit->second.data;
761 VkShaderModuleCreateInfo shader_cinfo;
762 shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
763 shader_cinfo.pNext = nullptr;
764 shader_cinfo.flags = 0;
765 shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
766 shader_cinfo.pCode = data.data();
767 VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
768 }
769 std::vector<VkDescriptorSetLayoutBinding> arg_binding;
770 std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
771 uint32_t num_pod = 0, num_buffer = 0;
772 {
773 auto fit = fmap_.find(func_name);
774 CHECK(fit != fmap_.end());
775 for (TVMType arg_type : fit->second.arg_types) {
776 if (arg_type.code == kHandle) {
777 {
778 VkDescriptorSetLayoutBinding bd;
779 bd.binding = num_buffer;
780 bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
781 bd.descriptorCount = 1;
782 bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
783 bd.pImmutableSamplers = nullptr;
784 arg_binding.push_back(bd);
785 }
786 {
787 VkDescriptorUpdateTemplateEntryKHR tpl;
788 tpl.dstBinding = num_buffer;
789 tpl.dstArrayElement = 0;
790 tpl.descriptorCount = 1;
791 tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
792 tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
793 tpl.stride = sizeof(VkDescriptorBufferInfo);
794 arg_template.push_back(tpl);
795 }
796 ++num_buffer;
797 } else {
798 ++num_pod;
799 }
800 }
801 }
802
803 {
804 VkDescriptorSetLayoutCreateInfo descrip_cinfo;
805 descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
806 descrip_cinfo.pNext = nullptr;
807 descrip_cinfo.flags = 0;
808 if (vctx.UseImmediate()) {
809 descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
810 }
811 descrip_cinfo.bindingCount = arg_binding.size();
812 descrip_cinfo.pBindings = arg_binding.data();
813 VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
814 &(pe->descriptor_set_layout)));
815 }
816
817 {
818 VkDescriptorPoolSize pool_size;
819 pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
820 pool_size.descriptorCount = arg_binding.size();
821 VkDescriptorPoolCreateInfo descrip_pool_cinfo;
822 descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
823 descrip_pool_cinfo.pNext = nullptr;
824 descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
825 descrip_pool_cinfo.maxSets = 1;
826 descrip_pool_cinfo.poolSizeCount = 1;
827 descrip_pool_cinfo.pPoolSizes = &pool_size;
828 VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
829 &(pe->descriptor_pool)));
830 }
831
832 if (!vctx.UseImmediate()) {
833 VkDescriptorSetAllocateInfo alloc_info;
834 alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
835 alloc_info.pNext = nullptr;
836 alloc_info.descriptorPool = pe->descriptor_pool;
837 alloc_info.descriptorSetCount = 1;
838 alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
839 VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
840 }
841
842 VkPushConstantRange crange;
843 crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
844 crange.offset = 0;
845 crange.size = sizeof(ArgUnion) * num_pack_args;
846
847 VkPipelineLayoutCreateInfo playout_cinfo;
848 playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
849 playout_cinfo.pNext = nullptr;
850 playout_cinfo.flags = 0;
851 playout_cinfo.setLayoutCount = 1;
852 playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
853
854 if (num_pack_args != 0) {
855 playout_cinfo.pushConstantRangeCount = 1;
856 playout_cinfo.pPushConstantRanges = &crange;
857 CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
858 } else {
859 playout_cinfo.pushConstantRangeCount = 0;
860 playout_cinfo.pPushConstantRanges = nullptr;
861 }
862
863 VULKAN_CALL(
864 vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
865
866 VkComputePipelineCreateInfo pipeline_cinfo;
867 pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
868 pipeline_cinfo.pNext = nullptr;
869 pipeline_cinfo.flags = 0;
870 pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
871 pipeline_cinfo.stage.pNext = nullptr;
872 pipeline_cinfo.stage.flags = 0;
873 pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
874 pipeline_cinfo.stage.module = pe->shader;
875 pipeline_cinfo.stage.pName = func_name.c_str();
876 pipeline_cinfo.stage.pSpecializationInfo = nullptr;
877 pipeline_cinfo.layout = pe->pipeline_layout;
878 pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
879 pipeline_cinfo.basePipelineIndex = 0;
880 VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
881 &(pe->pipeline)));
882
883 if (vctx.UseImmediate()) {
884 VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
885 descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
886 descrip_template_cinfo.pNext = 0;
887 descrip_template_cinfo.flags = 0;
888 descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
889 descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
890 descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
891 descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
892 descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
893 descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
894 descrip_template_cinfo.set = 0;
895 VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
896 vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
897 }
898 ecache_[device_id][func_name] = pe;
899 return pe;
900 }
901
SaveToFile(const std::string & file_name,const std::string & format)902 void SaveToFile(const std::string& file_name, const std::string& format) final {
903 std::string fmt = GetFileFormat(file_name, format);
904 CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
905 std::string meta_file = GetMetaFilePath(file_name);
906 SaveMetaDataToFile(meta_file, fmap_);
907 std::string data_bin;
908 dmlc::MemoryStringStream fs(&data_bin);
909 dmlc::Stream* stream = &fs;
910 uint32_t magic = kVulkanModuleMagic;
911 stream->Write(magic);
912 stream->Write(smap_);
913 SaveBinaryToFile(file_name, data_bin);
914 }
915
SaveToBinary(dmlc::Stream * stream)916 void SaveToBinary(dmlc::Stream* stream) final {
917 stream->Write(fmt_);
918 stream->Write(fmap_);
919 stream->Write(smap_);
920 }
GetSource(const std::string & format)921 std::string GetSource(const std::string& format) final {
922 // can only return source code.
923 return source_;
924 }
925
926 private:
927 // the binary data
928 std::vector<uint32_t> data_;
929 // function information table.
930 std::unordered_map<std::string, VulkanShader> smap_;
931 // function information table.
932 std::unordered_map<std::string, FunctionInfo> fmap_;
933 // The format
934 std::string fmt_{"vulkan"};
935 // The source
936 std::string source_;
937
938 // Guards accesses to `ecache_`
939 std::mutex mutex_;
940 std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
941 ecache_;
942 };
943
VulkanModuleCreate(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)944 Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
945 std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
946 auto n = make_object<VulkanModuleNode>(smap, fmap, source);
947 return Module(n);
948 }
949
ThreadLocal()950 VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
951
StagingBuffer(int device_id,size_t size)952 VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
953 if (!staging_buffers_[device_id]) {
954 staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
955 }
956 auto& buf = *(staging_buffers_[device_id]);
957 if (buf.device != nullptr && buf.size < size) {
958 // free previous buffer
959 if (buf.host_addr != nullptr) {
960 vkUnmapMemory(buf.device, buf.memory);
961 }
962 if (buf.memory != VK_NULL_HANDLE) {
963 vkFreeMemory(buf.device, buf.memory, nullptr);
964 }
965 if (buf.buffer != VK_NULL_HANDLE) {
966 vkDestroyBuffer(buf.device, buf.buffer, nullptr);
967 }
968 buf.host_addr = nullptr;
969 buf.memory = VK_NULL_HANDLE;
970 buf.buffer = VK_NULL_HANDLE;
971 }
972 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
973
974 if (buf.device == nullptr) {
975 buf.device = vctx.device;
976 }
977 if (buf.memory == VK_NULL_HANDLE) {
978 // allocate the stagging buffer memory if necessary
979 VkBufferCreateInfo info;
980 info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
981 info.pNext = nullptr;
982 info.flags = 0;
983 info.size = size;
984 info.queueFamilyIndexCount = 1;
985 info.pQueueFamilyIndices = &(vctx.queue_family_index);
986 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
987 info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
988 VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
989 VkMemoryAllocateInfo minfo;
990 minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
991 minfo.pNext = nullptr;
992 minfo.allocationSize = size;
993 minfo.memoryTypeIndex = vctx.staging_mtype_index;
994 VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
995 VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
996 VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
997 buf.size = size;
998 }
999 memset(buf.host_addr, 0, size);
1000 return &buf;
1001 }
1002
VulkanThreadEntry()1003 VulkanThreadEntry::VulkanThreadEntry()
1004 : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
1005 ctx.device_id = 0;
1006 ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
1007 }
1008
Stream(size_t device_id)1009 VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
1010 if (!streams_[device_id]) {
1011 streams_[device_id] = std::unique_ptr<VulkanStream>(
1012 new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
1013 }
1014 return streams_[device_id].get();
1015 }
1016
operator ()(TVMArgs args,TVMRetValue * rv,const ArgUnion * pack_args) const1017 void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
1018 const ArgUnion* pack_args) const {
1019 int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
1020 CHECK_LT(device_id, kVulkanMaxNumDevice);
1021 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
1022 if (!scache_[device_id]) {
1023 scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
1024 }
1025 const auto& pipeline = scache_[device_id];
1026 ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
1027 std::vector<VkDescriptorBufferInfo> descriptor_buffers;
1028 descriptor_buffers.resize(num_buffer_args_);
1029 for (int i = 0; i < num_buffer_args_; ++i) {
1030 void* buf = args[static_cast<int>(i)];
1031 VkDescriptorBufferInfo binfo;
1032 binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
1033 binfo.offset = 0;
1034 binfo.range = VK_WHOLE_SIZE;
1035 descriptor_buffers[i] = binfo;
1036 }
1037 if (vctx.UseImmediate()) {
1038 // Can safely capture by reference as this lambda is immediately executed on the calling thread.
1039 VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
1040 vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1041 CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
1042 vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
1043 state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
1044 descriptor_buffers.data());
1045 if (num_pack_args_ != 0) {
1046 vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
1047 VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
1048 pack_args);
1049 }
1050 vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1051 VkMemoryBarrier barrier_info;
1052 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1053 barrier_info.pNext = nullptr;
1054 barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1055 barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1056 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1057 vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1058 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1059 1, &barrier_info, 0, nullptr, 0, nullptr);
1060 });
1061 return;
1062 }
1063
1064 // Otherwise, the more expensive deferred path.
1065 std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
1066 const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
1067 std::vector<VkWriteDescriptorSet> write_descriptor_sets;
1068 write_descriptor_sets.resize(descriptor_buffers.size());
1069 for (int i = 0; i < write_descriptor_sets.size(); i++) {
1070 write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
1071 write_descriptor_sets[i].pNext = 0;
1072 write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
1073 write_descriptor_sets[i].dstBinding = i;
1074 write_descriptor_sets[i].dstArrayElement = 0;
1075 write_descriptor_sets[i].descriptorCount = 1;
1076 write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
1077 write_descriptor_sets[i].pImageInfo = 0;
1078 write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
1079 write_descriptor_sets[i].pTexelBufferView = 0;
1080 }
1081 vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
1082 0, 0);
1083 };
1084 const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
1085 vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1086 vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
1087 pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
1088 nullptr);
1089 if (pack_args_storage.size() != 0) {
1090 vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
1091 0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
1092 }
1093 vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1094 VkMemoryBarrier barrier_info;
1095 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1096 barrier_info.pNext = nullptr;
1097 barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1098 barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1099 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1100 vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1101 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1102 1, &barrier_info, 0, nullptr, 0, nullptr);
1103 };
1104 VulkanStreamToken deferred_token;
1105 deferred_token.descriptor_set_ = pipeline->descriptor_set;
1106 deferred_token.buffers_.resize(descriptor_buffers.size());
1107 for (int i = 0; i < descriptor_buffers.size(); ++i) {
1108 deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
1109 }
1110 VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
1111 deferred_initializer, deferred_kernel, deferred_token);
1112 }
1113
VulkanModuleLoadFile(const std::string & file_name,const std::string & format)1114 Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
1115 std::string data;
1116 std::unordered_map<std::string, VulkanShader> smap;
1117 std::unordered_map<std::string, FunctionInfo> fmap;
1118 std::string fmt = GetFileFormat(file_name, format);
1119 std::string meta_file = GetMetaFilePath(file_name);
1120 LoadBinaryFromFile(file_name, &data);
1121 LoadMetaDataFromFile(meta_file, &fmap);
1122 dmlc::MemoryStringStream fs(&data);
1123 dmlc::Stream* stream = &fs;
1124 uint32_t magic;
1125 stream->Read(&magic);
1126 CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
1127 stream->Read(&smap);
1128 return VulkanModuleCreate(smap, fmap, "");
1129 }
1130
VulkanModuleLoadBinary(void * strm)1131 Module VulkanModuleLoadBinary(void* strm) {
1132 dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
1133 std::unordered_map<std::string, VulkanShader> smap;
1134 std::unordered_map<std::string, FunctionInfo> fmap;
1135
1136 std::string fmt;
1137 stream->Read(&fmt);
1138 stream->Read(&fmap);
1139 stream->Read(&smap);
1140 return VulkanModuleCreate(smap, fmap, "");
1141 }
1142
1143 TVM_REGISTER_GLOBAL("module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
1144
1145 TVM_REGISTER_GLOBAL("module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
1146
__anoneb4276740c02(TVMArgs args, TVMRetValue* rv) 1147 TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
1148 DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
1149 *rv = static_cast<void*>(ptr);
1150 });
1151
1152 } // namespace vulkan
1153 } // namespace runtime
1154 } // namespace tvm
1155