1 //
2 //  VulkanConvolution.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef VulkanConvolution_hpp
10 #define VulkanConvolution_hpp
11 
12 #include "VulkanBasicExecution.hpp"
13 #include "core/ConvolutionCommon.hpp"
14 namespace MNN {
15 class VulkanConvolutionCommon : public VulkanBasicExecution {
16 public:
17     VulkanConvolutionCommon(const Op* op, Backend* bn);
18     virtual ~VulkanConvolutionCommon();
19 
20     virtual ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
21                                const VulkanCommandPool::Buffer* cmdBuffer) override;
22 
23     struct ConvolutionParameter {
24         ivec2 pad;
25         ivec2 kernelSize;
26         ivec2 stride;
27         ivec2 dilate;
28         ivec4 inputSize;
29         ivec4 outputSize;
30         ivec4 offset;
31     };
32 
33     static void writeParameter(ConvolutionParameter* dest, const Convolution2DCommon* common, const Tensor* input,
34                                const Tensor* output);
35     static std::string getPostTreatMacro(const Convolution2DCommon* common);
36     class BufferToImageCopy {
37     public:
BufferToImageCopy(const VulkanBackend * bn)38         BufferToImageCopy(const VulkanBackend* bn) {
39             mBackend = bn;
40             std::vector<VkDescriptorType> types{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
41                                                 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER};
42             mPipeline = mBackend->getPipeline("glsl_buffer2Image2D_comp", types);
43             mSets.reset(mPipeline->createSet());
44             mConstBuffer = std::make_shared<VulkanBuffer>(bn->getMemoryPool(), false, 2 * sizeof(int),
45                                                               nullptr, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
46         }
encode(const VulkanImage * image,VkBuffer buffer,size_t bufferSize,const VulkanCommandPool::Buffer * cmdBuffer)47         void encode(const VulkanImage* image, VkBuffer buffer, size_t bufferSize, const VulkanCommandPool::Buffer* cmdBuffer) {
48             int localX = 16;
49             int localY = 16;
50             int localZ = 1;
51             int* dim = (int*)mConstBuffer->map();
52             dim[0] = image->width();
53             dim[1] = image->height();
54             mConstBuffer->unmap();
55             cmdBuffer->barrierImageIfNeeded(image, VK_IMAGE_LAYOUT_GENERAL);
56             mSets->writeImage(image->view(), mBackend->getCommonSampler()->get(), VK_IMAGE_LAYOUT_GENERAL, 0);
57             mSets->writeBuffer(buffer, 1, bufferSize);
58             mSets->writeBuffer(mConstBuffer->buffer(), 2, mConstBuffer->size());
59             mPipeline->bind(cmdBuffer->get(), mSets->get());
60             cmdBuffer->barrierSource(buffer, 0, bufferSize);
61             vkCmdDispatch(cmdBuffer->get(), UP_DIV(image->width(), localX), UP_DIV(image->height(), localY),
62                           UP_DIV(image->depth(), localZ));
63         }
64     private:
65         const VulkanBackend* mBackend;
66         const VulkanPipeline* mPipeline;
67         std::shared_ptr<VulkanPipeline::DescriptorSet> mSets;
68         std::shared_ptr<VulkanBuffer> mConstBuffer;
69     };
70     static int gImage2ColLocal;
71 protected:
72     virtual ErrorCode onEncodeConvolution(const Convolution2DCommon* common, const std::vector<Tensor*>& inputs,
73                                           const std::vector<Tensor*>& outputs,
74                                           const VulkanCommandPool::Buffer* cmdBuffer,
75                                           const VulkanBuffer* constConvBuffer) = 0;
76 
77 private:
78     const Convolution2DCommon* mCommon;
79     std::shared_ptr<VulkanBuffer> mConvCons;
80 };
81 
82 class VulkanConvolutionDepthwise : public VulkanConvolutionCommon {
83 public:
84     VulkanConvolutionDepthwise(const float* weightData, size_t weightSize, const Op* op, Backend* bn);
85     virtual ~VulkanConvolutionDepthwise();
86     virtual ErrorCode onEncodeConvolution(const Convolution2DCommon* common, const std::vector<Tensor*>& inputs,
87                                           const std::vector<Tensor*>& outputs,
88                                           const VulkanCommandPool::Buffer* cmdBuffer,
89                                           const VulkanBuffer* constConvBuffer) override;
90 
91 private:
92     bool _init(const float* weightData, size_t weightSize, const Op* op, Backend* bn);
93     std::shared_ptr<VulkanImage> mKernel;
94 
95     const VulkanPipeline* mConvPipeline;
96 
97     std::shared_ptr<VulkanPipeline::DescriptorSet> mConvSet;
98     const VulkanSampler* mSampler;
99     std::shared_ptr<VulkanImage> mBias;
100     std::vector<std::shared_ptr<VulkanPipeline::DescriptorSet>> mExtraSets;
101     std::vector<std::shared_ptr<VulkanBuffer>> mExtraBuffers;
102 
103     int mLocalX = 0;
104     int mLocalY = 0;
105 };
106 } // namespace MNN
107 
108 #endif /* VulkanConvolution_hpp */
109