1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <dmlc/memory_io.h>
21 #include <dmlc/thread_local.h>
22 #include <tvm/runtime/device_api.h>
23 #include <tvm/runtime/registry.h>
24 #include <vulkan/vulkan.h>
25
26 #include <array>
27 #include <cstring>
28
29 #include "../file_util.h"
30 #include "../pack_args.h"
31 #include "../thread_storage_scope.h"
32 #include "../workspace_pool.h"
33 #include "vulkan_common.h"
34 #include "vulkan_module.h"
35 #include "vulkan_shader.h"
36 #include "vulkan_stream.h"
37
38 namespace tvm {
39 namespace runtime {
40 namespace vulkan {
41
42 /*! \brief Maximum number of GPU supported in VulkanModule. */
43 static constexpr const int kVulkanMaxNumDevice = 8;
44
45 /*! \brief TVM Vulkan binary pack magic number */
46 static constexpr const int kVulkanModuleMagic = 0x02700027;
47
48 class VulkanThreadEntry {
49 public:
50 VulkanThreadEntry();
51 static VulkanThreadEntry* ThreadLocal();
52
~VulkanThreadEntry()53 ~VulkanThreadEntry() {
54 // Because the thread entry refers to Device API
55 // The command buffer always will be destroyed before
56 // the instance and device get destroyed.
57 // The destruction need to be manually called
58 // to ensure the destruction order.
59
60 pool.reset();
61 streams_.clear();
62 for (const auto& kv : staging_buffers_) {
63 if (!kv.second) {
64 continue;
65 }
66 auto& buf = *(kv.second);
67 if (buf.host_addr != nullptr) {
68 vkUnmapMemory(buf.device, buf.memory);
69 }
70 if (buf.memory != VK_NULL_HANDLE) {
71 vkFreeMemory(buf.device, buf.memory, nullptr);
72 }
73 if (buf.buffer != VK_NULL_HANDLE) {
74 vkDestroyBuffer(buf.device, buf.buffer, nullptr);
75 }
76 }
77 }
78
79 TVMContext ctx;
80 std::unique_ptr<WorkspacePool> pool;
81 VulkanStream* Stream(size_t device_id);
82 VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
83
84 private:
85 std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
86 std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
87 };
88
89 struct VulkanBuffer {
90 VkBuffer buffer{VK_NULL_HANDLE};
91 VkDeviceMemory memory{VK_NULL_HANDLE};
92 };
93
94 struct VulkanPipeline {
95 VulkanContext* vctx_{nullptr};
96 VkShaderModule shader{VK_NULL_HANDLE};
97 VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
98 VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
99 VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
100 VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
101 VkPipeline pipeline{VK_NULL_HANDLE};
102 VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
103 };
104
105 typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
106
107 class VulkanDeviceAPI final : public DeviceAPI {
108 public:
109 VulkanDeviceAPI();
~VulkanDeviceAPI()110 ~VulkanDeviceAPI() {
111 for (auto& vctx : context_) {
112 vkDestroyDevice(vctx.device, nullptr);
113 }
114 if (instance_) {
115 vkDestroyInstance(instance_, nullptr);
116 }
117 }
SetDevice(TVMContext ctx)118 void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
119 void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
AllocDataSpace(TVMContext ctx,size_t nbytes,size_t alignment,DLDataType type_hint)120 void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
121 DLDataType type_hint) final {
122 const auto& vctx = context(ctx.device_id);
123 VkBufferCreateInfo info;
124 info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
125 info.pNext = nullptr;
126 info.flags = 0;
127 info.size = nbytes;
128 info.queueFamilyIndexCount = 1;
129 info.pQueueFamilyIndices = &(vctx.queue_family_index);
130 info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
131 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
132 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
133 // create buffer
134 VkBuffer buffer;
135 VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
136 // bind to memory
137 VkBufferMemoryRequirementsInfo2KHR req_info2;
138 req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
139 req_info2.pNext = 0;
140 req_info2.buffer = buffer;
141
142 VkMemoryRequirements2KHR req2;
143 req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
144 req2.pNext = 0;
145
146 VkMemoryDedicatedRequirementsKHR dedicated_req;
147 dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
148 dedicated_req.pNext = 0;
149 req2.pNext = &dedicated_req;
150
151 bool dedicated_allocation = false;
152 if (vctx.get_buffer_memory_requirements_2_functions) {
153 vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
154 vctx.device, &req_info2, &req2);
155 dedicated_allocation =
156 dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
157 }
158
159 VkDeviceMemory memory;
160 if (!dedicated_allocation) {
161 VkMemoryAllocateInfo minfo;
162 minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
163 minfo.pNext = nullptr;
164 minfo.allocationSize = nbytes;
165 minfo.memoryTypeIndex = vctx.compute_mtype_index;
166 VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
167 } else {
168 VkMemoryAllocateInfo minfo;
169 minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
170 minfo.pNext = nullptr;
171 minfo.allocationSize = req2.memoryRequirements.size;
172 minfo.memoryTypeIndex = vctx.compute_mtype_index;
173
174 VkMemoryDedicatedAllocateInfoKHR mdinfo;
175 mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
176 mdinfo.pNext = 0;
177 mdinfo.image = 0;
178 mdinfo.buffer = buffer;
179 minfo.pNext = &mdinfo;
180 VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
181 }
182 VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
183 VulkanBuffer* pbuf = new VulkanBuffer();
184 pbuf->memory = memory;
185 pbuf->buffer = buffer;
186 return pbuf;
187 }
188
FreeDataSpace(TVMContext ctx,void * ptr)189 void FreeDataSpace(TVMContext ctx, void* ptr) final {
190 // Before releasing the vkBuffer, call sync to
191 // finish all the vulkan commands that reference the buffer.
192 StreamSync(ctx, nullptr);
193
194 const auto& vctx = context(ctx.device_id);
195 auto* pbuf = static_cast<VulkanBuffer*>(ptr);
196 vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
197 vkFreeMemory(vctx.device, pbuf->memory, nullptr);
198 delete pbuf;
199 }
200
CopyDataFromTo(const void * from,size_t from_offset,void * to,size_t to_offset,size_t size,TVMContext ctx_from,TVMContext ctx_to,DLDataType type_hint,TVMStreamHandle stream)201 void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
202 TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
203 TVMStreamHandle stream) final {
204 CHECK(stream == nullptr);
205 TVMContext ctx = ctx_from;
206 if (ctx_from.device_type == kDLCPU) {
207 ctx = ctx_to;
208 }
209
210 int from_dev_type = static_cast<int>(ctx_from.device_type);
211 int to_dev_type = static_cast<int>(ctx_to.device_type);
212 if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
213 VulkanThreadEntry::ThreadLocal()
214 ->Stream(ctx_from.device_id)
215 ->Launch([=](VulkanStreamState* state) {
216 // 1: copy
217 const auto* from_buf = static_cast<const VulkanBuffer*>(from);
218 auto* to_buf = static_cast<VulkanBuffer*>(to);
219 VkBufferCopy copy_info;
220 copy_info.srcOffset = from_offset;
221 copy_info.dstOffset = to_offset;
222 copy_info.size = size;
223 vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info);
224 // 2: barrier(transfer-> compute|transfer)
225 CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
226 VkMemoryBarrier barrier_info;
227 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
228 barrier_info.pNext = nullptr;
229 barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
230 barrier_info.dstAccessMask =
231 (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
232 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
233 vkCmdPipelineBarrier(
234 state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
235 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
236 &barrier_info, 0, nullptr, 0, nullptr);
237 });
238
239 } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
240 const auto* from_buf = static_cast<const VulkanBuffer*>(from);
241 const auto& vctx = context(ctx_from.device_id);
242 auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
243 VulkanThreadEntry::ThreadLocal()
244 ->Stream(ctx_from.device_id)
245 ->Launch([&](VulkanStreamState* state) {
246 VkBufferCopy copy_info;
247 copy_info.srcOffset = from_offset;
248 copy_info.dstOffset = 0;
249 copy_info.size = size;
250 vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, ©_info);
251 });
252 VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
253 if (!vctx.coherent_staging) {
254 VkMappedMemoryRange mrange;
255 mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
256 mrange.pNext = nullptr;
257 mrange.memory = temp->memory;
258 mrange.offset = 0;
259 mrange.size = VK_WHOLE_SIZE; // size;
260 VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
261 }
262 memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
263 } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
264 const auto& vctx = context(ctx_to.device_id);
265 const auto* to_buf = static_cast<const VulkanBuffer*>(to);
266 VulkanStagingBuffer* temp =
267 VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
268 memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
269 // host side flush if access is not coherent.
270 // so writes from CPU is visible to GPU
271 if (!vctx.coherent_staging) {
272 VkMappedMemoryRange mrange;
273 mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
274 mrange.pNext = nullptr;
275 mrange.memory = temp->memory;
276 mrange.offset = 0;
277 mrange.size = VK_WHOLE_SIZE; // size;
278 VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
279 }
280
281 VulkanThreadEntry::ThreadLocal()
282 ->Stream(ctx_from.device_id)
283 ->Launch([&](VulkanStreamState* state) {
284 // 0: barrier(host->transfer)
285 VkMemoryBarrier barrier_info;
286 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
287 barrier_info.pNext = nullptr;
288 barrier_info.srcAccessMask = 0;
289 barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
290 vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
291 VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
292 nullptr);
293 // 1: copy
294 VkBufferCopy copy_info;
295 copy_info.srcOffset = 0;
296 copy_info.dstOffset = to_offset;
297 copy_info.size = size;
298 vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, ©_info);
299 });
300 // TODO(tulloch): should we instead make the staging buffer a property of the
301 // Stream? This would allow us to elide synchronizations here.
302 VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
303 } else {
304 LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
305 << ", from=" << from_dev_type << ", to=" << to_dev_type;
306 }
307 }
308
309 // Always use the default stream
CreateStream(TVMContext ctx)310 TVMStreamHandle CreateStream(TVMContext ctx) {
311 LOG(FATAL) << "Not implemented";
312 return nullptr;
313 }
314
FreeStream(TVMContext ctx,TVMStreamHandle stream)315 void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
316 LOG(FATAL) << "Not implemented";
317 return;
318 }
319
SyncStreamFromTo(TVMContext ctx,TVMStreamHandle event_src,TVMStreamHandle event_dst)320 void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
321 LOG(FATAL) << "Not implemented";
322 return;
323 }
324
StreamSync(TVMContext ctx,TVMStreamHandle stream)325 void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
326 CHECK(stream == nullptr);
327 VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
328 }
329
SetStream(TVMContext ctx,TVMStreamHandle stream)330 void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
331 LOG(FATAL) << "Not implemented";
332 return;
333 }
334
AllocWorkspace(TVMContext ctx,size_t size,DLDataType type_hint)335 void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
336 return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(ctx, size);
337 }
338
FreeWorkspace(TVMContext ctx,void * data)339 void FreeWorkspace(TVMContext ctx, void* data) final {
340 VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
341 }
342
Global()343 static VulkanDeviceAPI* Global() {
344 static VulkanDeviceAPI* inst = new VulkanDeviceAPI();
345 return inst;
346 }
347
context(size_t device_id) const348 const VulkanContext& context(size_t device_id) const {
349 CHECK_LT(device_id, context_.size());
350 return context_[device_id];
351 }
352
353 private:
354 VkInstance instance_{nullptr};
355 // The physical devices, have 1 to 1 mapping to devices
356 std::vector<VulkanContext> context_;
357 };
358
GetAttr(TVMContext ctx,DeviceAttrKind kind,TVMRetValue * rv)359 void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
360 size_t index = static_cast<size_t>(ctx.device_id);
361 if (kind == kExist) {
362 *rv = static_cast<int>(index < context_.size());
363 return;
364 }
365 CHECK_LT(index, context_.size()) << "Invalid device id " << index;
366 const auto& vctx = context(index);
367 switch (kind) {
368 case kMaxThreadsPerBlock: {
369 VkPhysicalDeviceProperties phy_prop;
370 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
371 int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations;
372 *rv = value;
373 break;
374 }
375 case kMaxSharedMemoryPerBlock: {
376 VkPhysicalDeviceProperties phy_prop;
377 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
378 int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
379 *rv = value;
380 break;
381 }
382 case kWarpSize: {
383 *rv = 1;
384 break;
385 }
386 case kComputeVersion: {
387 VkPhysicalDeviceProperties phy_prop;
388 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
389 int64_t value = phy_prop.apiVersion;
390 std::ostringstream os;
391 os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
392 << VK_VERSION_PATCH(value);
393 *rv = os.str();
394 break;
395 }
396 case kDeviceName:
397 return;
398 case kMaxClockRate:
399 return;
400 case kMultiProcessorCount:
401 return;
402 case kExist:
403 break;
404 case kMaxThreadDimensions: {
405 VkPhysicalDeviceProperties phy_prop;
406 vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
407 int64_t dims[3];
408 dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0];
409 dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1];
410 dims[2] = phy_prop.limits.maxComputeWorkGroupSize[2];
411 std::stringstream ss; // use json string to return multiple int values;
412 ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
413 *rv = ss.str();
414 break;
415 }
416 case kMaxRegistersPerBlock:
417 return;
418 case kGcnArch:
419 return;
420 case kApiVersion:
421 return;
422 }
423 }
424
VulkanDeviceAPI()425 VulkanDeviceAPI::VulkanDeviceAPI() {
426 VkApplicationInfo app_info;
427 app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
428 app_info.pNext = nullptr;
429 app_info.pApplicationName = "TVM";
430 app_info.applicationVersion = 0;
431 app_info.pEngineName = "";
432 app_info.engineVersion = 0;
433 app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
434
435 VkInstanceCreateInfo inst_info;
436 inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
437 inst_info.pNext = nullptr;
438 inst_info.flags = 0;
439
440 const auto layers = []() -> std::vector<const char*> {
441 uint32_t inst_layer_prop_count;
442 VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
443 std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
444 VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
445 std::vector<const char*> l;
446 for (const auto& lp : inst_layer_prop) {
447 // TODO(tulloch): add CMAKE options.
448 (void)lp; // suppress unused variable warning.
449 #ifdef USE_VULKAN_VALIDATION
450 if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
451 l.push_back("VK_LAYER_LUNARG_standard_validation");
452 }
453 if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
454 l.push_back("VK_LAYER_LUNARG_parameter_validation");
455 }
456 if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
457 l.push_back("VK_LAYER_KHRONOS_validation");
458 }
459 #endif
460 }
461 return l;
462 }();
463
464 const auto instance_extensions = []() -> std::vector<const char*> {
465 uint32_t inst_extension_prop_count;
466 VULKAN_CALL(
467 vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
468 std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
469 VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
470 inst_extension_prop.data()));
471 std::vector<const char*> extensions;
472 for (const auto& ip : inst_extension_prop) {
473 if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
474 extensions.push_back("VK_KHR_get_physical_device_properties2");
475 }
476 }
477 return extensions;
478 }();
479
480 inst_info.pApplicationInfo = &app_info;
481 inst_info.enabledLayerCount = layers.size();
482 inst_info.ppEnabledLayerNames = layers.data();
483 inst_info.enabledExtensionCount = instance_extensions.size();
484 inst_info.ppEnabledExtensionNames = instance_extensions.data();
485
486 VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
487
488 uint32_t phy_dev_count = 0;
489 VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
490 std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
491 VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
492 for (VkPhysicalDevice phy_dev : all_phy_devs) {
493 uint32_t queue_prop_count = 0;
494 vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
495 std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
496 vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
497 dmlc::BeginPtr(queue_props));
498 uint32_t queue_family_index = 0;
499 std::vector<VkDeviceQueueCreateInfo> queue_create_info;
500 float priority = 1.0f;
501 for (uint32_t i = 0; i < queue_props.size(); i++) {
502 // find queues that support compute
503 if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
504 VkDeviceQueueCreateInfo info;
505 info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
506 info.pNext = nullptr;
507 info.flags = 0;
508 info.queueFamilyIndex = i;
509 info.queueCount = 1;
510 info.pQueuePriorities = &priority;
511
512 queue_create_info.push_back(info);
513 // only use the first available queue for now
514 if (queue_create_info.size() == 0) {
515 queue_family_index = i;
516 }
517 }
518 }
519 if (queue_create_info.size() == 0) continue;
520
521 VulkanContext ctx;
522 // setup context
523 ctx.phy_device = phy_dev;
524 vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
525
526 const auto extensions = [&]() {
527 uint32_t device_extension_prop_count;
528 VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
529 &device_extension_prop_count, nullptr));
530 std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
531 VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
532 ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
533 std::vector<const char*> extensions;
534 for (const auto& dp : device_extension_prop) {
535 if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
536 extensions.push_back("VK_KHR_push_descriptor");
537 }
538 if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
539 dp.specVersion > 0) {
540 extensions.push_back("VK_KHR_descriptor_update_template");
541 }
542 if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
543 dp.specVersion > 0) {
544 extensions.push_back("VK_KHR_get_memory_requirements2");
545 }
546 if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
547 dp.specVersion > 0) {
548 extensions.push_back("VK_KHR_dedicated_allocation");
549 }
550 }
551 return extensions;
552 }();
553 VkDeviceCreateInfo device_create_info;
554 device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
555 device_create_info.pNext = nullptr;
556 device_create_info.flags = 0;
557 device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
558 device_create_info.pQueueCreateInfos = queue_create_info.data();
559 device_create_info.enabledLayerCount = 0;
560 device_create_info.ppEnabledLayerNames = nullptr;
561 device_create_info.enabledExtensionCount = extensions.size();
562 device_create_info.ppEnabledExtensionNames = extensions.data();
563 device_create_info.pEnabledFeatures = nullptr;
564 VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
565 ctx.queue_mutex.reset(new std::mutex());
566 vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
567 ctx.queue_family_index = queue_family_index;
568 // Find suitable memory type for staging and compute
569 // Find suitable compute index.
570 VkBuffer buffer;
571 VkMemoryRequirements req_staging, req_compute;
572 VkBufferCreateInfo info;
573 info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
574 info.pNext = nullptr;
575 info.flags = 0;
576 info.size = 1024;
577 info.queueFamilyIndexCount = 1;
578 info.pQueueFamilyIndices = &(ctx.queue_family_index);
579 info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
580
581 // get staging requirement
582 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
583 VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
584 vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
585 vkDestroyBuffer(ctx.device, buffer, nullptr);
586 // get compute requirement
587 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
588 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
589 VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
590 vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
591 vkDestroyBuffer(ctx.device, buffer, nullptr);
592
593 // Query phyiscal device property
594 // find a memory that is host visible, no need to be consistent
595 int win_rank = -1;
596 VkPhysicalDeviceMemoryProperties prop;
597 vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
598
599 for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
600 VkMemoryType ty = prop.memoryTypes[k];
601 size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
602 // host visible
603 if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
604 // match copy requirment
605 if (!(req_staging.memoryTypeBits & (1 << k))) continue;
606 if (heap_size < 1024) continue;
607 int rank = 0;
608 rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
609 if (rank > win_rank) {
610 win_rank = rank;
611 ctx.staging_mtype_index = k;
612 ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
613 }
614 }
615 CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
616
617 win_rank = -1;
618 for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
619 VkMemoryType ty = prop.memoryTypes[k];
620 size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
621 // host visible
622 if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
623 // match copy requirment
624 if (!(req_staging.memoryTypeBits & (1 << k))) continue;
625 if (heap_size < 1024) continue;
626 int rank = 0;
627 // prefer not host visible
628 rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
629 if (rank > win_rank) {
630 win_rank = rank;
631 ctx.compute_mtype_index = k;
632 }
633 }
634 CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
635 auto has_extension = [&extensions](const char* query) {
636 return std::any_of(extensions.begin(), extensions.end(),
637 [&](const char* extension) { return std::strcmp(query, extension) == 0; });
638 };
639
640 #ifdef USE_VULKAN_IMMEDIATE_MODE
641 if (has_extension("VK_KHR_push_descriptor") &&
642 has_extension("VK_KHR_descriptor_update_template")) {
643 ctx.descriptor_template_khr_functions = std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
644 new VulkanDescriptorTemplateKHRFunctions());
645 ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
646 CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
647 ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
648 ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
649 CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
650 ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
651 ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
652 CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
653 ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
654 ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
655 CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
656 ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
657 }
658 #endif
659
660 #ifdef USE_VULKAN_DEDICATED_ALLOCATION
661 if (has_extension("VK_KHR_get_memory_requirements2") &&
662 has_extension("VK_KHR_dedicated_allocation")) {
663 ctx.get_buffer_memory_requirements_2_functions =
664 std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
665 new VulkanGetBufferMemoryRequirements2Functions());
666 ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
667 CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
668 ctx.device, "vkGetBufferMemoryRequirements2KHR"));
669 }
670 #endif
671 context_.push_back(std::move(ctx));
672 }
673
674 LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
675 for (size_t i = 0; i < context_.size(); ++i) {
676 LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
677 << "\' phy_dev_id=" << context_[i].phy_device
678 << " use_immediate=" << context_[i].UseImmediate();
679 }
680 } // namespace vulkan
681 class VulkanModuleNode;
682
683 // a wrapped function class to get packed func.
684 class VulkanWrappedFunc {
685 public:
Init(VulkanModuleNode * m,ObjectPtr<Object> sptr,const std::string & func_name,size_t num_buffer_args,size_t num_pack_args,const std::vector<std::string> & thread_axis_tags)686 void Init(VulkanModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
687 size_t num_buffer_args, size_t num_pack_args,
688 const std::vector<std::string>& thread_axis_tags) {
689 m_ = m;
690 sptr_ = sptr;
691 func_name_ = func_name;
692 num_buffer_args_ = num_buffer_args;
693 num_pack_args_ = num_pack_args;
694 thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
695 }
696
697 void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
698
699 private:
700 // internal module
701 VulkanModuleNode* m_;
702 // the resource holder
703 ObjectPtr<Object> sptr_;
704 // v The name of the function.
705 std::string func_name_;
706 // Number of buffer arguments
707 size_t num_buffer_args_;
708 // number of packed arguments.
709 size_t num_pack_args_;
710 // Device state cache per device.
711 // mark as mutable, to enable lazy initialization
712 // thread axis configuration
713 ThreadAxisConfig thread_axis_cfg_;
714
715 mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
716 };
717
718 // Multi-device enabled module.
719 class VulkanModuleNode final : public runtime::ModuleNode {
720 public:
VulkanModuleNode(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)721 explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
722 std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
723 : smap_(smap), fmap_(fmap), source_(source) {}
724
type_key() const725 const char* type_key() const final { return "vulkan"; }
726
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)727 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
728 CHECK_EQ(sptr_to_self.get(), this);
729 CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
730 auto it = fmap_.find(name);
731 if (it == fmap_.end()) return PackedFunc();
732 const FunctionInfo& info = it->second;
733 VulkanWrappedFunc f;
734 size_t num_buffer_args = NumBufferArgs(info.arg_types);
735 f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
736 info.thread_axis_tags);
737 return PackFuncNonBufferArg(std::move(f), info.arg_types);
738 }
739
~VulkanModuleNode()740 ~VulkanModuleNode() {
741 // cleanup vulkan related caches.
742 for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) {
743 for (auto& kv : ecache_[device_id]) {
744 auto& pe = kv.second;
745 CHECK(pe);
746 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
747
748 if (pe->descriptor_update_template != VK_NULL_HANDLE) {
749 vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
750 vctx.device, pe->descriptor_update_template, nullptr);
751 }
752 vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
753 vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
754 vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
755 vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
756 vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
757 }
758 }
759 }
760
GetPipeline(size_t device_id,const std::string & func_name,size_t num_pack_args)761 std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
762 size_t num_pack_args) {
763 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
764 std::lock_guard<std::mutex> lock(mutex_);
765 const auto& cp = ecache_[device_id][func_name];
766 if (cp) {
767 return cp;
768 }
769 // Create new pipeline
770 auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
771 {
772 // create shader
773 auto sit = smap_.find(func_name);
774 CHECK(sit != smap_.end());
775 const std::vector<uint32_t>& data = sit->second.data;
776 VkShaderModuleCreateInfo shader_cinfo;
777 shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
778 shader_cinfo.pNext = nullptr;
779 shader_cinfo.flags = 0;
780 shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
781 shader_cinfo.pCode = data.data();
782 VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
783 }
784 std::vector<VkDescriptorSetLayoutBinding> arg_binding;
785 std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
786 uint32_t num_pod = 0, num_buffer = 0;
787
788 {
789 auto fit = fmap_.find(func_name);
790 CHECK(fit != fmap_.end());
791 for (DLDataType arg_type : fit->second.arg_types) {
792 if (arg_type.code == kTVMOpaqueHandle) {
793 {
794 VkDescriptorSetLayoutBinding bd;
795 bd.binding = num_buffer;
796 bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
797 bd.descriptorCount = 1;
798 bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
799 bd.pImmutableSamplers = nullptr;
800 arg_binding.push_back(bd);
801 }
802 {
803 VkDescriptorUpdateTemplateEntryKHR tpl;
804 tpl.dstBinding = num_buffer;
805 tpl.dstArrayElement = 0;
806 tpl.descriptorCount = 1;
807 tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
808 tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
809 tpl.stride = sizeof(VkDescriptorBufferInfo);
810 arg_template.push_back(tpl);
811 }
812 ++num_buffer;
813 } else {
814 ++num_pod;
815 }
816 }
817 }
818
819 {
820 VkDescriptorSetLayoutCreateInfo descrip_cinfo;
821 descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
822 descrip_cinfo.pNext = nullptr;
823 descrip_cinfo.flags = 0;
824 if (vctx.UseImmediate()) {
825 descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
826 }
827 descrip_cinfo.bindingCount = arg_binding.size();
828 descrip_cinfo.pBindings = arg_binding.data();
829 VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
830 &(pe->descriptor_set_layout)));
831 }
832
833 {
834 VkDescriptorPoolSize pool_size;
835 pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
836 pool_size.descriptorCount = arg_binding.size();
837 VkDescriptorPoolCreateInfo descrip_pool_cinfo;
838 descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
839 descrip_pool_cinfo.pNext = nullptr;
840 descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
841 descrip_pool_cinfo.maxSets = 1;
842 descrip_pool_cinfo.poolSizeCount = 1;
843 descrip_pool_cinfo.pPoolSizes = &pool_size;
844 VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
845 &(pe->descriptor_pool)));
846 }
847
848 if (!vctx.UseImmediate()) {
849 VkDescriptorSetAllocateInfo alloc_info;
850 alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
851 alloc_info.pNext = nullptr;
852 alloc_info.descriptorPool = pe->descriptor_pool;
853 alloc_info.descriptorSetCount = 1;
854 alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
855 VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
856 }
857
858 VkPushConstantRange crange;
859 crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
860 crange.offset = 0;
861 crange.size = sizeof(ArgUnion) * num_pack_args;
862
863 VkPipelineLayoutCreateInfo playout_cinfo;
864 playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
865 playout_cinfo.pNext = nullptr;
866 playout_cinfo.flags = 0;
867 playout_cinfo.setLayoutCount = 1;
868 playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
869
870 if (num_pack_args != 0) {
871 playout_cinfo.pushConstantRangeCount = 1;
872 playout_cinfo.pPushConstantRanges = &crange;
873 CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
874 } else {
875 playout_cinfo.pushConstantRangeCount = 0;
876 playout_cinfo.pPushConstantRanges = nullptr;
877 }
878
879 VULKAN_CALL(
880 vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
881
882 VkComputePipelineCreateInfo pipeline_cinfo;
883 pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
884 pipeline_cinfo.pNext = nullptr;
885 pipeline_cinfo.flags = 0;
886 pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
887 pipeline_cinfo.stage.pNext = nullptr;
888 pipeline_cinfo.stage.flags = 0;
889 pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
890 pipeline_cinfo.stage.module = pe->shader;
891 pipeline_cinfo.stage.pName = func_name.c_str();
892 pipeline_cinfo.stage.pSpecializationInfo = nullptr;
893 pipeline_cinfo.layout = pe->pipeline_layout;
894 pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
895 pipeline_cinfo.basePipelineIndex = 0;
896 VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
897 &(pe->pipeline)));
898
899 if (vctx.UseImmediate()) {
900 VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
901 descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
902 descrip_template_cinfo.pNext = 0;
903 descrip_template_cinfo.flags = 0;
904 descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
905 descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
906 descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
907 descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
908 descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
909 descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
910 descrip_template_cinfo.set = 0;
911 VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
912 vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
913 }
914 ecache_[device_id][func_name] = pe;
915 return pe;
916 }
917
SaveToFile(const std::string & file_name,const std::string & format)918 void SaveToFile(const std::string& file_name, const std::string& format) final {
919 std::string fmt = GetFileFormat(file_name, format);
920 CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
921 std::string meta_file = GetMetaFilePath(file_name);
922 SaveMetaDataToFile(meta_file, fmap_);
923 std::string data_bin;
924 dmlc::MemoryStringStream fs(&data_bin);
925 dmlc::Stream* stream = &fs;
926 uint32_t magic = kVulkanModuleMagic;
927 stream->Write(magic);
928 stream->Write(smap_);
929 SaveBinaryToFile(file_name, data_bin);
930 }
931
SaveToBinary(dmlc::Stream * stream)932 void SaveToBinary(dmlc::Stream* stream) final {
933 stream->Write(fmt_);
934 stream->Write(fmap_);
935 stream->Write(smap_);
936 }
GetSource(const std::string & format)937 std::string GetSource(const std::string& format) final {
938 // can only return source code.
939 return source_;
940 }
941
942 private:
943 // function information table.
944 std::unordered_map<std::string, VulkanShader> smap_;
945 // function information table.
946 std::unordered_map<std::string, FunctionInfo> fmap_;
947 // The format
948 std::string fmt_{"vulkan"};
949 // The source
950 std::string source_;
951
952 // Guards accesses to `ecache_`
953 std::mutex mutex_;
954 std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
955 ecache_;
956 };
957
VulkanModuleCreate(std::unordered_map<std::string,VulkanShader> smap,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)958 Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
959 std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
960 auto n = make_object<VulkanModuleNode>(smap, fmap, source);
961 return Module(n);
962 }
963
ThreadLocal()964 VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
965
StagingBuffer(int device_id,size_t size)966 VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
967 if (!staging_buffers_[device_id]) {
968 staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
969 }
970 auto& buf = *(staging_buffers_[device_id]);
971 if (buf.device != nullptr && buf.size < size) {
972 // free previous buffer
973 if (buf.host_addr != nullptr) {
974 vkUnmapMemory(buf.device, buf.memory);
975 }
976 if (buf.memory != VK_NULL_HANDLE) {
977 vkFreeMemory(buf.device, buf.memory, nullptr);
978 }
979 if (buf.buffer != VK_NULL_HANDLE) {
980 vkDestroyBuffer(buf.device, buf.buffer, nullptr);
981 }
982 buf.host_addr = nullptr;
983 buf.memory = VK_NULL_HANDLE;
984 buf.buffer = VK_NULL_HANDLE;
985 }
986 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
987
988 if (buf.device == nullptr) {
989 buf.device = vctx.device;
990 }
991 if (buf.memory == VK_NULL_HANDLE) {
992 // allocate the stagging buffer memory if necessary
993 VkBufferCreateInfo info;
994 info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
995 info.pNext = nullptr;
996 info.flags = 0;
997 info.size = size;
998 info.queueFamilyIndexCount = 1;
999 info.pQueueFamilyIndices = &(vctx.queue_family_index);
1000 info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
1001 info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
1002 VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
1003 VkMemoryAllocateInfo minfo;
1004 minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
1005 minfo.pNext = nullptr;
1006 minfo.allocationSize = size;
1007 minfo.memoryTypeIndex = vctx.staging_mtype_index;
1008 VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
1009 VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
1010 VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
1011 buf.size = size;
1012 }
1013 memset(buf.host_addr, 0, size);
1014 return &buf;
1015 }
1016
VulkanThreadEntry()1017 VulkanThreadEntry::VulkanThreadEntry()
1018 : pool(std::make_unique<WorkspacePool>(static_cast<DLDeviceType>(kDLVulkan),
1019 VulkanDeviceAPI::Global())) {
1020 ctx.device_id = 0;
1021 ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
1022 }
1023
Stream(size_t device_id)1024 VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
1025 if (!streams_[device_id]) {
1026 streams_[device_id] = std::unique_ptr<VulkanStream>(
1027 new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
1028 }
1029 return streams_[device_id].get();
1030 }
1031
operator ()(TVMArgs args,TVMRetValue * rv,const ArgUnion * pack_args) const1032 void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
1033 int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
1034 CHECK_LT(device_id, kVulkanMaxNumDevice);
1035 const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
1036 if (!scache_[device_id]) {
1037 scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
1038 }
1039 const auto& pipeline = scache_[device_id];
1040 ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
1041 std::vector<VkDescriptorBufferInfo> descriptor_buffers;
1042 descriptor_buffers.resize(num_buffer_args_);
1043 for (size_t i = 0; i < num_buffer_args_; ++i) {
1044 void* buf = args[static_cast<int>(i)];
1045 VkDescriptorBufferInfo binfo;
1046 binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
1047 binfo.offset = 0;
1048 binfo.range = VK_WHOLE_SIZE;
1049 descriptor_buffers[i] = binfo;
1050 }
1051 if (vctx.UseImmediate()) {
1052 // Can safely capture by reference as this lambda is immediately executed on the calling thread.
1053 VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
1054 vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1055 CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
1056 vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
1057 state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
1058 descriptor_buffers.data());
1059 if (num_pack_args_ != 0) {
1060 vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
1061 VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
1062 pack_args);
1063 }
1064 vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1065 VkMemoryBarrier barrier_info;
1066 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1067 barrier_info.pNext = nullptr;
1068 barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1069 barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1070 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1071 vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1072 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1073 1, &barrier_info, 0, nullptr, 0, nullptr);
1074 });
1075 return;
1076 }
1077
1078 // Otherwise, the more expensive deferred path.
1079 std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
1080 const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
1081 std::vector<VkWriteDescriptorSet> write_descriptor_sets;
1082 write_descriptor_sets.resize(descriptor_buffers.size());
1083 for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
1084 write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
1085 write_descriptor_sets[i].pNext = 0;
1086 write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
1087 write_descriptor_sets[i].dstBinding = i;
1088 write_descriptor_sets[i].dstArrayElement = 0;
1089 write_descriptor_sets[i].descriptorCount = 1;
1090 write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
1091 write_descriptor_sets[i].pImageInfo = 0;
1092 write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
1093 write_descriptor_sets[i].pTexelBufferView = 0;
1094 }
1095 vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
1096 0, 0);
1097 };
1098 const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
1099 vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
1100 vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
1101 pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
1102 nullptr);
1103 if (pack_args_storage.size() != 0) {
1104 vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
1105 0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
1106 }
1107 vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
1108 VkMemoryBarrier barrier_info;
1109 barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
1110 barrier_info.pNext = nullptr;
1111 barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
1112 barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
1113 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
1114 vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1115 VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1116 1, &barrier_info, 0, nullptr, 0, nullptr);
1117 };
1118 VulkanStreamToken deferred_token;
1119 deferred_token.descriptor_set_ = pipeline->descriptor_set;
1120 deferred_token.buffers_.resize(descriptor_buffers.size());
1121 for (size_t i = 0; i < descriptor_buffers.size(); ++i) {
1122 deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
1123 }
1124 VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
1125 deferred_initializer, deferred_kernel, deferred_token);
1126 }
1127
VulkanModuleLoadFile(const std::string & file_name,const std::string & format)1128 Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
1129 std::string data;
1130 std::unordered_map<std::string, VulkanShader> smap;
1131 std::unordered_map<std::string, FunctionInfo> fmap;
1132 std::string fmt = GetFileFormat(file_name, format);
1133 std::string meta_file = GetMetaFilePath(file_name);
1134 LoadBinaryFromFile(file_name, &data);
1135 LoadMetaDataFromFile(meta_file, &fmap);
1136 dmlc::MemoryStringStream fs(&data);
1137 dmlc::Stream* stream = &fs;
1138 uint32_t magic;
1139 stream->Read(&magic);
1140 CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
1141 stream->Read(&smap);
1142 return VulkanModuleCreate(smap, fmap, "");
1143 }
1144
VulkanModuleLoadBinary(void * strm)1145 Module VulkanModuleLoadBinary(void* strm) {
1146 dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
1147 std::unordered_map<std::string, VulkanShader> smap;
1148 std::unordered_map<std::string, FunctionInfo> fmap;
1149
1150 std::string fmt;
1151 stream->Read(&fmt);
1152 stream->Read(&fmap);
1153 stream->Read(&smap);
1154 return VulkanModuleCreate(smap, fmap, "");
1155 }
1156
1157 TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
1158
1159 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
1160
__anonbdf0b3af0c02(TVMArgs args, TVMRetValue* rv) 1161 TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
1162 DeviceAPI* ptr = VulkanDeviceAPI::Global();
1163 *rv = static_cast<void*>(ptr);
1164 });
1165
1166 } // namespace vulkan
1167 } // namespace runtime
1168 } // namespace tvm
1169