1 // 2 // SpaceToBatchNDTest.cpp 3 // MNNTests 4 // 5 // Created by MNN on 2019/01/15. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include <MNN/expr/Expr.hpp> 10 #include <MNN/expr/ExprCreator.hpp> 11 #include "MNNTestSuite.h" 12 #include "TestUtils.h" 13 14 using namespace MNN::Express; 15 class SpaceToBatchNDTest : public MNNTestCase { 16 public: 17 virtual ~SpaceToBatchNDTest() = default; run(int precision)18 virtual bool run(int precision) { 19 auto input = _Input({3, 1, 2, 2}, NCHW); 20 input->setName("input_tensor"); 21 // set input data 22 const float inpudata[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; 23 auto inputPtr = input->writeMap<float>(); 24 memcpy(inputPtr, inpudata, 12 * sizeof(float)); 25 input->unMap(); 26 const int blockshapedata[] = {2, 2}; 27 const int paddingdata[] = {0, 0, 0, 0}; 28 auto block_shape = _Const(blockshapedata, 29 { 30 2, 31 }, 32 NCHW, halide_type_of<int>()); 33 auto paddings = _Const(paddingdata, {2, 2}, NCHW, halide_type_of<int>()); 34 input = _Convert(input, NC4HW4); 35 // 1 input and 2 params 36 auto tmp = _SpaceToBatchND(input, block_shape, paddings); 37 auto output = _Convert(tmp, NCHW); 38 // 3 inputs and 0 param 39 std::unique_ptr<MNN::OpT> op(new MNN::OpT); 40 op->type = MNN::OpType_SpaceToBatchND; 41 auto _tmp = Variable::create(Expr::create(std::move(op), {input, block_shape, paddings})); 42 auto _output = _Convert(_tmp, NCHW); 43 auto checkOutput = [](VARP output) { 44 const std::vector<float> expectedOutput = {1.0, 5.0, 9.0, 2.0, 6.0, 10.0, 3.0, 7.0, 11.0, 4.0, 8.0, 12.0}; 45 auto gotOutput = output->readMap<float>(); 46 if (!checkVector<float>(gotOutput, expectedOutput.data(), 12, 0.01)) { 47 MNN_ERROR("SpaceToBatchNDTest test failed!\n"); 48 return false; 49 } 50 const std::vector<int> expectedDims = {12, 1, 1, 1}; 51 auto gotDims = output->getInfo()->dim; 52 if (!checkVector<int>(gotDims.data(), expectedDims.data(), 4, 0)) { 53 MNN_ERROR("SpaceToBatchNDTest test failed!\n"); 54 return false; 55 } 56 return true; 57 }; 58 return checkOutput(output) && checkOutput(_output); 59 } 60 }; 61 MNNTestSuiteRegister(SpaceToBatchNDTest, "op/space_to_batch_nd"); 62