1 //
2 //  ReverseSequenceTest.cpp
3 //  MNNTests
4 //
5 //  Created by MNN on 2019/08/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <MNN/expr/ExprCreator.hpp>
10 #include "MNNTestSuite.h"
11 
12 using namespace MNN::Express;
13 
14 class ReverseSequenceTest : public MNNTestCase {
15 public:
run(int precision)16     virtual bool run(int precision) {
17         // high dimension, batch_dim ahead
18         {
19             auto y               = _Input({4}, NHWC, halide_type_of<int32_t>());
20             std::vector<int> seq = {7, 2, 3, 5};
21             auto yPtr            = y->writeMap<int32_t>();
22             ::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
23             auto x    = _Input({6, 4, 7, 10, 8}, NHWC, halide_type_of<float>());
24             auto xPtr = x->writeMap<float>();
25             for (int o = 0; o < 6; ++o) {
26                 for (int i = 0; i < 4; ++i) {
27                     for (int m = 0; m < 7; ++m) {
28                         for (int j = 0; j < 10; ++j) {
29                             for (int k = 0; k < 8; ++k) {
30                                 xPtr[2240 * o + 560 * i + 80 * m + 8 * j + k] = 10000 * o + 1000 * i + 100 * m + 10 * j + k;
31                             }
32                         }
33                     }
34                 }
35             }
36 
37             auto ry    = _ReverseSequence(x, y, 1, 3);
38             auto ryPtr = ry->readMap<float>();
39 
40             auto func_equal = [](float a, float b) -> bool {
41                 if (a - b > 0.0001 || a - b < -0.0001) {
42                     return false;
43                 } else {
44                     return true;
45                 }
46             };
47 
48             int count = 0;
49             for (int o = 0; o < 6; ++o) {
50                 for (int i = 0; i < 4; ++i) {
51                     auto req = seq[i];
52                     for (int m = 0; m < 7; ++m) {
53                         for (int j = 0; j < 10; ++j) {
54                             for (int k = 0; k < 8; ++k) {
55                                 float compute = ryPtr[2240 * o + 560 * i + 80 * m + 8 * j + k];
56                                 float need    = 10000 * o + 1000 * i + 100 * m + 10 * j + k;
57                                 if (j < req) {
58                                     need = 10000 * o + 1000 * i + 100 * m + 10 * (req - j - 1) + k;
59                                 }
60 
61                                 if (!func_equal(need, compute)) {
62                                     return false;
63                                 }
64                             }
65                         }
66                     }
67                 }
68             }
69             return true;
70         }
71 
72         // high dimension, seq_dim ahead
73         {
74             auto y               = _Input({4}, NHWC, halide_type_of<int32_t>());
75             std::vector<int> seq = {7, 2, 3, 5};
76             auto yPtr            = y->writeMap<int32_t>();
77             ::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
78             auto x    = _Input({6, 10, 7, 4, 8}, NHWC, halide_type_of<float>());
79             auto xPtr = x->writeMap<float>();
80             for (int o = 0; o < 6; ++o) {
81                 for (int i = 0; i < 10; ++i) {
82                     for (int m = 0; m < 7; ++m) {
83                         for (int j = 0; j < 4; ++j) {
84                             for (int k = 0; k < 8; ++k) {
85                                 xPtr[2240 * o + 224 * i + 32 * m + 8 * j + k] = 10000 * o + 1000 * i + 100 * m + 10 * j + k;
86                             }
87                         }
88                     }
89                 }
90             }
91 
92             auto ry    = _ReverseSequence(x, y, 3, 1);
93             auto ryPtr = ry->readMap<float>();
94 
95             auto func_equal = [](float a, float b) -> bool {
96                 if (a - b > 0.0001 || a - b < -0.0001) {
97                     return false;
98                 } else {
99                     return true;
100                 }
101             };
102 
103             int count = 0;
104             for (int o = 0; o < 6; ++o) {
105                 for (int i = 0; i < 10; ++i) {
106                     for (int m = 0; m < 7; ++m) {
107                         for (int j = 0; j < 4; ++j) {
108                             auto req = seq[j];
109                             for (int k = 0; k < 8; ++k) {
110                                 auto compute = ryPtr[2240 * o + 224 * i + 32 * m + 8 * j + k];
111                                 auto need    = 10000 * o + 1000 * i + 100 * m + 10 * j + k;
112                                 if (i < req) {
113                                     need = 10000 * o + 1000 * (req - i - 1) + 100 * m + 10 * j + k;
114                                 }
115                                 if (!func_equal(need, compute)) {
116                                     return false;
117                                 }
118                             }
119                         }
120                     }
121                 }
122             }
123             return true;
124         }
125 
126         // 3 dimension
127         {
128             auto y               = _Input({4}, NHWC, halide_type_of<int32_t>());
129             std::vector<int> seq = {7, 2, 3, 5};
130             auto yPtr            = y->writeMap<int32_t>();
131             ::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
132             auto x    = _Input({10, 4, 8}, NHWC, halide_type_of<float>());
133             auto xPtr = x->writeMap<float>();
134             for (int i = 0; i < 10; ++i) {
135                 for (int j = 0; j < 4; ++j) {
136                     for (int k = 0; k < 8; ++k) {
137                         xPtr[32 * i + 8 * j + k] = 100 * i + 10 * j + k;
138                     }
139                 }
140             }
141 
142             auto ry    = _ReverseSequence(x, y, 1, 0);
143             auto ryPtr = ry->readMap<float>();
144 
145             auto func_equal = [](float a, float b) -> bool {
146                 if (a - b > 0.0001 || a - b < -0.0001) {
147                     return false;
148                 } else {
149                     return true;
150                 }
151             };
152 
153             for (int i = 0; i < 10; ++i) {
154                 for (int j = 0; j < 4; ++j) {
155                     auto req = seq[j];
156                     for (int k = 0; k < 8; ++k) {
157                         auto compute = ryPtr[32 * i + 8 * j + k];
158                         auto need    = 100 * i + 10 * j + k;
159                         if (i < req) {
160                             need = 100 * (req - i - 1) + 10 * j + k;
161                         }
162                         if (!func_equal(need, compute)) {
163                             return false;
164                         }
165                     }
166                 }
167             }
168             return true;
169         }
170     }
171 };
172 MNNTestSuiteRegister(ReverseSequenceTest, "expr/ReverseSequence");
173