1 // 2 // VulkanConvolution.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/31. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef VulkanConvolution_hpp 10 #define VulkanConvolution_hpp 11 12 #include "VulkanBasicExecution.hpp" 13 #include "core/ConvolutionCommon.hpp" 14 namespace MNN { 15 class VulkanConvolutionCommon : public VulkanBasicExecution { 16 public: 17 VulkanConvolutionCommon(const Op* op, Backend* bn); 18 virtual ~VulkanConvolutionCommon(); 19 20 virtual ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, 21 const VulkanCommandPool::Buffer* cmdBuffer) override; 22 23 struct ConvolutionParameter { 24 ivec2 pad; 25 ivec2 kernelSize; 26 ivec2 stride; 27 ivec2 dilate; 28 ivec4 inputSize; 29 ivec4 outputSize; 30 ivec4 offset; 31 }; 32 33 static void writeParameter(ConvolutionParameter* dest, const Convolution2DCommon* common, const Tensor* input, 34 const Tensor* output); 35 static std::string getPostTreatMacro(const Convolution2DCommon* common); 36 class BufferToImageCopy { 37 public: BufferToImageCopy(const VulkanBackend * bn)38 BufferToImageCopy(const VulkanBackend* bn) { 39 mBackend = bn; 40 std::vector<VkDescriptorType> types{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 41 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}; 42 mPipeline = mBackend->getPipeline("glsl_buffer2Image2D_comp", types); 43 mSets.reset(mPipeline->createSet()); 44 mConstBuffer = std::make_shared<VulkanBuffer>(bn->getMemoryPool(), false, 2 * sizeof(int), 45 nullptr, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); 46 } encode(const VulkanImage * image,VkBuffer buffer,size_t bufferSize,const VulkanCommandPool::Buffer * cmdBuffer)47 void encode(const VulkanImage* image, VkBuffer buffer, size_t bufferSize, const VulkanCommandPool::Buffer* cmdBuffer) { 48 int localX = 16; 49 int localY = 16; 50 int localZ = 1; 51 int* dim = (int*)mConstBuffer->map(); 52 dim[0] = image->width(); 53 dim[1] = image->height(); 54 mConstBuffer->unmap(); 55 cmdBuffer->barrierImageIfNeeded(image, VK_IMAGE_LAYOUT_GENERAL); 56 mSets->writeImage(image->view(), mBackend->getCommonSampler()->get(), VK_IMAGE_LAYOUT_GENERAL, 0); 57 mSets->writeBuffer(buffer, 1, bufferSize); 58 mSets->writeBuffer(mConstBuffer->buffer(), 2, mConstBuffer->size()); 59 mPipeline->bind(cmdBuffer->get(), mSets->get()); 60 cmdBuffer->barrierSource(buffer, 0, bufferSize); 61 vkCmdDispatch(cmdBuffer->get(), UP_DIV(image->width(), localX), UP_DIV(image->height(), localY), 62 UP_DIV(image->depth(), localZ)); 63 } 64 private: 65 const VulkanBackend* mBackend; 66 const VulkanPipeline* mPipeline; 67 std::shared_ptr<VulkanPipeline::DescriptorSet> mSets; 68 std::shared_ptr<VulkanBuffer> mConstBuffer; 69 }; 70 static int gImage2ColLocal; 71 protected: 72 virtual ErrorCode onEncodeConvolution(const Convolution2DCommon* common, const std::vector<Tensor*>& inputs, 73 const std::vector<Tensor*>& outputs, 74 const VulkanCommandPool::Buffer* cmdBuffer, 75 const VulkanBuffer* constConvBuffer) = 0; 76 77 private: 78 const Convolution2DCommon* mCommon; 79 std::shared_ptr<VulkanBuffer> mConvCons; 80 }; 81 82 class VulkanConvolutionDepthwise : public VulkanConvolutionCommon { 83 public: 84 VulkanConvolutionDepthwise(const float* weightData, size_t weightSize, const Op* op, Backend* bn); 85 virtual ~VulkanConvolutionDepthwise(); 86 virtual ErrorCode onEncodeConvolution(const Convolution2DCommon* common, const std::vector<Tensor*>& inputs, 87 const std::vector<Tensor*>& outputs, 88 const VulkanCommandPool::Buffer* cmdBuffer, 89 const VulkanBuffer* constConvBuffer) override; 90 91 private: 92 bool _init(const float* weightData, size_t weightSize, const Op* op, Backend* bn); 93 std::shared_ptr<VulkanImage> mKernel; 94 95 const VulkanPipeline* mConvPipeline; 96 97 std::shared_ptr<VulkanPipeline::DescriptorSet> mConvSet; 98 const VulkanSampler* mSampler; 99 std::shared_ptr<VulkanImage> mBias; 100 std::vector<std::shared_ptr<VulkanPipeline::DescriptorSet>> mExtraSets; 101 std::vector<std::shared_ptr<VulkanBuffer>> mExtraBuffers; 102 103 int mLocalX = 0; 104 int mLocalY = 0; 105 }; 106 } // namespace MNN 107 108 #endif /* VulkanConvolution_hpp */ 109