1 // 2 // ReverseSequenceTest.cpp 3 // MNNTests 4 // 5 // Created by MNN on 2019/08/31. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include <MNN/expr/ExprCreator.hpp> 10 #include "MNNTestSuite.h" 11 12 using namespace MNN::Express; 13 14 class ReverseSequenceTest : public MNNTestCase { 15 public: run(int precision)16 virtual bool run(int precision) { 17 // high dimension, batch_dim ahead 18 { 19 auto y = _Input({4}, NHWC, halide_type_of<int32_t>()); 20 std::vector<int> seq = {7, 2, 3, 5}; 21 auto yPtr = y->writeMap<int32_t>(); 22 ::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t)); 23 auto x = _Input({6, 4, 7, 10, 8}, NHWC, halide_type_of<float>()); 24 auto xPtr = x->writeMap<float>(); 25 for (int o = 0; o < 6; ++o) { 26 for (int i = 0; i < 4; ++i) { 27 for (int m = 0; m < 7; ++m) { 28 for (int j = 0; j < 10; ++j) { 29 for (int k = 0; k < 8; ++k) { 30 xPtr[2240 * o + 560 * i + 80 * m + 8 * j + k] = 10000 * o + 1000 * i + 100 * m + 10 * j + k; 31 } 32 } 33 } 34 } 35 } 36 37 auto ry = _ReverseSequence(x, y, 1, 3); 38 auto ryPtr = ry->readMap<float>(); 39 40 auto func_equal = [](float a, float b) -> bool { 41 if (a - b > 0.0001 || a - b < -0.0001) { 42 return false; 43 } else { 44 return true; 45 } 46 }; 47 48 int count = 0; 49 for (int o = 0; o < 6; ++o) { 50 for (int i = 0; i < 4; ++i) { 51 auto req = seq[i]; 52 for (int m = 0; m < 7; ++m) { 53 for (int j = 0; j < 10; ++j) { 54 for (int k = 0; k < 8; ++k) { 55 float compute = ryPtr[2240 * o + 560 * i + 80 * m + 8 * j + k]; 56 float need = 10000 * o + 1000 * i + 100 * m + 10 * j + k; 57 if (j < req) { 58 need = 10000 * o + 1000 * i + 100 * m + 10 * (req - j - 1) + k; 59 } 60 61 if (!func_equal(need, compute)) { 62 return false; 63 } 64 } 65 } 66 } 67 } 68 } 69 return true; 70 } 71 72 // high dimension, seq_dim ahead 73 { 74 auto y = _Input({4}, NHWC, halide_type_of<int32_t>()); 75 std::vector<int> seq = {7, 2, 3, 5}; 76 auto yPtr = y->writeMap<int32_t>(); 77 ::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t)); 78 auto x = _Input({6, 10, 7, 4, 8}, NHWC, halide_type_of<float>()); 79 auto xPtr = x->writeMap<float>(); 80 for (int o = 0; o < 6; ++o) { 81 for (int i = 0; i < 10; ++i) { 82 for (int m = 0; m < 7; ++m) { 83 for (int j = 0; j < 4; ++j) { 84 for (int k = 0; k < 8; ++k) { 85 xPtr[2240 * o + 224 * i + 32 * m + 8 * j + k] = 10000 * o + 1000 * i + 100 * m + 10 * j + k; 86 } 87 } 88 } 89 } 90 } 91 92 auto ry = _ReverseSequence(x, y, 3, 1); 93 auto ryPtr = ry->readMap<float>(); 94 95 auto func_equal = [](float a, float b) -> bool { 96 if (a - b > 0.0001 || a - b < -0.0001) { 97 return false; 98 } else { 99 return true; 100 } 101 }; 102 103 int count = 0; 104 for (int o = 0; o < 6; ++o) { 105 for (int i = 0; i < 10; ++i) { 106 for (int m = 0; m < 7; ++m) { 107 for (int j = 0; j < 4; ++j) { 108 auto req = seq[j]; 109 for (int k = 0; k < 8; ++k) { 110 auto compute = ryPtr[2240 * o + 224 * i + 32 * m + 8 * j + k]; 111 auto need = 10000 * o + 1000 * i + 100 * m + 10 * j + k; 112 if (i < req) { 113 need = 10000 * o + 1000 * (req - i - 1) + 100 * m + 10 * j + k; 114 } 115 if (!func_equal(need, compute)) { 116 return false; 117 } 118 } 119 } 120 } 121 } 122 } 123 return true; 124 } 125 126 // 3 dimension 127 { 128 auto y = _Input({4}, NHWC, halide_type_of<int32_t>()); 129 std::vector<int> seq = {7, 2, 3, 5}; 130 auto yPtr = y->writeMap<int32_t>(); 131 ::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t)); 132 auto x = _Input({10, 4, 8}, NHWC, halide_type_of<float>()); 133 auto xPtr = x->writeMap<float>(); 134 for (int i = 0; i < 10; ++i) { 135 for (int j = 0; j < 4; ++j) { 136 for (int k = 0; k < 8; ++k) { 137 xPtr[32 * i + 8 * j + k] = 100 * i + 10 * j + k; 138 } 139 } 140 } 141 142 auto ry = _ReverseSequence(x, y, 1, 0); 143 auto ryPtr = ry->readMap<float>(); 144 145 auto func_equal = [](float a, float b) -> bool { 146 if (a - b > 0.0001 || a - b < -0.0001) { 147 return false; 148 } else { 149 return true; 150 } 151 }; 152 153 for (int i = 0; i < 10; ++i) { 154 for (int j = 0; j < 4; ++j) { 155 auto req = seq[j]; 156 for (int k = 0; k < 8; ++k) { 157 auto compute = ryPtr[32 * i + 8 * j + k]; 158 auto need = 100 * i + 10 * j + k; 159 if (i < req) { 160 need = 100 * (req - i - 1) + 10 * j + k; 161 } 162 if (!func_equal(need, compute)) { 163 return false; 164 } 165 } 166 } 167 } 168 return true; 169 } 170 } 171 }; 172 MNNTestSuiteRegister(ReverseSequenceTest, "expr/ReverseSequence"); 173