1 //
2 //  GeometryConv2DBackPropFilter.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/05/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "ConvertUtils.hpp"
10 #include "GeometryConvUtils.hpp"
11 #define MNN_OPEN_TIME_TRACE
12 #include <MNN/AutoTime.hpp>
13 namespace MNN {
14 class GeometryConv2DBackPropFilter : public GeometryComputer {
15 public:
computeForDepthWise(const Convolution2DCommon * common,Tensor * input,Tensor * outputDiff,Tensor * kernelDiff,Context & context,CommandBuffer & res) const16     bool computeForDepthWise(const Convolution2DCommon* common, Tensor* input, Tensor* outputDiff, Tensor* kernelDiff,
17                              Context& context, CommandBuffer& res) const {
18         auto kw    = common->kernelX();
19         auto kh    = common->kernelY();
20         auto sw    = common->strideX();
21         auto sh    = common->strideY();
22         auto dw    = common->dilateX();
23         auto dh    = common->dilateY();
24         auto batch = outputDiff->batch();
25         auto ow    = outputDiff->width();
26         auto oh    = outputDiff->height();
27         auto ic    = input->channel();
28         auto iw    = input->width();
29         auto ih    = input->height();
30         auto pads  = ConvolutionCommon::convolutionPad(input, outputDiff, common);
31         if (TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NCHW) {
32             std::shared_ptr<Tensor> newT(new Tensor(input, Tensor::CAFFE, false));
33             ConvertUtils::compute(input, newT.get(), res);
34             input = newT.get();
35             res.extras.emplace_back(newT);
36         }
37         if (TensorUtils::getDescribe(outputDiff)->dimensionFormat != MNN_DATA_FORMAT_NCHW) {
38             std::shared_ptr<Tensor> newT(new Tensor(outputDiff, Tensor::CAFFE, false));
39             ConvertUtils::compute(outputDiff, newT.get(), res);
40             outputDiff = newT.get();
41             res.extras.emplace_back(newT);
42         }
43         auto outputDes        = TensorUtils::getDescribe(kernelDiff);
44         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
45         outputDes->regions.clear();
46         for (int ky = 0; ky < kh; ++ky) {
47             auto startSy = ky * dh - pads.second;
48             int startDy  = 0;
49             if (startSy < 0) {
50                 startDy = ((-startSy) + sh - 1) / sh;
51                 startSy = startSy + startDy * sh;
52             }
53             auto endDy = oh - 1;
54             auto endSy = endDy * sh + ky * dh - pads.second;
55             if (endSy >= ih) {
56                 endDy = endDy - (endSy - ih + sh) / sh;
57                 endSy = endDy * sh + ky * dh - pads.second;
58             }
59             if (startDy > endDy) {
60                 continue;
61             }
62             MNN_ASSERT(endDy >= 0);
63             MNN_ASSERT(startDy < ih);
64             auto dstOffsetKy = startDy * ow;
65             auto srcOffsetKy = startSy * iw;
66             for (int kx = 0; kx < kw; ++kx) {
67                 auto startSx = kx * dw - pads.first;
68                 int startDx  = 0;
69                 if (startSx < 0) {
70                     startDx = ((-startSx) + sw - 1) / sw;
71                     startSx = startSx + startDx * sw;
72                 }
73                 auto endDx = ow - 1;
74                 auto endSx = endDx * sw + kx * dw - pads.first;
75                 if (endSx >= iw) {
76                     endDx = endDx - (endSx - iw + sw) / sw;
77                     endSx = endDx * sw + kx * dw - pads.first;
78                 }
79                 if (startDy > endDy) {
80                     continue;
81                 }
82                 auto dstOffsetKx = dstOffsetKy + startDx;
83                 auto srcOffsetKx = srcOffsetKy + startSx;
84                 // Sampler
85                 std::shared_ptr<Tensor> inputTensor(new Tensor(outputDiff, Tensor::CAFFE, false));
86                 auto des        = TensorUtils::getDescribe(inputTensor.get());
87                 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
88                 des->regions.resize(1);
89                 {
90                     Tensor::InsideDescribe::Region& region = des->regions[0];
91                     region.origin                          = input;
92                     region.size[0]                         = batch * ic;
93                     region.size[1]                         = endDy - startDy + 1;
94                     region.size[2]                         = endDx - startDx + 1;
95                     region.src.offset                      = srcOffsetKx;
96                     region.dst.offset                      = dstOffsetKx;
97                     region.src.stride[0]                   = iw * ih;
98                     region.dst.stride[0]                   = ow * oh;
99                     region.src.stride[1]                   = sh * iw;
100                     region.dst.stride[1]                   = ow;
101                     region.src.stride[2]                   = sw;
102                     region.dst.stride[2]                   = 1;
103                     res.extras.emplace_back(inputTensor);
104                 }
105 
106                 auto currentTensor = inputTensor.get();
107                 // Multi
108                 {
109                     std::shared_ptr<Tensor> newTensor(new Tensor(outputDiff, Tensor::CAFFE, false));
110                     auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, currentTensor, outputDiff,
111                                                                  newTensor.get());
112                     res.command.emplace_back(std::move(cmd));
113                     res.extras.emplace_back(newTensor);
114                     currentTensor = newTensor.get();
115                 }
116                 // Reduce - 0
117                 {
118                     std::shared_ptr<Tensor> reduceInputTensor(
119                         Tensor::createDevice<float>({batch * ic, ow * oh, 1}, Tensor::CAFFE));
120                     {
121                         auto inputDes        = TensorUtils::getDescribe(reduceInputTensor.get());
122                         inputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
123                         inputDes->regions    = {TensorUtils::makeFullSlice(currentTensor)};
124                     }
125                     std::shared_ptr<Tensor> reduceOutputTensor(
126                         Tensor::createDevice<float>({batch * ic, 1, 1}, Tensor::CAFFE));
127                     auto cmd      = GeometryComputerUtils::makeReduce(ReductionType_SUM, reduceInputTensor.get(),
128                                                                  reduceOutputTensor.get());
129                     currentTensor = reduceOutputTensor.get();
130                     res.command.emplace_back(std::move(cmd));
131                     res.extras.emplace_back(reduceInputTensor);
132                     res.extras.emplace_back(reduceOutputTensor);
133                 }
134                 // Reduce - 1
135                 {
136                     std::shared_ptr<Tensor> reduceInputTensor(
137                         Tensor::createDevice<float>({1, batch, ic}, Tensor::CAFFE));
138                     {
139                         auto inputDes        = TensorUtils::getDescribe(reduceInputTensor.get());
140                         inputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
141                         inputDes->regions    = {TensorUtils::makeFullSlice(currentTensor)};
142                     }
143                     std::shared_ptr<Tensor> reduceOutputTensor(Tensor::createDevice<float>({1, 1, ic}, Tensor::CAFFE));
144                     currentTensor = reduceOutputTensor.get();
145                     auto cmd      = GeometryComputerUtils::makeReduce(ReductionType_SUM, reduceInputTensor.get(),
146                                                                  reduceOutputTensor.get());
147                     res.command.emplace_back(std::move(cmd));
148                     res.extras.emplace_back(reduceInputTensor);
149                     res.extras.emplace_back(reduceOutputTensor);
150                 }
151                 // Set to output
152                 Tensor::InsideDescribe::Region region;
153                 region.origin        = currentTensor;
154                 region.size[0]       = 1;
155                 region.size[1]       = 1;
156                 region.size[2]       = ic;
157                 region.dst.offset    = ky * kw + kx;
158                 region.dst.stride[0] = 0;
159                 region.dst.stride[1] = 0;
160                 region.dst.stride[2] = kh * kw;
161                 region.src.offset    = 0;
162                 region.src.stride[0] = 0;
163                 region.src.stride[1] = 0;
164                 region.src.stride[2] = 1;
165                 outputDes->regions.emplace_back(std::move(region));
166             }
167         }
168         return true;
169     }
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const170     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
171                            Context& context, CommandBuffer& res) const override {
172         auto common     = op->main_as_Convolution2D()->common();
173         auto input      = inputs[0];
174         auto outputDiff = inputs[1];
175         bool depthWise  = false;
176         if (inputs[0]->channel() == inputs[1]->channel() && inputs[1]->channel() == common->group()) {
177             depthWise = true;
178             return computeForDepthWise(common, input, outputDiff, outputs[0], context, res);
179         }
180         auto kw    = common->kernelX();
181         auto kh    = common->kernelY();
182         auto sw    = common->strideX();
183         auto sh    = common->strideY();
184         auto dw    = common->dilateX();
185         auto dh    = common->dilateY();
186         auto batch = outputDiff->batch();
187         auto ow    = outputDiff->width();
188         auto oh    = outputDiff->height();
189         auto oc    = outputDiff->channel();
190         auto ic    = input->channel();
191         auto iw    = input->width();
192         auto ih    = input->height();
193         auto pads  = ConvolutionCommon::convolutionPad(input, outputDiff, common);
194         MNN_ASSERT(TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
195         MNN_ASSERT(TensorUtils::getDescribe(outputDiff)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
196         Tensor* A = nullptr;
197         Tensor* B = nullptr;
198         {
199             // B: Input Im2Col, n, ic, ih, iw -> ic*kh*kw, n*oh*ow
200             std::shared_ptr<Tensor> im2Col(new Tensor);
201             GeometryConvUtils::im2Col(im2Col.get(), input, ic, kh, kw, batch, oh, ow, ih, iw, sh, sw, dh, dw, pads);
202             B = im2Col.get();
203             res.extras.emplace_back(im2Col);
204         }
205         {
206             // A: Output n, oc, oh, ow -> oc, n*oh*ow
207             std::shared_ptr<Tensor> outputTranspose(new Tensor);
208             A                                    = outputTranspose.get();
209             outputTranspose->buffer().type       = halide_type_of<float>();
210             outputTranspose->buffer().dimensions = 2;
211             outputTranspose->setLength(0, oc);
212             outputTranspose->setLength(1, batch * ow * oh);
213             auto des = TensorUtils::getDescribe(outputTranspose.get());
214             des->regions.resize(1);
215             des->memoryType   = Tensor::InsideDescribe::MEMORY_VIRTUAL;
216             auto& reg         = des->regions[0];
217             reg.origin        = outputDiff;
218             reg.size[0]       = oc;
219             reg.size[1]       = batch;
220             reg.size[2]       = ow * oh;
221             reg.src.offset    = 0;
222             reg.src.stride[0] = oh * ow;
223             reg.src.stride[1] = oh * ow * oc;
224             reg.src.stride[2] = 1;
225             reg.dst.offset    = 0;
226             reg.dst.stride[0] = oh * ow * batch;
227             reg.dst.stride[1] = oh * ow;
228             reg.dst.stride[2] = 1;
229             res.extras.emplace_back(std::move(outputTranspose));
230         }
231         {
232             // C = MatMul(B, A)
233             std::shared_ptr<Tensor> C(new Tensor);
234             C->buffer().type       = halide_type_of<float>();
235             C->buffer().dimensions = 2;
236             C->setLength(0, ic * kw * kh);
237             C->setLength(1, oc);
238             auto cmd = GeometryComputerUtils::makeMatMul(B, A, C.get(), nullptr, false, true);
239             auto kernelDiffDes        = TensorUtils::getDescribe(outputs[0]);
240             kernelDiffDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
241 
242             // Transpose
243             auto len0 = kw * kh * ic;
244             auto len1 = oc;
245             kernelDiffDes->regions.resize(1);
246             auto& desReg         = kernelDiffDes->regions[0];
247             desReg.size[0]       = 1;
248             desReg.size[1]       = len1;
249             desReg.size[2]       = len0;
250             desReg.dst.offset    = 0;
251             desReg.dst.stride[0] = 0;
252             desReg.dst.stride[1] = len0;
253             desReg.dst.stride[2] = 1;
254             desReg.src.offset    = 0;
255             desReg.src.stride[0] = 0;
256             desReg.src.stride[1] = 1;
257             desReg.src.stride[2] = len1;
258             desReg.origin        = C.get();
259             res.extras.emplace_back(std::move(C));
260             res.command.emplace_back(std::move(cmd));
261         }
262         return true;
263     }
264 };
265 
_create()266 static void _create() {
267     std::shared_ptr<GeometryComputer> comp(new GeometryConv2DBackPropFilter);
268     GeometryComputer::registerGeometryComputer(comp, {OpType_Conv2DBackPropFilter});
269 }
270 
271 REGISTER_GEOMETRY(GeometryConv2DBackPropFilter, _create);
272 
273 } // namespace MNN
274