1 //
2 //  GeometryConv2D.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/07/14.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <limits>
10 #include "ConvertUtils.hpp"
11 #include "GeometryConvUtils.hpp"
12 #define MNN_OPEN_TIME_TRACE
13 #include <MNN/AutoTime.hpp>
14 namespace MNN {
15 
16 class GeometryConv2D : public DefaultGeometryComputer {
17 public:
18     // Im2Col + GEMM
computeIm2Col_GEMM(const Convolution2DCommon * common,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const19     bool computeIm2Col_GEMM(  const Convolution2DCommon* common, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
20                             Context& context, CommandBuffer& res) const {
21         auto input      = inputs[0];
22         auto outputDiff = outputs[0];
23         MNN_ASSERT(1 == common->group());
24         auto kw    = common->kernelX();
25         auto kh    = common->kernelY();
26         auto sw    = common->strideX();
27         auto sh    = common->strideY();
28         auto dw    = common->dilateX();
29         auto dh    = common->dilateY();
30         auto batch = outputDiff->batch();
31         auto ow    = outputDiff->width();
32         auto oh    = outputDiff->height();
33         auto oc    = outputDiff->channel();
34         auto ic    = input->channel();
35         auto iw    = input->width();
36         auto ih    = input->height();
37         auto pads  = ConvolutionCommon::convolutionPad(input, outputDiff, common);
38         MNN_ASSERT(TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
39         MNN_ASSERT(TensorUtils::getDescribe(outputDiff)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
40         Tensor* A = nullptr;
41         Tensor* B = nullptr;
42         {
43             // B: Input Im2Col, n, ic, ih, iw -> ic*kh*kw, n*oh*ow
44             std::shared_ptr<Tensor> im2Col(new Tensor);
45             GeometryConvUtils::im2Col(im2Col.get(), input, ic, kh, kw, batch, oh, ow, ih, iw, sh, sw, dh, dw, pads);
46             B = im2Col.get();
47             res.extras.emplace_back(im2Col);
48         }
49         {
50             // A: Weight oc, ic, kh, kw -> oc, ic*kh*kw
51             std::shared_ptr<Tensor> kernel(new Tensor);
52             A                           = kernel.get();
53             kernel->buffer().type       = halide_type_of<float>();
54             kernel->buffer().dimensions = 2;
55             kernel->setLength(0, oc);
56             kernel->setLength(1, ic * kw * kh);
57             auto des             = TensorUtils::getDescribe(kernel.get());
58             des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
59             GeometryComputerUtils::makeRawAddressRef(kernel.get(), inputs[1], 0, ic * kw * kh * oc);
60             res.extras.emplace_back(std::move(kernel));
61         }
62         {
63             // C = MatMul(B, A)
64             std::shared_ptr<Tensor> C(new Tensor);
65             C->buffer().type       = halide_type_of<float>();
66             C->buffer().dimensions = 2;
67             C->setLength(0, batch * ow * oh);
68             C->setLength(1, oc);
69             TensorUtils::getDescribe(C.get())->dimensionFormat = MNN_DATA_FORMAT_NCHW;
70             Tensor* bias                                       = nullptr;
71             if (inputs.size() > 2) {
72                 bias = inputs[2];
73             }
74             res.command.emplace_back(GeometryComputerUtils::makeMatMul(B, A, C.get(), bias, true, true));
75             res.extras.emplace_back(C);
76 
77             // Activation
78             float minValue = 0.0f, maxValue = 6.0f;
79             bool needPostTreat = false;
80             if (common->relu()) {
81                 needPostTreat = true;
82                 minValue      = 0.0f;
83                 maxValue      = std::numeric_limits<float>().max();
84             }
85             if (common->relu6()) {
86                 needPostTreat = true;
87                 minValue      = 0.0f;
88                 maxValue      = 6.0f;
89             }
90             if (needPostTreat) {
91                 flatbuffers::FlatBufferBuilder builder;
92                 builder.Finish(GeometryConvUtils::makeRelu6(builder, minValue, maxValue));
93                 std::shared_ptr<Tensor> C2(new Tensor);
94                 C2->buffer().type       = halide_type_of<float>();
95                 C2->buffer().dimensions = 2;
96                 C2->setLength(0, batch * ow * oh);
97                 C2->setLength(1, oc);
98                 TensorUtils::getDescribe(C2.get())->dimensionFormat = MNN_DATA_FORMAT_NCHW;
99                 auto cmd = GeometryComputerUtils::makeCommand(builder, {C.get()}, {C2.get()});
100                 res.command.emplace_back(cmd);
101                 res.extras.emplace_back(C2);
102                 C = C2;
103             }
104             // Transpose
105             // Batch, oh, ow, oc -> batch, oc, oh, ow
106             TensorUtils::setLinearLayout(C.get());
107             if (ow == oh && oh == 1) {
108                 GeometryComputerUtils::makeRawAddressRef(outputs[0], C.get(), 0, batch * oc);
109             } else {
110                 auto kernelDiffDes        = TensorUtils::getDescribe(outputs[0]);
111                 kernelDiffDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
112                 kernelDiffDes->regions.resize(1);
113                 auto& desReg         = kernelDiffDes->regions[0];
114                 desReg.size[0]       = batch;
115                 desReg.size[1]       = oc;
116                 desReg.size[2]       = oh * ow;
117                 desReg.dst.offset    = 0;
118                 desReg.dst.stride[0] = oc * oh * ow;
119                 desReg.dst.stride[1] = oh * ow;
120                 desReg.dst.stride[2] = 1;
121                 desReg.src.offset    = 0;
122                 desReg.src.stride[0] = oh * ow * oc;
123                 desReg.src.stride[1] = 1;
124                 desReg.src.stride[2] = oc;
125                 desReg.origin        = C.get();
126             }
127         }
128         return true;
129     }
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const130     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
131                            Context& context, CommandBuffer& res) const override {
132         if (inputs.size() == 1) {
133             // Origin convolution with format converter
134             return GeometryConvUtils::computeSingle(op, inputs, outputs, context, res);
135         }
136         auto common = op->main_as_Convolution2D()->common();
137         if (common->outputCount() > 0) {
138             // FIXME: Remove this logical in future
139             if (context.forwardType() == MNN_FORWARD_CPU || context.forwardType() == MNN_FORWARD_CPU_EXTENSION || context.forwardType() == MNN_FORWARD_OPENCL) {
140                 auto inputDes     = TensorUtils::getDescribe(inputs[0]);
141                 auto format       = inputDes->dimensionFormat;
142                 if (MNN_DATA_FORMAT_NC4HW4 == format) {
143                     return DefaultGeometryComputer::onCompute(op, inputs, outputs, context, res);
144                 }
145             }
146             return computeIm2Col_GEMM(common, inputs, outputs, context, res);
147         }
148         std::unique_ptr<Convolution2DCommonT> temp(common->UnPack());
149         temp->outputCount = inputs[1]->length(0);
150         temp->kernelY = inputs[1]->length(2);
151         temp->kernelX = inputs[1]->length(3);
152         flatbuffers::FlatBufferBuilder builder;
153         builder.Finish(Convolution2DCommon::Pack(builder, temp.get()));
154         return computeIm2Col_GEMM(flatbuffers::GetRoot<MNN::Convolution2DCommon>(builder.GetBufferPointer()), inputs, outputs, context, res);
155     }
156 };
157 
158 
159 class GeometryConvTranspose2D : public GeometryConv2D {
160 public:
161     // Im2Col + GEMM
computeGEMM_Col2Im(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const162     bool computeGEMM_Col2Im(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
163                             Context& context, CommandBuffer& res) const {
164         auto common     = op->main_as_Convolution2D()->common();
165         auto input      = inputs[0];
166         auto outputDiff = outputs[0];
167         auto weight = inputs[1];
168         MNN_ASSERT(1 == common->group());
169         auto kw    = common->kernelX();
170         auto kh    = common->kernelY();
171         auto sw    = common->strideX();
172         auto sh    = common->strideY();
173         auto dw    = common->dilateX();
174         auto dh    = common->dilateY();
175         auto batch = outputDiff->batch();
176         auto ow    = outputDiff->width();
177         auto oh    = outputDiff->height();
178         auto oc    = outputDiff->channel();
179         auto ic    = input->channel();
180         auto iw    = input->width();
181         auto ih    = input->height();
182         auto pads  = ConvolutionCommon::convolutionTransposePad(input, outputDiff, common);
183         MNN_ASSERT(TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
184         MNN_ASSERT(TensorUtils::getDescribe(outputDiff)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
185         Tensor* A = nullptr;
186         Tensor* B = nullptr;
187         {
188             // B: Input n, ic, ih, iw -> ic, n * ih * iw
189             std::shared_ptr<Tensor> dest(Tensor::createDevice<float>({ic, batch * ih * iw}));
190             res.extras.emplace_back(dest);
191             B = dest.get();
192             auto des = TensorUtils::getDescribe(dest.get());
193             des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
194             des->regions.resize(1);
195             auto& reg = des->regions[0];
196             reg.origin = input;
197             reg.size[0] = ic;
198             reg.size[1] = batch;
199             reg.size[2] = ih * iw;
200             reg.src.offset = 0;
201             reg.src.stride[0] = ih * iw;
202             reg.src.stride[1] = ic * ih * iw;
203             reg.src.stride[2] = 1;
204             reg.dst.offset = 0;
205             reg.dst.stride[0] = ih * iw * batch;
206             reg.dst.stride[1] = ih * iw;
207             reg.dst.stride[2] = 1;
208         }
209         {
210             // A: Weight ic, oc, kh, kw -> ic, oc*kh*kw
211             std::shared_ptr<Tensor> kernel(Tensor::createDevice<float>({ic, oc * kw * kh}));
212             A                           = kernel.get();
213             GeometryComputerUtils::makeRawAddressRef(kernel.get(), weight, 0, ic * kw * kh * oc);
214             res.extras.emplace_back(std::move(kernel));
215         }
216         {
217             // C = MatMul(B, A)
218             std::shared_ptr<Tensor> C(Tensor::createDevice<float>({oc * kw * kh, batch * ih * iw}));
219             res.command.emplace_back(GeometryComputerUtils::makeMatMul(A, B, C.get(), nullptr, true, false));
220             res.extras.emplace_back(C);
221 
222             // Col2Im:
223             // 1. C-> C' batch, oc, oh, ow, kw*kh, 2. C' -> C'' batch, oc, oh, ow (reduce_sum)
224             // 3. C'' -> C'' + bias, 4. posttreat(C'' + bias)
225             std::shared_ptr<Tensor> C_(Tensor::createDevice<float>({batch, kw * kh, oc * oh * ow}));
226             res.extras.emplace_back(C_);
227             {
228                 std::shared_ptr<Tensor> im2ColTemp(Tensor::createDevice<float>({oc * kw * kh, batch * ih * iw}));
229                 // Swap ow, iw, oh, ih for im2Col
230                 GeometryConvUtils::im2Col(im2ColTemp.get(), outputDiff, oc, kh, kw, batch, ih, iw, oh, ow, sh, sw, dh, dw, pads, oh * ow * oc);
231                 auto des = TensorUtils::getDescribe(C_.get());
232                 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
233                 auto originDes = TensorUtils::getDescribe(im2ColTemp.get());
234                 des->regions = std::move(originDes->regions);
235                 // Swap src and dst, from im2col->col2im
236                 for (auto& reg : des->regions) {
237                     reg.origin = C.get();
238                     auto temp = reg.src;
239                     reg.src = std::move(reg.dst);
240                     reg.dst = std::move(temp);
241                 }
242             }
243             std::shared_ptr<Tensor> C__(Tensor::createDevice<float>({batch, 1, oc * oh * ow}));
244             res.extras.emplace_back(C__);
245             res.command.emplace_back(GeometryComputerUtils::makeReduce(ReductionType_SUM, C_.get(), C__.get()));
246 
247             if (inputs.size() > 2) {
248                 MNN_ASSERT(oc == inputs[2]->elementSize());
249                 std::shared_ptr<Tensor> biasLarge(Tensor::createDevice<float>({batch, 1, oc * oh * ow}));
250                 res.extras.emplace_back(biasLarge);
251                 auto des = TensorUtils::getDescribe(biasLarge.get());
252                 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
253                 des->regions.resize(1);
254                 auto& reg = des->regions[0];
255                 reg.origin = inputs[2];
256                 reg.size[0] = batch;
257                 reg.size[1] = oc;
258                 reg.size[2] = oh * ow;
259                 reg.src.offset = 0;
260                 reg.src.stride[0] = 0;
261                 reg.src.stride[1] = 1;
262                 reg.src.stride[2] = 0;
263                 reg.dst.offset = 0;
264                 reg.dst.stride[0] = oc * oh * ow;
265                 reg.dst.stride[1] = oh * ow;
266                 reg.dst.stride[2] = 1;
267                 std::shared_ptr<Tensor> temp(Tensor::createDevice<float>({batch, 1, oh * ow * oc}));
268                 res.extras.emplace_back(temp);
269                 res.command.emplace_back(GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, C__.get(), biasLarge.get(), temp.get()));
270                 C__ = temp;
271             }
272 
273             // Activation
274             float minValue = 0.0f, maxValue = 0.0f;
275             bool needPostTreat = false;
276             if (common->relu()) {
277                 needPostTreat = true;
278                 minValue      = 0.0f;
279                 maxValue      = std::numeric_limits<float>().max();
280             }
281             if (common->relu6()) {
282                 needPostTreat = true;
283                 minValue      = 0.0f;
284                 maxValue      = 6.0f;
285             }
286             if (needPostTreat) {
287                 flatbuffers::FlatBufferBuilder builder;
288                 builder.Finish(GeometryConvUtils::makeRelu6(builder, minValue, maxValue));
289                 std::shared_ptr<Tensor> C2(new Tensor);
290                 C2->buffer().type       = halide_type_of<float>();
291                 C2->buffer().dimensions = 3;
292                 C2->setLength(0, batch);
293                 C2->setLength(1, 1);
294                 C2->setLength(2, ow * oh * oc);
295                 TensorUtils::getDescribe(C2.get())->dimensionFormat = MNN_DATA_FORMAT_NCHW;
296                 auto cmd = GeometryComputerUtils::makeCommand(builder, {C__.get()}, {C2.get()});
297                 res.command.emplace_back(cmd);
298                 res.extras.emplace_back(C2);
299                 C__ = C2;
300             }
301             GeometryComputerUtils::makeRawAddressRef(outputs[0], C__.get(), 0, oc * batch * ow * oh);
302         }
303         return true;
304     }
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const305     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
306                            Context& context, CommandBuffer& res) const override {
307         if (op->main_as_Convolution2D()->common()->hasOutputShape()) {
308             const std::vector<Tensor*> newInputs(inputs.begin(), inputs.end() - 1);
309             if (newInputs.size() == 1) {
310                 // Origin convolution with format converter
311                 return GeometryConvUtils::computeSingle(op, newInputs, outputs, context, res);
312             }
313             return computeGEMM_Col2Im(op, newInputs, outputs, context, res);
314         } else {
315             if (inputs.size() == 1) {
316                 // Origin convolution with format converter
317                 return GeometryConvUtils::computeSingle(op, inputs, outputs, context, res);
318             }
319             return computeGEMM_Col2Im(op, inputs, outputs, context, res);
320         }
321     }
322 };
_create()323 static void _create() {
324     std::shared_ptr<GeometryComputer> comp(new GeometryConv2D);
325     GeometryComputer::registerGeometryComputer(comp, {OpType_Convolution});
326 
327     std::shared_ptr<GeometryComputer> comp2(new GeometryConvTranspose2D);
328     GeometryComputer::registerGeometryComputer(comp2, {OpType_Deconvolution});
329 }
330 
331 REGISTER_GEOMETRY(GeometryConv2D, _create);
332 
333 } // namespace MNN
334