1 //
2 // PluginTest.cpp
3 // MNNTests
4 //
5 // Created by MNN on 2020/04/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #ifdef MNN_WITH_PLUGIN
10
11 #include <vector>
12
13 #include <MNN/expr/Expr.hpp>
14 #include <MNN/expr/ExprCreator.hpp>
15 #include "MNNTestSuite.h"
16 #include "TestUtils.h"
17
18 #include "./PluginMatMulCommon.hpp"
19
20 // Import _intPluginMatMul in order to link the shared library successfully.
21 extern int _intPluginMatMul;
22
23 namespace MNN {
24
25 using Express::Expr;
26 using Express::Variable;
27 using Express::VARP;
28
29 namespace plugin {
30
VecRandInit(std::vector<float> * vec)31 static void VecRandInit(std::vector<float>* vec) {
32 for (int i = 0; i < vec->size(); ++i) {
33 (*vec)[i] = rand() % 255 / 255.f;
34 }
35 }
36
_PluginMatMul(VARP x,VARP y,bool transpose_x=false,bool transpose_y=false)37 static VARP _PluginMatMul(VARP x, VARP y, bool transpose_x = false, // NOLINT
38 bool transpose_y = false) {
39 std::unique_ptr<OpT> pluginOp(new OpT);
40 pluginOp->type = OpType_Plugin;
41
42 auto x_shape = x->getInfo();
43 auto y_shape = y->getInfo();
44 MNN_CHECK(x_shape->dim.size() == 2, "2-D shape is required.");
45 MNN_CHECK(y_shape->dim.size() == 2, "2-D shape is required.");
46
47 int M = x_shape->dim[0];
48 int K = x_shape->dim[1];
49 int N = y_shape->dim[1];
50 if (transpose_x) {
51 M = x_shape->dim[1];
52 K = x_shape->dim[0];
53 }
54 if (transpose_y) {
55 N = y_shape->dim[0];
56 MNN_CHECK(K == y_shape->dim[1], "K dim is not match.");
57 } else {
58 MNN_CHECK(K == y_shape->dim[0], "K dim is not match.");
59 }
60
61 PluginT* plugin_param = new PluginT;
62 plugin_param->type = "PluginMatMul";
63 plugin_param->attr.resize(2);
64 plugin_param->attr[0].reset(new AttributeT);
65 plugin_param->attr[0]->key = "transpose_x";
66 plugin_param->attr[0]->b = transpose_x;
67 plugin_param->attr[1].reset(new AttributeT);
68 plugin_param->attr[1]->key = "transpose_y";
69 plugin_param->attr[1]->b = transpose_y;
70
71 pluginOp->main.type = OpParameter_Plugin;
72 pluginOp->main.value = plugin_param;
73 return Variable::create(Expr::create(pluginOp.get(), {x, y}));
74 }
75
76 struct PluginTestHelper {
77 bool operator()();
78 };
79
operator ()()80 bool PluginTestHelper::operator()() {
81 VARP x = _Input({3, 10}, Express::NCHW);
82 VARP y = _Input({10, 3}, Express::NCHW);
83 VARP z = _PluginMatMul(x, y, false, false);
84
85 std::vector<float> x_data(30);
86 std::vector<float> y_data(30);
87 VecRandInit(&x_data);
88 VecRandInit(&y_data);
89 memcpy(x->writeMap<float>(), x_data.data(), x_data.size() * sizeof(float));
90 memcpy(y->writeMap<float>(), y_data.data(), y_data.size() * sizeof(float));
91 const float* z_data = z->readMap<float>();
92
93 std::vector<float> out(9);
94 doGemm(3, 3, 10, false, false, x_data.data(), y_data.data(), out.data());
95 for (int i = 0; i < 9; ++i) {
96 if ((abs(out[i] - z_data[i]) > 1e-5)) {
97 MNN_ERROR("z[%i] = %f\n, but %f is right.", i, z_data[i], out[i]);
98 return false;
99 }
100 }
101 return true;
102 }
103
104 } // namespace plugin
105 } // namespace MNN
106
107 class PluginTest : public MNNTestCase {
108 public:
run(int precision)109 bool run(int precision) override {
110 // The statment in `MNN_ASSERT` will be ignored for release version, so
111 // the plugin dynamic library will be linked failed.
112 // MNN_ASSERT(_intPluginMatMul == 10);
113 if (_intPluginMatMul != 10) {
114 MNN_ERROR("intPluginMatMul should be 10 other than %d.\n", // NOLINT
115 _intPluginMatMul);
116 return false;
117 }
118 // Run plugin unittest.
119 return MNN::plugin::PluginTestHelper()();
120 }
121 };
122
123 MNNTestSuiteRegister(PluginTest, "plugin");
124
125 #endif // MNN_WITH_PLUGIN
126