1 //
2 //  PluginTest.cpp
3 //  MNNTests
4 //
5 //  Created by MNN on 2020/04/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifdef MNN_WITH_PLUGIN
10 
11 #include <vector>
12 
13 #include <MNN/expr/Expr.hpp>
14 #include <MNN/expr/ExprCreator.hpp>
15 #include "MNNTestSuite.h"
16 #include "TestUtils.h"
17 
18 #include "./PluginMatMulCommon.hpp"
19 
20 // Import _intPluginMatMul in order to link the shared library successfully.
21 extern int _intPluginMatMul;
22 
23 namespace MNN {
24 
25 using Express::Expr;
26 using Express::Variable;
27 using Express::VARP;
28 
29 namespace plugin {
30 
VecRandInit(std::vector<float> * vec)31 static void VecRandInit(std::vector<float>* vec) {
32     for (int i = 0; i < vec->size(); ++i) {
33         (*vec)[i] = rand() % 255 / 255.f;
34     }
35 }
36 
_PluginMatMul(VARP x,VARP y,bool transpose_x=false,bool transpose_y=false)37 static VARP _PluginMatMul(VARP x, VARP y, bool transpose_x = false, // NOLINT
38                           bool transpose_y = false) {
39     std::unique_ptr<OpT> pluginOp(new OpT);
40     pluginOp->type = OpType_Plugin;
41 
42     auto x_shape = x->getInfo();
43     auto y_shape = y->getInfo();
44     MNN_CHECK(x_shape->dim.size() == 2, "2-D shape is required.");
45     MNN_CHECK(y_shape->dim.size() == 2, "2-D shape is required.");
46 
47     int M = x_shape->dim[0];
48     int K = x_shape->dim[1];
49     int N = y_shape->dim[1];
50     if (transpose_x) {
51         M = x_shape->dim[1];
52         K = x_shape->dim[0];
53     }
54     if (transpose_y) {
55         N = y_shape->dim[0];
56         MNN_CHECK(K == y_shape->dim[1], "K dim is not match.");
57     } else {
58         MNN_CHECK(K == y_shape->dim[0], "K dim is not match.");
59     }
60 
61     PluginT* plugin_param = new PluginT;
62     plugin_param->type    = "PluginMatMul";
63     plugin_param->attr.resize(2);
64     plugin_param->attr[0].reset(new AttributeT);
65     plugin_param->attr[0]->key = "transpose_x";
66     plugin_param->attr[0]->b   = transpose_x;
67     plugin_param->attr[1].reset(new AttributeT);
68     plugin_param->attr[1]->key = "transpose_y";
69     plugin_param->attr[1]->b   = transpose_y;
70 
71     pluginOp->main.type  = OpParameter_Plugin;
72     pluginOp->main.value = plugin_param;
73     return Variable::create(Expr::create(pluginOp.get(), {x, y}));
74 }
75 
76 struct PluginTestHelper {
77     bool operator()();
78 };
79 
operator ()()80 bool PluginTestHelper::operator()() {
81     VARP x = _Input({3, 10}, Express::NCHW);
82     VARP y = _Input({10, 3}, Express::NCHW);
83     VARP z = _PluginMatMul(x, y, false, false);
84 
85     std::vector<float> x_data(30);
86     std::vector<float> y_data(30);
87     VecRandInit(&x_data);
88     VecRandInit(&y_data);
89     memcpy(x->writeMap<float>(), x_data.data(), x_data.size() * sizeof(float));
90     memcpy(y->writeMap<float>(), y_data.data(), y_data.size() * sizeof(float));
91     const float* z_data = z->readMap<float>();
92 
93     std::vector<float> out(9);
94     doGemm(3, 3, 10, false, false, x_data.data(), y_data.data(), out.data());
95     for (int i = 0; i < 9; ++i) {
96         if ((abs(out[i] - z_data[i]) > 1e-5)) {
97             MNN_ERROR("z[%i] = %f\n, but %f is right.", i, z_data[i], out[i]);
98             return false;
99         }
100     }
101     return true;
102 }
103 
104 } // namespace plugin
105 } // namespace MNN
106 
107 class PluginTest : public MNNTestCase {
108 public:
run(int precision)109     bool run(int precision) override {
110         // The statment in `MNN_ASSERT` will be ignored for release version, so
111         // the plugin dynamic library will be linked failed.
112         // MNN_ASSERT(_intPluginMatMul == 10);
113         if (_intPluginMatMul != 10) {
114             MNN_ERROR("intPluginMatMul should be 10 other than %d.\n",  // NOLINT
115                       _intPluginMatMul);
116             return false;
117         }
118         // Run plugin unittest.
119         return MNN::plugin::PluginTestHelper()();
120     }
121 };
122 
123 MNNTestSuiteRegister(PluginTest, "plugin");
124 
125 #endif // MNN_WITH_PLUGIN
126