1 //
2 //  VulkanScale.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "VulkanScale.hpp"
10 #include "core/Macro.h"
11 #include "core/TensorUtils.hpp"
12 
13 namespace MNN {
14 
15 struct gpuScaleParam {
16     ivec4 imgSize;
17 };
18 
VulkanScale(const Op * op,Backend * bn)19 VulkanScale::VulkanScale(const Op* op, Backend* bn) : VulkanBasicExecution(bn) {
20     const auto scale   = op->main_as_Scale();
21     const int channels = scale->scaleData()->size();
22 
23     std::vector<VkDescriptorType> types{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
24                                         VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
25                                         VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER};
26 
27     auto extra = static_cast<VulkanBackend*>(bn);
28 
29     mScalePipeline = extra->getPipeline("glsl_scale_comp", /*glsl_scale_comp, glsl_scale_comp_len,*/ types);
30     mScaleParam    = std::make_shared<VulkanBuffer>(extra->getMemoryPool(), false, sizeof(gpuScaleParam), nullptr,
31                                                  VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
32     mScaleBuffer   = std::make_shared<VulkanBuffer>(extra->getMemoryPool(), false, sizeof(float) * channels,
33                                                   scale->scaleData()->data(), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
34     mBiasBuffer    = std::make_shared<VulkanBuffer>(extra->getMemoryPool(), false, sizeof(float) * channels,
35                                                  scale->biasData()->data(), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
36     mSampler       = extra->getCommonSampler();
37 }
38 
~VulkanScale()39 VulkanScale::~VulkanScale() {
40 }
41 
onEncode(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const VulkanCommandPool::Buffer * cmdBuffer)42 ErrorCode VulkanScale::onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
43                                 const VulkanCommandPool::Buffer* cmdBuffer) {
44     auto input  = inputs[0];
45     auto output = outputs[0];
46 
47     MNN_ASSERT(MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(input)->dimensionFormat);
48 
49     auto scaleP = reinterpret_cast<gpuScaleParam*>(mScaleParam->map());
50     ::memset(scaleP, 0, sizeof(gpuScaleParam));
51 
52     const int channelDiv4 = UP_DIV(input->channel(), 4);
53 
54     scaleP->imgSize[0] = input->width();
55     scaleP->imgSize[1] = input->height();
56     scaleP->imgSize[2] = channelDiv4;
57     scaleP->imgSize[3] = input->batch();
58     mScaleParam->flush(true, 0, sizeof(gpuScaleParam));
59     mScaleParam->unmap();
60 
61     mDescriptorSet.reset(mScalePipeline->createSet());
62     mDescriptorSet->writeImage(reinterpret_cast<VulkanTensor*>(output->deviceId())->image()->view(), mSampler->get(),
63                                VK_IMAGE_LAYOUT_GENERAL, 0);
64     mDescriptorSet->writeImage(reinterpret_cast<VulkanTensor*>(input->deviceId())->image()->view(), mSampler->get(),
65                                VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, 1);
66     mDescriptorSet->writeBuffer(mScaleBuffer->buffer(), 2, mScaleBuffer->size());
67     mDescriptorSet->writeBuffer(mBiasBuffer->buffer(), 3, mBiasBuffer->size());
68     mDescriptorSet->writeBuffer(mScaleParam->buffer(), 4, mScaleParam->size());
69     mScalePipeline->bind(cmdBuffer->get(), mDescriptorSet->get());
70 
71     vkCmdDispatch(cmdBuffer->get(), UP_DIV(input->width(), 16), UP_DIV(input->height(), 16),
72                   channelDiv4 * input->batch());
73 
74     return NO_ERROR;
75 }
76 
77 class VulkanScaleCreator : public VulkanBackend::Creator {
78 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * bn) const79     virtual VulkanBasicExecution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op, Backend* bn) const override {
80         return new VulkanScale(op, bn);
81     }
82 };
83 
__anon7bcd29d50102() 84 static bool gResistor = []() {
85     VulkanBackend::addCreator(OpType_Scale, new VulkanScaleCreator);
86     return true;
87 }();
88 
89 } // namespace MNN
90