1 //
2 //  PoolGradTest.cpp
3 //  MNNTests
4 //
5 //  Created by MNN on 2019/09/24.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <MNN/expr/Expr.hpp>
10 #include <MNN/expr/ExprCreator.hpp>
11 #include <MNN/expr/Optimizer.hpp>
12 #include <string>
13 #include "MNNTestSuite.h"
14 #include "TestUtils.h"
15 
16 using namespace MNN::Express;
17 
18 class PoolGradTest : public MNNTestCase {
19 public:
20     virtual ~PoolGradTest() = default;
21 
22 protected:
testOnBackend(MNNForwardType type,const std::string & deviceName)23     bool testOnBackend(MNNForwardType type, const std::string &deviceName) {
24         const int h = 7, w = 7, size = h * w;
25         const float originInputData[]   = {0.3100,  0.0156,  0.0765, 0.1872, 0.2949,  0.2949,  0.0052, 0.0455,  0.3000,
26                                          0.1872,  -0.1304, 0.2939, 0.2949, 0.2437,  -0.0330, 0.0641, 0.2934,  0.0452,
27                                          -0.1621, 0.2534,  0.3948, 0.2203, -0.0665, 0.1727,  0.1119, -0.1570, 0.1260,
28                                          0.3523,  0.2305,  0.1664, 0.1277, 0.4092,  -0.1601, 0.0929, 0.1138,  0.2331,
29                                          0.3501,  0.3382,  0.2309, 0.2175, 0.0826,  -0.1567, 0.0320, 0.1205,  -0.0566,
30                                          0.1267,  -0.0004, 0.2930, 0.2353};
31         const float poolInputGradData[] = {1., 2., 3., 2., 3., 1., 3., 1., 2.};
32         const float maxExpectedGrad[]   = {1., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2.,
33                                          0., 0., 0., 4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0.,
34                                          0., 0., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0.};
35         const float aveExpectedGrad[]   = {
36             0.111111, 0.111111, 0.333333, 0.222222, 0.555556, 0.333333, 0.333333, 0.111111, 0.111111, 0.333333,
37             0.222222, 0.555556, 0.333333, 0.333333, 0.333333, 0.333333, 0.888889, 0.555556, 1.000000, 0.444444,
38             0.444444, 0.222222, 0.222222, 0.555556, 0.333333, 0.444444, 0.111111, 0.111111, 0.555556, 0.555556,
39             1.000000, 0.444444, 0.777778, 0.333333, 0.333333, 0.333333, 0.333333, 0.444444, 0.111111, 0.333333,
40             0.222222, 0.222222, 0.333333, 0.333333, 0.444444, 0.111111, 0.333333, 0.222222, 0.222222};
41 
42         auto poolInput        = _Input({1, 1, h, w}, NCHW, halide_type_of<float>());
43         auto poolInputConvert = _Convert(poolInput, NC4HW4);
44         auto maxPoolOut       = _MaxPool(poolInputConvert, {3, 3}, {2, 2});
45         auto avePoolOut       = _AvePool(poolInputConvert, {3, 3}, {2, 2});
46         auto poolOutDim       = maxPoolOut->getInfo()->dim;
47 
48         int poolSize = 1;
49         for (auto length : poolOutDim) {
50             poolSize *= length;
51         }
52 
53         auto poolInputGrad        = _Input(poolOutDim, NCHW, halide_type_of<float>());
54         auto poolInputGradConvert = _Convert(poolInputGrad, NC4HW4);
55 
56         auto maxPoolOutputGrad =
57             _Convert(_PoolGrad(poolInputConvert, maxPoolOut, poolInputGradConvert, {3, 3}, {2, 2}, MAXPOOL), NCHW);
58         auto avePoolOutputGrad =
59             _Convert(_PoolGrad(poolInputConvert, avePoolOut, poolInputGradConvert, {3, 3}, {2, 2}, AVEPOOL), NCHW);
60 
61         const std::vector<int> outDim = {1, 1, h, w};
62         auto maxpoolOutputGradDim     = maxPoolOutputGrad->getInfo()->dim;
63         auto avepoolOutputGradDim     = avePoolOutputGrad->getInfo()->dim;
64         if (!checkVector<int>(maxpoolOutputGradDim.data(), outDim.data(), 4, 0)) {
65             MNN_ERROR("MaxpoolGrad(%s) shape test failed!\n", deviceName.c_str());
66             return false;
67         }
68         if (!checkVector<int>(avepoolOutputGradDim.data(), outDim.data(), 4, 0)) {
69             MNN_ERROR("AvepoolGrad(%s) shape test failed!\n", deviceName.c_str());
70             return false;
71         }
72 
73         ::memcpy(poolInput->writeMap<float>(), (const float *)originInputData, size * sizeof(float));
74         ::memcpy(poolInputGrad->writeMap<float>(), (const float *)poolInputGradData, poolSize * sizeof(float));
75         auto compute = maxPoolOutputGrad->readMap<float>();
76         if (!checkVectorByRelativeError<float>(compute, maxExpectedGrad, size, 0.001)) {
77             MNN_ERROR("MaxpoolGrad(%s) test failed!\n", deviceName.c_str());
78             return false;
79         }
80         if (!checkVectorByRelativeError<float>(avePoolOutputGrad->readMap<float>(), aveExpectedGrad, size, 0.001)) {
81             MNN_ERROR("AvepoolGrad(%s) test failed!\n", deviceName.c_str());
82             return false;
83         }
84 
85         return true;
86     }
87 };
88 
89 class PoolGradTestOnCPU : public PoolGradTest {
90 public:
91     virtual ~PoolGradTestOnCPU() = default;
run(int precision)92     virtual bool run(int precision) {
93         return testOnBackend(MNN_FORWARD_CPU, "CPU");
94     }
95 };
96 
97 MNNTestSuiteRegister(PoolGradTestOnCPU, "op/PoolGrad");
98