1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 
8 #include "../../precomp.hpp"
9 #include "common.hpp"
10 #include "internal.hpp"
11 #include "../include/op_base.hpp"
12 
13 namespace cv { namespace dnn { namespace vkcom {
14 
15 #ifdef HAVE_VULKAN
16 
OpBase()17 OpBase::OpBase()
18 {
19     createContext();
20     device_ = kDevice;
21     pipeline_ = VK_NULL_HANDLE;
22     cmd_buffer_ = VK_NULL_HANDLE;
23     descriptor_pool_ = VK_NULL_HANDLE;
24     descriptor_set_ = VK_NULL_HANDLE;
25     descriptor_set_layout_ = VK_NULL_HANDLE;
26     pipeline_layout_ = VK_NULL_HANDLE;
27     module_ = VK_NULL_HANDLE;
28 }
29 
~OpBase()30 OpBase::~OpBase()
31 {
32     vkDestroyShaderModule(device_, module_, NULL);
33     vkDestroyDescriptorPool(device_, descriptor_pool_, NULL);
34     vkDestroyDescriptorSetLayout(device_, descriptor_set_layout_, NULL);
35     vkDestroyPipeline(device_, pipeline_, NULL);
36     vkDestroyPipelineLayout(device_, pipeline_layout_, NULL);
37 }
38 
initVulkanThing(int buffer_num)39 void OpBase::initVulkanThing(int buffer_num)
40 {
41     createDescriptorSetLayout(buffer_num);
42     createDescriptorSet(buffer_num);
43     createCommandBuffer();
44 }
45 
createDescriptorSetLayout(int buffer_num)46 void OpBase::createDescriptorSetLayout(int buffer_num)
47 {
48     if (buffer_num <= 0)
49         return;
50     std::vector<VkDescriptorSetLayoutBinding> bindings(buffer_num);
51     for (int i = 0; i < buffer_num; i++)
52     {
53         bindings[i].binding = i;
54         bindings[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
55         bindings[i].descriptorCount = 1;
56         bindings[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
57     }
58     VkDescriptorSetLayoutCreateInfo info = {};
59     info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
60     info.bindingCount = buffer_num;
61     info.pBindings = &bindings[0];
62     VK_CHECK_RESULT(vkCreateDescriptorSetLayout(device_, &info, NULL, &descriptor_set_layout_));
63 }
64 
createDescriptorSet(int buffer_num)65 void OpBase::createDescriptorSet(int buffer_num)
66 {
67     VkDescriptorPoolSize pool_size = {};
68     pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
69     pool_size.descriptorCount = buffer_num;
70 
71     VkDescriptorPoolCreateInfo info = {};
72     info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
73     info.maxSets = 1;
74     info.poolSizeCount = 1;
75     info.pPoolSizes = &pool_size;
76     VK_CHECK_RESULT(vkCreateDescriptorPool(device_, &info, NULL, &descriptor_pool_));
77 
78     VkDescriptorSetAllocateInfo allocate_info = {};
79     allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
80     allocate_info.descriptorPool = descriptor_pool_;
81     allocate_info.descriptorSetCount = 1;
82     allocate_info.pSetLayouts = &descriptor_set_layout_;
83     VK_CHECK_RESULT(vkAllocateDescriptorSets(device_, &allocate_info, &descriptor_set_));
84 }
85 
createShaderModule(const uint32_t * spv,size_t sz,const std::string & source)86 void OpBase::createShaderModule(const uint32_t* spv, size_t sz, const std::string& source)
87 {
88     VkShaderModuleCreateInfo create_info = {};
89     create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
90     if (spv)
91     {
92         create_info.pCode = spv;
93         create_info.codeSize = sz;
94     }
95     else
96     {
97         // online compilation
98         std::vector<uint32_t> code;
99         code = compile("shader", shaderc_compute_shader, source);
100         create_info.pCode = code.data();
101         create_info.codeSize = sizeof(uint32_t) * code.size();
102     }
103     VK_CHECK_RESULT(vkCreateShaderModule(device_, &create_info, NULL, &module_));
104 }
105 
createPipeline(size_t push_constants_size,VkSpecializationInfo * specialization_info)106 void OpBase::createPipeline(size_t push_constants_size, VkSpecializationInfo* specialization_info)
107 {
108     // create pipeline
109     VkPipelineShaderStageCreateInfo stage_create_info = {};
110     stage_create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
111     stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT;
112     stage_create_info.module = module_;
113     stage_create_info.pName = "main";
114     stage_create_info.pSpecializationInfo = specialization_info;
115     VkPushConstantRange push_constant_ranges[1] = {};
116     push_constant_ranges[0].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
117     push_constant_ranges[0].offset = 0;
118     push_constant_ranges[0].size = push_constants_size;
119 
120     VkPipelineLayoutCreateInfo pipeline_layout_create_info = {};
121     pipeline_layout_create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
122     if (push_constants_size != 0)
123     {
124         pipeline_layout_create_info.pushConstantRangeCount = 1;
125         pipeline_layout_create_info.pPushConstantRanges = push_constant_ranges;
126     }
127     pipeline_layout_create_info.setLayoutCount = 1;
128     pipeline_layout_create_info.pSetLayouts = &descriptor_set_layout_;
129     VK_CHECK_RESULT(vkCreatePipelineLayout(device_, &pipeline_layout_create_info,
130                                            NULL, &pipeline_layout_));
131 
132     VkComputePipelineCreateInfo pipeline_create_info = {};
133     pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
134     pipeline_create_info.stage = stage_create_info;
135     pipeline_create_info.layout = pipeline_layout_;
136     VK_CHECK_RESULT(vkCreateComputePipelines(device_, VK_NULL_HANDLE,
137                                              1, &pipeline_create_info,
138                                              NULL, &pipeline_));
139 }
140 
createCommandBuffer()141 void OpBase::createCommandBuffer()
142 {
143     VkCommandBufferAllocateInfo info = {};
144     info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
145     info.commandPool = kCmdPool;
146     info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
147     info.commandBufferCount = 1;
148     VK_CHECK_RESULT(vkAllocateCommandBuffers(device_, &info, &cmd_buffer_));
149 }
150 
recordCommandBuffer(void * push_constants,size_t push_constants_size)151 void OpBase::recordCommandBuffer(void* push_constants, size_t push_constants_size)
152 {
153     VkCommandBufferBeginInfo beginInfo = {};
154     beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
155     beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
156     cv::AutoLock lock(kContextMtx);
157     VK_CHECK_RESULT(vkBeginCommandBuffer(cmd_buffer_, &beginInfo));
158     if (push_constants)
159         vkCmdPushConstants(cmd_buffer_, pipeline_layout_,
160                            VK_SHADER_STAGE_COMPUTE_BIT, 0,
161                            push_constants_size, push_constants);
162     vkCmdBindPipeline(cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_);
163     vkCmdBindDescriptorSets(cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
164                             pipeline_layout_, 0, 1, &descriptor_set_, 0, NULL);
165     vkCmdDispatch(cmd_buffer_, group_x_, group_y_, group_z_);
166 
167     VK_CHECK_RESULT(vkEndCommandBuffer(cmd_buffer_));
168 }
169 
runCommandBuffer()170 void OpBase::runCommandBuffer()
171 {
172     VkSubmitInfo submit_info = {};
173     submit_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
174     submit_info.commandBufferCount = 1;
175     submit_info.pCommandBuffers = &cmd_buffer_;
176 
177     VkFence fence;
178     VkFenceCreateInfo fence_create_info_ = {};
179     fence_create_info_.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
180     fence_create_info_.flags = 0;
181 
182     VK_CHECK_RESULT(vkCreateFence(device_, &fence_create_info_, NULL, &fence));
183     {
184         cv::AutoLock lock(kContextMtx);
185         VK_CHECK_RESULT(vkQueueSubmit(kQueue, 1, &submit_info, fence));
186     }
187     VK_CHECK_RESULT(vkWaitForFences(device_, 1, &fence, VK_TRUE, 100000000000));
188     vkDestroyFence(device_, fence, NULL);
189 }
190 
191 #endif // HAVE_VULKAN
192 
193 }}} // namespace cv::dnn::vkcom
194