1 //
2 //  GeometryPoolGrad.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/06/04.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12 #include "core/Macro.h"
13 #include "core/OpCommonUtils.hpp"
14 #define MNN_OPEN_TIME_TRACE
15 #include <MNN/AutoTime.hpp>
16 namespace MNN {
17 class GeometryPoolGrad : public GeometryComputer {
18 public:
19     // PoolGrad PoolType_MAXPOOL
onComputeMaxPool(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const20     bool onComputeMaxPool(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
21                           Context& context, CommandBuffer& res) const {
22         auto origin       = inputs[0];
23         auto originOutput = inputs[1];
24         auto inputDiff    = inputs[2];
25 
26         auto ow = inputDiff->width();
27         auto oh = inputDiff->height();
28         auto iw = origin->width();
29         auto ih = origin->height();
30         auto oc = inputDiff->channel();
31         auto ob = inputDiff->batch();
32 
33         auto parameter = op->main_as_Pool();
34         auto stride_w  = parameter->strideX();
35         auto stride_h  = parameter->strideY();
36         auto kernel_w  = parameter->kernelX();
37         auto kernel_h  = parameter->kernelY();
38         auto isGlobal  = parameter->isGlobal();
39         auto pad_w     = parameter->padX();
40         auto pad_h     = parameter->padY();
41 
42         // edit const if global
43         if (isGlobal) {
44             kernel_w = iw;
45             kernel_h = ih;
46             stride_w = iw;
47             stride_h = ih;
48             pad_w    = 0;
49             pad_h    = 0;
50         }
51 
52         if (parameter->padType() == PoolPadType_SAME) {
53             int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
54             int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
55             pad_w           = pad_w_total > 0 ? pad_w_total / 2 : 0;
56             pad_h           = pad_h_total > 0 ? pad_h_total / 2 : 0;
57         } else if (parameter->padType() == PoolPadType_VALID) {
58             pad_w = 0;
59             pad_h = 0;
60         } else {
61             MNN_PRINT("Pool padtype not supported!\n");
62             return false;
63         }
64 
65         std::vector<std::shared_ptr<Tensor>> originSplit;
66         originSplit.resize(kernel_h * kernel_w);
67 
68         std::vector<std::shared_ptr<Tensor>> originGEqual;
69         originGEqual.resize(kernel_h * kernel_w);
70 
71         std::vector<std::shared_ptr<Tensor>> originDiff;
72         originDiff.resize(kernel_h * kernel_w);
73 
74         std::vector<std::shared_ptr<Tensor>> inpDiffAdd;
75         inpDiffAdd.resize(kernel_h * kernel_w);
76 
77         for (int ky = 0; ky < kernel_h; ky++) {
78             auto startSy = ky - pad_h;
79             int startDy  = 0;
80             if (startSy < 0) {
81                 startDy = ((-startSy) + stride_h - 1) / stride_h;
82                 startSy = startSy + startDy * stride_h;
83             }
84             auto endDy = oh - 1;
85             auto endSy = endDy * stride_h + ky - pad_h;
86             if (endSy >= ih) {
87                 endDy = endDy - (endSy - ih + stride_h) / stride_h;
88                 endSy = endDy * stride_h + ky - pad_h;
89             }
90             if (startDy > endDy) {
91                 continue;
92             }
93             MNN_ASSERT(endDy >= 0);
94             MNN_ASSERT(startDy < oh);
95 
96             for (int kx = 0; kx < kernel_w; kx++) {
97                 auto startSx = kx - pad_w;
98                 int startDx  = 0;
99                 if (startSx < 0) {
100                     startDx = ((-startSx) + stride_w - 1) / stride_w;
101                     startSx = startSx + startDx * stride_w;
102                 }
103                 auto endDx = ow - 1;
104                 auto endSx = endDx * stride_w + kx - pad_w;
105                 if (endSx >= iw) {
106                     endDx = endDx - (endSx - iw + stride_w) / stride_w;
107                     endSx = endDx * stride_w + kx - pad_w;
108                 }
109                 if (startDx > endDx) {
110                     continue;
111                 }
112                 MNN_ASSERT(endDx >= 0);
113                 MNN_ASSERT(startDx < ow);
114 
115                 // A: Input feature
116                 int index = ky * kernel_w + kx;
117                 originSplit[index].reset(new Tensor);
118                 originSplit[index]->buffer().type       = halide_type_of<float>();
119                 originSplit[index]->buffer().dimensions = 4;
120                 originSplit[index]->setLength(0, ob);
121                 originSplit[index]->setLength(1, oc);
122                 originSplit[index]->setLength(2, oh);
123                 originSplit[index]->setLength(3, ow);
124                 auto des             = TensorUtils::getDescribe(originSplit[index].get());
125                 des->memoryType      = Tensor::InsideDescribe::MEMORY_VIRTUAL;
126                 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
127 
128                 Tensor::InsideDescribe::Region region;
129                 region.origin        = origin;
130                 region.size[0]       = ob * oc;
131                 region.size[1]       = endDy - startDy + 1;
132                 region.size[2]       = endDx - startDx + 1;
133                 region.dst.offset    = startDy * ow + startDx;
134                 region.src.offset    = startSy * iw + startSx;
135                 region.dst.stride[0] = ow * oh;
136                 region.src.stride[0] = iw * ih;
137                 region.dst.stride[1] = ow;
138                 region.src.stride[1] = iw * stride_h;
139                 region.dst.stride[2] = 1;
140                 region.src.stride[2] = stride_w;
141                 des->regions.emplace_back(std::move(region));
142 
143                 // greater equal
144                 std::shared_ptr<Tensor> originGEqualInt(new Tensor);
145                 originGEqualInt->buffer().type       = halide_type_of<int32_t>();
146                 originGEqualInt->buffer().dimensions = 4;
147                 originGEqualInt->setLength(0, ob);
148                 originGEqualInt->setLength(1, oc);
149                 originGEqualInt->setLength(2, oh);
150                 originGEqualInt->setLength(3, ow);
151                 des = TensorUtils::getDescribe(originGEqualInt.get());
152                 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
153                 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
154 
155                 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_GREATER_EQUAL, originSplit[index].get(),
156                                                              originOutput, originGEqualInt.get());
157 
158                 // cast int to float
159                 originGEqual[index].reset(new Tensor);
160                 originGEqual[index]->buffer().type       = halide_type_of<float>();
161                 originGEqual[index]->buffer().dimensions = 4;
162                 originGEqual[index]->setLength(0, ob);
163                 originGEqual[index]->setLength(1, oc);
164                 originGEqual[index]->setLength(2, oh);
165                 originGEqual[index]->setLength(3, ow);
166                 des = TensorUtils::getDescribe(originGEqual[index].get());
167                 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
168                 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
169 
170                 std::unique_ptr<OpT> cast2float(new OpT);
171                 cast2float->type                     = OpType_Cast;
172                 cast2float->main.type                = OpParameter_CastParam;
173                 cast2float->main.value               = new CastParamT;
174                 cast2float->main.AsCastParam()->dstT = DataType_DT_FLOAT;
175 
176                 flatbuffers::FlatBufferBuilder builder1;
177                 auto lastOffset1 = Op::Pack(builder1, cast2float.get());
178                 builder1.Finish(lastOffset1);
179                 Command cmd1;
180                 cmd1.buffer.resize(builder1.GetSize());
181                 ::memcpy(cmd1.buffer.data(), builder1.GetBufferPointer(), cmd1.buffer.size());
182                 cmd1.inputs.resize(1);
183                 cmd1.inputs[0] = originGEqualInt.get();
184 
185                 cmd1.outputs = {originGEqual[index].get()};
186                 cmd1.op      = flatbuffers::GetMutableRoot<Op>(cmd1.buffer.data());
187 
188                 // mul inputDiff
189                 originDiff[index].reset(new Tensor);
190                 originDiff[index]->buffer().type       = halide_type_of<float>();
191                 originDiff[index]->buffer().dimensions = 4;
192                 originDiff[index]->setLength(0, ob);
193                 originDiff[index]->setLength(1, oc);
194                 originDiff[index]->setLength(2, oh);
195                 originDiff[index]->setLength(3, ow);
196                 des = TensorUtils::getDescribe(originDiff[index].get());
197                 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
198                 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
199 
200                 auto cmd2 = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, inputDiff,
201                                                               originGEqual[index].get(), originDiff[index].get());
202 
203                 // expand tensor
204                 inpDiffAdd[index].reset(new Tensor);
205                 inpDiffAdd[index]->buffer().type       = halide_type_of<float>();
206                 inpDiffAdd[index]->buffer().dimensions = 4;
207                 inpDiffAdd[index]->setLength(0, ob);
208                 inpDiffAdd[index]->setLength(1, oc);
209                 inpDiffAdd[index]->setLength(2, ih);
210                 inpDiffAdd[index]->setLength(3, iw);
211                 des                  = TensorUtils::getDescribe(inpDiffAdd[index].get());
212                 des->memoryType      = Tensor::InsideDescribe::MEMORY_VIRTUAL;
213                 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
214 
215                 // Tensor::InsideDescribe::Region region;
216                 region.origin        = originDiff[index].get();
217                 region.size[0]       = ob * oc;
218                 region.size[1]       = oh;
219                 region.size[2]       = ow;
220                 region.src.offset    = 0;
221                 region.dst.offset    = ky * iw + kx;
222                 region.src.stride[0] = ow * oh;
223                 region.dst.stride[0] = iw * ih;
224                 region.src.stride[1] = ow;
225                 region.dst.stride[1] = iw * stride_h;
226                 region.src.stride[2] = 1;
227                 region.dst.stride[2] = stride_w;
228                 des->regions.emplace_back(std::move(region));
229 
230                 res.extras.emplace_back(inpDiffAdd[index]);
231                 res.extras.emplace_back(originSplit[index]);
232                 res.extras.emplace_back(originGEqual[index]);
233                 res.extras.emplace_back(originGEqualInt);
234                 res.extras.emplace_back(originDiff[index]);
235                 res.command.emplace_back(std::move(cmd));
236                 res.command.emplace_back(std::move(cmd1));
237                 res.command.emplace_back(std::move(cmd2));
238             }
239         }
240 
241         // eltwise
242         std::shared_ptr<Tensor> tmpOutput(new Tensor);
243         {
244             std::unique_ptr<OpT> eltWise(new OpT);
245             eltWise->type                   = OpType_Eltwise;
246             eltWise->main.type              = OpParameter_Eltwise;
247             eltWise->main.value             = new EltwiseT;
248             eltWise->main.AsEltwise()->type = EltwiseType_SUM;
249             // eltWise->main.AsEltwise()->coeff() = nullptr;
250             flatbuffers::FlatBufferBuilder builder;
251             auto lastOffset = Op::Pack(builder, eltWise.get());
252             builder.Finish(lastOffset);
253             Command cmd;
254             cmd.buffer.resize(builder.GetSize());
255             ::memcpy(cmd.buffer.data(), builder.GetBufferPointer(), cmd.buffer.size());
256             cmd.inputs.resize(kernel_w * kernel_h);
257             for (int i = 0; i < kernel_w * kernel_h; i++) {
258                 cmd.inputs[i] = inpDiffAdd[i].get();
259             }
260             cmd.outputs = outputs;
261             cmd.op      = flatbuffers::GetMutableRoot<Op>(cmd.buffer.data());
262 
263             res.command.emplace_back(std::move(cmd));
264         }
265 
266         return true;
267     }
268 
269     // PoolGrad PoolType_AVEPOOL
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const270     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
271                            Context& context, CommandBuffer& res) const override {
272         auto parameter = op->main_as_Pool();
273         if (parameter->type() == PoolType_MAXPOOL) {
274             return onComputeMaxPool(op, inputs, outputs, context, res);
275         } else if (parameter->type() != PoolType_AVEPOOL) {
276             MNN_PRINT("Pool type not supported!\n");
277             return false;
278         }
279 
280         auto origin     = inputs[0];
281         auto inputDiff  = inputs[2];
282         auto outputDiff = outputs[0];
283 
284         auto ow = inputDiff->width();
285         auto oh = inputDiff->height();
286         auto iw = origin->width();
287         auto ih = origin->height();
288         auto oc = inputDiff->channel();
289         auto ob = inputDiff->batch();
290 
291         auto stride_w = parameter->strideX();
292         auto stride_h = parameter->strideY();
293         auto kernel_w = parameter->kernelX();
294         auto kernel_h = parameter->kernelY();
295         auto isGlobal = parameter->isGlobal();
296         auto pad_w    = parameter->padX();
297         auto pad_h    = parameter->padY();
298 
299         // edit const if global
300         if (isGlobal) {
301             kernel_w = iw;
302             kernel_h = ih;
303             stride_w = iw;
304             stride_h = ih;
305             pad_w    = 0;
306             pad_h    = 0;
307         }
308 
309         if (parameter->padType() == PoolPadType_SAME) {
310             int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
311             int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
312             pad_w           = pad_w_total > 0 ? pad_w_total / 2 : 0;
313             pad_h           = pad_h_total > 0 ? pad_h_total / 2 : 0;
314         } else if (parameter->padType() == PoolPadType_VALID) {
315             pad_w = 0;
316             pad_h = 0;
317         } else {
318             MNN_PRINT("Pool padtype not supported!\n");
319             return false;
320         }
321 
322         std::shared_ptr<Tensor> inpDifTrans;
323 
324         inpDifTrans.reset(new Tensor);
325         inpDifTrans->buffer().type       = halide_type_of<float>();
326         inpDifTrans->buffer().dimensions = 5;
327         inpDifTrans->setLength(0, kernel_h * kernel_w);
328         inpDifTrans->setLength(1, ob);
329         inpDifTrans->setLength(2, oc);
330         inpDifTrans->setLength(3, ih);
331         inpDifTrans->setLength(4, iw);
332         auto des             = TensorUtils::getDescribe(inpDifTrans.get());
333         des->memoryType      = Tensor::InsideDescribe::MEMORY_VIRTUAL;
334         des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
335         des->regions.clear();
336         // des->regions.reserve(kernel_h*kernel_w);
337 
338         for (int ky = 0; ky < kernel_h; ky++) {
339             auto startSy = ky - pad_h;
340             int startDy  = 0;
341             if (startSy < 0) {
342                 startDy = ((-startSy) + stride_h - 1) / stride_h;
343                 startSy = startSy + startDy * stride_h;
344             }
345             auto endDy = oh - 1;
346             auto endSy = endDy * stride_h + ky - pad_h;
347             if (endSy >= ih) {
348                 endDy = endDy - (endSy - ih + stride_h) / stride_h;
349                 endSy = endDy * stride_h + ky - pad_h;
350             }
351             if (startDy > endDy) {
352                 continue;
353             }
354             MNN_ASSERT(endDy >= 0);
355             MNN_ASSERT(startDy < oh);
356 
357             for (int kx = 0; kx < kernel_w; kx++) {
358                 auto startSx = kx - pad_w;
359                 int startDx  = 0;
360                 if (startSx < 0) {
361                     startDx = ((-startSx) + stride_w - 1) / stride_w;
362                     startSx = startSx + startDx * stride_w;
363                 }
364                 auto endDx = ow - 1;
365                 auto endSx = endDx * stride_w + kx - pad_w;
366                 if (endSx >= iw) {
367                     endDx = endDx - (endSx - iw + stride_w) / stride_w;
368                     endSx = endDx * stride_w + kx - pad_w;
369                 }
370                 if (startDx > endDx) {
371                     continue;
372                 }
373                 MNN_ASSERT(endDx >= 0);
374                 MNN_ASSERT(startDx < ow);
375 
376                 // A: Input feature
377                 int index = ky * kernel_w + kx;
378 
379                 Tensor::InsideDescribe::Region region;
380                 region.origin        = inputDiff;
381                 region.size[0]       = ob * oc;
382                 region.size[1]       = endDy - startDy + 1;
383                 region.size[2]       = endDx - startDx + 1;
384                 region.src.offset    = startDy * ow + startDx;
385                 region.dst.offset    = index * ob * oc * ih * iw + startSy * iw + startSx;
386                 region.src.stride[0] = ow * oh;
387                 region.dst.stride[0] = iw * ih;
388                 region.src.stride[1] = ow;
389                 region.dst.stride[1] = iw * stride_h;
390                 region.src.stride[2] = 1;
391                 region.dst.stride[2] = stride_w;
392                 des->regions.emplace_back(std::move(region));
393             }
394         }
395         res.extras.emplace_back(inpDifTrans);
396 
397         // reduction mean
398         std::shared_ptr<Tensor> tmpOutput;
399         {
400             tmpOutput.reset(new Tensor);
401             tmpOutput->buffer().type       = halide_type_of<float>();
402             tmpOutput->buffer().dimensions = 4;
403             tmpOutput->setLength(0, ob);
404             tmpOutput->setLength(1, oc);
405             tmpOutput->setLength(2, ih);
406             tmpOutput->setLength(3, iw);
407             auto des = TensorUtils::getDescribe(tmpOutput.get());
408             // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
409             des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
410 
411             std::unique_ptr<OpT> mean(new OpT);
412             mean->type                               = OpType_Reduction;
413             mean->main.type                          = OpParameter_ReductionParam;
414             mean->main.value                         = new ReductionParamT;
415             mean->main.AsReductionParam()->dim       = {0};
416             mean->main.AsReductionParam()->keepDims  = false;
417             mean->main.AsReductionParam()->operation = ReductionType_MEAN;
418 
419             flatbuffers::FlatBufferBuilder builder;
420             auto lastOffset = Op::Pack(builder, mean.get());
421             builder.Finish(lastOffset);
422             Command cmd;
423             cmd.buffer.resize(builder.GetSize());
424             ::memcpy(cmd.buffer.data(), builder.GetBufferPointer(), cmd.buffer.size());
425 
426             cmd.inputs = {inpDifTrans.get()};
427 
428             cmd.outputs = {tmpOutput.get()};
429             cmd.op      = flatbuffers::GetMutableRoot<Op>(cmd.buffer.data());
430 
431             auto outputDes        = TensorUtils::getDescribe(outputs[0]);
432             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
433             Tensor::InsideDescribe::Region desReg;
434             desReg.size[0]       = ob * oc;
435             desReg.size[1]       = ih;
436             desReg.size[2]       = iw;
437             desReg.dst.offset    = 0;
438             desReg.dst.stride[0] = ih * iw;
439             desReg.dst.stride[1] = iw;
440             desReg.dst.stride[2] = 1;
441             desReg.src.offset    = 0;
442             desReg.src.stride[0] = ih * iw;
443             desReg.src.stride[1] = iw;
444             desReg.src.stride[2] = 1;
445             desReg.origin        = tmpOutput.get();
446             outputDes->regions.emplace_back(std::move(desReg));
447 
448             res.extras.emplace_back(std::move(tmpOutput));
449             res.command.emplace_back(std::move(cmd));
450         }
451 
452         return true;
453     }
454 };
455 
_create()456 static void _create() {
457     std::shared_ptr<GeometryComputer> comp(new GeometryPoolGrad);
458     GeometryComputer::registerGeometryComputer(comp, {OpType_PoolGrad});
459 }
460 
461 REGISTER_GEOMETRY(GeometryPoolGrad, _create);
462 
463 } // namespace MNN
464