1 //
2 //  SpaceToBatchNDTest.cpp
3 //  MNNTests
4 //
5 //  Created by MNN on 2019/01/15.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <MNN/expr/Expr.hpp>
10 #include <MNN/expr/ExprCreator.hpp>
11 #include "MNNTestSuite.h"
12 #include "TestUtils.h"
13 
14 using namespace MNN::Express;
15 class SpaceToBatchNDTest : public MNNTestCase {
16 public:
17     virtual ~SpaceToBatchNDTest() = default;
run(int precision)18     virtual bool run(int precision) {
19         auto input = _Input({3, 1, 2, 2}, NCHW);
20         input->setName("input_tensor");
21         // set input data
22         const float inpudata[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
23         auto inputPtr          = input->writeMap<float>();
24         memcpy(inputPtr, inpudata, 12 * sizeof(float));
25         input->unMap();
26         const int blockshapedata[]              = {2, 2};
27         const int paddingdata[]                 = {0, 0, 0, 0};
28         auto block_shape                        = _Const(blockshapedata,
29                                   {
30                                       2,
31                                   },
32                                   NCHW, halide_type_of<int>());
33         auto paddings                           = _Const(paddingdata, {2, 2}, NCHW, halide_type_of<int>());
34         input                                   = _Convert(input, NC4HW4);
35         // 1 input and 2 params
36         auto tmp                                = _SpaceToBatchND(input, block_shape, paddings);
37         auto output                             = _Convert(tmp, NCHW);
38         // 3 inputs and 0 param
39         std::unique_ptr<MNN::OpT> op(new MNN::OpT);
40         op->type       = MNN::OpType_SpaceToBatchND;
41         auto _tmp = Variable::create(Expr::create(std::move(op), {input, block_shape, paddings}));
42         auto _output                             = _Convert(_tmp, NCHW);
43         auto checkOutput = [](VARP output) {
44             const std::vector<float> expectedOutput = {1.0, 5.0, 9.0, 2.0, 6.0, 10.0, 3.0, 7.0, 11.0, 4.0, 8.0, 12.0};
45             auto gotOutput                          = output->readMap<float>();
46             if (!checkVector<float>(gotOutput, expectedOutput.data(), 12, 0.01)) {
47                 MNN_ERROR("SpaceToBatchNDTest test failed!\n");
48                 return false;
49             }
50             const std::vector<int> expectedDims = {12, 1, 1, 1};
51             auto gotDims                        = output->getInfo()->dim;
52             if (!checkVector<int>(gotDims.data(), expectedDims.data(), 4, 0)) {
53                 MNN_ERROR("SpaceToBatchNDTest test failed!\n");
54                 return false;
55             }
56             return true;
57         };
58         return checkOutput(output) && checkOutput(_output);
59     }
60 };
61 MNNTestSuiteRegister(SpaceToBatchNDTest, "op/space_to_batch_nd");
62