1 //
2 // GeometryPoolGrad.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/06/04.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12 #include "core/Macro.h"
13 #include "core/OpCommonUtils.hpp"
14 #define MNN_OPEN_TIME_TRACE
15 #include <MNN/AutoTime.hpp>
16 namespace MNN {
17 class GeometryPoolGrad : public GeometryComputer {
18 public:
19 // PoolGrad PoolType_MAXPOOL
onComputeMaxPool(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const20 bool onComputeMaxPool(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
21 Context& context, CommandBuffer& res) const {
22 auto origin = inputs[0];
23 auto originOutput = inputs[1];
24 auto inputDiff = inputs[2];
25
26 auto ow = inputDiff->width();
27 auto oh = inputDiff->height();
28 auto iw = origin->width();
29 auto ih = origin->height();
30 auto oc = inputDiff->channel();
31 auto ob = inputDiff->batch();
32
33 auto parameter = op->main_as_Pool();
34 auto stride_w = parameter->strideX();
35 auto stride_h = parameter->strideY();
36 auto kernel_w = parameter->kernelX();
37 auto kernel_h = parameter->kernelY();
38 auto isGlobal = parameter->isGlobal();
39 auto pad_w = parameter->padX();
40 auto pad_h = parameter->padY();
41
42 // edit const if global
43 if (isGlobal) {
44 kernel_w = iw;
45 kernel_h = ih;
46 stride_w = iw;
47 stride_h = ih;
48 pad_w = 0;
49 pad_h = 0;
50 }
51
52 if (parameter->padType() == PoolPadType_SAME) {
53 int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
54 int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
55 pad_w = pad_w_total > 0 ? pad_w_total / 2 : 0;
56 pad_h = pad_h_total > 0 ? pad_h_total / 2 : 0;
57 } else if (parameter->padType() == PoolPadType_VALID) {
58 pad_w = 0;
59 pad_h = 0;
60 } else {
61 MNN_PRINT("Pool padtype not supported!\n");
62 return false;
63 }
64
65 std::vector<std::shared_ptr<Tensor>> originSplit;
66 originSplit.resize(kernel_h * kernel_w);
67
68 std::vector<std::shared_ptr<Tensor>> originGEqual;
69 originGEqual.resize(kernel_h * kernel_w);
70
71 std::vector<std::shared_ptr<Tensor>> originDiff;
72 originDiff.resize(kernel_h * kernel_w);
73
74 std::vector<std::shared_ptr<Tensor>> inpDiffAdd;
75 inpDiffAdd.resize(kernel_h * kernel_w);
76
77 for (int ky = 0; ky < kernel_h; ky++) {
78 auto startSy = ky - pad_h;
79 int startDy = 0;
80 if (startSy < 0) {
81 startDy = ((-startSy) + stride_h - 1) / stride_h;
82 startSy = startSy + startDy * stride_h;
83 }
84 auto endDy = oh - 1;
85 auto endSy = endDy * stride_h + ky - pad_h;
86 if (endSy >= ih) {
87 endDy = endDy - (endSy - ih + stride_h) / stride_h;
88 endSy = endDy * stride_h + ky - pad_h;
89 }
90 if (startDy > endDy) {
91 continue;
92 }
93 MNN_ASSERT(endDy >= 0);
94 MNN_ASSERT(startDy < oh);
95
96 for (int kx = 0; kx < kernel_w; kx++) {
97 auto startSx = kx - pad_w;
98 int startDx = 0;
99 if (startSx < 0) {
100 startDx = ((-startSx) + stride_w - 1) / stride_w;
101 startSx = startSx + startDx * stride_w;
102 }
103 auto endDx = ow - 1;
104 auto endSx = endDx * stride_w + kx - pad_w;
105 if (endSx >= iw) {
106 endDx = endDx - (endSx - iw + stride_w) / stride_w;
107 endSx = endDx * stride_w + kx - pad_w;
108 }
109 if (startDx > endDx) {
110 continue;
111 }
112 MNN_ASSERT(endDx >= 0);
113 MNN_ASSERT(startDx < ow);
114
115 // A: Input feature
116 int index = ky * kernel_w + kx;
117 originSplit[index].reset(new Tensor);
118 originSplit[index]->buffer().type = halide_type_of<float>();
119 originSplit[index]->buffer().dimensions = 4;
120 originSplit[index]->setLength(0, ob);
121 originSplit[index]->setLength(1, oc);
122 originSplit[index]->setLength(2, oh);
123 originSplit[index]->setLength(3, ow);
124 auto des = TensorUtils::getDescribe(originSplit[index].get());
125 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
126 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
127
128 Tensor::InsideDescribe::Region region;
129 region.origin = origin;
130 region.size[0] = ob * oc;
131 region.size[1] = endDy - startDy + 1;
132 region.size[2] = endDx - startDx + 1;
133 region.dst.offset = startDy * ow + startDx;
134 region.src.offset = startSy * iw + startSx;
135 region.dst.stride[0] = ow * oh;
136 region.src.stride[0] = iw * ih;
137 region.dst.stride[1] = ow;
138 region.src.stride[1] = iw * stride_h;
139 region.dst.stride[2] = 1;
140 region.src.stride[2] = stride_w;
141 des->regions.emplace_back(std::move(region));
142
143 // greater equal
144 std::shared_ptr<Tensor> originGEqualInt(new Tensor);
145 originGEqualInt->buffer().type = halide_type_of<int32_t>();
146 originGEqualInt->buffer().dimensions = 4;
147 originGEqualInt->setLength(0, ob);
148 originGEqualInt->setLength(1, oc);
149 originGEqualInt->setLength(2, oh);
150 originGEqualInt->setLength(3, ow);
151 des = TensorUtils::getDescribe(originGEqualInt.get());
152 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
153 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
154
155 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_GREATER_EQUAL, originSplit[index].get(),
156 originOutput, originGEqualInt.get());
157
158 // cast int to float
159 originGEqual[index].reset(new Tensor);
160 originGEqual[index]->buffer().type = halide_type_of<float>();
161 originGEqual[index]->buffer().dimensions = 4;
162 originGEqual[index]->setLength(0, ob);
163 originGEqual[index]->setLength(1, oc);
164 originGEqual[index]->setLength(2, oh);
165 originGEqual[index]->setLength(3, ow);
166 des = TensorUtils::getDescribe(originGEqual[index].get());
167 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
168 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
169
170 std::unique_ptr<OpT> cast2float(new OpT);
171 cast2float->type = OpType_Cast;
172 cast2float->main.type = OpParameter_CastParam;
173 cast2float->main.value = new CastParamT;
174 cast2float->main.AsCastParam()->dstT = DataType_DT_FLOAT;
175
176 flatbuffers::FlatBufferBuilder builder1;
177 auto lastOffset1 = Op::Pack(builder1, cast2float.get());
178 builder1.Finish(lastOffset1);
179 Command cmd1;
180 cmd1.buffer.resize(builder1.GetSize());
181 ::memcpy(cmd1.buffer.data(), builder1.GetBufferPointer(), cmd1.buffer.size());
182 cmd1.inputs.resize(1);
183 cmd1.inputs[0] = originGEqualInt.get();
184
185 cmd1.outputs = {originGEqual[index].get()};
186 cmd1.op = flatbuffers::GetMutableRoot<Op>(cmd1.buffer.data());
187
188 // mul inputDiff
189 originDiff[index].reset(new Tensor);
190 originDiff[index]->buffer().type = halide_type_of<float>();
191 originDiff[index]->buffer().dimensions = 4;
192 originDiff[index]->setLength(0, ob);
193 originDiff[index]->setLength(1, oc);
194 originDiff[index]->setLength(2, oh);
195 originDiff[index]->setLength(3, ow);
196 des = TensorUtils::getDescribe(originDiff[index].get());
197 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
198 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
199
200 auto cmd2 = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, inputDiff,
201 originGEqual[index].get(), originDiff[index].get());
202
203 // expand tensor
204 inpDiffAdd[index].reset(new Tensor);
205 inpDiffAdd[index]->buffer().type = halide_type_of<float>();
206 inpDiffAdd[index]->buffer().dimensions = 4;
207 inpDiffAdd[index]->setLength(0, ob);
208 inpDiffAdd[index]->setLength(1, oc);
209 inpDiffAdd[index]->setLength(2, ih);
210 inpDiffAdd[index]->setLength(3, iw);
211 des = TensorUtils::getDescribe(inpDiffAdd[index].get());
212 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
213 des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
214
215 // Tensor::InsideDescribe::Region region;
216 region.origin = originDiff[index].get();
217 region.size[0] = ob * oc;
218 region.size[1] = oh;
219 region.size[2] = ow;
220 region.src.offset = 0;
221 region.dst.offset = ky * iw + kx;
222 region.src.stride[0] = ow * oh;
223 region.dst.stride[0] = iw * ih;
224 region.src.stride[1] = ow;
225 region.dst.stride[1] = iw * stride_h;
226 region.src.stride[2] = 1;
227 region.dst.stride[2] = stride_w;
228 des->regions.emplace_back(std::move(region));
229
230 res.extras.emplace_back(inpDiffAdd[index]);
231 res.extras.emplace_back(originSplit[index]);
232 res.extras.emplace_back(originGEqual[index]);
233 res.extras.emplace_back(originGEqualInt);
234 res.extras.emplace_back(originDiff[index]);
235 res.command.emplace_back(std::move(cmd));
236 res.command.emplace_back(std::move(cmd1));
237 res.command.emplace_back(std::move(cmd2));
238 }
239 }
240
241 // eltwise
242 std::shared_ptr<Tensor> tmpOutput(new Tensor);
243 {
244 std::unique_ptr<OpT> eltWise(new OpT);
245 eltWise->type = OpType_Eltwise;
246 eltWise->main.type = OpParameter_Eltwise;
247 eltWise->main.value = new EltwiseT;
248 eltWise->main.AsEltwise()->type = EltwiseType_SUM;
249 // eltWise->main.AsEltwise()->coeff() = nullptr;
250 flatbuffers::FlatBufferBuilder builder;
251 auto lastOffset = Op::Pack(builder, eltWise.get());
252 builder.Finish(lastOffset);
253 Command cmd;
254 cmd.buffer.resize(builder.GetSize());
255 ::memcpy(cmd.buffer.data(), builder.GetBufferPointer(), cmd.buffer.size());
256 cmd.inputs.resize(kernel_w * kernel_h);
257 for (int i = 0; i < kernel_w * kernel_h; i++) {
258 cmd.inputs[i] = inpDiffAdd[i].get();
259 }
260 cmd.outputs = outputs;
261 cmd.op = flatbuffers::GetMutableRoot<Op>(cmd.buffer.data());
262
263 res.command.emplace_back(std::move(cmd));
264 }
265
266 return true;
267 }
268
269 // PoolGrad PoolType_AVEPOOL
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const270 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
271 Context& context, CommandBuffer& res) const override {
272 auto parameter = op->main_as_Pool();
273 if (parameter->type() == PoolType_MAXPOOL) {
274 return onComputeMaxPool(op, inputs, outputs, context, res);
275 } else if (parameter->type() != PoolType_AVEPOOL) {
276 MNN_PRINT("Pool type not supported!\n");
277 return false;
278 }
279
280 auto origin = inputs[0];
281 auto inputDiff = inputs[2];
282 auto outputDiff = outputs[0];
283
284 auto ow = inputDiff->width();
285 auto oh = inputDiff->height();
286 auto iw = origin->width();
287 auto ih = origin->height();
288 auto oc = inputDiff->channel();
289 auto ob = inputDiff->batch();
290
291 auto stride_w = parameter->strideX();
292 auto stride_h = parameter->strideY();
293 auto kernel_w = parameter->kernelX();
294 auto kernel_h = parameter->kernelY();
295 auto isGlobal = parameter->isGlobal();
296 auto pad_w = parameter->padX();
297 auto pad_h = parameter->padY();
298
299 // edit const if global
300 if (isGlobal) {
301 kernel_w = iw;
302 kernel_h = ih;
303 stride_w = iw;
304 stride_h = ih;
305 pad_w = 0;
306 pad_h = 0;
307 }
308
309 if (parameter->padType() == PoolPadType_SAME) {
310 int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
311 int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
312 pad_w = pad_w_total > 0 ? pad_w_total / 2 : 0;
313 pad_h = pad_h_total > 0 ? pad_h_total / 2 : 0;
314 } else if (parameter->padType() == PoolPadType_VALID) {
315 pad_w = 0;
316 pad_h = 0;
317 } else {
318 MNN_PRINT("Pool padtype not supported!\n");
319 return false;
320 }
321
322 std::shared_ptr<Tensor> inpDifTrans;
323
324 inpDifTrans.reset(new Tensor);
325 inpDifTrans->buffer().type = halide_type_of<float>();
326 inpDifTrans->buffer().dimensions = 5;
327 inpDifTrans->setLength(0, kernel_h * kernel_w);
328 inpDifTrans->setLength(1, ob);
329 inpDifTrans->setLength(2, oc);
330 inpDifTrans->setLength(3, ih);
331 inpDifTrans->setLength(4, iw);
332 auto des = TensorUtils::getDescribe(inpDifTrans.get());
333 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
334 des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
335 des->regions.clear();
336 // des->regions.reserve(kernel_h*kernel_w);
337
338 for (int ky = 0; ky < kernel_h; ky++) {
339 auto startSy = ky - pad_h;
340 int startDy = 0;
341 if (startSy < 0) {
342 startDy = ((-startSy) + stride_h - 1) / stride_h;
343 startSy = startSy + startDy * stride_h;
344 }
345 auto endDy = oh - 1;
346 auto endSy = endDy * stride_h + ky - pad_h;
347 if (endSy >= ih) {
348 endDy = endDy - (endSy - ih + stride_h) / stride_h;
349 endSy = endDy * stride_h + ky - pad_h;
350 }
351 if (startDy > endDy) {
352 continue;
353 }
354 MNN_ASSERT(endDy >= 0);
355 MNN_ASSERT(startDy < oh);
356
357 for (int kx = 0; kx < kernel_w; kx++) {
358 auto startSx = kx - pad_w;
359 int startDx = 0;
360 if (startSx < 0) {
361 startDx = ((-startSx) + stride_w - 1) / stride_w;
362 startSx = startSx + startDx * stride_w;
363 }
364 auto endDx = ow - 1;
365 auto endSx = endDx * stride_w + kx - pad_w;
366 if (endSx >= iw) {
367 endDx = endDx - (endSx - iw + stride_w) / stride_w;
368 endSx = endDx * stride_w + kx - pad_w;
369 }
370 if (startDx > endDx) {
371 continue;
372 }
373 MNN_ASSERT(endDx >= 0);
374 MNN_ASSERT(startDx < ow);
375
376 // A: Input feature
377 int index = ky * kernel_w + kx;
378
379 Tensor::InsideDescribe::Region region;
380 region.origin = inputDiff;
381 region.size[0] = ob * oc;
382 region.size[1] = endDy - startDy + 1;
383 region.size[2] = endDx - startDx + 1;
384 region.src.offset = startDy * ow + startDx;
385 region.dst.offset = index * ob * oc * ih * iw + startSy * iw + startSx;
386 region.src.stride[0] = ow * oh;
387 region.dst.stride[0] = iw * ih;
388 region.src.stride[1] = ow;
389 region.dst.stride[1] = iw * stride_h;
390 region.src.stride[2] = 1;
391 region.dst.stride[2] = stride_w;
392 des->regions.emplace_back(std::move(region));
393 }
394 }
395 res.extras.emplace_back(inpDifTrans);
396
397 // reduction mean
398 std::shared_ptr<Tensor> tmpOutput;
399 {
400 tmpOutput.reset(new Tensor);
401 tmpOutput->buffer().type = halide_type_of<float>();
402 tmpOutput->buffer().dimensions = 4;
403 tmpOutput->setLength(0, ob);
404 tmpOutput->setLength(1, oc);
405 tmpOutput->setLength(2, ih);
406 tmpOutput->setLength(3, iw);
407 auto des = TensorUtils::getDescribe(tmpOutput.get());
408 // des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
409 des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
410
411 std::unique_ptr<OpT> mean(new OpT);
412 mean->type = OpType_Reduction;
413 mean->main.type = OpParameter_ReductionParam;
414 mean->main.value = new ReductionParamT;
415 mean->main.AsReductionParam()->dim = {0};
416 mean->main.AsReductionParam()->keepDims = false;
417 mean->main.AsReductionParam()->operation = ReductionType_MEAN;
418
419 flatbuffers::FlatBufferBuilder builder;
420 auto lastOffset = Op::Pack(builder, mean.get());
421 builder.Finish(lastOffset);
422 Command cmd;
423 cmd.buffer.resize(builder.GetSize());
424 ::memcpy(cmd.buffer.data(), builder.GetBufferPointer(), cmd.buffer.size());
425
426 cmd.inputs = {inpDifTrans.get()};
427
428 cmd.outputs = {tmpOutput.get()};
429 cmd.op = flatbuffers::GetMutableRoot<Op>(cmd.buffer.data());
430
431 auto outputDes = TensorUtils::getDescribe(outputs[0]);
432 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
433 Tensor::InsideDescribe::Region desReg;
434 desReg.size[0] = ob * oc;
435 desReg.size[1] = ih;
436 desReg.size[2] = iw;
437 desReg.dst.offset = 0;
438 desReg.dst.stride[0] = ih * iw;
439 desReg.dst.stride[1] = iw;
440 desReg.dst.stride[2] = 1;
441 desReg.src.offset = 0;
442 desReg.src.stride[0] = ih * iw;
443 desReg.src.stride[1] = iw;
444 desReg.src.stride[2] = 1;
445 desReg.origin = tmpOutput.get();
446 outputDes->regions.emplace_back(std::move(desReg));
447
448 res.extras.emplace_back(std::move(tmpOutput));
449 res.command.emplace_back(std::move(cmd));
450 }
451
452 return true;
453 }
454 };
455
_create()456 static void _create() {
457 std::shared_ptr<GeometryComputer> comp(new GeometryPoolGrad);
458 GeometryComputer::registerGeometryComputer(comp, {OpType_PoolGrad});
459 }
460
461 REGISTER_GEOMETRY(GeometryPoolGrad, _create);
462
463 } // namespace MNN
464