1 // 2 // PoolGradTest.cpp 3 // MNNTests 4 // 5 // Created by MNN on 2019/09/24. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include <MNN/expr/Expr.hpp> 10 #include <MNN/expr/ExprCreator.hpp> 11 #include <MNN/expr/Optimizer.hpp> 12 #include <string> 13 #include "MNNTestSuite.h" 14 #include "TestUtils.h" 15 16 using namespace MNN::Express; 17 18 class PoolGradTest : public MNNTestCase { 19 public: 20 virtual ~PoolGradTest() = default; 21 22 protected: testOnBackend(MNNForwardType type,const std::string & deviceName)23 bool testOnBackend(MNNForwardType type, const std::string &deviceName) { 24 const int h = 7, w = 7, size = h * w; 25 const float originInputData[] = {0.3100, 0.0156, 0.0765, 0.1872, 0.2949, 0.2949, 0.0052, 0.0455, 0.3000, 26 0.1872, -0.1304, 0.2939, 0.2949, 0.2437, -0.0330, 0.0641, 0.2934, 0.0452, 27 -0.1621, 0.2534, 0.3948, 0.2203, -0.0665, 0.1727, 0.1119, -0.1570, 0.1260, 28 0.3523, 0.2305, 0.1664, 0.1277, 0.4092, -0.1601, 0.0929, 0.1138, 0.2331, 29 0.3501, 0.3382, 0.2309, 0.2175, 0.0826, -0.1567, 0.0320, 0.1205, -0.0566, 30 0.1267, -0.0004, 0.2930, 0.2353}; 31 const float poolInputGradData[] = {1., 2., 3., 2., 3., 1., 3., 1., 2.}; 32 const float maxExpectedGrad[] = {1., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 33 0., 0., 0., 4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 34 0., 0., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0.}; 35 const float aveExpectedGrad[] = { 36 0.111111, 0.111111, 0.333333, 0.222222, 0.555556, 0.333333, 0.333333, 0.111111, 0.111111, 0.333333, 37 0.222222, 0.555556, 0.333333, 0.333333, 0.333333, 0.333333, 0.888889, 0.555556, 1.000000, 0.444444, 38 0.444444, 0.222222, 0.222222, 0.555556, 0.333333, 0.444444, 0.111111, 0.111111, 0.555556, 0.555556, 39 1.000000, 0.444444, 0.777778, 0.333333, 0.333333, 0.333333, 0.333333, 0.444444, 0.111111, 0.333333, 40 0.222222, 0.222222, 0.333333, 0.333333, 0.444444, 0.111111, 0.333333, 0.222222, 0.222222}; 41 42 auto poolInput = _Input({1, 1, h, w}, NCHW, halide_type_of<float>()); 43 auto poolInputConvert = _Convert(poolInput, NC4HW4); 44 auto maxPoolOut = _MaxPool(poolInputConvert, {3, 3}, {2, 2}); 45 auto avePoolOut = _AvePool(poolInputConvert, {3, 3}, {2, 2}); 46 auto poolOutDim = maxPoolOut->getInfo()->dim; 47 48 int poolSize = 1; 49 for (auto length : poolOutDim) { 50 poolSize *= length; 51 } 52 53 auto poolInputGrad = _Input(poolOutDim, NCHW, halide_type_of<float>()); 54 auto poolInputGradConvert = _Convert(poolInputGrad, NC4HW4); 55 56 auto maxPoolOutputGrad = 57 _Convert(_PoolGrad(poolInputConvert, maxPoolOut, poolInputGradConvert, {3, 3}, {2, 2}, MAXPOOL), NCHW); 58 auto avePoolOutputGrad = 59 _Convert(_PoolGrad(poolInputConvert, avePoolOut, poolInputGradConvert, {3, 3}, {2, 2}, AVEPOOL), NCHW); 60 61 const std::vector<int> outDim = {1, 1, h, w}; 62 auto maxpoolOutputGradDim = maxPoolOutputGrad->getInfo()->dim; 63 auto avepoolOutputGradDim = avePoolOutputGrad->getInfo()->dim; 64 if (!checkVector<int>(maxpoolOutputGradDim.data(), outDim.data(), 4, 0)) { 65 MNN_ERROR("MaxpoolGrad(%s) shape test failed!\n", deviceName.c_str()); 66 return false; 67 } 68 if (!checkVector<int>(avepoolOutputGradDim.data(), outDim.data(), 4, 0)) { 69 MNN_ERROR("AvepoolGrad(%s) shape test failed!\n", deviceName.c_str()); 70 return false; 71 } 72 73 ::memcpy(poolInput->writeMap<float>(), (const float *)originInputData, size * sizeof(float)); 74 ::memcpy(poolInputGrad->writeMap<float>(), (const float *)poolInputGradData, poolSize * sizeof(float)); 75 auto compute = maxPoolOutputGrad->readMap<float>(); 76 if (!checkVectorByRelativeError<float>(compute, maxExpectedGrad, size, 0.001)) { 77 MNN_ERROR("MaxpoolGrad(%s) test failed!\n", deviceName.c_str()); 78 return false; 79 } 80 if (!checkVectorByRelativeError<float>(avePoolOutputGrad->readMap<float>(), aveExpectedGrad, size, 0.001)) { 81 MNN_ERROR("AvepoolGrad(%s) test failed!\n", deviceName.c_str()); 82 return false; 83 } 84 85 return true; 86 } 87 }; 88 89 class PoolGradTestOnCPU : public PoolGradTest { 90 public: 91 virtual ~PoolGradTestOnCPU() = default; run(int precision)92 virtual bool run(int precision) { 93 return testOnBackend(MNN_FORWARD_CPU, "CPU"); 94 } 95 }; 96 97 MNNTestSuiteRegister(PoolGradTestOnCPU, "op/PoolGrad"); 98