1 //
2 //  ConvertTest.cpp
3 //  MNNTests
4 //
5 //  Created by MNN on 2019/01/15.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <MNN/expr/Expr.hpp>
10 #include <MNN/expr/ExprCreator.hpp>
11 #include "MNNTestSuite.h"
12 #include "TestUtils.h"
13 
14 using namespace MNN::Express;
15 class ConvertTest : public MNNTestCase {
16 public:
17     virtual ~ConvertTest() = default;
run(int precision)18     virtual bool run(int precision) {
19         auto input = _Input({4, 1, 1, 3}, NHWC);
20         input->setName("input_tensor");
21         // set input data
22         const float inpudata[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
23         auto inputPtr          = input->writeMap<float>();
24         memcpy(inputPtr, inpudata, 12 * sizeof(float));
25         input->unMap();
26         auto output                             = _Convert(input, NC4HW4);
27         const std::vector<float> expectedOutput = {1.0f, 2.0f, 3.0f, 0.0f, 4.0f,  5.0f,  6.0f,  0.0f,
28                                                    7.0f, 8.0f, 9.0f, 0.0f, 10.0f, 11.0f, 12.0f, 0.0f};
29         {
30             auto gotOutput = output->readMap<float>();
31             if (!checkVector<float>(gotOutput, expectedOutput.data(), 16, 0.01)) {
32                 MNN_ERROR("ConvertTest test failed!\n");
33                 for (int i = 0; i < 16; ++i) {
34                     MNN_PRINT("Correct: %f - Compute: %f\n", expectedOutput[i], gotOutput[i]);
35                 }
36                 return false;
37             }
38         }
39         output = _Convert(output, NHWC);
40         {
41             auto gotOutput = output->readMap<float>();
42             if (!checkVector<float>(gotOutput, inpudata, 12, 0.01)) {
43                 MNN_ERROR("ConvertTest test failed!\n");
44                 for (int i = 0; i < 12; ++i) {
45                     MNN_PRINT("Correct: %f - Compute: %f\n", inpudata[i], gotOutput[i]);
46                 }
47                 return false;
48             }
49         }
50         return true;
51     }
52 };
53 MNNTestSuiteRegister(ConvertTest, "op/convert");
54